|
|
@ -69,7 +69,7 @@ static void tanh(const Mat &src, Mat &dst) |
|
|
|
else if (src.type() == CV_64F) |
|
|
|
else if (src.type() == CV_64F) |
|
|
|
tanh<double>(src, dst); |
|
|
|
tanh<double>(src, dst); |
|
|
|
else |
|
|
|
else |
|
|
|
CV_Error(Error::StsUnsupportedFormat, "Functions supports only floating point types"); |
|
|
|
CV_Error(Error::StsUnsupportedFormat, "Function supports only floating point types"); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
static void sigmoid(const Mat &src, Mat &dst) |
|
|
|
static void sigmoid(const Mat &src, Mat &dst) |
|
|
@ -86,6 +86,10 @@ class LSTMLayerImpl : public LSTMLayer |
|
|
|
int dtype; |
|
|
|
int dtype; |
|
|
|
bool allocated; |
|
|
|
bool allocated; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
BlobShape outTailShape; //shape of single output sample
|
|
|
|
|
|
|
|
BlobShape outTsMatShape, outTsShape; //shape of N output samples
|
|
|
|
|
|
|
|
BlobShape outResShape; //shape of T timestamps and N output samples
|
|
|
|
|
|
|
|
|
|
|
|
bool useTimestampDim; |
|
|
|
bool useTimestampDim; |
|
|
|
bool produceCellOutput; |
|
|
|
bool produceCellOutput; |
|
|
|
|
|
|
|
|
|
|
@ -97,6 +101,7 @@ public: |
|
|
|
useTimestampDim = true; |
|
|
|
useTimestampDim = true; |
|
|
|
produceCellOutput = false; |
|
|
|
produceCellOutput = false; |
|
|
|
allocated = false; |
|
|
|
allocated = false; |
|
|
|
|
|
|
|
outTailShape = BlobShape::empty(); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
void setUseTimstampsDim(bool use) |
|
|
|
void setUseTimstampsDim(bool use) |
|
|
@ -113,14 +118,20 @@ public: |
|
|
|
|
|
|
|
|
|
|
|
void setC(const Blob &C) |
|
|
|
void setC(const Blob &C) |
|
|
|
{ |
|
|
|
{ |
|
|
|
CV_Assert(!allocated || C.total() == cInternal.total()); |
|
|
|
CV_Assert(cInternal.empty() || C.total() == cInternal.total()); |
|
|
|
C.matRefConst().copyTo(cInternal); |
|
|
|
if (!cInternal.empty()) |
|
|
|
|
|
|
|
C.reshaped(BlobShape::like(cInternal)).matRefConst().copyTo(cInternal); |
|
|
|
|
|
|
|
else |
|
|
|
|
|
|
|
C.matRefConst().copyTo(cInternal); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
void setH(const Blob &H) |
|
|
|
void setH(const Blob &H) |
|
|
|
{ |
|
|
|
{ |
|
|
|
CV_Assert(!allocated || H.total() == hInternal.total()); |
|
|
|
CV_Assert(hInternal.empty() || H.total() == hInternal.total()); |
|
|
|
H.matRefConst().copyTo(hInternal); |
|
|
|
if (!hInternal.empty()) |
|
|
|
|
|
|
|
H.reshaped(BlobShape::like(hInternal)).matRefConst().copyTo(hInternal); |
|
|
|
|
|
|
|
else |
|
|
|
|
|
|
|
H.matRefConst().copyTo(hInternal); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
Blob getC() const |
|
|
|
Blob getC() const |
|
|
@ -128,8 +139,8 @@ public: |
|
|
|
CV_Assert(!cInternal.empty()); |
|
|
|
CV_Assert(!cInternal.empty()); |
|
|
|
|
|
|
|
|
|
|
|
//TODO: add convinient Mat -> Blob constructor
|
|
|
|
//TODO: add convinient Mat -> Blob constructor
|
|
|
|
Blob res; |
|
|
|
Blob res(outTsShape, cInternal.type()); |
|
|
|
res.fill(BlobShape::like(cInternal), cInternal.type(), cInternal.data); |
|
|
|
res.fill(res.shape(), res.type(), cInternal.data); |
|
|
|
return res; |
|
|
|
return res; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
@ -137,11 +148,17 @@ public: |
|
|
|
{ |
|
|
|
{ |
|
|
|
CV_Assert(!hInternal.empty()); |
|
|
|
CV_Assert(!hInternal.empty()); |
|
|
|
|
|
|
|
|
|
|
|
Blob res; |
|
|
|
Blob res(outTsShape, hInternal.type()); |
|
|
|
res.fill(BlobShape::like(hInternal), hInternal.type(), hInternal.data); |
|
|
|
res.fill(res.shape(), res.type(), hInternal.data); |
|
|
|
return res; |
|
|
|
return res; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void setOutShape(const BlobShape &outTailShape_) |
|
|
|
|
|
|
|
{ |
|
|
|
|
|
|
|
CV_Assert(!allocated || outTailShape_.total() == outTailShape.total()); |
|
|
|
|
|
|
|
outTailShape = outTailShape_; |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
void setWeights(const Blob &Wh, const Blob &Wx, const Blob &bias) |
|
|
|
void setWeights(const Blob &Wh, const Blob &Wx, const Blob &bias) |
|
|
|
{ |
|
|
|
{ |
|
|
|
CV_Assert(Wh.dims() == 2 && Wx.dims() == 2); |
|
|
|
CV_Assert(Wh.dims() == 2 && Wx.dims() == 2); |
|
|
@ -160,31 +177,64 @@ public: |
|
|
|
void allocate(const std::vector<Blob*> &input, std::vector<Blob> &output) |
|
|
|
void allocate(const std::vector<Blob*> &input, std::vector<Blob> &output) |
|
|
|
{ |
|
|
|
{ |
|
|
|
CV_Assert(blobs.size() == 3); |
|
|
|
CV_Assert(blobs.size() == 3); |
|
|
|
Blob &Wh = blobs[0], &Wx = blobs[1]; |
|
|
|
CV_Assert(input.size() == 1); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Blob &Wh = blobs[0], &Wx = blobs[1]; |
|
|
|
numOut = Wh.size(1); |
|
|
|
numOut = Wh.size(1); |
|
|
|
numInp = Wx.size(1); |
|
|
|
numInp = Wx.size(1); |
|
|
|
|
|
|
|
|
|
|
|
CV_Assert(input.size() == 1); |
|
|
|
if (!outTailShape.isEmpty()) |
|
|
|
CV_Assert(input[0]->dims() > 2 && (int)input[0]->total(2) == numInp); |
|
|
|
CV_Assert(outTailShape.total() == numOut); |
|
|
|
|
|
|
|
else |
|
|
|
|
|
|
|
outTailShape = BlobShape(numOut); |
|
|
|
|
|
|
|
|
|
|
|
numTimeStamps = input[0]->size(0); |
|
|
|
if (useTimestampDim) |
|
|
|
numSamples = input[0]->size(1); |
|
|
|
{ |
|
|
|
dtype = input[0]->type(); |
|
|
|
CV_Assert(input[0]->dims() >= 2 && (int)input[0]->total(2) == numInp); |
|
|
|
|
|
|
|
numTimeStamps = input[0]->size(0); |
|
|
|
|
|
|
|
numSamples = input[0]->size(1); |
|
|
|
|
|
|
|
outResShape = BlobShape(numTimeStamps, numSamples) + outTailShape; |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
else |
|
|
|
|
|
|
|
{ |
|
|
|
|
|
|
|
CV_Assert(input[0]->dims() >= 1 && (int)input[0]->total(1) == numInp); |
|
|
|
|
|
|
|
numTimeStamps = 1; |
|
|
|
|
|
|
|
numSamples = input[0]->size(0); |
|
|
|
|
|
|
|
outResShape = BlobShape(numSamples) + outTailShape; |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
outTsMatShape = BlobShape(numSamples, numOut); |
|
|
|
|
|
|
|
outTsShape = BlobShape(numSamples) + outTailShape; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dtype = input[0]->type(); |
|
|
|
CV_Assert(dtype == CV_32F || dtype == CV_64F); |
|
|
|
CV_Assert(dtype == CV_32F || dtype == CV_64F); |
|
|
|
CV_Assert(Wh.type() == dtype); |
|
|
|
CV_Assert(Wh.type() == dtype); |
|
|
|
|
|
|
|
|
|
|
|
BlobShape outShape(numTimeStamps, numSamples, numOut); |
|
|
|
output.resize( (produceCellOutput) ? 2 : 1 ); |
|
|
|
output.resize(2); |
|
|
|
output[0].create(outResShape, dtype); |
|
|
|
output[0].create(outShape, dtype); |
|
|
|
if (produceCellOutput) |
|
|
|
output[1].create(outShape, dtype); |
|
|
|
output[1].create(outResShape, dtype); |
|
|
|
|
|
|
|
|
|
|
|
hInternal.create(numSamples, numOut, dtype); |
|
|
|
if (hInternal.empty()) |
|
|
|
hInternal.setTo(0); |
|
|
|
{ |
|
|
|
|
|
|
|
hInternal.create(outTsMatShape.dims(), outTsMatShape.ptr(), dtype); |
|
|
|
|
|
|
|
hInternal.setTo(0); |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
else |
|
|
|
|
|
|
|
{ |
|
|
|
|
|
|
|
CV_Assert((int)hInternal.total() == numSamples*numOut); |
|
|
|
|
|
|
|
hInternal = hInternal.reshape(1, outTsMatShape.dims(), outTsMatShape.ptr()); |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
cInternal.create(numSamples, numOut, dtype); |
|
|
|
if (cInternal.empty()) |
|
|
|
cInternal.setTo(0); |
|
|
|
{ |
|
|
|
|
|
|
|
cInternal.create(outTsMatShape.dims(), outTsMatShape.ptr(), dtype); |
|
|
|
|
|
|
|
cInternal.setTo(0); |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
else |
|
|
|
|
|
|
|
{ |
|
|
|
|
|
|
|
CV_Assert((int)cInternal.total() == numSamples*numOut); |
|
|
|
|
|
|
|
cInternal = cInternal.reshape(1, outTsMatShape.dims(), outTsMatShape.ptr()); |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
gates.create(numSamples, 4*numOut, dtype); |
|
|
|
gates.create(numSamples, 4*numOut, dtype); |
|
|
|
|
|
|
|
|
|
|
@ -252,6 +302,22 @@ void LSTMLayer::forward(std::vector<Blob*>&, std::vector<Blob>&) |
|
|
|
CV_Error(Error::StsInternal, "This function should be unreached"); |
|
|
|
CV_Error(Error::StsInternal, "This function should be unreached"); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
int LSTMLayer::inputNameToIndex(String inputName) |
|
|
|
|
|
|
|
{ |
|
|
|
|
|
|
|
if (inputName.toLowerCase() == "x") |
|
|
|
|
|
|
|
return 0; |
|
|
|
|
|
|
|
return -1; |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
int LSTMLayer::outputNameToIndex(String outputName) |
|
|
|
|
|
|
|
{ |
|
|
|
|
|
|
|
if (outputName.toLowerCase() == "h") |
|
|
|
|
|
|
|
return 0; |
|
|
|
|
|
|
|
else if (outputName.toLowerCase() == "c") |
|
|
|
|
|
|
|
return 1; |
|
|
|
|
|
|
|
return -1; |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RNNLayerImpl : public RNNLayer |
|
|
|
class RNNLayerImpl : public RNNLayer |
|
|
|
{ |
|
|
|
{ |
|
|
|