Implemented draft Torch importer

pull/265/head
Vitaliy Lyudvichenko 9 years ago
parent db4ff2172a
commit 4da1046d68
  1. 8
      modules/dnn/CMakeLists.txt
  2. 2
      modules/dnn/include/opencv2/dnn/dnn.hpp
  3. 36
      modules/dnn/src/torch/COPYRIGHT.txt
  4. 348
      modules/dnn/src/torch/File.lua
  5. 609
      modules/dnn/src/torch/THDiskFile.cpp
  6. 17
      modules/dnn/src/torch/THDiskFile.h
  7. 161
      modules/dnn/src/torch/THFile.cpp
  8. 87
      modules/dnn/src/torch/THFile.h
  9. 43
      modules/dnn/src/torch/THFilePrivate.h
  10. 254
      modules/dnn/src/torch/THGeneral.cpp
  11. 89
      modules/dnn/src/torch/THGeneral.h
  12. 317
      modules/dnn/src/torch/torch_importer.cpp
  13. 35
      modules/dnn/test/test_torch_importer.cpp

@ -42,8 +42,12 @@ if(BUILD_TESTS AND ${the_module}_DOWNLOAD_CAFFE_MODELS)
add_custom_command( TARGET opencv_test_${name} POST_BUILD
COMMAND ${PYTHON2_EXECUTABLE} download_model.py test_models.json
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/scripts )
else()
add_definitions(-DDISABLE_CAFFE_MODEL_TESTS=1)
add_definitions(-DENABLE_CAFFE_MODEL_TESTS=1)
endif()
OCV_OPTION(${the_module}_BUILD_TORCH_IMPORTER "Build Torch model importer" OFF)
if(${the_module}_BUILD_TORCH_IMPORTER)
add_definitions(-DENABLE_TORCH_IMPORTER=1)
endif()
else()#build as standalone module (for development purposes)

@ -90,6 +90,8 @@ namespace dnn
CV_EXPORTS Ptr<Importer> createCaffeImporter(const String &prototxt, const String &caffeModel = String());
CV_EXPORTS Ptr<Importer> createTorchImporter(const String &filename, bool isBinary = true);
//Layer factory allows to create instances of registered layers.
class CV_EXPORTS LayerRegister
{

@ -0,0 +1,36 @@
Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
Copyright (c) 2011-2013 NYU (Clement Farabet)
Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
3. Neither the names of Deepmind Technologies, NYU, NEC Laboratories America
and IDIAP Research Institute nor the names of its contributors may be
used to endorse or promote products derived from this software without
specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
POSSIBILITY OF SUCH DAMAGE.

@ -0,0 +1,348 @@
local File = torch.getmetatable('torch.File')
function File:writeBool(value)
if value then
self:writeInt(1)
else
self:writeInt(0)
end
end
function File:readBool()
return (self:readInt() == 1)
end
local TYPE_NIL = 0
local TYPE_NUMBER = 1
local TYPE_STRING = 2
local TYPE_TABLE = 3
local TYPE_TORCH = 4
local TYPE_BOOLEAN = 5
local TYPE_FUNCTION = 6
local TYPE_RECUR_FUNCTION = 8
local LEGACY_TYPE_RECUR_FUNCTION = 7
-- Lua 5.2 compatibility
local loadstring = loadstring or load
function File:isWritableObject(object)
local typename = type(object)
local typeidx
if type(object) ~= 'boolean' and not object then
typeidx = TYPE_NIL
elseif torch.typename(object) and torch.factory(torch.typename(object)) then
typeidx = TYPE_TORCH
elseif typename == 'table' then
typeidx = TYPE_TABLE
elseif typename == 'number' then
typeidx = TYPE_NUMBER
elseif typename == 'string' then
typeidx = TYPE_STRING
elseif typename == 'boolean' then
typeidx = TYPE_BOOLEAN
elseif typename == 'function' and pcall(string.dump, object) then
typeidx = TYPE_RECUR_FUNCTION
end
return typeidx
end
function File:referenced(ref)
-- we use an environment to keep a record of written objects
if not torch.getenv(self).writeObjects then
torch.setenv(self, {writeObjects={}, writeObjectsRef={}, readObjects={}})
end
local env = torch.getenv(self)
env.force = not ref
torch.setenv(self,env)
return self
end
function File:isReferenced()
-- if no environment, then no forcing setup yet
if not torch.getenv(self).writeObjects then
return true
end
local env = torch.getenv(self)
return not env.force
end
local function getmetamethod(obj, name)
local func
local status
-- check getmetatable(obj).__name or
-- check getmetatable(obj).name
status, func = pcall(
function()
-- note that sometimes the metatable is hidden
-- we get it for sure through the torch type system
local mt = torch.getmetatable(torch.typename(obj))
if mt then
return mt['__' .. name] or mt[name]
end
end
)
if status and type(func) == 'function' then
return func
end
end
function File:writeObject(object)
-- we use an environment to keep a record of written objects
if not torch.getenv(self).writeObjects then
torch.setenv(self, {writeObjects={}, writeObjectsRef={}, readObjects={}})
end
local force = torch.getenv(self).force
-- if nil object, only write the type and return
if type(object) ~= 'boolean' and not object then
self:writeInt(TYPE_NIL)
return
end
-- check the type we are dealing with
local typeidx = self:isWritableObject(object)
if not typeidx then
error(string.format('Unwritable object <%s>', type(object)))
end
self:writeInt(typeidx)
if typeidx == TYPE_NUMBER then
self:writeDouble(object)
elseif typeidx == TYPE_BOOLEAN then
self:writeBool(object)
elseif typeidx == TYPE_STRING then
local stringStorage = torch.CharStorage():string(object)
self:writeInt(#stringStorage)
self:writeChar(stringStorage)
elseif typeidx == TYPE_TORCH or typeidx == TYPE_TABLE or typeidx == TYPE_RECUR_FUNCTION then
-- check it exists already (we look at the pointer!)
local objects = torch.getenv(self).writeObjects
local objectsRef = torch.getenv(self).writeObjectsRef
local index = objects[torch.pointer(object)]
if index and (not force) then
-- if already exists, write only its index
self:writeInt(index)
else
-- else write the object itself
index = objects.nWriteObject or 0
index = index + 1
objects[torch.pointer(object)] = index
if not force then
objectsRef[object] = index -- we make sure the object is not going to disappear
end
self:writeInt(index)
objects.nWriteObject = index
if typeidx == TYPE_RECUR_FUNCTION then
local upvalues = {}
local counter = 0
while true do
counter = counter + 1
local name,value = debug.getupvalue(object, counter)
if not name then break end
if name == '_ENV' then value = nil end
table.insert(upvalues, {name=name, value=value})
end
local dumped = string.dump(object)
local stringStorage = torch.CharStorage():string(dumped)
self:writeInt(#stringStorage)
self:writeChar(stringStorage)
self:writeObject(upvalues)
elseif typeidx == TYPE_TORCH then
local version = torch.CharStorage():string('V ' .. torch.version(object))
local className = torch.CharStorage():string(torch.typename(object))
self:writeInt(#version)
self:writeChar(version)
self:writeInt(#className)
self:writeChar(className)
local write = getmetamethod(object, 'write')
if write then
write(object, self)
elseif type(object) == 'table' then
local var = {}
for k,v in pairs(object) do
if self:isWritableObject(v) then
var[k] = v
else
print(string.format('$ Warning: cannot write object field <%s>', k))
end
end
self:writeObject(var)
else
error(string.format('<%s> is a non-serializable Torch object', torch.typename(object)))
end
else -- it is a table
local size = 0; for k,v in pairs(object) do size = size + 1 end
self:writeInt(size)
for k,v in pairs(object) do
self:writeObject(k)
self:writeObject(v)
end
end
end
else
error('Unwritable object')
end
end
function File:readObject()
-- we use an environment to keep a record of read objects
if not torch.getenv(self).writeObjects then
torch.setenv(self, {writeObjects={}, writeObjectsRef={}, readObjects={}})
end
local force = torch.getenv(self).force
-- read the typeidx
local typeidx = self:readInt()
-- is it nil?
if typeidx == TYPE_NIL then
return nil
end
if typeidx == TYPE_NUMBER then
return self:readDouble()
elseif typeidx == TYPE_BOOLEAN then
return self:readBool()
elseif typeidx == TYPE_STRING then
local size = self:readInt()
return self:readChar(size):string()
elseif typeidx == TYPE_FUNCTION then
local size = self:readInt()
local dumped = self:readChar(size):string()
local func = loadstring(dumped)
local upvalues = self:readObject()
for index,upvalue in ipairs(upvalues) do
debug.setupvalue(func, index, upvalue)
end
return func
elseif typeidx == TYPE_TABLE or typeidx == TYPE_TORCH or typeidx == TYPE_RECUR_FUNCTION or typeidx == LEGACY_TYPE_RECUR_FUNCTION then
-- read the index
local index = self:readInt()
-- check it is loaded already
local objects = torch.getenv(self).readObjects
if objects[index] and not force then
return objects[index]
end
-- otherwise read it
if typeidx == TYPE_RECUR_FUNCTION or typeidx == LEGACY_TYPE_RECUR_FUNCTION then
local size = self:readInt()
local dumped = self:readChar(size):string()
local func = loadstring(dumped)
objects[index] = func
local upvalues = self:readObject()
for index,upvalue in ipairs(upvalues) do
if typeidx == LEGACY_TYPE_RECUR_FUNCTION then
debug.setupvalue(func, index, upvalue)
elseif upvalue.name == '_ENV' then
debug.setupvalue(func, index, _ENV)
else
debug.setupvalue(func, index, upvalue.value)
end
end
return func
elseif typeidx == TYPE_TORCH then
local version, className, versionNumber
version = self:readChar(self:readInt()):string()
versionNumber = tonumber(string.match(version, '^V (.*)$'))
if not versionNumber then
className = version
versionNumber = 0 -- file created before existence of versioning system
else
className = self:readChar(self:readInt()):string()
end
if not torch.factory(className) then
error(string.format('unknown Torch class <%s>', tostring(className)))
end
local object = torch.factory(className)(self)
objects[index] = object
local read = getmetamethod(object, 'read')
if read then
read(object, self, versionNumber)
elseif type(object) == 'table' then
local var = self:readObject()
for k,v in pairs(var) do
object[k] = v
end
else
error(string.format('Cannot load object class <%s>', tostring(className)))
end
return object
else -- it is a table
local size = self:readInt()
local object = {}
objects[index] = object
for i = 1,size do
local k = self:readObject()
local v = self:readObject()
object[k] = v
end
return object
end
else
error('unknown object')
end
end
-- simple helpers to save/load arbitrary objects/tables
function torch.save(filename, object, mode)
mode = mode or 'binary'
local file = torch.DiskFile(filename, 'w')
file[mode](file)
file:writeObject(object)
file:close()
end
function torch.load(filename, mode)
mode = mode or 'binary'
local file = torch.DiskFile(filename, 'r')
file[mode](file)
local object = file:readObject()
file:close()
return object
end
-- simple helpers to serialize/deserialize arbitrary objects/tables
function torch.serialize(object, mode)
local storage = torch.serializeToStorage(object, mode)
return storage:string()
end
-- Serialize to a CharStorage, not a lua string. This avoids
function torch.serializeToStorage(object, mode)
mode = mode or 'binary'
local f = torch.MemoryFile()
f = f[mode](f)
f:writeObject(object)
local storage = f:storage()
f:close()
return storage
end
function torch.deserializeFromStorage(storage, mode)
mode = mode or 'binary'
local tx = torch.CharTensor(storage)
local xp = torch.CharStorage(tx:size(1)+1)
local txp = torch.CharTensor(xp)
txp:narrow(1,1,tx:size(1)):copy(tx)
txp[tx:size(1)+1] = 0
local f = torch.MemoryFile(xp)
f = f[mode](f)
local object = f:readObject()
f:close()
return object
end
function torch.deserialize(str, mode)
local storage = torch.CharStorage():string(str)
return torch.deserializeFromStorage(storage, mode)
end
-- public API (saveobj/loadobj are safe for global import)
torch.saveobj = torch.save
torch.loadobj = torch.load

@ -0,0 +1,609 @@
#include "THGeneral.h"
#include "THDiskFile.h"
#include "THFilePrivate.h"
extern "C"
{
typedef struct THDiskFile__
{
THFile file;
FILE *handle;
char *name;
int isNativeEncoding;
} THDiskFile;
static int THDiskFile_isOpened(THFile *self)
{
THDiskFile *dfself = (THDiskFile*)self;
return (dfself->handle != NULL);
}
const char *THDiskFile_name(THFile *self)
{
THDiskFile *dfself = (THDiskFile*)self;
return dfself->name;
}
/* workaround mac osx lion ***insane*** fread bug */
#ifdef __APPLE__
size_t fread__(void *ptr, size_t size, size_t nitems, FILE *stream)
{
size_t nread = 0;
while(!feof(stream) && !ferror(stream) && (nread < nitems))
nread += fread((char*)ptr+nread*size, size, THMin(2147483648/size, nitems-nread), stream);
return nread;
}
#else
#define fread__ fread
#endif
#define READ_WRITE_METHODS(TYPE, TYPEC, ASCII_READ_ELEM, ASCII_WRITE_ELEM) \
static long THDiskFile_read##TYPEC(THFile *self, TYPE *data, long n) \
{ \
THDiskFile *dfself = (THDiskFile*)(self); \
long nread = 0L; \
\
THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file"); \
THArgCheck(dfself->file.isReadable, 1, "attempt to read in a write-only file"); \
\
if(dfself->file.isBinary) \
{ \
nread = fread__(data, sizeof(TYPE), n, dfself->handle); \
if(!dfself->isNativeEncoding && (sizeof(TYPE) > 1) && (nread > 0)) \
THDiskFile_reverseMemory(data, data, sizeof(TYPE), nread); \
} \
else \
{ \
long i; \
for(i = 0; i < n; i++) \
{ \
ASCII_READ_ELEM; /* increment here result and break if wrong */ \
} \
if(dfself->file.isAutoSpacing && (n > 0)) \
{ \
int c = fgetc(dfself->handle); \
if( (c != '\n') && (c != EOF) ) \
ungetc(c, dfself->handle); \
} \
} \
\
if(nread != n) \
{ \
dfself->file.hasError = 1; /* shouldn't we put hasError to 0 all the time ? */ \
if(!dfself->file.isQuiet) \
THError("read error: read %d blocks instead of %d", nread, n); \
} \
\
return nread; \
} \
\
static long THDiskFile_write##TYPEC(THFile *self, TYPE *data, long n) \
{ \
THDiskFile *dfself = (THDiskFile*)(self); \
long nwrite = 0L; \
\
THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file"); \
THArgCheck(dfself->file.isWritable, 1, "attempt to write in a read-only file"); \
\
if(dfself->file.isBinary) \
{ \
if(dfself->isNativeEncoding) \
{ \
nwrite = fwrite(data, sizeof(TYPE), n, dfself->handle); \
} \
else \
{ \
if(sizeof(TYPE) > 1) \
{ \
char *buffer = (char*)THAlloc(sizeof(TYPE)*n); \
THDiskFile_reverseMemory(buffer, data, sizeof(TYPE), n); \
nwrite = fwrite(buffer, sizeof(TYPE), n, dfself->handle); \
THFree(buffer); \
} \
else \
nwrite = fwrite(data, sizeof(TYPE), n, dfself->handle); \
} \
} \
else \
{ \
long i; \
for(i = 0; i < n; i++) \
{ \
ASCII_WRITE_ELEM; \
if( dfself->file.isAutoSpacing && (i < n-1) ) \
fprintf(dfself->handle, " "); \
} \
if(dfself->file.isAutoSpacing && (n > 0)) \
fprintf(dfself->handle, "\n"); \
} \
\
if(nwrite != n) \
{ \
dfself->file.hasError = 1; \
if(!dfself->file.isQuiet) \
THError("write error: wrote %d blocks instead of %d", nwrite, n); \
} \
\
return nwrite; \
}
static int THDiskFile_mode(const char *mode, int *isReadable, int *isWritable)
{
*isReadable = 0;
*isWritable = 0;
if(strlen(mode) == 1)
{
if(*mode == 'r')
{
*isReadable = 1;
return 1;
}
else if(*mode == 'w')
{
*isWritable = 1;
return 1;
}
}
else if(strlen(mode) == 2)
{
if(mode[0] == 'r' && mode[1] == 'w')
{
*isReadable = 1;
*isWritable = 1;
return 1;
}
}
return 0;
}
static void THDiskFile_synchronize(THFile *self)
{
THDiskFile *dfself = (THDiskFile*)(self);
THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file");
fflush(dfself->handle);
}
static void THDiskFile_seek(THFile *self, long position)
{
THDiskFile *dfself = (THDiskFile*)(self);
THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file");
THArgCheck(position >= 0, 2, "position must be positive");
if(fseek(dfself->handle, position, SEEK_SET) < 0)
{
dfself->file.hasError = 1;
if(!dfself->file.isQuiet)
THError("unable to seek at position %d", position);
}
}
static void THDiskFile_seekEnd(THFile *self)
{
THDiskFile *dfself = (THDiskFile*)(self);
THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file");
if(fseek(dfself->handle, 0L, SEEK_END) < 0)
{
dfself->file.hasError = 1;
if(!dfself->file.isQuiet)
THError("unable to seek at end of file");
}
}
static long THDiskFile_position(THFile *self)
{
THDiskFile *dfself = (THDiskFile*)(self);
THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file");
return ftell(dfself->handle);
}
static void THDiskFile_close(THFile *self)
{
THDiskFile *dfself = (THDiskFile*)(self);
THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file");
fclose(dfself->handle);
dfself->handle = NULL;
}
/* Little and Big Endian */
static void THDiskFile_reverseMemory(void *dst, const void *src, long blockSize, long numBlocks)
{
if(blockSize != 1)
{
long halfBlockSize = blockSize/2;
char *charSrc = (char*)src;
char *charDst = (char*)dst;
long b, i;
for(b = 0; b < numBlocks; b++)
{
for(i = 0; i < halfBlockSize; i++)
{
char z = charSrc[i];
charDst[i] = charSrc[blockSize-1-i];
charDst[blockSize-1-i] = z;
}
charSrc += blockSize;
charDst += blockSize;
}
}
}
int THDiskFile_isLittleEndianCPU(void)
{
int x = 7;
char *ptr = (char *)&x;
if(ptr[0] == 0)
return 0;
else
return 1;
}
int THDiskFile_isBigEndianCPU(void)
{
return(!THDiskFile_isLittleEndianCPU());
}
void THDiskFile_nativeEndianEncoding(THFile *self)
{
THDiskFile *dfself = (THDiskFile*)(self);
THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file");
dfself->isNativeEncoding = 1;
}
void THDiskFile_littleEndianEncoding(THFile *self)
{
THDiskFile *dfself = (THDiskFile*)(self);
THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file");
dfself->isNativeEncoding = THDiskFile_isLittleEndianCPU();
}
void THDiskFile_bigEndianEncoding(THFile *self)
{
THDiskFile *dfself = (THDiskFile*)(self);
THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file");
dfself->isNativeEncoding = !THDiskFile_isLittleEndianCPU();
}
/* End of Little and Big Endian Stuff */
static void THDiskFile_free(THFile *self)
{
THDiskFile *dfself = (THDiskFile*)(self);
if(dfself->handle)
fclose(dfself->handle);
THFree(dfself->name);
THFree(dfself);
}
/* READ_WRITE_METHODS(int, Bool, */
/* int value = 0; int ret = fscanf(file->handle, "%d", &value); array[i] = (value ? 1 : 0); if(ret <= 0) break; else result++, */
/* int value = (array[i] ? 1 : 0); nElemWritten = fprintf(file->handle, "%d", value), */
/* true) */
/* Note that we do a trick */
READ_WRITE_METHODS(unsigned char, Byte,
nread = fread(data, 1, n, dfself->handle); break,
nwrite = fwrite(data, 1, n, dfself->handle); break)
READ_WRITE_METHODS(char, Char,
nread = fread(data, 1, n, dfself->handle); break,
nwrite = fwrite(data, 1, n, dfself->handle); break)
READ_WRITE_METHODS(short, Short,
int ret = fscanf(dfself->handle, "%hd", &data[i]); if(ret <= 0) break; else nread++,
int ret = fprintf(dfself->handle, "%hd", data[i]); if(ret <= 0) break; else nwrite++)
READ_WRITE_METHODS(int, Int,
int ret = fscanf(dfself->handle, "%d", &data[i]); if(ret <= 0) break; else nread++,
int ret = fprintf(dfself->handle, "%d", data[i]); if(ret <= 0) break; else nwrite++)
READ_WRITE_METHODS(long, Long,
int ret = fscanf(dfself->handle, "%ld", &data[i]); if(ret <= 0) break; else nread++,
int ret = fprintf(dfself->handle, "%ld", data[i]); if(ret <= 0) break; else nwrite++)
READ_WRITE_METHODS(float, Float,
int ret = fscanf(dfself->handle, "%g", &data[i]); if(ret <= 0) break; else nread++,
int ret = fprintf(dfself->handle, "%.9g", data[i]); if(ret <= 0) break; else nwrite++)
READ_WRITE_METHODS(double, Double,
int ret = fscanf(dfself->handle, "%lg", &data[i]); if(ret <= 0) break; else nread++,
int ret = fprintf(dfself->handle, "%.17g", data[i]); if(ret <= 0) break; else nwrite++)
static long THDiskFile_readString(THFile *self, const char *format, char **str_)
{
THDiskFile *dfself = (THDiskFile*)(self);
THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file");
THArgCheck(dfself->file.isReadable, 1, "attempt to read in a write-only file");
THArgCheck((strlen(format) >= 2 ? (format[0] == '*') && (format[1] == 'a' || format[1] == 'l') : 0), 2, "format must be '*a' or '*l'");
/* note: the string won't survive long, as it is copied into lua */
/* so 1024 is not that big... */
#define TBRS_BSZ 1024L
if(format[1] == 'a')
{
char *p = (char*)THAlloc(TBRS_BSZ);
long total = TBRS_BSZ;
long pos = 0L;
for (;;)
{
if(total-pos == 0) /* we need more space! */
{
total += TBRS_BSZ;
p = (char*)THRealloc(p, total);
}
pos += fread(p+pos, 1, total-pos, dfself->handle);
if (pos < total) /* eof? */
{
if(pos == 0L)
{
THFree(p);
dfself->file.hasError = 1;
if(!dfself->file.isQuiet)
THError("read error: read 0 blocks instead of 1");
*str_ = NULL;
return 0;
}
*str_ = p;
return pos;
}
}
}
else
{
char *p = (char*)THAlloc(TBRS_BSZ);
long total = TBRS_BSZ;
long pos = 0L;
long size;
for (;;)
{
if(total-pos <= 1) /* we can only write '\0' in there! */
{
total += TBRS_BSZ;
p = (char*)THRealloc(p, total);
}
if (fgets(p+pos, total-pos, dfself->handle) == NULL) /* eof? */
{
if(pos == 0L)
{
THFree(p);
dfself->file.hasError = 1;
if(!dfself->file.isQuiet)
THError("read error: read 0 blocks instead of 1");
*str_ = NULL;
return 0;
}
*str_ = p;
return pos;
}
size = strlen(p+pos);
if (size == 0L || (p+pos)[size-1] != '\n')
{
pos += size;
}
else
{
pos += size-1L; /* do not include `eol' */
*str_ = p;
return pos;
}
}
}
*str_ = NULL;
return 0;
}
static long THDiskFile_writeString(THFile *self, const char *str, long size)
{
THDiskFile *dfself = (THDiskFile*)(self);
long nwrite;
THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file");
THArgCheck(dfself->file.isWritable, 1, "attempt to write in a read-only file");
nwrite = fwrite(str, 1, size, dfself->handle);
if(nwrite != size)
{
dfself->file.hasError = 1;
if(!dfself->file.isQuiet)
THError("write error: wrote %ld blocks instead of %ld", nwrite, size);
}
return nwrite;
}
THFile *THDiskFile_new(const char *name, const char *mode, int isQuiet)
{
static struct THFileVTable vtable = {
THDiskFile_isOpened,
THDiskFile_readByte,
THDiskFile_readChar,
THDiskFile_readShort,
THDiskFile_readInt,
THDiskFile_readLong,
THDiskFile_readFloat,
THDiskFile_readDouble,
THDiskFile_readString,
THDiskFile_writeByte,
THDiskFile_writeChar,
THDiskFile_writeShort,
THDiskFile_writeInt,
THDiskFile_writeLong,
THDiskFile_writeFloat,
THDiskFile_writeDouble,
THDiskFile_writeString,
THDiskFile_synchronize,
THDiskFile_seek,
THDiskFile_seekEnd,
THDiskFile_position,
THDiskFile_close,
THDiskFile_free
};
int isReadable;
int isWritable;
FILE *handle;
THDiskFile *self;
THArgCheck(THDiskFile_mode(mode, &isReadable, &isWritable), 2, "file mode should be 'r','w' or 'rw'");
if( isReadable && isWritable )
{
handle = fopen(name, "r+b");
if(!handle)
{
handle = fopen(name, "wb");
if(handle)
{
fclose(handle);
handle = fopen(name, "r+b");
}
}
}
else
handle = fopen(name, (isReadable ? "rb" : "wb"));
if(!handle)
{
if(isQuiet)
return 0;
else
THError("cannot open <%s> in mode %c%c", name, (isReadable ? 'r' : ' '), (isWritable ? 'w' : ' '));
}
self = (THDiskFile*)THAlloc(sizeof(THDiskFile));
self->handle = handle;
self->name = (char*)THAlloc(strlen(name)+1);
strcpy(self->name, name);
self->isNativeEncoding = 1;
self->file.vtable = &vtable;
self->file.isQuiet = isQuiet;
self->file.isReadable = isReadable;
self->file.isWritable = isWritable;
self->file.isBinary = 0;
self->file.isAutoSpacing = 1;
self->file.hasError = 0;
return (THFile*)self;
}
/* PipeFile */
static int THPipeFile_mode(const char *mode, int *isReadable, int *isWritable)
{
*isReadable = 0;
*isWritable = 0;
if(strlen(mode) == 1)
{
if(*mode == 'r')
{
*isReadable = 1;
return 1;
}
else if(*mode == 'w')
{
*isWritable = 1;
return 1;
}
}
return 0;
}
static void THPipeFile_free(THFile *self)
{
THDiskFile *dfself = (THDiskFile*)(self);
if(dfself->handle)
pclose(dfself->handle);
THFree(dfself->name);
THFree(dfself);
}
THFile *THPipeFile_new(const char *name, const char *mode, int isQuiet)
{
static struct THFileVTable vtable = {
THDiskFile_isOpened,
THDiskFile_readByte,
THDiskFile_readChar,
THDiskFile_readShort,
THDiskFile_readInt,
THDiskFile_readLong,
THDiskFile_readFloat,
THDiskFile_readDouble,
THDiskFile_readString,
THDiskFile_writeByte,
THDiskFile_writeChar,
THDiskFile_writeShort,
THDiskFile_writeInt,
THDiskFile_writeLong,
THDiskFile_writeFloat,
THDiskFile_writeDouble,
THDiskFile_writeString,
THDiskFile_synchronize,
THDiskFile_seek,
THDiskFile_seekEnd,
THDiskFile_position,
THDiskFile_close,
THPipeFile_free
};
int isReadable;
int isWritable;
FILE *handle;
THDiskFile *self;
THArgCheck(THPipeFile_mode(mode, &isReadable, &isWritable), 2, "file mode should be 'r','w'");
#ifdef _WIN32
handle = popen(name, (isReadable ? "rb" : "wb"));
#else
handle = popen(name, (isReadable ? "r" : "w"));
#endif
if(!handle)
{
if(isQuiet)
return 0;
else
THError("cannot open <%s> in mode %c%c", name, (isReadable ? 'r' : ' '), (isWritable ? 'w' : ' '));
}
self = (THDiskFile*)THAlloc(sizeof(THDiskFile));
self->handle = handle;
self->name = (char*)THAlloc(strlen(name)+1);
strcpy(self->name, name);
self->isNativeEncoding = 1;
self->file.vtable = &vtable;
self->file.isQuiet = isQuiet;
self->file.isReadable = isReadable;
self->file.isWritable = isWritable;
self->file.isBinary = 0;
self->file.isAutoSpacing = 1;
self->file.hasError = 0;
return (THFile*)self;
}
}

@ -0,0 +1,17 @@
#ifndef TH_DISK_FILE_INC
#define TH_DISK_FILE_INC
#include "THFile.h"
TH_API THFile *THDiskFile_new(const char *name, const char *mode, int isQuiet);
TH_API THFile *THPipeFile_new(const char *name, const char *mode, int isQuiet);
TH_API const char *THDiskFile_name(THFile *self);
TH_API int THDiskFile_isLittleEndianCPU(void);
TH_API int THDiskFile_isBigEndianCPU(void);
TH_API void THDiskFile_nativeEndianEncoding(THFile *self);
TH_API void THDiskFile_littleEndianEncoding(THFile *self);
TH_API void THDiskFile_bigEndianEncoding(THFile *self);
#endif

@ -0,0 +1,161 @@
#include "THFile.h"
#include "THFilePrivate.h"
extern "C"
{
#define IMPLEMENT_THFILE_RW(TYPEC, TYPE) \
long THFile_read##TYPEC##Raw(THFile *self, TYPE *data, long n) \
{ \
return (*self->vtable->read##TYPEC)(self, data, n); \
} \
\
long THFile_write##TYPEC##Raw(THFile *self, TYPE *data, long n) \
{ \
return (*self->vtable->write##TYPEC)(self, data, n); \
}
IMPLEMENT_THFILE_RW(Byte, unsigned char)
IMPLEMENT_THFILE_RW(Char, char)
IMPLEMENT_THFILE_RW(Short, short)
IMPLEMENT_THFILE_RW(Int, int)
IMPLEMENT_THFILE_RW(Long, long)
IMPLEMENT_THFILE_RW(Float, float)
IMPLEMENT_THFILE_RW(Double, double)
long THFile_readStringRaw(THFile *self, const char *format, char **str_)
{
return self->vtable->readString(self, format, str_);
}
long THFile_writeStringRaw(THFile *self, const char *str, long size)
{
return self->vtable->writeString(self, str, size);
}
void THFile_synchronize(THFile *self)
{
self->vtable->synchronize(self);
}
void THFile_seek(THFile *self, long position)
{
self->vtable->seek(self, position);
}
void THFile_seekEnd(THFile *self)
{
self->vtable->seekEnd(self);
}
long THFile_position(THFile *self)
{
return self->vtable->position(self);
}
void THFile_close(THFile *self)
{
self->vtable->close(self);
}
void THFile_free(THFile *self)
{
self->vtable->free(self);
}
int THFile_isOpened(THFile *self)
{
return self->vtable->isOpened(self);
}
#define IMPLEMENT_THFILE_FLAGS(FLAG) \
int THFile_##FLAG(THFile *self) \
{ \
return self->FLAG; \
}
IMPLEMENT_THFILE_FLAGS(isQuiet)
IMPLEMENT_THFILE_FLAGS(isReadable)
IMPLEMENT_THFILE_FLAGS(isWritable)
IMPLEMENT_THFILE_FLAGS(isBinary)
IMPLEMENT_THFILE_FLAGS(isAutoSpacing)
IMPLEMENT_THFILE_FLAGS(hasError)
void THFile_binary(THFile *self)
{
self->isBinary = 1;
}
void THFile_ascii(THFile *self)
{
self->isBinary = 0;
}
void THFile_autoSpacing(THFile *self)
{
self->isAutoSpacing = 1;
}
void THFile_noAutoSpacing(THFile *self)
{
self->isAutoSpacing = 0;
}
void THFile_quiet(THFile *self)
{
self->isQuiet = 1;
}
void THFile_pedantic(THFile *self)
{
self->isQuiet = 0;
}
void THFile_clearError(THFile *self)
{
self->hasError = 0;
}
#define IMPLEMENT_THFILE_SCALAR(TYPEC, TYPE) \
TYPE THFile_read##TYPEC##Scalar(THFile *self) \
{ \
TYPE scalar; \
THFile_read##TYPEC##Raw(self, &scalar, 1); \
return scalar; \
} \
\
void THFile_write##TYPEC##Scalar(THFile *self, TYPE scalar) \
{ \
THFile_write##TYPEC##Raw(self, &scalar, 1); \
}
IMPLEMENT_THFILE_SCALAR(Byte, unsigned char)
IMPLEMENT_THFILE_SCALAR(Char, char)
IMPLEMENT_THFILE_SCALAR(Short, short)
IMPLEMENT_THFILE_SCALAR(Int, int)
IMPLEMENT_THFILE_SCALAR(Long, long)
IMPLEMENT_THFILE_SCALAR(Float, float)
IMPLEMENT_THFILE_SCALAR(Double, double)
/*
#define IMPLEMENT_THFILE_STORAGE(TYPEC, TYPE) \
long THFile_read##TYPEC(THFile *self, TH##TYPEC##Storage *storage) \
{ \
return THFile_read##TYPEC##Raw(self, storage->data, storage->size); \
} \
\
long THFile_write##TYPEC(THFile *self, TH##TYPEC##Storage *storage) \
{ \
return THFile_write##TYPEC##Raw(self, storage->data, storage->size); \
}
IMPLEMENT_THFILE_STORAGE(Byte, unsigned char)
IMPLEMENT_THFILE_STORAGE(Char, char)
IMPLEMENT_THFILE_STORAGE(Short, short)
IMPLEMENT_THFILE_STORAGE(Int, int)
IMPLEMENT_THFILE_STORAGE(Long, long)
IMPLEMENT_THFILE_STORAGE(Float, float)
IMPLEMENT_THFILE_STORAGE(Double, double)
*/
}

@ -0,0 +1,87 @@
#ifndef TH_FILE_INC
#define TH_FILE_INC
//#include "THStorage.h"
#include "THGeneral.h"
typedef struct THFile__ THFile;
TH_API int THFile_isOpened(THFile *self);
TH_API int THFile_isQuiet(THFile *self);
TH_API int THFile_isReadable(THFile *self);
TH_API int THFile_isWritable(THFile *self);
TH_API int THFile_isBinary(THFile *self);
TH_API int THFile_isAutoSpacing(THFile *self);
TH_API int THFile_hasError(THFile *self);
TH_API void THFile_binary(THFile *self);
TH_API void THFile_ascii(THFile *self);
TH_API void THFile_autoSpacing(THFile *self);
TH_API void THFile_noAutoSpacing(THFile *self);
TH_API void THFile_quiet(THFile *self);
TH_API void THFile_pedantic(THFile *self);
TH_API void THFile_clearError(THFile *self);
/* scalar */
TH_API unsigned char THFile_readByteScalar(THFile *self);
TH_API char THFile_readCharScalar(THFile *self);
TH_API short THFile_readShortScalar(THFile *self);
TH_API int THFile_readIntScalar(THFile *self);
TH_API long THFile_readLongScalar(THFile *self);
TH_API float THFile_readFloatScalar(THFile *self);
TH_API double THFile_readDoubleScalar(THFile *self);
TH_API void THFile_writeByteScalar(THFile *self, unsigned char scalar);
TH_API void THFile_writeCharScalar(THFile *self, char scalar);
TH_API void THFile_writeShortScalar(THFile *self, short scalar);
TH_API void THFile_writeIntScalar(THFile *self, int scalar);
TH_API void THFile_writeLongScalar(THFile *self, long scalar);
TH_API void THFile_writeFloatScalar(THFile *self, float scalar);
TH_API void THFile_writeDoubleScalar(THFile *self, double scalar);
/* storage */
/*
TH_API long THFile_readByte(THFile *self, THByteStorage *storage);
TH_API long THFile_readChar(THFile *self, THCharStorage *storage);
TH_API long THFile_readShort(THFile *self, THShortStorage *storage);
TH_API long THFile_readInt(THFile *self, THIntStorage *storage);
TH_API long THFile_readLong(THFile *self, THLongStorage *storage);
TH_API long THFile_readFloat(THFile *self, THFloatStorage *storage);
TH_API long THFile_readDouble(THFile *self, THDoubleStorage *storage);
TH_API long THFile_writeByte(THFile *self, THByteStorage *storage);
TH_API long THFile_writeChar(THFile *self, THCharStorage *storage);
TH_API long THFile_writeShort(THFile *self, THShortStorage *storage);
TH_API long THFile_writeInt(THFile *self, THIntStorage *storage);
TH_API long THFile_writeLong(THFile *self, THLongStorage *storage);
TH_API long THFile_writeFloat(THFile *self, THFloatStorage *storage);
TH_API long THFile_writeDouble(THFile *self, THDoubleStorage *storage);
*/
/* raw */
TH_API long THFile_readByteRaw(THFile *self, unsigned char *data, long n);
TH_API long THFile_readCharRaw(THFile *self, char *data, long n);
TH_API long THFile_readShortRaw(THFile *self, short *data, long n);
TH_API long THFile_readIntRaw(THFile *self, int *data, long n);
TH_API long THFile_readLongRaw(THFile *self, long *data, long n);
TH_API long THFile_readFloatRaw(THFile *self, float *data, long n);
TH_API long THFile_readDoubleRaw(THFile *self, double *data, long n);
TH_API long THFile_readStringRaw(THFile *self, const char *format, char **str_); /* you must deallocate str_ */
TH_API long THFile_writeByteRaw(THFile *self, unsigned char *data, long n);
TH_API long THFile_writeCharRaw(THFile *self, char *data, long n);
TH_API long THFile_writeShortRaw(THFile *self, short *data, long n);
TH_API long THFile_writeIntRaw(THFile *self, int *data, long n);
TH_API long THFile_writeLongRaw(THFile *self, long *data, long n);
TH_API long THFile_writeFloatRaw(THFile *self, float *data, long n);
TH_API long THFile_writeDoubleRaw(THFile *self, double *data, long n);
TH_API long THFile_writeStringRaw(THFile *self, const char *str, long size);
TH_API void THFile_synchronize(THFile *self);
TH_API void THFile_seek(THFile *self, long position);
TH_API void THFile_seekEnd(THFile *self);
TH_API long THFile_position(THFile *self);
TH_API void THFile_close(THFile *self);
TH_API void THFile_free(THFile *self);
#endif

@ -0,0 +1,43 @@
struct THFile__
{
struct THFileVTable *vtable;
int isQuiet;
int isReadable;
int isWritable;
int isBinary;
int isAutoSpacing;
int hasError;
};
/* virtual table definition */
struct THFileVTable
{
int (*isOpened)(THFile *self);
long (*readByte)(THFile *self, unsigned char *data, long n);
long (*readChar)(THFile *self, char *data, long n);
long (*readShort)(THFile *self, short *data, long n);
long (*readInt)(THFile *self, int *data, long n);
long (*readLong)(THFile *self, long *data, long n);
long (*readFloat)(THFile *self, float *data, long n);
long (*readDouble)(THFile *self, double *data, long n);
long (*readString)(THFile *self, const char *format, char **str_);
long (*writeByte)(THFile *self, unsigned char *data, long n);
long (*writeChar)(THFile *self, char *data, long n);
long (*writeShort)(THFile *self, short *data, long n);
long (*writeInt)(THFile *self, int *data, long n);
long (*writeLong)(THFile *self, long *data, long n);
long (*writeFloat)(THFile *self, float *data, long n);
long (*writeDouble)(THFile *self, double *data, long n);
long (*writeString)(THFile *self, const char *str, long size);
void (*synchronize)(THFile *self);
void (*seek)(THFile *self, long position);
void (*seekEnd)(THFile *self);
long (*position)(THFile *self);
void (*close)(THFile *self);
void (*free)(THFile *self);
};

@ -0,0 +1,254 @@
#include "THGeneral.h"
extern "C"
{
#ifndef TH_HAVE_THREAD
#define __thread
#endif
#if defined(TH_DISABLE_HEAP_TRACKING)
#elif (defined(__unix) || defined(_WIN32))
#include <malloc.h>
#elif defined(__APPLE__)
#include <malloc/malloc.h>
#endif
/* Torch Error Handling */
static void defaultTorchErrorHandlerFunction(const char *msg, void *data)
{
printf("$ Error: %s\n", msg);
exit(-1);
}
static __thread void (*torchErrorHandlerFunction)(const char *msg, void *data) = defaultTorchErrorHandlerFunction;
static __thread void *torchErrorHandlerData;
void _THError(const char *file, const int line, const char *fmt, ...)
{
char msg[2048];
va_list args;
/* vasprintf not standard */
/* vsnprintf: how to handle if does not exists? */
va_start(args, fmt);
int n = vsnprintf(msg, 2048, fmt, args);
va_end(args);
if(n < 2048) {
snprintf(msg + n, 2048 - n, " at %s:%d", file, line);
}
(*torchErrorHandlerFunction)(msg, torchErrorHandlerData);
}
void _THAssertionFailed(const char *file, const int line, const char *exp, const char *fmt, ...) {
char msg[1024];
va_list args;
va_start(args, fmt);
vsnprintf(msg, 1024, fmt, args);
va_end(args);
_THError(file, line, "Assertion `%s' failed. %s", exp, msg);
}
void THSetErrorHandler( void (*torchErrorHandlerFunction_)(const char *msg, void *data), void *data )
{
if(torchErrorHandlerFunction_)
torchErrorHandlerFunction = torchErrorHandlerFunction_;
else
torchErrorHandlerFunction = defaultTorchErrorHandlerFunction;
torchErrorHandlerData = data;
}
/* Torch Arg Checking Handling */
static void defaultTorchArgErrorHandlerFunction(int argNumber, const char *msg, void *data)
{
if(msg)
printf("$ Invalid argument %d: %s\n", argNumber, msg);
else
printf("$ Invalid argument %d\n", argNumber);
exit(-1);
}
static __thread void (*torchArgErrorHandlerFunction)(int argNumber, const char *msg, void *data) = defaultTorchArgErrorHandlerFunction;
static __thread void *torchArgErrorHandlerData;
void _THArgCheck(const char *file, int line, int condition, int argNumber, const char *fmt, ...)
{
if(!condition) {
char msg[2048];
va_list args;
/* vasprintf not standard */
/* vsnprintf: how to handle if does not exists? */
va_start(args, fmt);
int n = vsnprintf(msg, 2048, fmt, args);
va_end(args);
if(n < 2048) {
snprintf(msg + n, 2048 - n, " at %s:%d", file, line);
}
(*torchArgErrorHandlerFunction)(argNumber, msg, torchArgErrorHandlerData);
}
}
void THSetArgErrorHandler( void (*torchArgErrorHandlerFunction_)(int argNumber, const char *msg, void *data), void *data )
{
if(torchArgErrorHandlerFunction_)
torchArgErrorHandlerFunction = torchArgErrorHandlerFunction_;
else
torchArgErrorHandlerFunction = defaultTorchArgErrorHandlerFunction;
torchArgErrorHandlerData = data;
}
static __thread void (*torchGCFunction)(void *data) = NULL;
static __thread void *torchGCData;
static __thread long torchHeapSize = 0;
static __thread long torchHeapSizeSoftMax = 300000000; // 300MB, adjusted upward dynamically
/* Optional hook for integrating with a garbage-collected frontend.
*
* If torch is running with a garbage-collected frontend (e.g. Lua),
* the GC isn't aware of TH-allocated memory so may not know when it
* needs to run. These hooks trigger the GC to run in two cases:
*
* (1) When a memory allocation (malloc, realloc, ...) fails
* (2) When the total TH-allocated memory hits a dynamically-adjusted
* soft maximum.
*/
void THSetGCHandler( void (*torchGCFunction_)(void *data), void *data )
{
torchGCFunction = torchGCFunction_;
torchGCData = data;
}
static long getAllocSize(void *ptr) {
#if defined(TH_DISABLE_HEAP_TRACKING)
return 0;
#elif defined(__unix)
return malloc_usable_size(ptr);
#elif defined(__APPLE__)
return malloc_size(ptr);
#elif defined(_WIN32)
return _msize(ptr);
#else
return 0;
#endif
}
/* (1) if the torch-allocated heap size exceeds the soft max, run GC
* (2) if post-GC heap size exceeds 80% of the soft max, increase the
* soft max by 40%
*/
static void maybeTriggerGC() {
if(torchGCFunction && torchHeapSize > torchHeapSizeSoftMax) {
torchGCFunction(torchGCData);
if(torchHeapSize > torchHeapSizeSoftMax * 0.8) {
torchHeapSizeSoftMax = torchHeapSizeSoftMax * 1.4;
}
}
}
// hooks into the TH heap tracking
void THHeapUpdate(long size) {
torchHeapSize += size;
if (size > 0)
maybeTriggerGC();
}
static void* THAllocInternal(long size)
{
void *ptr;
if (size > 5120)
{
#if (defined(__unix) || defined(__APPLE__)) && (!defined(DISABLE_POSIX_MEMALIGN))
if (posix_memalign(&ptr, 64, size) != 0)
ptr = NULL;
/*
#elif defined(_WIN32)
ptr = _aligned_malloc(size, 64);
*/
#else
ptr = malloc(size);
#endif
}
else
{
ptr = malloc(size);
}
THHeapUpdate(getAllocSize(ptr));
return ptr;
}
void* THAlloc(long size)
{
void *ptr;
if(size < 0)
THError("$ Torch: invalid memory size -- maybe an overflow?");
if(size == 0)
return NULL;
ptr = THAllocInternal(size);
if(!ptr && torchGCFunction) {
torchGCFunction(torchGCData);
ptr = THAllocInternal(size);
}
if(!ptr)
THError("$ Torch: not enough memory: you tried to allocate %dGB. Buy new RAM!", size/1073741824);
return ptr;
}
void* THRealloc(void *ptr, long size)
{
if(!ptr)
return(THAlloc(size));
if(size == 0)
{
THFree(ptr);
return NULL;
}
if(size < 0)
THError("$ Torch: invalid memory size -- maybe an overflow?");
THHeapUpdate(-getAllocSize(ptr));
void *newptr = realloc(ptr, size);
if(!newptr && torchGCFunction) {
torchGCFunction(torchGCData);
newptr = realloc(ptr, size);
}
THHeapUpdate(getAllocSize(newptr ? newptr : ptr));
if(!newptr)
THError("$ Torch: not enough memory: you tried to reallocate %dGB. Buy new RAM!", size/1073741824);
return newptr;
}
void THFree(void *ptr)
{
THHeapUpdate(-getAllocSize(ptr));
free(ptr);
}
double THLog1p(const double x)
{
#if (defined(_MSC_VER) || defined(__MINGW32__))
volatile double y = 1 + x;
return log(y) - ((y-1)-x)/y ; /* cancels errors with IEEE arithmetic */
#else
return log1p(x);
#endif
}
}

@ -0,0 +1,89 @@
#ifndef TH_GENERAL_INC
#define TH_GENERAL_INC
#include <stdlib.h>
#include <stdio.h>
#include <stdarg.h>
#include <math.h>
#include <limits.h>
#include <float.h>
#include <time.h>
#include <string.h>
#ifdef __cplusplus
# define TH_EXTERNC extern "C"
#else
# define TH_EXTERNC extern
#endif
#define TH_API TH_EXTERNC
#define THInf DBL_MAX
//#define TH_INLINE @TH_INLINE@
#ifndef __cplusplus
//#define inline @TH_INLINE@
#endif
#ifndef M_PI
# define M_PI 3.14159265358979323846
#endif
TH_API double THLog1p(const double x);
TH_API void _THError(const char *file, const int line, const char *fmt, ...);
TH_API void _THAssertionFailed(const char *file, const int line, const char *exp, const char *fmt, ...);
TH_API void THSetErrorHandler( void (*torchErrorHandlerFunction)(const char *msg, void *data), void *data );
TH_API void _THArgCheck(const char *file, int line, int condition, int argNumber, const char *fmt, ...);
TH_API void THSetArgErrorHandler( void (*torchArgErrorHandlerFunction)(int argNumber, const char *msg, void *data), void *data );
TH_API void* THAlloc(long size);
TH_API void* THRealloc(void *ptr, long size);
TH_API void THFree(void *ptr);
TH_API void THSetGCHandler( void (*torchGCHandlerFunction)(void *data), void *data );
// this hook should only be called by custom allocator functions
TH_API void THHeapUpdate(long size);
#define THError(...) _THError(__FILE__, __LINE__, __VA_ARGS__)
#define THArgCheck(...) _THArgCheck(__FILE__, __LINE__, __VA_ARGS__)
#define THAssert(exp) \
do { \
if (!(exp)) { \
_THAssertionFailed(__FILE__, __LINE__, #exp, ""); \
} \
} while(0)
#define THAssertMsg(exp, ...) \
do { \
if (!(exp)) { \
_THAssertionFailed(__FILE__, __LINE__, #exp, __VA_ARGS__); \
} \
} while(0)
#define TH_CONCAT_STRING_2(x,y) TH_CONCAT_STRING_2_EXPAND(x,y)
#define TH_CONCAT_STRING_2_EXPAND(x,y) #x #y
#define TH_CONCAT_STRING_3(x,y,z) TH_CONCAT_STRING_3_EXPAND(x,y,z)
#define TH_CONCAT_STRING_3_EXPAND(x,y,z) #x #y #z
#define TH_CONCAT_STRING_4(x,y,z,w) TH_CONCAT_STRING_4_EXPAND(x,y,z,w)
#define TH_CONCAT_STRING_4_EXPAND(x,y,z,w) #x #y #z #w
#define TH_CONCAT_2(x,y) TH_CONCAT_2_EXPAND(x,y)
#define TH_CONCAT_2_EXPAND(x,y) x ## y
#define TH_CONCAT_3(x,y,z) TH_CONCAT_3_EXPAND(x,y,z)
#define TH_CONCAT_3_EXPAND(x,y,z) x ## y ## z
#define TH_CONCAT_4_EXPAND(x,y,z,w) x ## y ## z ## w
#define TH_CONCAT_4(x,y,z,w) TH_CONCAT_4_EXPAND(x,y,z,w)
#define THMin(X, Y) ((X) < (Y) ? (X) : (Y))
#define THMax(X, Y) ((X) > (Y) ? (X) : (Y))
#if (defined(_MSC_VER) || defined(__MINGW32__))
# define log1p(x) THLog1p(x)
#define snprintf _snprintf
#define popen _popen
#define pclose _pclose
#endif
#endif

@ -0,0 +1,317 @@
#include "../precomp.hpp"
#include <limits>
#include <set>
#include <map>
#include <algorithm>
#include <iostream>
namespace cv {
namespace dnn {
#if ENABLE_TORCH_IMPORTER || 1
#include "THDiskFile.h"
enum LuaType
{
TYPE_NIL = 0,
TYPE_NUMBER = 1,
TYPE_STRING = 2,
TYPE_TABLE = 3,
TYPE_TORCH = 4,
TYPE_BOOLEAN = 5,
TYPE_FUNCTION = 6,
TYPE_RECUR_FUNCTION = 8,
LEGACY_TYPE_RECUR_FUNCTION = 7
};
template<typename T>
static String toString(const T &v)
{
std::ostringstream ss;
ss << v;
return ss.str();
}
static inline bool startsWith(const String &str, const char *substr)
{
return str.find(substr) == 0;
}
static inline bool endsWith(const String &str, const char *substr)
{
return str.rfind(substr) == str.length() - strlen(substr);
}
struct TorchImporter : public ::cv::dnn::Importer
{
THFile *file;
std::set<int> readedIndexes;
std::map<int, Mat> storages;
TorchImporter(String filename, bool isBinary)
{
file = THDiskFile_new(filename.c_str(), "r", 0);
CV_Assert(file && THFile_isOpened(file));
if (isBinary)
THFile_binary(file);
else
THFile_ascii(file);
}
/* Simple readers */
inline int readInt()
{
return THFile_readIntScalar(file);
}
inline long readLong()
{
return THFile_readLongScalar(file);
}
inline bool readBool()
{
return readInt();
}
inline double readDouble()
{
return THFile_readDoubleScalar(file);
}
inline String readString()
{
int size = THFile_readIntScalar(file);
String str(size, '\0');
THFile_readCharRaw(file, const_cast<char*>(str.c_str()), size);
return str;
}
inline String readTorchClassName()
{
String version = readString();
return startsWith(version, "V ") ? readString() : version;
}
inline void readFunction()
{
readString();
readObject(true);
}
void readTable()
{
std::cout << "Skipping table\n";
int index = readInt();
CV_Assert(readedIndexes.count(index) == 0);
readedIndexes.insert(index);
int size = readInt();
for (int i = 0; i < size; i++)
{
readObject(true); //key
readObject(true); //value
}
}
/* Special readers */
static inline int parseTorchType(const String &str, const char *suffix, const char *prefix = "torch.")
{
if (startsWith(str, prefix) && endsWith(str, suffix))
{
String typeStr = str.substr(strlen(prefix), str.length() - strlen(prefix) - strlen(suffix));
if (typeStr == "Double")
return CV_64F;
else if (typeStr == "Float")
return CV_32F;
else if (typeStr == "Byte")
return CV_8U;
else if (typeStr == "Char")
return CV_8S;
else if (typeStr == "Short")
return CV_16S;
else if (typeStr == "Int")
return CV_32S;
else
CV_Error(Error::StsNotImplemented, "Unknown type \"" + typeStr + "\" of torch class \"" + str + "\"");
}
return -1;
}
static int parseTensorType(const String &className)
{
return parseTorchType(className, "Tensor");
}
static int parseStorageType(const String &className)
{
return parseTorchType(className, "Storage");
}
void readTorchStorage(int index, int type = -1)
{
long size = readLong();
Mat storageMat(1, size, type);
THFile_readByteRaw(file, storageMat.data, size * CV_ELEM_SIZE(type));
storages.insert(std::make_pair(index, storageMat));
readedIndexes.insert(index);
}
Blob readTorchTensor(int typeTensor, bool skip = false)
{
int ndims = readInt();
AutoBuffer<long, 4> sizes(ndims);
AutoBuffer<long, 4> steps(ndims);
THFile_readLongRaw(file, sizes, ndims);
THFile_readLongRaw(file, sizes, ndims);
long offset = readLong() - 1;
//read Storage
int typeidx = readInt();
std::cout << "stograge typeidx of tensor: " << typeidx << "\n";
CV_Assert(typeidx == TYPE_TORCH || (typeidx == TYPE_NIL && ndims == 0));
if (typeidx == TYPE_NIL)
return Blob();
int index = readInt();
if (readedIndexes.count(index) == 0)
{
int typeStorage = parseStorageType(readTorchClassName());
CV_Assert(typeStorage >= 0 && typeTensor == typeStorage);
readTorchStorage(typeStorage, index);
}
//allocate Blob
AutoBuffer<int, 4> isizes(ndims);
AutoBuffer<size_t, 4> ssteps(ndims);
size_t stepExpected = 1;
for (int i = ndims - 1; i >= 0; i--)
{
isizes[i] = (int)sizes[i];
ssteps[i] = (size_t)steps[i] * CV_ELEM_SIZE(typeTensor);
stepExpected *= sizes[i];
}
if (skip)
return Blob();
Mat srcMat(ndims, (int*)isizes, typeTensor , storages[index].ptr(), (size_t*)ssteps);
int dstType = (typeTensor == CV_64F) ? CV_64F : CV_32F;
Blob blob;
blob.create(BlobShape(ndims, isizes), dstType);
srcMat.convertTo(blob.getMatRef(), dstType);
return blob;
}
void readTorchObject(int index, bool skip = false)
{
String className = readTorchClassName();
std::cout << "Class: " << className << std::endl;
int type;
if ( (type = parseTensorType(className)) >= 0 ) //is Tensor
{
readTorchTensor(type);
return;
}
else if ( (type = parseStorageType(className)) >= 0 ) //is Storage
{
readTorchStorage(index, type);
}
else if (className == "nn.Sequential")
{
readObject();
}
else if (className == "nn.Concat")
{
readObject();
}
else if (className == "nn.SpatialConvolution")
{
readObject();
}
else if (className == "nn.ReLU")
{
readObject();
}
else
{
CV_Error(Error::StsNotImplemented, "Unsupported Torch class \"" + className +"\"");
}
}
void readObject(bool skip = false)
{
int typeidx = readInt();
std::cout << "typeidx: " << typeidx << "\n";
if (typeidx == TYPE_TORCH)
{
int index = readInt();
if (readedIndexes.count(index) == 0)
{
readTorchObject(index, skip);
readedIndexes.insert(index);
}
else
{
//CV_Error(Error::StsNotImplemented, "");
//TBD
}
}
else if (typeidx == TYPE_NIL)
return;
else if (typeidx == TYPE_NUMBER)
readDouble();
else if (typeidx == TYPE_BOOLEAN)
readBool();
else if (typeidx == TYPE_STRING)
readString();
else if (typeidx == TYPE_TABLE)
readTable();
else
CV_Error(Error::StsNotImplemented, "Unsupported Lua type");
}
void populateNet(Net net)
{
THFile_seek(file, 0);
readedIndexes.clear();
readObject();
}
};
CV_EXPORTS Ptr<Importer> createTorchImporter(const String &filename, bool isBinary)
{
return Ptr<Importer>(new TorchImporter(filename, isBinary));
}
#else //ENABLE_TORCH_IMPORTER
CV_EXPORTS Ptr<Importer> createTorchImporter(const String&, bool)
{
CV_Error(Error::StsNotImplemented, "Module was build without Torch importer");
return Ptr<Importer>();
}
#endif //ENABLE_TORCH_IMPORTER
}
}

@ -0,0 +1,35 @@
#if 1 || defined(ENABLE_TORCH_IMPORTER) && ENABLE_TORCH_IMPORTER
#include "test_precomp.hpp"
namespace cvtest
{
using namespace std;
using namespace testing;
using namespace cv;
using namespace cv::dnn;
static std::string getOpenCVExtraDir()
{
return cvtest::TS::ptr()->get_data_path();
}
template<typename TStr>
static std::string getTestFile(TStr filename)
{
return (getOpenCVExtraDir() + "/dnn/") + filename;
}
TEST(Torch_Importer, simple_read)
{
Net net;
Ptr<Importer> importer;
ASSERT_NO_THROW( importer = createTorchImporter("/home/vitaliy/th/conv1.txt", false) );
ASSERT_TRUE( importer != NULL );
importer->populateNet(net);
//ASSERT_NO_THROW( importer->populateNet(net) );
}
}
#endif
Loading…
Cancel
Save