parent
db4ff2172a
commit
4da1046d68
13 changed files with 2004 additions and 2 deletions
@ -0,0 +1,36 @@ |
||||
Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) |
||||
Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) |
||||
Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) |
||||
Copyright (c) 2011-2013 NYU (Clement Farabet) |
||||
Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) |
||||
Copyright (c) 2006 Idiap Research Institute (Samy Bengio) |
||||
Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) |
||||
|
||||
All rights reserved. |
||||
|
||||
Redistribution and use in source and binary forms, with or without |
||||
modification, are permitted provided that the following conditions are met: |
||||
|
||||
1. Redistributions of source code must retain the above copyright |
||||
notice, this list of conditions and the following disclaimer. |
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright |
||||
notice, this list of conditions and the following disclaimer in the |
||||
documentation and/or other materials provided with the distribution. |
||||
|
||||
3. Neither the names of Deepmind Technologies, NYU, NEC Laboratories America |
||||
and IDIAP Research Institute nor the names of its contributors may be |
||||
used to endorse or promote products derived from this software without |
||||
specific prior written permission. |
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" |
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE |
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE |
||||
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE |
||||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR |
||||
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF |
||||
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS |
||||
INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN |
||||
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) |
||||
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE |
||||
POSSIBILITY OF SUCH DAMAGE. |
@ -0,0 +1,348 @@ |
||||
local File = torch.getmetatable('torch.File') |
||||
|
||||
function File:writeBool(value) |
||||
if value then |
||||
self:writeInt(1) |
||||
else |
||||
self:writeInt(0) |
||||
end |
||||
end |
||||
|
||||
function File:readBool() |
||||
return (self:readInt() == 1) |
||||
end |
||||
|
||||
local TYPE_NIL = 0 |
||||
local TYPE_NUMBER = 1 |
||||
local TYPE_STRING = 2 |
||||
local TYPE_TABLE = 3 |
||||
local TYPE_TORCH = 4 |
||||
local TYPE_BOOLEAN = 5 |
||||
local TYPE_FUNCTION = 6 |
||||
local TYPE_RECUR_FUNCTION = 8 |
||||
local LEGACY_TYPE_RECUR_FUNCTION = 7 |
||||
|
||||
-- Lua 5.2 compatibility |
||||
local loadstring = loadstring or load |
||||
|
||||
function File:isWritableObject(object) |
||||
local typename = type(object) |
||||
local typeidx |
||||
if type(object) ~= 'boolean' and not object then |
||||
typeidx = TYPE_NIL |
||||
elseif torch.typename(object) and torch.factory(torch.typename(object)) then |
||||
typeidx = TYPE_TORCH |
||||
elseif typename == 'table' then |
||||
typeidx = TYPE_TABLE |
||||
elseif typename == 'number' then |
||||
typeidx = TYPE_NUMBER |
||||
elseif typename == 'string' then |
||||
typeidx = TYPE_STRING |
||||
elseif typename == 'boolean' then |
||||
typeidx = TYPE_BOOLEAN |
||||
elseif typename == 'function' and pcall(string.dump, object) then |
||||
typeidx = TYPE_RECUR_FUNCTION |
||||
end |
||||
return typeidx |
||||
end |
||||
|
||||
function File:referenced(ref) |
||||
-- we use an environment to keep a record of written objects |
||||
if not torch.getenv(self).writeObjects then |
||||
torch.setenv(self, {writeObjects={}, writeObjectsRef={}, readObjects={}}) |
||||
end |
||||
local env = torch.getenv(self) |
||||
env.force = not ref |
||||
torch.setenv(self,env) |
||||
return self |
||||
end |
||||
|
||||
function File:isReferenced() |
||||
-- if no environment, then no forcing setup yet |
||||
if not torch.getenv(self).writeObjects then |
||||
return true |
||||
end |
||||
local env = torch.getenv(self) |
||||
return not env.force |
||||
end |
||||
|
||||
local function getmetamethod(obj, name) |
||||
local func |
||||
local status |
||||
|
||||
-- check getmetatable(obj).__name or |
||||
-- check getmetatable(obj).name |
||||
status, func = pcall( |
||||
function() |
||||
-- note that sometimes the metatable is hidden |
||||
-- we get it for sure through the torch type system |
||||
local mt = torch.getmetatable(torch.typename(obj)) |
||||
if mt then |
||||
return mt['__' .. name] or mt[name] |
||||
end |
||||
end |
||||
) |
||||
if status and type(func) == 'function' then |
||||
return func |
||||
end |
||||
end |
||||
|
||||
function File:writeObject(object) |
||||
-- we use an environment to keep a record of written objects |
||||
if not torch.getenv(self).writeObjects then |
||||
torch.setenv(self, {writeObjects={}, writeObjectsRef={}, readObjects={}}) |
||||
end |
||||
|
||||
local force = torch.getenv(self).force |
||||
|
||||
-- if nil object, only write the type and return |
||||
if type(object) ~= 'boolean' and not object then |
||||
self:writeInt(TYPE_NIL) |
||||
return |
||||
end |
||||
|
||||
-- check the type we are dealing with |
||||
local typeidx = self:isWritableObject(object) |
||||
if not typeidx then |
||||
error(string.format('Unwritable object <%s>', type(object))) |
||||
end |
||||
self:writeInt(typeidx) |
||||
|
||||
if typeidx == TYPE_NUMBER then |
||||
self:writeDouble(object) |
||||
elseif typeidx == TYPE_BOOLEAN then |
||||
self:writeBool(object) |
||||
elseif typeidx == TYPE_STRING then |
||||
local stringStorage = torch.CharStorage():string(object) |
||||
self:writeInt(#stringStorage) |
||||
self:writeChar(stringStorage) |
||||
elseif typeidx == TYPE_TORCH or typeidx == TYPE_TABLE or typeidx == TYPE_RECUR_FUNCTION then |
||||
-- check it exists already (we look at the pointer!) |
||||
local objects = torch.getenv(self).writeObjects |
||||
local objectsRef = torch.getenv(self).writeObjectsRef |
||||
local index = objects[torch.pointer(object)] |
||||
|
||||
if index and (not force) then |
||||
-- if already exists, write only its index |
||||
self:writeInt(index) |
||||
else |
||||
-- else write the object itself |
||||
index = objects.nWriteObject or 0 |
||||
index = index + 1 |
||||
objects[torch.pointer(object)] = index |
||||
if not force then |
||||
objectsRef[object] = index -- we make sure the object is not going to disappear |
||||
end |
||||
self:writeInt(index) |
||||
objects.nWriteObject = index |
||||
if typeidx == TYPE_RECUR_FUNCTION then |
||||
local upvalues = {} |
||||
local counter = 0 |
||||
while true do |
||||
counter = counter + 1 |
||||
local name,value = debug.getupvalue(object, counter) |
||||
if not name then break end |
||||
if name == '_ENV' then value = nil end |
||||
table.insert(upvalues, {name=name, value=value}) |
||||
end |
||||
local dumped = string.dump(object) |
||||
local stringStorage = torch.CharStorage():string(dumped) |
||||
self:writeInt(#stringStorage) |
||||
self:writeChar(stringStorage) |
||||
self:writeObject(upvalues) |
||||
elseif typeidx == TYPE_TORCH then |
||||
local version = torch.CharStorage():string('V ' .. torch.version(object)) |
||||
local className = torch.CharStorage():string(torch.typename(object)) |
||||
self:writeInt(#version) |
||||
self:writeChar(version) |
||||
self:writeInt(#className) |
||||
self:writeChar(className) |
||||
local write = getmetamethod(object, 'write') |
||||
if write then |
||||
write(object, self) |
||||
elseif type(object) == 'table' then |
||||
local var = {} |
||||
for k,v in pairs(object) do |
||||
if self:isWritableObject(v) then |
||||
var[k] = v |
||||
else |
||||
print(string.format('$ Warning: cannot write object field <%s>', k)) |
||||
end |
||||
end |
||||
self:writeObject(var) |
||||
else |
||||
error(string.format('<%s> is a non-serializable Torch object', torch.typename(object))) |
||||
end |
||||
else -- it is a table |
||||
local size = 0; for k,v in pairs(object) do size = size + 1 end |
||||
self:writeInt(size) |
||||
for k,v in pairs(object) do |
||||
self:writeObject(k) |
||||
self:writeObject(v) |
||||
end |
||||
end |
||||
end |
||||
else |
||||
error('Unwritable object') |
||||
end |
||||
end |
||||
|
||||
function File:readObject() |
||||
-- we use an environment to keep a record of read objects |
||||
if not torch.getenv(self).writeObjects then |
||||
torch.setenv(self, {writeObjects={}, writeObjectsRef={}, readObjects={}}) |
||||
end |
||||
|
||||
local force = torch.getenv(self).force |
||||
|
||||
-- read the typeidx |
||||
local typeidx = self:readInt() |
||||
|
||||
-- is it nil? |
||||
if typeidx == TYPE_NIL then |
||||
return nil |
||||
end |
||||
|
||||
if typeidx == TYPE_NUMBER then |
||||
return self:readDouble() |
||||
elseif typeidx == TYPE_BOOLEAN then |
||||
return self:readBool() |
||||
elseif typeidx == TYPE_STRING then |
||||
local size = self:readInt() |
||||
return self:readChar(size):string() |
||||
elseif typeidx == TYPE_FUNCTION then |
||||
local size = self:readInt() |
||||
local dumped = self:readChar(size):string() |
||||
local func = loadstring(dumped) |
||||
local upvalues = self:readObject() |
||||
for index,upvalue in ipairs(upvalues) do |
||||
debug.setupvalue(func, index, upvalue) |
||||
end |
||||
return func |
||||
elseif typeidx == TYPE_TABLE or typeidx == TYPE_TORCH or typeidx == TYPE_RECUR_FUNCTION or typeidx == LEGACY_TYPE_RECUR_FUNCTION then |
||||
-- read the index |
||||
local index = self:readInt() |
||||
|
||||
-- check it is loaded already |
||||
local objects = torch.getenv(self).readObjects |
||||
if objects[index] and not force then |
||||
return objects[index] |
||||
end |
||||
|
||||
-- otherwise read it |
||||
if typeidx == TYPE_RECUR_FUNCTION or typeidx == LEGACY_TYPE_RECUR_FUNCTION then |
||||
local size = self:readInt() |
||||
local dumped = self:readChar(size):string() |
||||
local func = loadstring(dumped) |
||||
objects[index] = func |
||||
local upvalues = self:readObject() |
||||
for index,upvalue in ipairs(upvalues) do |
||||
if typeidx == LEGACY_TYPE_RECUR_FUNCTION then |
||||
debug.setupvalue(func, index, upvalue) |
||||
elseif upvalue.name == '_ENV' then |
||||
debug.setupvalue(func, index, _ENV) |
||||
else |
||||
debug.setupvalue(func, index, upvalue.value) |
||||
end |
||||
end |
||||
return func |
||||
elseif typeidx == TYPE_TORCH then |
||||
local version, className, versionNumber |
||||
version = self:readChar(self:readInt()):string() |
||||
versionNumber = tonumber(string.match(version, '^V (.*)$')) |
||||
if not versionNumber then |
||||
className = version |
||||
versionNumber = 0 -- file created before existence of versioning system |
||||
else |
||||
className = self:readChar(self:readInt()):string() |
||||
end |
||||
if not torch.factory(className) then |
||||
error(string.format('unknown Torch class <%s>', tostring(className))) |
||||
end |
||||
local object = torch.factory(className)(self) |
||||
objects[index] = object |
||||
local read = getmetamethod(object, 'read') |
||||
if read then |
||||
read(object, self, versionNumber) |
||||
elseif type(object) == 'table' then |
||||
local var = self:readObject() |
||||
for k,v in pairs(var) do |
||||
object[k] = v |
||||
end |
||||
else |
||||
error(string.format('Cannot load object class <%s>', tostring(className))) |
||||
end |
||||
return object |
||||
else -- it is a table |
||||
local size = self:readInt() |
||||
local object = {} |
||||
objects[index] = object |
||||
for i = 1,size do |
||||
local k = self:readObject() |
||||
local v = self:readObject() |
||||
object[k] = v |
||||
end |
||||
return object |
||||
end |
||||
else |
||||
error('unknown object') |
||||
end |
||||
end |
||||
|
||||
-- simple helpers to save/load arbitrary objects/tables |
||||
function torch.save(filename, object, mode) |
||||
mode = mode or 'binary' |
||||
local file = torch.DiskFile(filename, 'w') |
||||
file[mode](file) |
||||
file:writeObject(object) |
||||
file:close() |
||||
end |
||||
|
||||
function torch.load(filename, mode) |
||||
mode = mode or 'binary' |
||||
local file = torch.DiskFile(filename, 'r') |
||||
file[mode](file) |
||||
local object = file:readObject() |
||||
file:close() |
||||
return object |
||||
end |
||||
|
||||
-- simple helpers to serialize/deserialize arbitrary objects/tables |
||||
function torch.serialize(object, mode) |
||||
local storage = torch.serializeToStorage(object, mode) |
||||
return storage:string() |
||||
end |
||||
|
||||
-- Serialize to a CharStorage, not a lua string. This avoids |
||||
function torch.serializeToStorage(object, mode) |
||||
mode = mode or 'binary' |
||||
local f = torch.MemoryFile() |
||||
f = f[mode](f) |
||||
f:writeObject(object) |
||||
local storage = f:storage() |
||||
f:close() |
||||
return storage |
||||
end |
||||
|
||||
function torch.deserializeFromStorage(storage, mode) |
||||
mode = mode or 'binary' |
||||
local tx = torch.CharTensor(storage) |
||||
local xp = torch.CharStorage(tx:size(1)+1) |
||||
local txp = torch.CharTensor(xp) |
||||
txp:narrow(1,1,tx:size(1)):copy(tx) |
||||
txp[tx:size(1)+1] = 0 |
||||
local f = torch.MemoryFile(xp) |
||||
f = f[mode](f) |
||||
local object = f:readObject() |
||||
f:close() |
||||
return object |
||||
end |
||||
|
||||
function torch.deserialize(str, mode) |
||||
local storage = torch.CharStorage():string(str) |
||||
return torch.deserializeFromStorage(storage, mode) |
||||
end |
||||
|
||||
-- public API (saveobj/loadobj are safe for global import) |
||||
torch.saveobj = torch.save |
||||
torch.loadobj = torch.load |
@ -0,0 +1,609 @@ |
||||
#include "THGeneral.h" |
||||
#include "THDiskFile.h" |
||||
#include "THFilePrivate.h" |
||||
|
||||
extern "C" |
||||
{ |
||||
|
||||
typedef struct THDiskFile__ |
||||
{ |
||||
THFile file; |
||||
|
||||
FILE *handle; |
||||
char *name; |
||||
int isNativeEncoding; |
||||
|
||||
} THDiskFile; |
||||
|
||||
static int THDiskFile_isOpened(THFile *self) |
||||
{ |
||||
THDiskFile *dfself = (THDiskFile*)self; |
||||
return (dfself->handle != NULL); |
||||
} |
||||
|
||||
const char *THDiskFile_name(THFile *self) |
||||
{ |
||||
THDiskFile *dfself = (THDiskFile*)self; |
||||
return dfself->name; |
||||
} |
||||
|
||||
/* workaround mac osx lion ***insane*** fread bug */ |
||||
#ifdef __APPLE__ |
||||
size_t fread__(void *ptr, size_t size, size_t nitems, FILE *stream) |
||||
{ |
||||
size_t nread = 0; |
||||
while(!feof(stream) && !ferror(stream) && (nread < nitems)) |
||||
nread += fread((char*)ptr+nread*size, size, THMin(2147483648/size, nitems-nread), stream); |
||||
return nread; |
||||
} |
||||
#else |
||||
#define fread__ fread |
||||
#endif |
||||
|
||||
#define READ_WRITE_METHODS(TYPE, TYPEC, ASCII_READ_ELEM, ASCII_WRITE_ELEM) \ |
||||
static long THDiskFile_read##TYPEC(THFile *self, TYPE *data, long n) \
|
||||
{ \
|
||||
THDiskFile *dfself = (THDiskFile*)(self); \
|
||||
long nread = 0L; \
|
||||
\
|
||||
THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file"); \
|
||||
THArgCheck(dfself->file.isReadable, 1, "attempt to read in a write-only file"); \
|
||||
\
|
||||
if(dfself->file.isBinary) \
|
||||
{ \
|
||||
nread = fread__(data, sizeof(TYPE), n, dfself->handle); \
|
||||
if(!dfself->isNativeEncoding && (sizeof(TYPE) > 1) && (nread > 0)) \
|
||||
THDiskFile_reverseMemory(data, data, sizeof(TYPE), nread); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
long i; \
|
||||
for(i = 0; i < n; i++) \
|
||||
{ \
|
||||
ASCII_READ_ELEM; /* increment here result and break if wrong */ \
|
||||
} \
|
||||
if(dfself->file.isAutoSpacing && (n > 0)) \
|
||||
{ \
|
||||
int c = fgetc(dfself->handle); \
|
||||
if( (c != '\n') && (c != EOF) ) \
|
||||
ungetc(c, dfself->handle); \
|
||||
} \
|
||||
} \
|
||||
\
|
||||
if(nread != n) \
|
||||
{ \
|
||||
dfself->file.hasError = 1; /* shouldn't we put hasError to 0 all the time ? */ \
|
||||
if(!dfself->file.isQuiet) \
|
||||
THError("read error: read %d blocks instead of %d", nread, n); \
|
||||
} \
|
||||
\
|
||||
return nread; \
|
||||
} \
|
||||
\
|
||||
static long THDiskFile_write##TYPEC(THFile *self, TYPE *data, long n) \
|
||||
{ \
|
||||
THDiskFile *dfself = (THDiskFile*)(self); \
|
||||
long nwrite = 0L; \
|
||||
\
|
||||
THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file"); \
|
||||
THArgCheck(dfself->file.isWritable, 1, "attempt to write in a read-only file"); \
|
||||
\
|
||||
if(dfself->file.isBinary) \
|
||||
{ \
|
||||
if(dfself->isNativeEncoding) \
|
||||
{ \
|
||||
nwrite = fwrite(data, sizeof(TYPE), n, dfself->handle); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
if(sizeof(TYPE) > 1) \
|
||||
{ \
|
||||
char *buffer = (char*)THAlloc(sizeof(TYPE)*n); \
|
||||
THDiskFile_reverseMemory(buffer, data, sizeof(TYPE), n); \
|
||||
nwrite = fwrite(buffer, sizeof(TYPE), n, dfself->handle); \
|
||||
THFree(buffer); \
|
||||
} \
|
||||
else \
|
||||
nwrite = fwrite(data, sizeof(TYPE), n, dfself->handle); \
|
||||
} \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
long i; \
|
||||
for(i = 0; i < n; i++) \
|
||||
{ \
|
||||
ASCII_WRITE_ELEM; \
|
||||
if( dfself->file.isAutoSpacing && (i < n-1) ) \
|
||||
fprintf(dfself->handle, " "); \
|
||||
} \
|
||||
if(dfself->file.isAutoSpacing && (n > 0)) \
|
||||
fprintf(dfself->handle, "\n"); \
|
||||
} \
|
||||
\
|
||||
if(nwrite != n) \
|
||||
{ \
|
||||
dfself->file.hasError = 1; \
|
||||
if(!dfself->file.isQuiet) \
|
||||
THError("write error: wrote %d blocks instead of %d", nwrite, n); \
|
||||
} \
|
||||
\
|
||||
return nwrite; \
|
||||
} |
||||
|
||||
static int THDiskFile_mode(const char *mode, int *isReadable, int *isWritable) |
||||
{ |
||||
*isReadable = 0; |
||||
*isWritable = 0; |
||||
if(strlen(mode) == 1) |
||||
{ |
||||
if(*mode == 'r') |
||||
{ |
||||
*isReadable = 1; |
||||
return 1; |
||||
} |
||||
else if(*mode == 'w') |
||||
{ |
||||
*isWritable = 1; |
||||
return 1; |
||||
} |
||||
} |
||||
else if(strlen(mode) == 2) |
||||
{ |
||||
if(mode[0] == 'r' && mode[1] == 'w') |
||||
{ |
||||
*isReadable = 1; |
||||
*isWritable = 1; |
||||
return 1; |
||||
} |
||||
} |
||||
return 0; |
||||
} |
||||
|
||||
static void THDiskFile_synchronize(THFile *self) |
||||
{ |
||||
THDiskFile *dfself = (THDiskFile*)(self); |
||||
THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file"); |
||||
fflush(dfself->handle); |
||||
} |
||||
|
||||
static void THDiskFile_seek(THFile *self, long position) |
||||
{ |
||||
THDiskFile *dfself = (THDiskFile*)(self); |
||||
|
||||
THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file"); |
||||
THArgCheck(position >= 0, 2, "position must be positive"); |
||||
|
||||
if(fseek(dfself->handle, position, SEEK_SET) < 0) |
||||
{ |
||||
dfself->file.hasError = 1; |
||||
if(!dfself->file.isQuiet) |
||||
THError("unable to seek at position %d", position); |
||||
} |
||||
} |
||||
|
||||
static void THDiskFile_seekEnd(THFile *self) |
||||
{ |
||||
THDiskFile *dfself = (THDiskFile*)(self); |
||||
|
||||
THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file"); |
||||
|
||||
if(fseek(dfself->handle, 0L, SEEK_END) < 0) |
||||
{ |
||||
dfself->file.hasError = 1; |
||||
if(!dfself->file.isQuiet) |
||||
THError("unable to seek at end of file"); |
||||
} |
||||
} |
||||
|
||||
static long THDiskFile_position(THFile *self) |
||||
{ |
||||
THDiskFile *dfself = (THDiskFile*)(self); |
||||
THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file"); |
||||
return ftell(dfself->handle); |
||||
} |
||||
|
||||
static void THDiskFile_close(THFile *self) |
||||
{ |
||||
THDiskFile *dfself = (THDiskFile*)(self); |
||||
THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file"); |
||||
fclose(dfself->handle); |
||||
dfself->handle = NULL; |
||||
} |
||||
|
||||
/* Little and Big Endian */ |
||||
|
||||
static void THDiskFile_reverseMemory(void *dst, const void *src, long blockSize, long numBlocks) |
||||
{ |
||||
if(blockSize != 1) |
||||
{ |
||||
long halfBlockSize = blockSize/2; |
||||
char *charSrc = (char*)src; |
||||
char *charDst = (char*)dst; |
||||
long b, i; |
||||
for(b = 0; b < numBlocks; b++) |
||||
{ |
||||
for(i = 0; i < halfBlockSize; i++) |
||||
{ |
||||
char z = charSrc[i]; |
||||
charDst[i] = charSrc[blockSize-1-i]; |
||||
charDst[blockSize-1-i] = z; |
||||
} |
||||
charSrc += blockSize; |
||||
charDst += blockSize; |
||||
} |
||||
} |
||||
} |
||||
|
||||
int THDiskFile_isLittleEndianCPU(void) |
||||
{ |
||||
int x = 7; |
||||
char *ptr = (char *)&x; |
||||
|
||||
if(ptr[0] == 0) |
||||
return 0; |
||||
else |
||||
return 1; |
||||
} |
||||
|
||||
int THDiskFile_isBigEndianCPU(void) |
||||
{ |
||||
return(!THDiskFile_isLittleEndianCPU()); |
||||
} |
||||
|
||||
void THDiskFile_nativeEndianEncoding(THFile *self) |
||||
{ |
||||
THDiskFile *dfself = (THDiskFile*)(self); |
||||
THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file"); |
||||
dfself->isNativeEncoding = 1; |
||||
} |
||||
|
||||
void THDiskFile_littleEndianEncoding(THFile *self) |
||||
{ |
||||
THDiskFile *dfself = (THDiskFile*)(self); |
||||
THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file"); |
||||
dfself->isNativeEncoding = THDiskFile_isLittleEndianCPU(); |
||||
} |
||||
|
||||
void THDiskFile_bigEndianEncoding(THFile *self) |
||||
{ |
||||
THDiskFile *dfself = (THDiskFile*)(self); |
||||
THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file"); |
||||
dfself->isNativeEncoding = !THDiskFile_isLittleEndianCPU(); |
||||
} |
||||
|
||||
/* End of Little and Big Endian Stuff */ |
||||
|
||||
static void THDiskFile_free(THFile *self) |
||||
{ |
||||
THDiskFile *dfself = (THDiskFile*)(self); |
||||
if(dfself->handle) |
||||
fclose(dfself->handle); |
||||
THFree(dfself->name); |
||||
THFree(dfself); |
||||
} |
||||
|
||||
/* READ_WRITE_METHODS(int, Bool, */ |
||||
/* int value = 0; int ret = fscanf(file->handle, "%d", &value); array[i] = (value ? 1 : 0); if(ret <= 0) break; else result++, */ |
||||
/* int value = (array[i] ? 1 : 0); nElemWritten = fprintf(file->handle, "%d", value), */ |
||||
/* true) */ |
||||
|
||||
/* Note that we do a trick */ |
||||
READ_WRITE_METHODS(unsigned char, Byte, |
||||
nread = fread(data, 1, n, dfself->handle); break, |
||||
nwrite = fwrite(data, 1, n, dfself->handle); break) |
||||
|
||||
READ_WRITE_METHODS(char, Char, |
||||
nread = fread(data, 1, n, dfself->handle); break, |
||||
nwrite = fwrite(data, 1, n, dfself->handle); break) |
||||
|
||||
READ_WRITE_METHODS(short, Short, |
||||
int ret = fscanf(dfself->handle, "%hd", &data[i]); if(ret <= 0) break; else nread++, |
||||
int ret = fprintf(dfself->handle, "%hd", data[i]); if(ret <= 0) break; else nwrite++) |
||||
|
||||
READ_WRITE_METHODS(int, Int, |
||||
int ret = fscanf(dfself->handle, "%d", &data[i]); if(ret <= 0) break; else nread++, |
||||
int ret = fprintf(dfself->handle, "%d", data[i]); if(ret <= 0) break; else nwrite++) |
||||
|
||||
READ_WRITE_METHODS(long, Long, |
||||
int ret = fscanf(dfself->handle, "%ld", &data[i]); if(ret <= 0) break; else nread++, |
||||
int ret = fprintf(dfself->handle, "%ld", data[i]); if(ret <= 0) break; else nwrite++) |
||||
|
||||
READ_WRITE_METHODS(float, Float, |
||||
int ret = fscanf(dfself->handle, "%g", &data[i]); if(ret <= 0) break; else nread++, |
||||
int ret = fprintf(dfself->handle, "%.9g", data[i]); if(ret <= 0) break; else nwrite++) |
||||
|
||||
READ_WRITE_METHODS(double, Double, |
||||
int ret = fscanf(dfself->handle, "%lg", &data[i]); if(ret <= 0) break; else nread++, |
||||
int ret = fprintf(dfself->handle, "%.17g", data[i]); if(ret <= 0) break; else nwrite++) |
||||
|
||||
static long THDiskFile_readString(THFile *self, const char *format, char **str_) |
||||
{ |
||||
THDiskFile *dfself = (THDiskFile*)(self); |
||||
THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file"); |
||||
THArgCheck(dfself->file.isReadable, 1, "attempt to read in a write-only file"); |
||||
THArgCheck((strlen(format) >= 2 ? (format[0] == '*') && (format[1] == 'a' || format[1] == 'l') : 0), 2, "format must be '*a' or '*l'"); |
||||
|
||||
/* note: the string won't survive long, as it is copied into lua */ |
||||
/* so 1024 is not that big... */ |
||||
#define TBRS_BSZ 1024L |
||||
|
||||
if(format[1] == 'a') |
||||
{ |
||||
char *p = (char*)THAlloc(TBRS_BSZ); |
||||
long total = TBRS_BSZ; |
||||
long pos = 0L; |
||||
|
||||
for (;;) |
||||
{ |
||||
if(total-pos == 0) /* we need more space! */ |
||||
{ |
||||
total += TBRS_BSZ; |
||||
p = (char*)THRealloc(p, total); |
||||
} |
||||
pos += fread(p+pos, 1, total-pos, dfself->handle); |
||||
if (pos < total) /* eof? */ |
||||
{ |
||||
if(pos == 0L) |
||||
{ |
||||
THFree(p); |
||||
dfself->file.hasError = 1; |
||||
if(!dfself->file.isQuiet) |
||||
THError("read error: read 0 blocks instead of 1"); |
||||
|
||||
*str_ = NULL; |
||||
return 0; |
||||
} |
||||
*str_ = p; |
||||
return pos; |
||||
} |
||||
}
|
||||
} |
||||
else |
||||
{ |
||||
char *p = (char*)THAlloc(TBRS_BSZ); |
||||
long total = TBRS_BSZ; |
||||
long pos = 0L; |
||||
long size; |
||||
|
||||
for (;;) |
||||
{ |
||||
if(total-pos <= 1) /* we can only write '\0' in there! */ |
||||
{ |
||||
total += TBRS_BSZ; |
||||
p = (char*)THRealloc(p, total); |
||||
} |
||||
if (fgets(p+pos, total-pos, dfself->handle) == NULL) /* eof? */ |
||||
{ |
||||
if(pos == 0L) |
||||
{ |
||||
THFree(p); |
||||
dfself->file.hasError = 1; |
||||
if(!dfself->file.isQuiet) |
||||
THError("read error: read 0 blocks instead of 1"); |
||||
|
||||
*str_ = NULL; |
||||
return 0; |
||||
} |
||||
*str_ = p; |
||||
return pos; |
||||
} |
||||
size = strlen(p+pos); |
||||
if (size == 0L || (p+pos)[size-1] != '\n') |
||||
{ |
||||
pos += size; |
||||
} |
||||
else |
||||
{ |
||||
pos += size-1L; /* do not include `eol' */ |
||||
*str_ = p; |
||||
return pos; |
||||
} |
||||
} |
||||
} |
||||
|
||||
*str_ = NULL; |
||||
return 0; |
||||
} |
||||
|
||||
|
||||
static long THDiskFile_writeString(THFile *self, const char *str, long size) |
||||
{ |
||||
THDiskFile *dfself = (THDiskFile*)(self); |
||||
long nwrite; |
||||
|
||||
THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file"); |
||||
THArgCheck(dfself->file.isWritable, 1, "attempt to write in a read-only file"); |
||||
|
||||
nwrite = fwrite(str, 1, size, dfself->handle); |
||||
if(nwrite != size) |
||||
{ |
||||
dfself->file.hasError = 1; |
||||
if(!dfself->file.isQuiet) |
||||
THError("write error: wrote %ld blocks instead of %ld", nwrite, size); |
||||
} |
||||
|
||||
return nwrite; |
||||
} |
||||
|
||||
THFile *THDiskFile_new(const char *name, const char *mode, int isQuiet) |
||||
{ |
||||
static struct THFileVTable vtable = { |
||||
THDiskFile_isOpened, |
||||
|
||||
THDiskFile_readByte, |
||||
THDiskFile_readChar, |
||||
THDiskFile_readShort, |
||||
THDiskFile_readInt, |
||||
THDiskFile_readLong, |
||||
THDiskFile_readFloat, |
||||
THDiskFile_readDouble, |
||||
THDiskFile_readString, |
||||
|
||||
THDiskFile_writeByte, |
||||
THDiskFile_writeChar, |
||||
THDiskFile_writeShort, |
||||
THDiskFile_writeInt, |
||||
THDiskFile_writeLong, |
||||
THDiskFile_writeFloat, |
||||
THDiskFile_writeDouble, |
||||
THDiskFile_writeString, |
||||
|
||||
THDiskFile_synchronize, |
||||
THDiskFile_seek, |
||||
THDiskFile_seekEnd, |
||||
THDiskFile_position, |
||||
THDiskFile_close, |
||||
THDiskFile_free |
||||
}; |
||||
|
||||
int isReadable; |
||||
int isWritable; |
||||
FILE *handle; |
||||
THDiskFile *self; |
||||
|
||||
THArgCheck(THDiskFile_mode(mode, &isReadable, &isWritable), 2, "file mode should be 'r','w' or 'rw'"); |
||||
|
||||
if( isReadable && isWritable ) |
||||
{ |
||||
handle = fopen(name, "r+b"); |
||||
if(!handle) |
||||
{ |
||||
handle = fopen(name, "wb"); |
||||
if(handle) |
||||
{ |
||||
fclose(handle); |
||||
handle = fopen(name, "r+b"); |
||||
} |
||||
} |
||||
} |
||||
else |
||||
handle = fopen(name, (isReadable ? "rb" : "wb")); |
||||
|
||||
if(!handle) |
||||
{ |
||||
if(isQuiet) |
||||
return 0; |
||||
else |
||||
THError("cannot open <%s> in mode %c%c", name, (isReadable ? 'r' : ' '), (isWritable ? 'w' : ' ')); |
||||
} |
||||
|
||||
self = (THDiskFile*)THAlloc(sizeof(THDiskFile)); |
||||
|
||||
self->handle = handle; |
||||
self->name = (char*)THAlloc(strlen(name)+1); |
||||
strcpy(self->name, name); |
||||
self->isNativeEncoding = 1; |
||||
|
||||
self->file.vtable = &vtable; |
||||
self->file.isQuiet = isQuiet; |
||||
self->file.isReadable = isReadable; |
||||
self->file.isWritable = isWritable; |
||||
self->file.isBinary = 0; |
||||
self->file.isAutoSpacing = 1; |
||||
self->file.hasError = 0; |
||||
|
||||
return (THFile*)self; |
||||
} |
||||
|
||||
/* PipeFile */ |
||||
|
||||
static int THPipeFile_mode(const char *mode, int *isReadable, int *isWritable) |
||||
{ |
||||
*isReadable = 0; |
||||
*isWritable = 0; |
||||
if(strlen(mode) == 1) |
||||
{ |
||||
if(*mode == 'r') |
||||
{ |
||||
*isReadable = 1; |
||||
return 1; |
||||
} |
||||
else if(*mode == 'w') |
||||
{ |
||||
*isWritable = 1; |
||||
return 1; |
||||
} |
||||
} |
||||
return 0; |
||||
} |
||||
|
||||
static void THPipeFile_free(THFile *self) |
||||
{ |
||||
THDiskFile *dfself = (THDiskFile*)(self); |
||||
if(dfself->handle) |
||||
pclose(dfself->handle); |
||||
THFree(dfself->name); |
||||
THFree(dfself); |
||||
} |
||||
|
||||
THFile *THPipeFile_new(const char *name, const char *mode, int isQuiet) |
||||
{ |
||||
static struct THFileVTable vtable = { |
||||
THDiskFile_isOpened, |
||||
|
||||
THDiskFile_readByte, |
||||
THDiskFile_readChar, |
||||
THDiskFile_readShort, |
||||
THDiskFile_readInt, |
||||
THDiskFile_readLong, |
||||
THDiskFile_readFloat, |
||||
THDiskFile_readDouble, |
||||
THDiskFile_readString, |
||||
|
||||
THDiskFile_writeByte, |
||||
THDiskFile_writeChar, |
||||
THDiskFile_writeShort, |
||||
THDiskFile_writeInt, |
||||
THDiskFile_writeLong, |
||||
THDiskFile_writeFloat, |
||||
THDiskFile_writeDouble, |
||||
THDiskFile_writeString, |
||||
|
||||
THDiskFile_synchronize, |
||||
THDiskFile_seek, |
||||
THDiskFile_seekEnd, |
||||
THDiskFile_position, |
||||
THDiskFile_close, |
||||
THPipeFile_free |
||||
}; |
||||
|
||||
int isReadable; |
||||
int isWritable; |
||||
FILE *handle; |
||||
THDiskFile *self; |
||||
|
||||
THArgCheck(THPipeFile_mode(mode, &isReadable, &isWritable), 2, "file mode should be 'r','w'"); |
||||
|
||||
#ifdef _WIN32 |
||||
handle = popen(name, (isReadable ? "rb" : "wb")); |
||||
#else |
||||
handle = popen(name, (isReadable ? "r" : "w")); |
||||
#endif |
||||
|
||||
if(!handle) |
||||
{ |
||||
if(isQuiet) |
||||
return 0; |
||||
else |
||||
THError("cannot open <%s> in mode %c%c", name, (isReadable ? 'r' : ' '), (isWritable ? 'w' : ' ')); |
||||
} |
||||
|
||||
self = (THDiskFile*)THAlloc(sizeof(THDiskFile)); |
||||
|
||||
self->handle = handle; |
||||
self->name = (char*)THAlloc(strlen(name)+1); |
||||
strcpy(self->name, name); |
||||
self->isNativeEncoding = 1; |
||||
|
||||
self->file.vtable = &vtable; |
||||
self->file.isQuiet = isQuiet; |
||||
self->file.isReadable = isReadable; |
||||
self->file.isWritable = isWritable; |
||||
self->file.isBinary = 0; |
||||
self->file.isAutoSpacing = 1; |
||||
self->file.hasError = 0; |
||||
|
||||
return (THFile*)self; |
||||
} |
||||
|
||||
} |
@ -0,0 +1,17 @@ |
||||
#ifndef TH_DISK_FILE_INC |
||||
#define TH_DISK_FILE_INC |
||||
|
||||
#include "THFile.h" |
||||
|
||||
TH_API THFile *THDiskFile_new(const char *name, const char *mode, int isQuiet); |
||||
TH_API THFile *THPipeFile_new(const char *name, const char *mode, int isQuiet); |
||||
|
||||
TH_API const char *THDiskFile_name(THFile *self); |
||||
|
||||
TH_API int THDiskFile_isLittleEndianCPU(void); |
||||
TH_API int THDiskFile_isBigEndianCPU(void); |
||||
TH_API void THDiskFile_nativeEndianEncoding(THFile *self); |
||||
TH_API void THDiskFile_littleEndianEncoding(THFile *self); |
||||
TH_API void THDiskFile_bigEndianEncoding(THFile *self); |
||||
|
||||
#endif |
@ -0,0 +1,161 @@ |
||||
#include "THFile.h" |
||||
#include "THFilePrivate.h" |
||||
|
||||
extern "C" |
||||
{ |
||||
|
||||
#define IMPLEMENT_THFILE_RW(TYPEC, TYPE) \ |
||||
long THFile_read##TYPEC##Raw(THFile *self, TYPE *data, long n) \
|
||||
{ \
|
||||
return (*self->vtable->read##TYPEC)(self, data, n); \
|
||||
} \
|
||||
\
|
||||
long THFile_write##TYPEC##Raw(THFile *self, TYPE *data, long n) \
|
||||
{ \
|
||||
return (*self->vtable->write##TYPEC)(self, data, n); \
|
||||
} |
||||
|
||||
IMPLEMENT_THFILE_RW(Byte, unsigned char) |
||||
IMPLEMENT_THFILE_RW(Char, char) |
||||
IMPLEMENT_THFILE_RW(Short, short) |
||||
IMPLEMENT_THFILE_RW(Int, int) |
||||
IMPLEMENT_THFILE_RW(Long, long) |
||||
IMPLEMENT_THFILE_RW(Float, float) |
||||
IMPLEMENT_THFILE_RW(Double, double) |
||||
|
||||
long THFile_readStringRaw(THFile *self, const char *format, char **str_) |
||||
{ |
||||
return self->vtable->readString(self, format, str_); |
||||
} |
||||
|
||||
long THFile_writeStringRaw(THFile *self, const char *str, long size) |
||||
{ |
||||
return self->vtable->writeString(self, str, size); |
||||
} |
||||
|
||||
void THFile_synchronize(THFile *self) |
||||
{ |
||||
self->vtable->synchronize(self); |
||||
} |
||||
|
||||
void THFile_seek(THFile *self, long position) |
||||
{ |
||||
self->vtable->seek(self, position); |
||||
} |
||||
|
||||
void THFile_seekEnd(THFile *self) |
||||
{ |
||||
self->vtable->seekEnd(self); |
||||
} |
||||
|
||||
long THFile_position(THFile *self) |
||||
{ |
||||
return self->vtable->position(self); |
||||
} |
||||
|
||||
void THFile_close(THFile *self) |
||||
{ |
||||
self->vtable->close(self); |
||||
} |
||||
|
||||
void THFile_free(THFile *self) |
||||
{ |
||||
self->vtable->free(self); |
||||
} |
||||
|
||||
int THFile_isOpened(THFile *self) |
||||
{ |
||||
return self->vtable->isOpened(self); |
||||
} |
||||
|
||||
#define IMPLEMENT_THFILE_FLAGS(FLAG) \ |
||||
int THFile_##FLAG(THFile *self) \
|
||||
{ \
|
||||
return self->FLAG; \
|
||||
} |
||||
|
||||
IMPLEMENT_THFILE_FLAGS(isQuiet) |
||||
IMPLEMENT_THFILE_FLAGS(isReadable) |
||||
IMPLEMENT_THFILE_FLAGS(isWritable) |
||||
IMPLEMENT_THFILE_FLAGS(isBinary) |
||||
IMPLEMENT_THFILE_FLAGS(isAutoSpacing) |
||||
IMPLEMENT_THFILE_FLAGS(hasError) |
||||
|
||||
void THFile_binary(THFile *self) |
||||
{ |
||||
self->isBinary = 1; |
||||
} |
||||
|
||||
void THFile_ascii(THFile *self) |
||||
{ |
||||
self->isBinary = 0; |
||||
} |
||||
|
||||
void THFile_autoSpacing(THFile *self) |
||||
{ |
||||
self->isAutoSpacing = 1; |
||||
} |
||||
|
||||
void THFile_noAutoSpacing(THFile *self) |
||||
{ |
||||
self->isAutoSpacing = 0; |
||||
} |
||||
|
||||
void THFile_quiet(THFile *self) |
||||
{ |
||||
self->isQuiet = 1; |
||||
} |
||||
|
||||
void THFile_pedantic(THFile *self) |
||||
{ |
||||
self->isQuiet = 0; |
||||
} |
||||
|
||||
void THFile_clearError(THFile *self) |
||||
{ |
||||
self->hasError = 0; |
||||
} |
||||
|
||||
#define IMPLEMENT_THFILE_SCALAR(TYPEC, TYPE) \ |
||||
TYPE THFile_read##TYPEC##Scalar(THFile *self) \
|
||||
{ \
|
||||
TYPE scalar; \
|
||||
THFile_read##TYPEC##Raw(self, &scalar, 1); \
|
||||
return scalar; \
|
||||
} \
|
||||
\
|
||||
void THFile_write##TYPEC##Scalar(THFile *self, TYPE scalar) \
|
||||
{ \
|
||||
THFile_write##TYPEC##Raw(self, &scalar, 1); \
|
||||
} |
||||
|
||||
IMPLEMENT_THFILE_SCALAR(Byte, unsigned char) |
||||
IMPLEMENT_THFILE_SCALAR(Char, char) |
||||
IMPLEMENT_THFILE_SCALAR(Short, short) |
||||
IMPLEMENT_THFILE_SCALAR(Int, int) |
||||
IMPLEMENT_THFILE_SCALAR(Long, long) |
||||
IMPLEMENT_THFILE_SCALAR(Float, float) |
||||
IMPLEMENT_THFILE_SCALAR(Double, double) |
||||
|
||||
/*
|
||||
#define IMPLEMENT_THFILE_STORAGE(TYPEC, TYPE) \ |
||||
long THFile_read##TYPEC(THFile *self, TH##TYPEC##Storage *storage) \
|
||||
{ \
|
||||
return THFile_read##TYPEC##Raw(self, storage->data, storage->size); \
|
||||
} \
|
||||
\
|
||||
long THFile_write##TYPEC(THFile *self, TH##TYPEC##Storage *storage) \
|
||||
{ \
|
||||
return THFile_write##TYPEC##Raw(self, storage->data, storage->size); \
|
||||
} |
||||
|
||||
IMPLEMENT_THFILE_STORAGE(Byte, unsigned char) |
||||
IMPLEMENT_THFILE_STORAGE(Char, char) |
||||
IMPLEMENT_THFILE_STORAGE(Short, short) |
||||
IMPLEMENT_THFILE_STORAGE(Int, int) |
||||
IMPLEMENT_THFILE_STORAGE(Long, long) |
||||
IMPLEMENT_THFILE_STORAGE(Float, float) |
||||
IMPLEMENT_THFILE_STORAGE(Double, double) |
||||
*/ |
||||
|
||||
} |
@ -0,0 +1,87 @@ |
||||
#ifndef TH_FILE_INC |
||||
#define TH_FILE_INC |
||||
|
||||
//#include "THStorage.h"
|
||||
#include "THGeneral.h" |
||||
|
||||
typedef struct THFile__ THFile; |
||||
|
||||
TH_API int THFile_isOpened(THFile *self); |
||||
TH_API int THFile_isQuiet(THFile *self); |
||||
TH_API int THFile_isReadable(THFile *self); |
||||
TH_API int THFile_isWritable(THFile *self); |
||||
TH_API int THFile_isBinary(THFile *self); |
||||
TH_API int THFile_isAutoSpacing(THFile *self); |
||||
TH_API int THFile_hasError(THFile *self); |
||||
|
||||
TH_API void THFile_binary(THFile *self); |
||||
TH_API void THFile_ascii(THFile *self); |
||||
TH_API void THFile_autoSpacing(THFile *self); |
||||
TH_API void THFile_noAutoSpacing(THFile *self); |
||||
TH_API void THFile_quiet(THFile *self); |
||||
TH_API void THFile_pedantic(THFile *self); |
||||
TH_API void THFile_clearError(THFile *self); |
||||
|
||||
/* scalar */ |
||||
TH_API unsigned char THFile_readByteScalar(THFile *self); |
||||
TH_API char THFile_readCharScalar(THFile *self); |
||||
TH_API short THFile_readShortScalar(THFile *self); |
||||
TH_API int THFile_readIntScalar(THFile *self); |
||||
TH_API long THFile_readLongScalar(THFile *self); |
||||
TH_API float THFile_readFloatScalar(THFile *self); |
||||
TH_API double THFile_readDoubleScalar(THFile *self); |
||||
|
||||
TH_API void THFile_writeByteScalar(THFile *self, unsigned char scalar); |
||||
TH_API void THFile_writeCharScalar(THFile *self, char scalar); |
||||
TH_API void THFile_writeShortScalar(THFile *self, short scalar); |
||||
TH_API void THFile_writeIntScalar(THFile *self, int scalar); |
||||
TH_API void THFile_writeLongScalar(THFile *self, long scalar); |
||||
TH_API void THFile_writeFloatScalar(THFile *self, float scalar); |
||||
TH_API void THFile_writeDoubleScalar(THFile *self, double scalar); |
||||
|
||||
/* storage */ |
||||
/*
|
||||
TH_API long THFile_readByte(THFile *self, THByteStorage *storage); |
||||
TH_API long THFile_readChar(THFile *self, THCharStorage *storage); |
||||
TH_API long THFile_readShort(THFile *self, THShortStorage *storage); |
||||
TH_API long THFile_readInt(THFile *self, THIntStorage *storage); |
||||
TH_API long THFile_readLong(THFile *self, THLongStorage *storage); |
||||
TH_API long THFile_readFloat(THFile *self, THFloatStorage *storage); |
||||
TH_API long THFile_readDouble(THFile *self, THDoubleStorage *storage); |
||||
|
||||
TH_API long THFile_writeByte(THFile *self, THByteStorage *storage); |
||||
TH_API long THFile_writeChar(THFile *self, THCharStorage *storage); |
||||
TH_API long THFile_writeShort(THFile *self, THShortStorage *storage); |
||||
TH_API long THFile_writeInt(THFile *self, THIntStorage *storage); |
||||
TH_API long THFile_writeLong(THFile *self, THLongStorage *storage); |
||||
TH_API long THFile_writeFloat(THFile *self, THFloatStorage *storage); |
||||
TH_API long THFile_writeDouble(THFile *self, THDoubleStorage *storage); |
||||
*/ |
||||
|
||||
/* raw */ |
||||
TH_API long THFile_readByteRaw(THFile *self, unsigned char *data, long n); |
||||
TH_API long THFile_readCharRaw(THFile *self, char *data, long n); |
||||
TH_API long THFile_readShortRaw(THFile *self, short *data, long n); |
||||
TH_API long THFile_readIntRaw(THFile *self, int *data, long n); |
||||
TH_API long THFile_readLongRaw(THFile *self, long *data, long n); |
||||
TH_API long THFile_readFloatRaw(THFile *self, float *data, long n); |
||||
TH_API long THFile_readDoubleRaw(THFile *self, double *data, long n); |
||||
TH_API long THFile_readStringRaw(THFile *self, const char *format, char **str_); /* you must deallocate str_ */ |
||||
|
||||
TH_API long THFile_writeByteRaw(THFile *self, unsigned char *data, long n); |
||||
TH_API long THFile_writeCharRaw(THFile *self, char *data, long n); |
||||
TH_API long THFile_writeShortRaw(THFile *self, short *data, long n); |
||||
TH_API long THFile_writeIntRaw(THFile *self, int *data, long n); |
||||
TH_API long THFile_writeLongRaw(THFile *self, long *data, long n); |
||||
TH_API long THFile_writeFloatRaw(THFile *self, float *data, long n); |
||||
TH_API long THFile_writeDoubleRaw(THFile *self, double *data, long n); |
||||
TH_API long THFile_writeStringRaw(THFile *self, const char *str, long size); |
||||
|
||||
TH_API void THFile_synchronize(THFile *self); |
||||
TH_API void THFile_seek(THFile *self, long position); |
||||
TH_API void THFile_seekEnd(THFile *self); |
||||
TH_API long THFile_position(THFile *self); |
||||
TH_API void THFile_close(THFile *self); |
||||
TH_API void THFile_free(THFile *self); |
||||
|
||||
#endif |
@ -0,0 +1,43 @@ |
||||
struct THFile__ |
||||
{ |
||||
struct THFileVTable *vtable; |
||||
|
||||
int isQuiet; |
||||
int isReadable; |
||||
int isWritable; |
||||
int isBinary; |
||||
int isAutoSpacing; |
||||
int hasError; |
||||
}; |
||||
|
||||
/* virtual table definition */ |
||||
|
||||
struct THFileVTable |
||||
{ |
||||
int (*isOpened)(THFile *self); |
||||
|
||||
long (*readByte)(THFile *self, unsigned char *data, long n); |
||||
long (*readChar)(THFile *self, char *data, long n); |
||||
long (*readShort)(THFile *self, short *data, long n); |
||||
long (*readInt)(THFile *self, int *data, long n); |
||||
long (*readLong)(THFile *self, long *data, long n); |
||||
long (*readFloat)(THFile *self, float *data, long n); |
||||
long (*readDouble)(THFile *self, double *data, long n); |
||||
long (*readString)(THFile *self, const char *format, char **str_); |
||||
|
||||
long (*writeByte)(THFile *self, unsigned char *data, long n); |
||||
long (*writeChar)(THFile *self, char *data, long n); |
||||
long (*writeShort)(THFile *self, short *data, long n); |
||||
long (*writeInt)(THFile *self, int *data, long n); |
||||
long (*writeLong)(THFile *self, long *data, long n); |
||||
long (*writeFloat)(THFile *self, float *data, long n); |
||||
long (*writeDouble)(THFile *self, double *data, long n); |
||||
long (*writeString)(THFile *self, const char *str, long size); |
||||
|
||||
void (*synchronize)(THFile *self); |
||||
void (*seek)(THFile *self, long position); |
||||
void (*seekEnd)(THFile *self); |
||||
long (*position)(THFile *self); |
||||
void (*close)(THFile *self); |
||||
void (*free)(THFile *self); |
||||
}; |
@ -0,0 +1,254 @@ |
||||
#include "THGeneral.h" |
||||
|
||||
extern "C" |
||||
{ |
||||
|
||||
#ifndef TH_HAVE_THREAD |
||||
#define __thread |
||||
#endif |
||||
|
||||
#if defined(TH_DISABLE_HEAP_TRACKING) |
||||
#elif (defined(__unix) || defined(_WIN32)) |
||||
#include <malloc.h> |
||||
#elif defined(__APPLE__) |
||||
#include <malloc/malloc.h> |
||||
#endif |
||||
|
||||
/* Torch Error Handling */ |
||||
static void defaultTorchErrorHandlerFunction(const char *msg, void *data) |
||||
{ |
||||
printf("$ Error: %s\n", msg); |
||||
exit(-1); |
||||
} |
||||
|
||||
static __thread void (*torchErrorHandlerFunction)(const char *msg, void *data) = defaultTorchErrorHandlerFunction; |
||||
static __thread void *torchErrorHandlerData; |
||||
|
||||
void _THError(const char *file, const int line, const char *fmt, ...) |
||||
{ |
||||
char msg[2048]; |
||||
va_list args; |
||||
|
||||
/* vasprintf not standard */ |
||||
/* vsnprintf: how to handle if does not exists? */ |
||||
va_start(args, fmt); |
||||
int n = vsnprintf(msg, 2048, fmt, args); |
||||
va_end(args); |
||||
|
||||
if(n < 2048) { |
||||
snprintf(msg + n, 2048 - n, " at %s:%d", file, line); |
||||
} |
||||
|
||||
(*torchErrorHandlerFunction)(msg, torchErrorHandlerData); |
||||
} |
||||
|
||||
void _THAssertionFailed(const char *file, const int line, const char *exp, const char *fmt, ...) { |
||||
char msg[1024]; |
||||
va_list args; |
||||
va_start(args, fmt); |
||||
vsnprintf(msg, 1024, fmt, args); |
||||
va_end(args); |
||||
_THError(file, line, "Assertion `%s' failed. %s", exp, msg); |
||||
} |
||||
|
||||
void THSetErrorHandler( void (*torchErrorHandlerFunction_)(const char *msg, void *data), void *data ) |
||||
{ |
||||
if(torchErrorHandlerFunction_) |
||||
torchErrorHandlerFunction = torchErrorHandlerFunction_; |
||||
else |
||||
torchErrorHandlerFunction = defaultTorchErrorHandlerFunction; |
||||
torchErrorHandlerData = data; |
||||
} |
||||
|
||||
/* Torch Arg Checking Handling */ |
||||
static void defaultTorchArgErrorHandlerFunction(int argNumber, const char *msg, void *data) |
||||
{ |
||||
if(msg) |
||||
printf("$ Invalid argument %d: %s\n", argNumber, msg); |
||||
else |
||||
printf("$ Invalid argument %d\n", argNumber); |
||||
exit(-1); |
||||
} |
||||
|
||||
static __thread void (*torchArgErrorHandlerFunction)(int argNumber, const char *msg, void *data) = defaultTorchArgErrorHandlerFunction; |
||||
static __thread void *torchArgErrorHandlerData; |
||||
|
||||
void _THArgCheck(const char *file, int line, int condition, int argNumber, const char *fmt, ...) |
||||
{ |
||||
if(!condition) { |
||||
char msg[2048]; |
||||
va_list args; |
||||
|
||||
/* vasprintf not standard */ |
||||
/* vsnprintf: how to handle if does not exists? */ |
||||
va_start(args, fmt); |
||||
int n = vsnprintf(msg, 2048, fmt, args); |
||||
va_end(args); |
||||
|
||||
if(n < 2048) { |
||||
snprintf(msg + n, 2048 - n, " at %s:%d", file, line); |
||||
} |
||||
|
||||
(*torchArgErrorHandlerFunction)(argNumber, msg, torchArgErrorHandlerData); |
||||
} |
||||
} |
||||
|
||||
void THSetArgErrorHandler( void (*torchArgErrorHandlerFunction_)(int argNumber, const char *msg, void *data), void *data ) |
||||
{ |
||||
if(torchArgErrorHandlerFunction_) |
||||
torchArgErrorHandlerFunction = torchArgErrorHandlerFunction_; |
||||
else |
||||
torchArgErrorHandlerFunction = defaultTorchArgErrorHandlerFunction; |
||||
torchArgErrorHandlerData = data; |
||||
} |
||||
|
||||
static __thread void (*torchGCFunction)(void *data) = NULL; |
||||
static __thread void *torchGCData; |
||||
static __thread long torchHeapSize = 0; |
||||
static __thread long torchHeapSizeSoftMax = 300000000; // 300MB, adjusted upward dynamically
|
||||
|
||||
/* Optional hook for integrating with a garbage-collected frontend.
|
||||
* |
||||
* If torch is running with a garbage-collected frontend (e.g. Lua), |
||||
* the GC isn't aware of TH-allocated memory so may not know when it |
||||
* needs to run. These hooks trigger the GC to run in two cases: |
||||
* |
||||
* (1) When a memory allocation (malloc, realloc, ...) fails |
||||
* (2) When the total TH-allocated memory hits a dynamically-adjusted |
||||
* soft maximum. |
||||
*/ |
||||
void THSetGCHandler( void (*torchGCFunction_)(void *data), void *data ) |
||||
{ |
||||
torchGCFunction = torchGCFunction_; |
||||
torchGCData = data; |
||||
} |
||||
|
||||
static long getAllocSize(void *ptr) { |
||||
#if defined(TH_DISABLE_HEAP_TRACKING) |
||||
return 0; |
||||
#elif defined(__unix) |
||||
return malloc_usable_size(ptr); |
||||
#elif defined(__APPLE__) |
||||
return malloc_size(ptr); |
||||
#elif defined(_WIN32) |
||||
return _msize(ptr); |
||||
#else |
||||
return 0; |
||||
#endif |
||||
} |
||||
|
||||
/* (1) if the torch-allocated heap size exceeds the soft max, run GC
|
||||
* (2) if post-GC heap size exceeds 80% of the soft max, increase the |
||||
* soft max by 40% |
||||
*/ |
||||
static void maybeTriggerGC() { |
||||
if(torchGCFunction && torchHeapSize > torchHeapSizeSoftMax) { |
||||
torchGCFunction(torchGCData); |
||||
if(torchHeapSize > torchHeapSizeSoftMax * 0.8) { |
||||
torchHeapSizeSoftMax = torchHeapSizeSoftMax * 1.4; |
||||
} |
||||
} |
||||
} |
||||
|
||||
// hooks into the TH heap tracking
|
||||
void THHeapUpdate(long size) { |
||||
torchHeapSize += size; |
||||
if (size > 0) |
||||
maybeTriggerGC(); |
||||
} |
||||
|
||||
static void* THAllocInternal(long size) |
||||
{ |
||||
void *ptr; |
||||
|
||||
if (size > 5120) |
||||
{ |
||||
#if (defined(__unix) || defined(__APPLE__)) && (!defined(DISABLE_POSIX_MEMALIGN)) |
||||
if (posix_memalign(&ptr, 64, size) != 0) |
||||
ptr = NULL; |
||||
/*
|
||||
#elif defined(_WIN32) |
||||
ptr = _aligned_malloc(size, 64); |
||||
*/ |
||||
#else |
||||
ptr = malloc(size); |
||||
#endif |
||||
} |
||||
else |
||||
{ |
||||
ptr = malloc(size); |
||||
} |
||||
|
||||
THHeapUpdate(getAllocSize(ptr)); |
||||
return ptr; |
||||
} |
||||
|
||||
void* THAlloc(long size) |
||||
{ |
||||
void *ptr; |
||||
|
||||
if(size < 0) |
||||
THError("$ Torch: invalid memory size -- maybe an overflow?"); |
||||
|
||||
if(size == 0) |
||||
return NULL; |
||||
|
||||
ptr = THAllocInternal(size); |
||||
|
||||
if(!ptr && torchGCFunction) { |
||||
torchGCFunction(torchGCData); |
||||
ptr = THAllocInternal(size); |
||||
} |
||||
|
||||
if(!ptr) |
||||
THError("$ Torch: not enough memory: you tried to allocate %dGB. Buy new RAM!", size/1073741824); |
||||
|
||||
return ptr; |
||||
} |
||||
|
||||
void* THRealloc(void *ptr, long size) |
||||
{ |
||||
if(!ptr) |
||||
return(THAlloc(size)); |
||||
|
||||
if(size == 0) |
||||
{ |
||||
THFree(ptr); |
||||
return NULL; |
||||
} |
||||
|
||||
if(size < 0) |
||||
THError("$ Torch: invalid memory size -- maybe an overflow?"); |
||||
|
||||
THHeapUpdate(-getAllocSize(ptr)); |
||||
void *newptr = realloc(ptr, size); |
||||
|
||||
if(!newptr && torchGCFunction) { |
||||
torchGCFunction(torchGCData); |
||||
newptr = realloc(ptr, size); |
||||
} |
||||
THHeapUpdate(getAllocSize(newptr ? newptr : ptr)); |
||||
|
||||
if(!newptr) |
||||
THError("$ Torch: not enough memory: you tried to reallocate %dGB. Buy new RAM!", size/1073741824); |
||||
|
||||
return newptr; |
||||
} |
||||
|
||||
void THFree(void *ptr) |
||||
{ |
||||
THHeapUpdate(-getAllocSize(ptr)); |
||||
free(ptr); |
||||
} |
||||
|
||||
double THLog1p(const double x) |
||||
{ |
||||
#if (defined(_MSC_VER) || defined(__MINGW32__)) |
||||
volatile double y = 1 + x; |
||||
return log(y) - ((y-1)-x)/y ; /* cancels errors with IEEE arithmetic */ |
||||
#else |
||||
return log1p(x); |
||||
#endif |
||||
} |
||||
|
||||
} |
@ -0,0 +1,89 @@ |
||||
#ifndef TH_GENERAL_INC |
||||
#define TH_GENERAL_INC |
||||
|
||||
#include <stdlib.h> |
||||
#include <stdio.h> |
||||
#include <stdarg.h> |
||||
#include <math.h> |
||||
#include <limits.h> |
||||
#include <float.h> |
||||
#include <time.h> |
||||
#include <string.h> |
||||
|
||||
#ifdef __cplusplus |
||||
# define TH_EXTERNC extern "C" |
||||
#else |
||||
# define TH_EXTERNC extern |
||||
#endif |
||||
|
||||
#define TH_API TH_EXTERNC |
||||
|
||||
#define THInf DBL_MAX |
||||
|
||||
//#define TH_INLINE @TH_INLINE@
|
||||
|
||||
#ifndef __cplusplus |
||||
//#define inline @TH_INLINE@
|
||||
#endif |
||||
|
||||
#ifndef M_PI |
||||
# define M_PI 3.14159265358979323846 |
||||
#endif |
||||
|
||||
TH_API double THLog1p(const double x); |
||||
TH_API void _THError(const char *file, const int line, const char *fmt, ...); |
||||
TH_API void _THAssertionFailed(const char *file, const int line, const char *exp, const char *fmt, ...); |
||||
TH_API void THSetErrorHandler( void (*torchErrorHandlerFunction)(const char *msg, void *data), void *data ); |
||||
TH_API void _THArgCheck(const char *file, int line, int condition, int argNumber, const char *fmt, ...); |
||||
TH_API void THSetArgErrorHandler( void (*torchArgErrorHandlerFunction)(int argNumber, const char *msg, void *data), void *data ); |
||||
TH_API void* THAlloc(long size); |
||||
TH_API void* THRealloc(void *ptr, long size); |
||||
TH_API void THFree(void *ptr); |
||||
TH_API void THSetGCHandler( void (*torchGCHandlerFunction)(void *data), void *data ); |
||||
// this hook should only be called by custom allocator functions
|
||||
TH_API void THHeapUpdate(long size); |
||||
|
||||
#define THError(...) _THError(__FILE__, __LINE__, __VA_ARGS__) |
||||
#define THArgCheck(...) _THArgCheck(__FILE__, __LINE__, __VA_ARGS__) |
||||
#define THAssert(exp) \ |
||||
do { \
|
||||
if (!(exp)) { \
|
||||
_THAssertionFailed(__FILE__, __LINE__, #exp, ""); \
|
||||
} \
|
||||
} while(0) |
||||
#define THAssertMsg(exp, ...) \ |
||||
do { \
|
||||
if (!(exp)) { \
|
||||
_THAssertionFailed(__FILE__, __LINE__, #exp, __VA_ARGS__); \
|
||||
} \
|
||||
} while(0) |
||||
|
||||
#define TH_CONCAT_STRING_2(x,y) TH_CONCAT_STRING_2_EXPAND(x,y) |
||||
#define TH_CONCAT_STRING_2_EXPAND(x,y) #x #y |
||||
|
||||
#define TH_CONCAT_STRING_3(x,y,z) TH_CONCAT_STRING_3_EXPAND(x,y,z) |
||||
#define TH_CONCAT_STRING_3_EXPAND(x,y,z) #x #y #z |
||||
|
||||
#define TH_CONCAT_STRING_4(x,y,z,w) TH_CONCAT_STRING_4_EXPAND(x,y,z,w) |
||||
#define TH_CONCAT_STRING_4_EXPAND(x,y,z,w) #x #y #z #w |
||||
|
||||
#define TH_CONCAT_2(x,y) TH_CONCAT_2_EXPAND(x,y) |
||||
#define TH_CONCAT_2_EXPAND(x,y) x ## y |
||||
|
||||
#define TH_CONCAT_3(x,y,z) TH_CONCAT_3_EXPAND(x,y,z) |
||||
#define TH_CONCAT_3_EXPAND(x,y,z) x ## y ## z |
||||
|
||||
#define TH_CONCAT_4_EXPAND(x,y,z,w) x ## y ## z ## w |
||||
#define TH_CONCAT_4(x,y,z,w) TH_CONCAT_4_EXPAND(x,y,z,w) |
||||
|
||||
#define THMin(X, Y) ((X) < (Y) ? (X) : (Y)) |
||||
#define THMax(X, Y) ((X) > (Y) ? (X) : (Y)) |
||||
|
||||
#if (defined(_MSC_VER) || defined(__MINGW32__)) |
||||
# define log1p(x) THLog1p(x) |
||||
#define snprintf _snprintf |
||||
#define popen _popen |
||||
#define pclose _pclose |
||||
#endif |
||||
|
||||
#endif |
@ -0,0 +1,317 @@ |
||||
#include "../precomp.hpp" |
||||
#include <limits> |
||||
#include <set> |
||||
#include <map> |
||||
#include <algorithm> |
||||
#include <iostream> |
||||
|
||||
namespace cv { |
||||
namespace dnn { |
||||
|
||||
#if ENABLE_TORCH_IMPORTER || 1 |
||||
#include "THDiskFile.h" |
||||
|
||||
enum LuaType |
||||
{ |
||||
TYPE_NIL = 0, |
||||
TYPE_NUMBER = 1, |
||||
TYPE_STRING = 2, |
||||
TYPE_TABLE = 3, |
||||
TYPE_TORCH = 4, |
||||
TYPE_BOOLEAN = 5, |
||||
TYPE_FUNCTION = 6, |
||||
TYPE_RECUR_FUNCTION = 8, |
||||
LEGACY_TYPE_RECUR_FUNCTION = 7 |
||||
}; |
||||
|
||||
template<typename T> |
||||
static String toString(const T &v) |
||||
{ |
||||
std::ostringstream ss; |
||||
ss << v; |
||||
return ss.str(); |
||||
} |
||||
|
||||
static inline bool startsWith(const String &str, const char *substr) |
||||
{ |
||||
return str.find(substr) == 0; |
||||
} |
||||
|
||||
static inline bool endsWith(const String &str, const char *substr) |
||||
{ |
||||
return str.rfind(substr) == str.length() - strlen(substr); |
||||
} |
||||
|
||||
|
||||
struct TorchImporter : public ::cv::dnn::Importer |
||||
{ |
||||
THFile *file; |
||||
std::set<int> readedIndexes; |
||||
std::map<int, Mat> storages; |
||||
|
||||
TorchImporter(String filename, bool isBinary) |
||||
{ |
||||
file = THDiskFile_new(filename.c_str(), "r", 0); |
||||
CV_Assert(file && THFile_isOpened(file)); |
||||
|
||||
if (isBinary) |
||||
THFile_binary(file); |
||||
else |
||||
THFile_ascii(file); |
||||
} |
||||
|
||||
/* Simple readers */ |
||||
|
||||
inline int readInt() |
||||
{ |
||||
return THFile_readIntScalar(file); |
||||
} |
||||
|
||||
inline long readLong() |
||||
{ |
||||
return THFile_readLongScalar(file); |
||||
} |
||||
|
||||
inline bool readBool() |
||||
{ |
||||
return readInt(); |
||||
} |
||||
|
||||
inline double readDouble() |
||||
{ |
||||
return THFile_readDoubleScalar(file); |
||||
} |
||||
|
||||
inline String readString() |
||||
{ |
||||
int size = THFile_readIntScalar(file); |
||||
String str(size, '\0'); |
||||
THFile_readCharRaw(file, const_cast<char*>(str.c_str()), size); |
||||
return str; |
||||
} |
||||
|
||||
inline String readTorchClassName() |
||||
{ |
||||
String version = readString(); |
||||
return startsWith(version, "V ") ? readString() : version; |
||||
} |
||||
|
||||
inline void readFunction() |
||||
{ |
||||
readString(); |
||||
readObject(true); |
||||
} |
||||
|
||||
void readTable() |
||||
{ |
||||
std::cout << "Skipping table\n"; |
||||
|
||||
int index = readInt(); |
||||
CV_Assert(readedIndexes.count(index) == 0); |
||||
readedIndexes.insert(index); |
||||
|
||||
int size = readInt(); |
||||
for (int i = 0; i < size; i++) |
||||
{ |
||||
readObject(true); //key
|
||||
readObject(true); //value
|
||||
} |
||||
} |
||||
|
||||
/* Special readers */ |
||||
|
||||
static inline int parseTorchType(const String &str, const char *suffix, const char *prefix = "torch.") |
||||
{ |
||||
if (startsWith(str, prefix) && endsWith(str, suffix)) |
||||
{ |
||||
String typeStr = str.substr(strlen(prefix), str.length() - strlen(prefix) - strlen(suffix)); |
||||
|
||||
if (typeStr == "Double") |
||||
return CV_64F; |
||||
else if (typeStr == "Float") |
||||
return CV_32F; |
||||
else if (typeStr == "Byte") |
||||
return CV_8U; |
||||
else if (typeStr == "Char") |
||||
return CV_8S; |
||||
else if (typeStr == "Short") |
||||
return CV_16S; |
||||
else if (typeStr == "Int") |
||||
return CV_32S; |
||||
else |
||||
CV_Error(Error::StsNotImplemented, "Unknown type \"" + typeStr + "\" of torch class \"" + str + "\""); |
||||
} |
||||
|
||||
return -1; |
||||
} |
||||
|
||||
static int parseTensorType(const String &className) |
||||
{ |
||||
return parseTorchType(className, "Tensor"); |
||||
} |
||||
|
||||
static int parseStorageType(const String &className) |
||||
{ |
||||
return parseTorchType(className, "Storage"); |
||||
} |
||||
|
||||
void readTorchStorage(int index, int type = -1) |
||||
{ |
||||
long size = readLong(); |
||||
Mat storageMat(1, size, type); |
||||
|
||||
THFile_readByteRaw(file, storageMat.data, size * CV_ELEM_SIZE(type)); |
||||
|
||||
storages.insert(std::make_pair(index, storageMat)); |
||||
readedIndexes.insert(index); |
||||
} |
||||
|
||||
Blob readTorchTensor(int typeTensor, bool skip = false) |
||||
{ |
||||
int ndims = readInt(); |
||||
|
||||
AutoBuffer<long, 4> sizes(ndims); |
||||
AutoBuffer<long, 4> steps(ndims); |
||||
THFile_readLongRaw(file, sizes, ndims); |
||||
THFile_readLongRaw(file, sizes, ndims); |
||||
|
||||
long offset = readLong() - 1; |
||||
|
||||
//read Storage
|
||||
int typeidx = readInt(); |
||||
std::cout << "stograge typeidx of tensor: " << typeidx << "\n"; |
||||
CV_Assert(typeidx == TYPE_TORCH || (typeidx == TYPE_NIL && ndims == 0)); |
||||
|
||||
if (typeidx == TYPE_NIL) |
||||
return Blob(); |
||||
|
||||
int index = readInt(); |
||||
if (readedIndexes.count(index) == 0) |
||||
{ |
||||
int typeStorage = parseStorageType(readTorchClassName()); |
||||
CV_Assert(typeStorage >= 0 && typeTensor == typeStorage); |
||||
readTorchStorage(typeStorage, index); |
||||
} |
||||
|
||||
//allocate Blob
|
||||
AutoBuffer<int, 4> isizes(ndims); |
||||
AutoBuffer<size_t, 4> ssteps(ndims); |
||||
|
||||
size_t stepExpected = 1; |
||||
for (int i = ndims - 1; i >= 0; i--) |
||||
{ |
||||
isizes[i] = (int)sizes[i]; |
||||
ssteps[i] = (size_t)steps[i] * CV_ELEM_SIZE(typeTensor); |
||||
|
||||
stepExpected *= sizes[i]; |
||||
} |
||||
|
||||
if (skip) |
||||
return Blob(); |
||||
|
||||
Mat srcMat(ndims, (int*)isizes, typeTensor , storages[index].ptr(), (size_t*)ssteps); |
||||
int dstType = (typeTensor == CV_64F) ? CV_64F : CV_32F; |
||||
|
||||
Blob blob; |
||||
blob.create(BlobShape(ndims, isizes), dstType); |
||||
srcMat.convertTo(blob.getMatRef(), dstType); |
||||
|
||||
return blob; |
||||
} |
||||
|
||||
void readTorchObject(int index, bool skip = false) |
||||
{ |
||||
String className = readTorchClassName(); |
||||
std::cout << "Class: " << className << std::endl; |
||||
|
||||
int type; |
||||
if ( (type = parseTensorType(className)) >= 0 ) //is Tensor
|
||||
{ |
||||
readTorchTensor(type); |
||||
return; |
||||
} |
||||
else if ( (type = parseStorageType(className)) >= 0 ) //is Storage
|
||||
{ |
||||
readTorchStorage(index, type); |
||||
} |
||||
else if (className == "nn.Sequential") |
||||
{ |
||||
readObject(); |
||||
} |
||||
else if (className == "nn.Concat") |
||||
{ |
||||
readObject(); |
||||
} |
||||
else if (className == "nn.SpatialConvolution") |
||||
{ |
||||
readObject(); |
||||
} |
||||
else if (className == "nn.ReLU") |
||||
{ |
||||
readObject(); |
||||
} |
||||
else |
||||
{ |
||||
CV_Error(Error::StsNotImplemented, "Unsupported Torch class \"" + className +"\""); |
||||
} |
||||
} |
||||
|
||||
void readObject(bool skip = false) |
||||
{ |
||||
int typeidx = readInt(); |
||||
std::cout << "typeidx: " << typeidx << "\n"; |
||||
|
||||
if (typeidx == TYPE_TORCH) |
||||
{ |
||||
int index = readInt(); |
||||
|
||||
if (readedIndexes.count(index) == 0) |
||||
{ |
||||
readTorchObject(index, skip); |
||||
readedIndexes.insert(index); |
||||
} |
||||
else |
||||
{ |
||||
//CV_Error(Error::StsNotImplemented, "");
|
||||
//TBD
|
||||
} |
||||
} |
||||
else if (typeidx == TYPE_NIL) |
||||
return; |
||||
else if (typeidx == TYPE_NUMBER) |
||||
readDouble(); |
||||
else if (typeidx == TYPE_BOOLEAN) |
||||
readBool(); |
||||
else if (typeidx == TYPE_STRING) |
||||
readString(); |
||||
else if (typeidx == TYPE_TABLE) |
||||
readTable(); |
||||
else |
||||
CV_Error(Error::StsNotImplemented, "Unsupported Lua type"); |
||||
} |
||||
|
||||
void populateNet(Net net) |
||||
{ |
||||
THFile_seek(file, 0); |
||||
readedIndexes.clear(); |
||||
|
||||
readObject(); |
||||
} |
||||
}; |
||||
|
||||
CV_EXPORTS Ptr<Importer> createTorchImporter(const String &filename, bool isBinary) |
||||
{ |
||||
return Ptr<Importer>(new TorchImporter(filename, isBinary)); |
||||
} |
||||
|
||||
#else //ENABLE_TORCH_IMPORTER
|
||||
|
||||
CV_EXPORTS Ptr<Importer> createTorchImporter(const String&, bool) |
||||
{ |
||||
CV_Error(Error::StsNotImplemented, "Module was build without Torch importer"); |
||||
return Ptr<Importer>(); |
||||
} |
||||
|
||||
#endif //ENABLE_TORCH_IMPORTER
|
||||
} |
||||
} |
@ -0,0 +1,35 @@ |
||||
#if 1 || defined(ENABLE_TORCH_IMPORTER) && ENABLE_TORCH_IMPORTER |
||||
#include "test_precomp.hpp" |
||||
|
||||
namespace cvtest |
||||
{ |
||||
|
||||
using namespace std; |
||||
using namespace testing; |
||||
using namespace cv; |
||||
using namespace cv::dnn; |
||||
|
||||
static std::string getOpenCVExtraDir() |
||||
{ |
||||
return cvtest::TS::ptr()->get_data_path(); |
||||
} |
||||
|
||||
template<typename TStr> |
||||
static std::string getTestFile(TStr filename) |
||||
{ |
||||
return (getOpenCVExtraDir() + "/dnn/") + filename; |
||||
} |
||||
|
||||
TEST(Torch_Importer, simple_read) |
||||
{ |
||||
Net net; |
||||
Ptr<Importer> importer; |
||||
|
||||
ASSERT_NO_THROW( importer = createTorchImporter("/home/vitaliy/th/conv1.txt", false) ); |
||||
ASSERT_TRUE( importer != NULL ); |
||||
importer->populateNet(net); |
||||
//ASSERT_NO_THROW( importer->populateNet(net) );
|
||||
} |
||||
|
||||
} |
||||
#endif |
Loading…
Reference in new issue