diff --git a/modules/dnn/src/layers/recurrent_layers.cpp b/modules/dnn/src/layers/recurrent_layers.cpp index 0cd68e2b5..65545fee7 100644 --- a/modules/dnn/src/layers/recurrent_layers.cpp +++ b/modules/dnn/src/layers/recurrent_layers.cpp @@ -44,6 +44,7 @@ #include "op_blas.hpp" #include #include +#include namespace cv { @@ -60,6 +61,7 @@ static void tanh(const Mat &src, Mat &dst) *itDst = std::tanh(*itSrc); } +//TODO: make utils method static void tanh(const Mat &src, Mat &dst) { dst.create(src.dims, (const int*)src.size, src.type()); @@ -86,9 +88,9 @@ class LSTMLayerImpl : public LSTMLayer int dtype; bool allocated; - BlobShape outTailShape; //shape of single output sample - BlobShape outTsMatShape, outTsShape; //shape of N output samples - BlobShape outResShape; //shape of T timestamps and N output samples + Shape outTailShape; //shape of single output sample + Shape outTsMatShape, outTsShape; //shape of N output samples + Shape outResShape; //shape of T timestamps and N output samples bool useTimestampDim; bool produceCellOutput; @@ -101,7 +103,7 @@ public: useTimestampDim = true; produceCellOutput = false; allocated = false; - outTailShape = BlobShape::empty(); + outTailShape = Shape::empty(); } void setUseTimstampsDim(bool use) @@ -120,7 +122,7 @@ public: { CV_Assert(cInternal.empty() || C.total() == cInternal.total()); if (!cInternal.empty()) - C.reshaped(BlobShape::like(cInternal)).matRefConst().copyTo(cInternal); + C.reshaped(Shape::like(cInternal)).matRefConst().copyTo(cInternal); else C.matRefConst().copyTo(cInternal); } @@ -129,7 +131,7 @@ public: { CV_Assert(hInternal.empty() || H.total() == hInternal.total()); if (!hInternal.empty()) - H.reshaped(BlobShape::like(hInternal)).matRefConst().copyTo(hInternal); + H.reshaped(Shape::like(hInternal)).matRefConst().copyTo(hInternal); else H.matRefConst().copyTo(hInternal); } @@ -153,7 +155,7 @@ public: return res; } - void setOutShape(const BlobShape &outTailShape_) + void setOutShape(const Shape &outTailShape_) { CV_Assert(!allocated || outTailShape_.total() == outTailShape.total()); outTailShape = outTailShape_; @@ -171,7 +173,7 @@ public: blobs[0] = Wh; blobs[1] = Wx; blobs[2] = bias; - blobs[2].reshape(BlobShape(1, (int)bias.total())); + blobs[2].reshape(Shape(1, (int)bias.total())); } void allocate(const std::vector &input, std::vector &output) @@ -186,24 +188,24 @@ public: if (!outTailShape.isEmpty()) CV_Assert(outTailShape.total() == numOut); else - outTailShape = BlobShape(numOut); + outTailShape = Shape(numOut); if (useTimestampDim) { CV_Assert(input[0]->dims() >= 2 && (int)input[0]->total(2) == numInp); numTimeStamps = input[0]->size(0); numSamples = input[0]->size(1); - outResShape = BlobShape(numTimeStamps, numSamples) + outTailShape; + outResShape = Shape(numTimeStamps, numSamples) + outTailShape; } else { CV_Assert(input[0]->dims() >= 1 && (int)input[0]->total(1) == numInp); numTimeStamps = 1; numSamples = input[0]->size(0); - outResShape = BlobShape(numSamples) + outTailShape; + outResShape = Shape(numSamples) + outTailShape; } - outTsMatShape = BlobShape(numSamples, numOut); - outTsShape = BlobShape(numSamples) + outTailShape; + outTsMatShape = Shape(numSamples, numOut); + outTsShape = Shape(numSamples) + outTailShape; dtype = input[0]->type(); CV_Assert(dtype == CV_32F || dtype == CV_64F); @@ -246,25 +248,25 @@ public: void forward(std::vector &input, std::vector &output) { - const Mat &Wh = blobs[0].matRefConst(); - const Mat &Wx = blobs[1].matRefConst(); - const Mat &bias = blobs[2].matRefConst(); + const Mat &Wh = blobs[0].getRefConst(); + const Mat &Wx = blobs[1].getRefConst(); + const Mat &bias = blobs[2].getRefConst(); int numSamplesTotal = numTimeStamps*numSamples; - Mat xTs = input[0]->reshaped(BlobShape(numSamplesTotal, numInp)).matRefConst(); + Mat xTs = reshaped(input[0]->getRefConst(), Shape(numSamplesTotal, numInp)); - BlobShape outMatShape(numSamplesTotal, numOut); - Mat hOutTs = output[0].reshaped(outMatShape).matRef(); - Mat cOutTs = (produceCellOutput) ? output[1].reshaped(outMatShape).matRef() : Mat(); + Shape outMatShape(numSamplesTotal, numOut); + Mat hOutTs = reshaped(output[0].getRef(), outMatShape); + Mat cOutTs = (produceCellOutput) ? reshaped(output[1].getRef(), outMatShape) : Mat(); for (int ts = 0; ts < numTimeStamps; ts++) { Range curRowRange(ts*numSamples, (ts + 1)*numSamples); Mat xCurr = xTs.rowRange(curRowRange); - gemmCPU(xCurr, Wx, 1, gates, 0, GEMM_2_T); // Wx * x_t - gemmCPU(hInternal, Wh, 1, gates, 1, GEMM_2_T); //+Wh * h_{t-1} - gemmCPU(dummyOnes, bias, 1, gates, 1); //+b + dnn::gemm(xCurr, Wx, 1, gates, 0, GEMM_2_T); // Wx * x_t + dnn::gemm(hInternal, Wh, 1, gates, 1, GEMM_2_T); //+Wh * h_{t-1} + dnn::gemm(dummyOnes, bias, 1, gates, 1); //+b Mat getesIFO = gates.colRange(0, 3*numOut); Mat gateI = gates.colRange(0*numOut, 1*numOut); @@ -394,30 +396,30 @@ public: void reshapeOutput(std::vector &output) { output.resize((produceH) ? 2 : 1); - output[0].create(BlobShape(numTimestamps, numSamples, numO), dtype); + output[0].create(Shape(numTimestamps, numSamples, numO), dtype); if (produceH) - output[1].create(BlobShape(numTimestamps, numSamples, numH), dtype); + output[1].create(Shape(numTimestamps, numSamples, numH), dtype); } void forward(std::vector &input, std::vector &output) { - Mat xTs = input[0]->reshaped(BlobShape(numSamplesTotal, numX)).matRefConst(); - Mat oTs = output[0].reshaped(BlobShape(numSamplesTotal, numO)).matRef(); - Mat hTs = (produceH) ? output[1].reshaped(BlobShape(numSamplesTotal, numH)).matRef() : Mat(); + Mat xTs = reshaped(input[0]->getRefConst(), Shape(numSamplesTotal, numX)); + Mat oTs = reshaped(output[0].getRef(), Shape(numSamplesTotal, numO)); + Mat hTs = (produceH) ? reshaped(output[1].getRef(), Shape(numSamplesTotal, numH)) : Mat(); for (int ts = 0; ts < numTimestamps; ts++) { Range curRowRange = Range(ts * numSamples, (ts + 1) * numSamples); Mat xCurr = xTs.rowRange(curRowRange); - gemmCPU(hPrev, Whh, 1, hCurr, 0, GEMM_2_T); // W_{hh} * h_{prev} - gemmCPU(xCurr, Wxh, 1, hCurr, 1, GEMM_2_T); //+W_{xh} * x_{curr} - gemmCPU(dummyBiasOnes, bh, 1, hCurr, 1); //+bh + dnn::gemm(hPrev, Whh, 1, hCurr, 0, GEMM_2_T); // W_{hh} * h_{prev} + dnn::gemm(xCurr, Wxh, 1, hCurr, 1, GEMM_2_T); //+W_{xh} * x_{curr} + dnn::gemm(dummyBiasOnes, bh, 1, hCurr, 1); //+bh tanh(hCurr, hPrev); Mat oCurr = oTs.rowRange(curRowRange); - gemmCPU(hPrev, Who, 1, oCurr, 0, GEMM_2_T); // W_{ho} * h_{prev} - gemmCPU(dummyBiasOnes, bo, 1, oCurr, 1); //+b_o + dnn::gemm(hPrev, Who, 1, oCurr, 0, GEMM_2_T); // W_{ho} * h_{prev} + dnn::gemm(dummyBiasOnes, bo, 1, oCurr, 1); //+b_o tanh(oCurr, oCurr); if (produceH)