|
|
@ -93,6 +93,7 @@ class LSTMLayerImpl CV_FINAL : public LSTMLayer |
|
|
|
float forgetBias, cellClip; |
|
|
|
float forgetBias, cellClip; |
|
|
|
bool useCellClip, usePeephole; |
|
|
|
bool useCellClip, usePeephole; |
|
|
|
bool reverse; // If true, go in negative direction along the time axis
|
|
|
|
bool reverse; // If true, go in negative direction along the time axis
|
|
|
|
|
|
|
|
bool bidirectional; // If true, produces both forward and reversed directions along time axis
|
|
|
|
|
|
|
|
|
|
|
|
public: |
|
|
|
public: |
|
|
|
|
|
|
|
|
|
|
@ -101,6 +102,7 @@ public: |
|
|
|
{ |
|
|
|
{ |
|
|
|
setParamsFrom(params); |
|
|
|
setParamsFrom(params); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bidirectional = params.get<bool>("bidirectional", false); |
|
|
|
if (!blobs.empty()) |
|
|
|
if (!blobs.empty()) |
|
|
|
{ |
|
|
|
{ |
|
|
|
CV_Assert(blobs.size() >= 3); |
|
|
|
CV_Assert(blobs.size() >= 3); |
|
|
@ -113,7 +115,7 @@ public: |
|
|
|
CV_CheckEQ(Wh.dims, 2, ""); |
|
|
|
CV_CheckEQ(Wh.dims, 2, ""); |
|
|
|
CV_CheckEQ(Wx.dims, 2, ""); |
|
|
|
CV_CheckEQ(Wx.dims, 2, ""); |
|
|
|
CV_CheckEQ(Wh.rows, Wx.rows, ""); |
|
|
|
CV_CheckEQ(Wh.rows, Wx.rows, ""); |
|
|
|
CV_CheckEQ(Wh.rows, 4*Wh.cols, ""); |
|
|
|
CV_CheckEQ(Wh.rows, (1 + static_cast<int>(bidirectional))*4*Wh.cols, ""); |
|
|
|
CV_CheckEQ(Wh.rows, (int)bias.total(), ""); |
|
|
|
CV_CheckEQ(Wh.rows, (int)bias.total(), ""); |
|
|
|
CV_Assert(Wh.type() == Wx.type() && Wx.type() == bias.type()); |
|
|
|
CV_Assert(Wh.type() == Wx.type() && Wx.type() == bias.type()); |
|
|
|
|
|
|
|
|
|
|
@ -136,6 +138,7 @@ public: |
|
|
|
useCellClip = params.get<bool>("use_cell_clip", false); |
|
|
|
useCellClip = params.get<bool>("use_cell_clip", false); |
|
|
|
usePeephole = params.get<bool>("use_peephole", false); |
|
|
|
usePeephole = params.get<bool>("use_peephole", false); |
|
|
|
reverse = params.get<bool>("reverse", false); |
|
|
|
reverse = params.get<bool>("reverse", false); |
|
|
|
|
|
|
|
CV_Assert(!reverse || !bidirectional); |
|
|
|
|
|
|
|
|
|
|
|
allocated = false; |
|
|
|
allocated = false; |
|
|
|
outTailShape.clear(); |
|
|
|
outTailShape.clear(); |
|
|
@ -207,6 +210,7 @@ public: |
|
|
|
|
|
|
|
|
|
|
|
outResShape.push_back(_numSamples); |
|
|
|
outResShape.push_back(_numSamples); |
|
|
|
outResShape.insert(outResShape.end(), outTailShape_.begin(), outTailShape_.end()); |
|
|
|
outResShape.insert(outResShape.end(), outTailShape_.begin(), outTailShape_.end()); |
|
|
|
|
|
|
|
outResShape.back() *= (1 + static_cast<int>(bidirectional)); |
|
|
|
|
|
|
|
|
|
|
|
size_t noutputs = produceCellOutput ? 2 : 1; |
|
|
|
size_t noutputs = produceCellOutput ? 2 : 1; |
|
|
|
outputs.assign(noutputs, outResShape); |
|
|
|
outputs.assign(noutputs, outResShape); |
|
|
@ -253,6 +257,7 @@ public: |
|
|
|
outTsShape.clear(); |
|
|
|
outTsShape.clear(); |
|
|
|
outTsShape.push_back(numSamples); |
|
|
|
outTsShape.push_back(numSamples); |
|
|
|
outTsShape.insert(outTsShape.end(), outTailShape.begin(), outTailShape.end()); |
|
|
|
outTsShape.insert(outTsShape.end(), outTailShape.begin(), outTailShape.end()); |
|
|
|
|
|
|
|
outTsShape.back() *= (1 + static_cast<int>(bidirectional)); |
|
|
|
|
|
|
|
|
|
|
|
allocated = true; |
|
|
|
allocated = true; |
|
|
|
} |
|
|
|
} |
|
|
@ -273,91 +278,96 @@ public: |
|
|
|
outputs_arr.getMatVector(output); |
|
|
|
outputs_arr.getMatVector(output); |
|
|
|
internals_arr.getMatVector(internals); |
|
|
|
internals_arr.getMatVector(internals); |
|
|
|
|
|
|
|
|
|
|
|
const Mat &Wh = blobs[0]; |
|
|
|
const int numDirs = 1 + static_cast<int>(bidirectional); |
|
|
|
const Mat &Wx = blobs[1]; |
|
|
|
for (int i = 0; i < numDirs; ++i) |
|
|
|
const Mat &bias = blobs[2]; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
int numOut = Wh.size[1]; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Mat hInternal = internals[0], cInternal = internals[1], |
|
|
|
|
|
|
|
dummyOnes = internals[2], gates = internals[3]; |
|
|
|
|
|
|
|
hInternal.setTo(0.); |
|
|
|
|
|
|
|
cInternal.setTo(0.); |
|
|
|
|
|
|
|
dummyOnes.setTo(1.); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
int numSamplesTotal = numTimeStamps*numSamples; |
|
|
|
|
|
|
|
Mat xTs = input[0].reshape(1, numSamplesTotal); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Mat hOutTs = output[0].reshape(1, numSamplesTotal); |
|
|
|
|
|
|
|
Mat cOutTs = produceCellOutput ? output[1].reshape(1, numSamplesTotal) : Mat(); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
int tsStart, tsEnd, tsInc; |
|
|
|
|
|
|
|
if (reverse) { |
|
|
|
|
|
|
|
tsStart = numTimeStamps - 1; |
|
|
|
|
|
|
|
tsEnd = -1; |
|
|
|
|
|
|
|
tsInc = -1; |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
else { |
|
|
|
|
|
|
|
tsStart = 0; |
|
|
|
|
|
|
|
tsEnd = numTimeStamps; |
|
|
|
|
|
|
|
tsInc = 1; |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
for (int ts = tsStart; ts != tsEnd; ts += tsInc) |
|
|
|
|
|
|
|
{ |
|
|
|
{ |
|
|
|
Range curRowRange(ts*numSamples, (ts + 1)*numSamples); |
|
|
|
const Mat &Wh = blobs[0].rowRange(i * blobs[0].rows / numDirs, (i + 1) * blobs[0].rows / numDirs); |
|
|
|
Mat xCurr = xTs.rowRange(curRowRange); |
|
|
|
const Mat &Wx = blobs[1].rowRange(i * blobs[1].rows / numDirs, (i + 1) * blobs[1].rows / numDirs); |
|
|
|
|
|
|
|
const Mat &bias = blobs[2].colRange(i * blobs[2].cols / numDirs, (i + 1) * blobs[2].cols / numDirs); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
int numOut = Wh.size[1]; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Mat hInternal = internals[0], cInternal = internals[1], |
|
|
|
|
|
|
|
dummyOnes = internals[2], gates = internals[3]; |
|
|
|
|
|
|
|
hInternal.setTo(0.); |
|
|
|
|
|
|
|
cInternal.setTo(0.); |
|
|
|
|
|
|
|
dummyOnes.setTo(1.); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
int numSamplesTotal = numTimeStamps*numSamples; |
|
|
|
|
|
|
|
Mat xTs = input[0].reshape(1, numSamplesTotal); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Mat hOutTs = output[0].reshape(1, numSamplesTotal); |
|
|
|
|
|
|
|
hOutTs = hOutTs.colRange(i * hOutTs.cols / numDirs, (i + 1) * hOutTs.cols / numDirs); |
|
|
|
|
|
|
|
Mat cOutTs = produceCellOutput ? output[1].reshape(1, numSamplesTotal) : Mat(); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
int tsStart, tsEnd, tsInc; |
|
|
|
|
|
|
|
if (reverse || i == 1) { |
|
|
|
|
|
|
|
tsStart = numTimeStamps - 1; |
|
|
|
|
|
|
|
tsEnd = -1; |
|
|
|
|
|
|
|
tsInc = -1; |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
else { |
|
|
|
|
|
|
|
tsStart = 0; |
|
|
|
|
|
|
|
tsEnd = numTimeStamps; |
|
|
|
|
|
|
|
tsInc = 1; |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
for (int ts = tsStart; ts != tsEnd; ts += tsInc) |
|
|
|
|
|
|
|
{ |
|
|
|
|
|
|
|
Range curRowRange(ts*numSamples, (ts + 1)*numSamples); |
|
|
|
|
|
|
|
Mat xCurr = xTs.rowRange(curRowRange); |
|
|
|
|
|
|
|
|
|
|
|
gemm(xCurr, Wx, 1, gates, 0, gates, GEMM_2_T); // Wx * x_t
|
|
|
|
gemm(xCurr, Wx, 1, gates, 0, gates, GEMM_2_T); // Wx * x_t
|
|
|
|
gemm(hInternal, Wh, 1, gates, 1, gates, GEMM_2_T); //+Wh * h_{t-1}
|
|
|
|
gemm(hInternal, Wh, 1, gates, 1, gates, GEMM_2_T); //+Wh * h_{t-1}
|
|
|
|
gemm(dummyOnes, bias, 1, gates, 1, gates); //+b
|
|
|
|
gemm(dummyOnes, bias, 1, gates, 1, gates); //+b
|
|
|
|
|
|
|
|
|
|
|
|
Mat gateI = gates.colRange(0*numOut, 1*numOut); |
|
|
|
Mat gateI = gates.colRange(0*numOut, 1*numOut); |
|
|
|
Mat gateF = gates.colRange(1*numOut, 2*numOut); |
|
|
|
Mat gateF = gates.colRange(1*numOut, 2*numOut); |
|
|
|
Mat gateO = gates.colRange(2*numOut, 3*numOut); |
|
|
|
Mat gateO = gates.colRange(2*numOut, 3*numOut); |
|
|
|
Mat gateG = gates.colRange(3*numOut, 4*numOut); |
|
|
|
Mat gateG = gates.colRange(3*numOut, 4*numOut); |
|
|
|
|
|
|
|
|
|
|
|
if (forgetBias) |
|
|
|
if (forgetBias) |
|
|
|
add(gateF, forgetBias, gateF); |
|
|
|
add(gateF, forgetBias, gateF); |
|
|
|
|
|
|
|
|
|
|
|
if (usePeephole) |
|
|
|
if (usePeephole) |
|
|
|
{ |
|
|
|
{ |
|
|
|
Mat gatesIF = gates.colRange(0, 2*numOut); |
|
|
|
Mat gatesIF = gates.colRange(0, 2*numOut); |
|
|
|
gemm(cInternal, blobs[3], 1, gateI, 1, gateI); |
|
|
|
gemm(cInternal, blobs[3], 1, gateI, 1, gateI); |
|
|
|
gemm(cInternal, blobs[4], 1, gateF, 1, gateF); |
|
|
|
gemm(cInternal, blobs[4], 1, gateF, 1, gateF); |
|
|
|
sigmoid(gatesIF, gatesIF); |
|
|
|
sigmoid(gatesIF, gatesIF); |
|
|
|
} |
|
|
|
} |
|
|
|
else |
|
|
|
else |
|
|
|
{ |
|
|
|
{ |
|
|
|
Mat gatesIFO = gates.colRange(0, 3*numOut); |
|
|
|
Mat gatesIFO = gates.colRange(0, 3*numOut); |
|
|
|
sigmoid(gatesIFO, gatesIFO); |
|
|
|
sigmoid(gatesIFO, gatesIFO); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
tanh(gateG, gateG); |
|
|
|
tanh(gateG, gateG); |
|
|
|
|
|
|
|
|
|
|
|
//compute c_t
|
|
|
|
//compute c_t
|
|
|
|
multiply(gateF, cInternal, gateF); // f_t (*) c_{t-1}
|
|
|
|
multiply(gateF, cInternal, gateF); // f_t (*) c_{t-1}
|
|
|
|
multiply(gateI, gateG, gateI); // i_t (*) g_t
|
|
|
|
multiply(gateI, gateG, gateI); // i_t (*) g_t
|
|
|
|
add(gateF, gateI, cInternal); // c_t = f_t (*) c_{t-1} + i_t (*) g_t
|
|
|
|
add(gateF, gateI, cInternal); // c_t = f_t (*) c_{t-1} + i_t (*) g_t
|
|
|
|
|
|
|
|
|
|
|
|
if (useCellClip) |
|
|
|
if (useCellClip) |
|
|
|
{ |
|
|
|
{ |
|
|
|
min(cInternal, cellClip, cInternal); |
|
|
|
min(cInternal, cellClip, cInternal); |
|
|
|
max(cInternal, -cellClip, cInternal); |
|
|
|
max(cInternal, -cellClip, cInternal); |
|
|
|
} |
|
|
|
} |
|
|
|
if (usePeephole) |
|
|
|
if (usePeephole) |
|
|
|
{ |
|
|
|
{ |
|
|
|
gemm(cInternal, blobs[5], 1, gateO, 1, gateO); |
|
|
|
gemm(cInternal, blobs[5], 1, gateO, 1, gateO); |
|
|
|
sigmoid(gateO, gateO); |
|
|
|
sigmoid(gateO, gateO); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
//compute h_t
|
|
|
|
//compute h_t
|
|
|
|
tanh(cInternal, hInternal); |
|
|
|
tanh(cInternal, hInternal); |
|
|
|
multiply(gateO, hInternal, hInternal); |
|
|
|
multiply(gateO, hInternal, hInternal); |
|
|
|
|
|
|
|
|
|
|
|
//save results in output blobs
|
|
|
|
//save results in output blobs
|
|
|
|
hInternal.copyTo(hOutTs.rowRange(curRowRange)); |
|
|
|
hInternal.copyTo(hOutTs.rowRange(curRowRange)); |
|
|
|
if (produceCellOutput) |
|
|
|
if (produceCellOutput) |
|
|
|
cInternal.copyTo(cOutTs.rowRange(curRowRange)); |
|
|
|
cInternal.copyTo(cOutTs.rowRange(curRowRange)); |
|
|
|
|
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
}; |
|
|
|
}; |
|
|
|