|
|
|
@ -563,5 +563,214 @@ CV_EXPORTS_W Ptr<RNNLayer> RNNLayer::create(const LayerParams& params) |
|
|
|
|
return Ptr<RNNLayer>(new RNNLayerImpl(params)); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
class GRULayerImpl CV_FINAL : public GRULayer |
|
|
|
|
{ |
|
|
|
|
int numTimeStamps, numSamples; |
|
|
|
|
bool allocated; |
|
|
|
|
|
|
|
|
|
MatShape outTailShape; //shape of single output sample
|
|
|
|
|
MatShape outTsShape; //shape of N output samples
|
|
|
|
|
bool bidirectional; // If true, produces both forward and reversed directions along time axis
|
|
|
|
|
|
|
|
|
|
public: |
|
|
|
|
|
|
|
|
|
GRULayerImpl(const LayerParams& params) : numTimeStamps(0), numSamples(0) |
|
|
|
|
{ |
|
|
|
|
setParamsFrom(params); |
|
|
|
|
|
|
|
|
|
bidirectional = params.get<bool>("bidirectional", false); |
|
|
|
|
if (!blobs.empty()) |
|
|
|
|
{ |
|
|
|
|
CV_Assert(blobs.size() >= 3); |
|
|
|
|
|
|
|
|
|
blobs[2] = blobs[2].reshape(1, 1); |
|
|
|
|
|
|
|
|
|
const Mat& Wh = blobs[0]; |
|
|
|
|
const Mat& Wx = blobs[1]; |
|
|
|
|
const Mat& bias = blobs[2]; |
|
|
|
|
const Mat& hInternal = blobs[3]; |
|
|
|
|
CV_CheckEQ(Wh.dims, 2, ""); |
|
|
|
|
CV_CheckEQ(Wx.dims, 2, ""); |
|
|
|
|
CV_CheckEQ(Wh.rows, Wx.rows, ""); |
|
|
|
|
CV_CheckEQ(Wh.rows, (1 + static_cast<int>(bidirectional)) * 3 * Wh.cols, ""); |
|
|
|
|
CV_CheckEQ(Wh.rows * 2, (int)bias.total(), ""); |
|
|
|
|
CV_CheckEQ(hInternal.cols, Wh.cols, ""); |
|
|
|
|
CV_CheckTypeEQ(Wh.type(), Wx.type(), ""); |
|
|
|
|
CV_CheckTypeEQ(Wx.type(), bias.type(), ""); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
allocated = false; |
|
|
|
|
outTailShape.clear(); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
bool getMemoryShapes(const std::vector<MatShape> &inputs, |
|
|
|
|
const int requiredOutputs, |
|
|
|
|
std::vector<MatShape> &outputs, |
|
|
|
|
std::vector<MatShape> &internals) const CV_OVERRIDE |
|
|
|
|
{ |
|
|
|
|
CV_Assert(inputs.size() == 1); |
|
|
|
|
const MatShape& inp0 = inputs[0]; |
|
|
|
|
|
|
|
|
|
const Mat &Wh = blobs[0], &Wx = blobs[1]; |
|
|
|
|
int _numOut = Wh.size[1]; |
|
|
|
|
int _numInp = Wx.size[1]; |
|
|
|
|
MatShape outTailShape_(outTailShape), outResShape; |
|
|
|
|
|
|
|
|
|
if (!outTailShape_.empty()) |
|
|
|
|
CV_Assert(total(outTailShape_) == _numOut); |
|
|
|
|
else |
|
|
|
|
outTailShape_.assign(1, _numOut); |
|
|
|
|
|
|
|
|
|
int _numSamples; |
|
|
|
|
CV_Assert(inp0.size() >= 2 && total(inp0, 2) == _numInp); |
|
|
|
|
_numSamples = inp0[1]; |
|
|
|
|
outResShape.push_back(inp0[0]); |
|
|
|
|
|
|
|
|
|
outResShape.push_back(_numSamples); |
|
|
|
|
outResShape.insert(outResShape.end(), outTailShape_.begin(), outTailShape_.end()); |
|
|
|
|
outResShape.back() *= (1 + static_cast<int>(bidirectional)); |
|
|
|
|
|
|
|
|
|
outputs.assign(1, outResShape); |
|
|
|
|
|
|
|
|
|
internals.assign(1, shape(_numSamples, _numOut)); // hInternal
|
|
|
|
|
internals.push_back(shape(_numSamples, 1)); // dummyOnes
|
|
|
|
|
internals.push_back(shape(_numSamples, 2 * _numOut)); // gates
|
|
|
|
|
internals.push_back(shape(_numSamples, 2 * _numOut)); // gates_b
|
|
|
|
|
internals.push_back(shape(_numSamples, 1 * _numOut)); // h_linear
|
|
|
|
|
internals.push_back(shape(_numSamples, _numOut)); // ones
|
|
|
|
|
|
|
|
|
|
return false; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
void finalize(InputArrayOfArrays inputs_arr, OutputArrayOfArrays) CV_OVERRIDE |
|
|
|
|
{ |
|
|
|
|
std::vector<Mat> input; |
|
|
|
|
inputs_arr.getMatVector(input); |
|
|
|
|
|
|
|
|
|
CV_Assert(input.size() == 1); |
|
|
|
|
const Mat& inp0 = input[0]; |
|
|
|
|
|
|
|
|
|
Mat &Wh = blobs[0], &Wx = blobs[1]; |
|
|
|
|
int numOut = Wh.size[1]; |
|
|
|
|
int numInp = Wx.size[1]; |
|
|
|
|
|
|
|
|
|
if (!outTailShape.empty()) |
|
|
|
|
CV_Assert(total(outTailShape) == numOut); |
|
|
|
|
else |
|
|
|
|
outTailShape.assign(1, numOut); |
|
|
|
|
|
|
|
|
|
CV_Assert(inp0.dims >= 2 && (int)inp0.total(2) == numInp); |
|
|
|
|
numTimeStamps = inp0.size[0]; |
|
|
|
|
numSamples = inp0.size[1]; |
|
|
|
|
|
|
|
|
|
outTsShape.clear(); |
|
|
|
|
outTsShape.push_back(numSamples); |
|
|
|
|
outTsShape.insert(outTsShape.end(), outTailShape.begin(), outTailShape.end()); |
|
|
|
|
outTsShape.back() *= (1 + static_cast<int>(bidirectional)); |
|
|
|
|
|
|
|
|
|
allocated = true; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE |
|
|
|
|
{ |
|
|
|
|
CV_TRACE_FUNCTION(); |
|
|
|
|
CV_TRACE_ARG_VALUE(name, "name", name.c_str()); |
|
|
|
|
|
|
|
|
|
if (inputs_arr.depth() == CV_16S) |
|
|
|
|
{ |
|
|
|
|
forward_fallback(inputs_arr, outputs_arr, internals_arr); |
|
|
|
|
return; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
std::vector<Mat> input, output, internals; |
|
|
|
|
inputs_arr.getMatVector(input); |
|
|
|
|
outputs_arr.getMatVector(output); |
|
|
|
|
internals_arr.getMatVector(internals); |
|
|
|
|
|
|
|
|
|
const int numDirs = 1 + static_cast<int>(bidirectional); |
|
|
|
|
for (int i = 0; i < numDirs; ++i) |
|
|
|
|
{ |
|
|
|
|
const Mat &Wh = blobs[0].rowRange(i * blobs[0].rows / numDirs, (i + 1) * blobs[0].rows / numDirs); |
|
|
|
|
const Mat &Wx = blobs[1].rowRange(i * blobs[1].rows / numDirs, (i + 1) * blobs[1].rows / numDirs); |
|
|
|
|
const Mat &bias = blobs[2].colRange(i * blobs[2].cols / numDirs, (i + 1) * blobs[2].cols / numDirs); |
|
|
|
|
const Mat &h_0 = blobs[3].rowRange(i * blobs[3].rows / numDirs, (i + 1) * blobs[3].rows / numDirs); |
|
|
|
|
|
|
|
|
|
const Mat &bx = bias.colRange(0, bias.cols / 2); |
|
|
|
|
const Mat &bh = bias.colRange(bias.cols / 2, bias.cols); |
|
|
|
|
|
|
|
|
|
Mat hInternal = internals[0], dummyOnes = internals[1], gates = internals[2], |
|
|
|
|
b_rz = internals[3], n_t = internals[4], ones = internals[5]; |
|
|
|
|
h_0.copyTo(hInternal); |
|
|
|
|
dummyOnes.setTo(1.); |
|
|
|
|
ones.setTo(1.); |
|
|
|
|
|
|
|
|
|
int numOut = Wh.size[1]; |
|
|
|
|
const Mat& wx_rz = Wx.rowRange(0, 2 * numOut); |
|
|
|
|
const Mat& wh_rz = Wh.rowRange(0, 2 * numOut); |
|
|
|
|
b_rz = bx.colRange(0, 2 * numOut) + bh.colRange(0, 2 * numOut); |
|
|
|
|
const Mat& wx_n = Wx.rowRange(2 * numOut, 3 * numOut); |
|
|
|
|
const Mat& wh_n = Wh.rowRange(2 * numOut, 3 * numOut); |
|
|
|
|
const Mat& b_in = bx.colRange(2 * numOut, 3 * numOut); |
|
|
|
|
const Mat& b_hn = bh.colRange(2 * numOut, 3 * numOut); |
|
|
|
|
|
|
|
|
|
int numSamplesTotal = numTimeStamps * numSamples; |
|
|
|
|
Mat xTs = input[0].reshape(1, numSamplesTotal); |
|
|
|
|
|
|
|
|
|
Mat hOutTs = output[0].reshape(1, numSamplesTotal); |
|
|
|
|
hOutTs = hOutTs.colRange(i * hOutTs.cols / numDirs, (i + 1) * hOutTs.cols / numDirs); |
|
|
|
|
Mat cOutTs = Mat(); |
|
|
|
|
|
|
|
|
|
int tsStart, tsEnd, tsInc; |
|
|
|
|
if (i == 1) { |
|
|
|
|
tsStart = numTimeStamps - 1; |
|
|
|
|
tsEnd = -1; |
|
|
|
|
tsInc = -1; |
|
|
|
|
} |
|
|
|
|
else { |
|
|
|
|
tsStart = 0; |
|
|
|
|
tsEnd = numTimeStamps; |
|
|
|
|
tsInc = 1; |
|
|
|
|
} |
|
|
|
|
for (int ts = tsStart; ts != tsEnd; ts += tsInc) |
|
|
|
|
{ |
|
|
|
|
Range curRowRange(ts * numSamples, (ts + 1) * numSamples); |
|
|
|
|
Mat xCurr = xTs.rowRange(curRowRange); |
|
|
|
|
|
|
|
|
|
// calculate r_t = sigmoid(x * Wx_r + h_(t-1) * Wh_r + b_r)
|
|
|
|
|
// calculate z_t = sigmoid(x * Wx_z + h_(t-1) * Wh_z + b_z)
|
|
|
|
|
gemm(xCurr, wx_rz, 1, gates, 0, gates, GEMM_2_T); // x * Wx_rz
|
|
|
|
|
gemm(hInternal, wh_rz, 1, gates, 1, gates, GEMM_2_T); // + h_(t-1) * Wh_rz
|
|
|
|
|
gemm(dummyOnes, b_rz, 1, gates, 1, gates); // + b_rz
|
|
|
|
|
sigmoid(gates, gates); // sigmoid()
|
|
|
|
|
|
|
|
|
|
Mat z = gates.colRange(0, gates.cols / 2); |
|
|
|
|
Mat r = gates.colRange(gates.cols / 2, gates.cols); |
|
|
|
|
|
|
|
|
|
// calculate n_t = tanh(r (*) (h_(t-1) * Wh_n + b_hn) + x * Wx_n + b_in)
|
|
|
|
|
gemm(hInternal, wh_n, 1, n_t, 0, n_t, GEMM_2_T); // h_(t-1) * Wh_n
|
|
|
|
|
gemm(dummyOnes, b_hn, 1, n_t, 1, n_t); // + b_hn
|
|
|
|
|
multiply(r, n_t, n_t); // r (*) (h_(t-1) * Wh_n + b_hn)
|
|
|
|
|
|
|
|
|
|
gemm(xCurr, wx_n, 1, n_t, 1, n_t, GEMM_2_T); // + x * Wx_n
|
|
|
|
|
gemm(dummyOnes, b_in, 1, n_t, 1, n_t); // + b_in
|
|
|
|
|
tanh(n_t, n_t); // tanh()
|
|
|
|
|
|
|
|
|
|
//compute next h_t = z (*) h_(t-1) + (1 - z) (*) n_t
|
|
|
|
|
multiply(z, hInternal, hInternal); // z (*) h_{t-1}
|
|
|
|
|
subtract(ones, z, z); // 1 - z
|
|
|
|
|
multiply(z, n_t, z); // (1 - z) * n
|
|
|
|
|
add(z, hInternal, hInternal); // z (*) h_(t-1) + (1 - z) (*) n_t
|
|
|
|
|
|
|
|
|
|
//save results in output blobs
|
|
|
|
|
hInternal.copyTo(hOutTs.rowRange(curRowRange)); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
Ptr<GRULayer> GRULayer::create(const LayerParams ¶ms) { |
|
|
|
|
return Ptr<GRULayer>(new GRULayerImpl(params)); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|