|
|
|
@ -321,20 +321,28 @@ int LSTMLayer::outputNameToIndex(String outputName) |
|
|
|
|
|
|
|
|
|
class RNNLayerImpl : public RNNLayer |
|
|
|
|
{ |
|
|
|
|
int nX, nH, nO, nSamples; |
|
|
|
|
int numX, numH, numO; |
|
|
|
|
int numSamples, numTimestamps, numSamplesTotal; |
|
|
|
|
int dtype; |
|
|
|
|
Mat Whh, Wxh, bh; |
|
|
|
|
Mat Who, bo; |
|
|
|
|
Mat hPrevInternal, dummyBiasOnes; |
|
|
|
|
Mat hCurr, hPrev, dummyBiasOnes; |
|
|
|
|
bool produceH; |
|
|
|
|
|
|
|
|
|
public: |
|
|
|
|
|
|
|
|
|
RNNLayerImpl() |
|
|
|
|
{ |
|
|
|
|
type = "RNN"; |
|
|
|
|
produceH = false; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
void setWeights(const Blob &W_hh, const Blob &W_xh, const Blob &b_h, const Blob &W_ho, const Blob &b_o) |
|
|
|
|
void setProduceHiddenOutput(bool produce = false) |
|
|
|
|
{ |
|
|
|
|
produceH = produce; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
void setWeights(const Blob &W_xh, const Blob &b_h, const Blob &W_hh, const Blob &W_ho, const Blob &b_o) |
|
|
|
|
{ |
|
|
|
|
CV_Assert(W_hh.dims() == 2 && W_xh.dims() == 2); |
|
|
|
|
CV_Assert(W_hh.size(0) == W_xh.size(0) && W_hh.size(0) == W_hh.size(1) && (int)b_h.total() == W_xh.size(0)); |
|
|
|
@ -342,9 +350,9 @@ public: |
|
|
|
|
CV_Assert(W_ho.size(1) == W_hh.size(1)); |
|
|
|
|
|
|
|
|
|
blobs.resize(5); |
|
|
|
|
blobs[0] = W_hh; |
|
|
|
|
blobs[1] = W_xh; |
|
|
|
|
blobs[2] = b_h; |
|
|
|
|
blobs[0] = W_xh; |
|
|
|
|
blobs[1] = b_h; |
|
|
|
|
blobs[2] = W_hh; |
|
|
|
|
blobs[3] = W_ho; |
|
|
|
|
blobs[4] = b_o; |
|
|
|
|
} |
|
|
|
@ -353,72 +361,68 @@ public: |
|
|
|
|
{ |
|
|
|
|
CV_Assert(input.size() >= 1 && input.size() <= 2); |
|
|
|
|
|
|
|
|
|
Whh = blobs[0].matRefConst(); |
|
|
|
|
Wxh = blobs[1].matRefConst(); |
|
|
|
|
bh = blobs[2].matRefConst(); |
|
|
|
|
Wxh = blobs[0].matRefConst(); |
|
|
|
|
bh = blobs[1].matRefConst(); |
|
|
|
|
Whh = blobs[2].matRefConst(); |
|
|
|
|
Who = blobs[3].matRefConst(); |
|
|
|
|
bo = blobs[4].matRefConst(); |
|
|
|
|
|
|
|
|
|
nH = Wxh.rows; |
|
|
|
|
nX = Wxh.cols; |
|
|
|
|
nO = Who.rows; |
|
|
|
|
numH = Wxh.rows; |
|
|
|
|
numX = Wxh.cols; |
|
|
|
|
numO = Who.rows; |
|
|
|
|
|
|
|
|
|
CV_Assert(input[0]->size(-1) == Wxh.cols); |
|
|
|
|
nSamples = input[0]->total(0, input[0]->dims() - 1); |
|
|
|
|
BlobShape xShape = input[0]->shape(); |
|
|
|
|
BlobShape hShape = xShape; |
|
|
|
|
BlobShape oShape = xShape; |
|
|
|
|
hShape[-1] = nH; |
|
|
|
|
oShape[-1] = nO; |
|
|
|
|
|
|
|
|
|
if (input.size() == 2) |
|
|
|
|
{ |
|
|
|
|
CV_Assert(input[1]->shape() == hShape); |
|
|
|
|
} |
|
|
|
|
else |
|
|
|
|
{ |
|
|
|
|
hPrevInternal.create(nSamples, nH, input[0]->type()); |
|
|
|
|
hPrevInternal.setTo(0); |
|
|
|
|
} |
|
|
|
|
CV_Assert(input[0]->dims() >= 2); |
|
|
|
|
CV_Assert((int)input[0]->total(2) == numX); |
|
|
|
|
CV_Assert(input[0]->type() == CV_32F || input[0]->type() == CV_64F); |
|
|
|
|
dtype = input[0]->type(); |
|
|
|
|
numTimestamps = input[0]->size(0); |
|
|
|
|
numSamples = input[0]->size(1); |
|
|
|
|
numSamplesTotal = numTimestamps * numSamples; |
|
|
|
|
|
|
|
|
|
output.resize(2); |
|
|
|
|
output[0].create(oShape, input[0]->type()); |
|
|
|
|
output[1].create(hShape, input[0]->type()); |
|
|
|
|
hCurr.create(numSamples, numH, dtype); |
|
|
|
|
hPrev.create(numSamples, numH, dtype); |
|
|
|
|
hPrev.setTo(0); |
|
|
|
|
|
|
|
|
|
dummyBiasOnes.create(nSamples, 1, bh.type()); |
|
|
|
|
dummyBiasOnes.create(numSamples, 1, dtype); |
|
|
|
|
dummyBiasOnes.setTo(1); |
|
|
|
|
bh = bh.reshape(1, 1); //is 1 x nH mat
|
|
|
|
|
bo = bo.reshape(1, 1); //is 1 x nO mat
|
|
|
|
|
bh = bh.reshape(1, 1); //is 1 x numH Mat
|
|
|
|
|
bo = bo.reshape(1, 1); //is 1 x numO Mat
|
|
|
|
|
|
|
|
|
|
reshapeOutput(output); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
void reshapeOutput(std::vector<Blob> &output) |
|
|
|
|
{ |
|
|
|
|
output.resize((produceH) ? 2 : 1); |
|
|
|
|
output[0].create(BlobShape(numTimestamps, numSamples, numO), dtype); |
|
|
|
|
if (produceH) |
|
|
|
|
output[1].create(BlobShape(numTimestamps, numSamples, numH), dtype); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
void forward(std::vector<Blob*> &input, std::vector<Blob> &output) |
|
|
|
|
{ |
|
|
|
|
Mat xCurr = input[0]->matRefConst(); |
|
|
|
|
Mat hPrev = (input.size() >= 2) ? input[1]->matRefConst() : hPrevInternal; |
|
|
|
|
Mat oCurr = output[0].matRef(); |
|
|
|
|
Mat hCurr = output[1].matRef(); |
|
|
|
|
|
|
|
|
|
//TODO: Check types
|
|
|
|
|
|
|
|
|
|
int xsz[] = {nSamples, nX}; |
|
|
|
|
int hsz[] = {nSamples, nH}; |
|
|
|
|
int osz[] = {nSamples, nO}; |
|
|
|
|
if (xCurr.dims != 2) xCurr = xCurr.reshape(1, 2, xsz); |
|
|
|
|
if (hPrev.dims != 2) hPrev = hPrev.reshape(1, 2, hsz); |
|
|
|
|
if (oCurr.dims != 2) oCurr = oCurr.reshape(1, 2, osz); |
|
|
|
|
if (hCurr.dims != 2) hCurr = hCurr.reshape(1, 2, hsz); |
|
|
|
|
|
|
|
|
|
gemmCPU(hPrev, Whh, 1, hCurr, 0, GEMM_2_T); // W_{hh} * h_{prev}
|
|
|
|
|
gemmCPU(xCurr, Wxh, 1, hCurr, 1, GEMM_2_T); //+W_{xh} * x_{curr}
|
|
|
|
|
gemmCPU(dummyBiasOnes, bh, 1, hCurr, 1); //+bh
|
|
|
|
|
tanh(hCurr, hCurr); |
|
|
|
|
|
|
|
|
|
gemmCPU(hPrev, Who, 1, oCurr, 0, GEMM_2_T); // W_{ho} * h_{prev}
|
|
|
|
|
gemmCPU(dummyBiasOnes, bo, 1, oCurr, 1); //+b_o
|
|
|
|
|
tanh(oCurr, oCurr); |
|
|
|
|
|
|
|
|
|
if (input.size() < 2) //save h_{prev}
|
|
|
|
|
hCurr.copyTo(hPrevInternal); |
|
|
|
|
Mat xTs = input[0]->reshaped(BlobShape(numSamplesTotal, numX)).matRefConst(); |
|
|
|
|
Mat oTs = output[0].reshaped(BlobShape(numSamplesTotal, numO)).matRef(); |
|
|
|
|
Mat hTs = (produceH) ? output[1].reshaped(BlobShape(numSamplesTotal, numH)).matRef() : Mat(); |
|
|
|
|
|
|
|
|
|
for (int ts = 0; ts < numTimestamps; ts++) |
|
|
|
|
{ |
|
|
|
|
Range curRowRange = Range(ts * numSamples, (ts + 1) * numSamples); |
|
|
|
|
Mat xCurr = xTs.rowRange(curRowRange); |
|
|
|
|
|
|
|
|
|
gemmCPU(hPrev, Whh, 1, hCurr, 0, GEMM_2_T); // W_{hh} * h_{prev}
|
|
|
|
|
gemmCPU(xCurr, Wxh, 1, hCurr, 1, GEMM_2_T); //+W_{xh} * x_{curr}
|
|
|
|
|
gemmCPU(dummyBiasOnes, bh, 1, hCurr, 1); //+bh
|
|
|
|
|
tanh(hCurr, hPrev); |
|
|
|
|
|
|
|
|
|
Mat oCurr = oTs.rowRange(curRowRange); |
|
|
|
|
gemmCPU(hPrev, Who, 1, oCurr, 0, GEMM_2_T); // W_{ho} * h_{prev}
|
|
|
|
|
gemmCPU(dummyBiasOnes, bo, 1, oCurr, 1); //+b_o
|
|
|
|
|
tanh(oCurr, oCurr); |
|
|
|
|
|
|
|
|
|
if (produceH) |
|
|
|
|
hPrev.copyTo(hTs.rowRange(curRowRange)); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|