Merge pull request #21522 from rogday:lstm

Fix LSTM support in ONNX

* fix LSTM and add peephole support

* disable old tests

* turn lambdas into functions

* more hacks for  c++98

* add assertions

* slice fixes

* backport of cuda-related fixes

* address review comments
pull/21750/head
rogday 3 years ago committed by GitHub
parent 5d8134ed32
commit 93353aea70
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 151
      modules/dnn/src/layers/recurrent_layers.cpp
  2. 300
      modules/dnn/src/onnx/onnx_importer.cpp
  3. 13
      modules/dnn/test/test_onnx_importer.cpp

@ -103,7 +103,7 @@ static ActivationFunction get_activation_function(const String& activation) {
class LSTMLayerImpl CV_FINAL : public LSTMLayer
{
int numTimeStamps, numSamples;
int numTimeStamps, numSamples, numHidden;
bool allocated;
MatShape outTailShape; //shape of single output sample
@ -127,6 +127,10 @@ class LSTMLayerImpl CV_FINAL : public LSTMLayer
bool useAVX2;
#endif
// CUDA needs input blobs to be rearranged in a specific way, but some transformations
// in ONNXImporter are destructive, so we keep a copy.
std::vector<Mat> originalBlobs;
public:
LSTMLayerImpl(const LayerParams& params)
@ -140,6 +144,13 @@ public:
{
setParamsFrom(params);
if (params.get<bool>("is_onnx", false))
{
// collect copies of onnx blobs
originalBlobs.insert(originalBlobs.begin(), blobs.begin(), blobs.begin() + 3);
blobs.erase(blobs.begin(), blobs.begin() + 3);
}
bidirectional = params.get<bool>("bidirectional", false);
if (!blobs.empty())
{
@ -181,6 +192,7 @@ public:
useCellClip = params.get<bool>("use_cell_clip", false);
usePeephole = params.get<bool>("use_peephole", false);
reverse = params.get<bool>("reverse", false);
numHidden = params.get<int>("hidden_size", 1);
CV_Assert(!reverse || !bidirectional);
// read activations
@ -269,8 +281,21 @@ public:
outResShape.insert(outResShape.end(), outTailShape_.begin(), outTailShape_.end());
outResShape.back() *= (1 + static_cast<int>(bidirectional));
size_t noutputs = produceCellOutput ? 2 : 1;
outputs.assign(noutputs, outResShape);
outputs.assign(1, outResShape);
if (produceCellOutput)
{
// the producer is ONNX, so CellState is different
if (!originalBlobs.empty())
{
int shp[] = {(1 + static_cast<int>(bidirectional)), _numSamples, numHidden};
MatShape newShape(shp, shp + sizeof(shp)/sizeof(shp[0]));
outputs.push_back(newShape);
}
else
{
outputs.push_back(outResShape);
}
}
internals.assign(1, shape(_numSamples, _numOut)); // hInternal
internals.push_back(shape(_numSamples, _numOut)); // cInternal
@ -335,14 +360,39 @@ public:
outputs_arr.getMatVector(output);
internals_arr.getMatVector(internals);
Mat cOut = produceCellOutput ? output[0].clone() : Mat();
const bool needYcTransform = !originalBlobs.empty(); // if the producer is onnx
const int numDirs = 1 + static_cast<int>(bidirectional);
for (int i = 0; i < numDirs; ++i)
{
const Mat &Wh = blobs[0].rowRange(i * blobs[0].rows / numDirs, (i + 1) * blobs[0].rows / numDirs);
const Mat &Wx = blobs[1].rowRange(i * blobs[1].rows / numDirs, (i + 1) * blobs[1].rows / numDirs);
const Mat &bias = blobs[2].colRange(i * blobs[2].cols / numDirs, (i + 1) * blobs[2].cols / numDirs);
const Mat &h_0 = blobs[3].rowRange(i * blobs[3].rows / numDirs, (i + 1) * blobs[3].rows / numDirs);
const Mat &c_0 = blobs[4].rowRange(i * blobs[4].rows / numDirs, (i + 1) * blobs[4].rows / numDirs);
Mat Wh = blobs[0];
Mat Wx = blobs[1];
Mat bias = blobs[2];
Mat h_0 = blobs[3];
Mat c_0 = blobs[4];
Mat pI, pF, pO;
Wh = Wh.rowRange(i * Wh.rows / numDirs, (i + 1) * Wh.rows / numDirs);
Wx = Wx.rowRange(i * Wx.rows / numDirs, (i + 1) * Wx.rows / numDirs);
bias = bias.colRange(i * bias.cols / numDirs, (i + 1) * bias.cols / numDirs);
h_0 = h_0.rowRange(i * h_0.rows / numDirs, (i + 1) * h_0.rows / numDirs);
c_0 = c_0.rowRange(i * c_0.rows / numDirs, (i + 1) * c_0.rows / numDirs);
if (usePeephole)
{
pI = blobs[5];
pF = blobs[6];
pO = blobs[7];
pI = pI.rowRange(i * pI.rows / numDirs, (i + 1) * pI.rows / numDirs);
pI = pI.colRange(i * pI.cols / numDirs, (i + 1) * pI.cols / numDirs);
pF = pF.rowRange(i * pF.rows / numDirs, (i + 1) * pF.rows / numDirs);
pF = pF.colRange(i * pF.cols / numDirs, (i + 1) * pF.cols / numDirs);
pO = pO.rowRange(i * pO.rows / numDirs, (i + 1) * pO.rows / numDirs);
pO = pO.colRange(i * pO.cols / numDirs, (i + 1) * pO.cols / numDirs);
}
int numOut = Wh.size[1];
Mat hInternal = internals[0], cInternal = internals[1],
@ -356,7 +406,12 @@ public:
Mat hOutTs = output[0].reshape(1, numSamplesTotal);
hOutTs = hOutTs.colRange(i * hOutTs.cols / numDirs, (i + 1) * hOutTs.cols / numDirs);
Mat cOutTs = produceCellOutput ? output[1].reshape(1, numSamplesTotal) : Mat();
Mat cOutTs;
if (produceCellOutput)
{
cOutTs = cOut.reshape(1, numSamplesTotal);
cOutTs = cOutTs.colRange(i * cOutTs.cols / numDirs, (i + 1) * cOutTs.cols / numDirs);
}
#if CV_TRY_AVX2 || CV_TRY_AVX
bool canUseAvx = gates.isContinuous() && bias.isContinuous()
@ -471,8 +526,8 @@ public:
if (usePeephole)
{
Mat gatesIF = gates.colRange(0, 2*numOut);
gemm(cInternal, blobs[5], 1, gateI, 1, gateI);
gemm(cInternal, blobs[6], 1, gateF, 1, gateF);
gemm(cInternal, pI, 1, gateI, 1, gateI);
gemm(cInternal, pF, 1, gateF, 1, gateF);
f_activation(gatesIF, gatesIF);
}
else
@ -495,7 +550,7 @@ public:
}
if (usePeephole)
{
gemm(cInternal, blobs[7], 1, gateO, 1, gateO);
gemm(cInternal, pO, 1, gateO, 1, gateO);
f_activation(gateO, gateO);
}
@ -509,6 +564,78 @@ public:
cInternal.copyTo(cOutTs.rowRange(curRowRange));
}
}
if (needYcTransform && produceCellOutput)
{
fixCellState(cOut, numDirs);
}
if (produceCellOutput)
{
cOut.copyTo(output[1]);
}
}
void fixCellState(Mat& cOut, int numDirs)
{
// seq, batch, dirs, hidden
int shp[] = {0, numSamples, numDirs, numHidden};
cOut = cOut.reshape(1, sizeof(shp)/sizeof(shp[0]), shp);
// permute to {0, 2, 1, 3};
std::vector<int> newShape = shape(cOut);
std::swap(newShape[1], newShape[2]);
cv::Mat newCellState(newShape, CV_32FC1);
const float* src = cOut.ptr<const float>();
float* dst = newCellState.ptr<float>();
size_t sj = newCellState.size[3];
size_t sk = newCellState.size[2] * sj;
size_t si = newCellState.size[1] * sk;
for (size_t i = 0; i < newCellState.size[0]; i++)
{
for (size_t j = 0; j < newCellState.size[2]; j++)
{
for (size_t k = 0; k < newCellState.size[1]; k++)
{
std::memcpy(dst, src, sizeof(float) * newCellState.size[3]);
src += cOut.size[3];
dst += sk;
}
dst = dst + sj - si;
}
dst = dst + si - sk;
}
cOut = newCellState;
if (numDirs == 1)
{
// Slice: Yh = Y[-1, :, :, :]
Range ranges[] = {cv::Range(cOut.size[0] - 1, cOut.size[0]), cv::Range::all(), cv::Range::all(), cv::Range::all()};
cOut = cOut(ranges);
// Reshape: 1x1xBxH -> 1xBxH
int shp[] = {1, numSamples, numHidden};
cOut = cOut.reshape(1, sizeof(shp)/sizeof(shp[0]), shp);
}
else
{
// Slice: SxDxBxH -> last sequence, first direction
Range ranges1[] = {cv::Range(cOut.size[0] - 1, cOut.size[0]), cv::Range(0, 1), cv::Range::all(), cv::Range::all()};
Mat part1 = cOut(ranges1);
// Slice: SxDxBxH -> first sequence, last direction
Range ranges2[] = {cv::Range(0, 1), cv::Range(cOut.size[1] - 1, cOut.size[1]), cv::Range::all(), cv::Range::all()};
Mat part2 = cOut(ranges2);
int shp[] = {1, part1.size[2] * part1.size[3]};
part1 = part1.reshape(1, sizeof(shp)/sizeof(shp[0]), shp);
part2 = part2.reshape(1, sizeof(shp)/sizeof(shp[0]), shp);
vconcat(part1, part2, cOut);
// Reshape: 1x2xBxH -> 2xBxH
int finalShape[] = {2, numSamples, numHidden};
cOut = cOut.reshape(1, sizeof(finalShape)/sizeof(finalShape[0]), finalShape);
}
}
};

@ -65,6 +65,14 @@ class ONNXImporter
void expandMid(const std::string& prefix, opencv_onnx::NodeProto& node_proto,
const std::string& input, size_t n);
void addNegation(const LayerParams& layerParams, opencv_onnx::NodeProto& node_proto, int input_id);
void lstm_extractConsts(LayerParams& layerParams, const opencv_onnx::NodeProto& lstm_proto, size_t idx, int* blobShape_, int size);
void lstm_add_reshape(const std::string& input_name, const std::string& output_name, int* layerShape, size_t n);
std::string lstm_add_slice(int index, const std::string& input_name, int* begin, int* end, size_t n);
std::string lstm_fix_dims(LayerParams& layerParams, const opencv_onnx::NodeProto& lstm_proto,
int batch_size, int num_directions, int hidden_size, bool need_y, const std::string& y_name,
const int index);
void lstm_add_transform(int num_directions, int batch_size, int hidden_size,
int index, const std::string& input_name, const std::string& output_name);
public:
ONNXImporter(Net& net, const char *onnxFile)
@ -1298,38 +1306,24 @@ void ONNXImporter::parseConstant(LayerParams& layerParams, const opencv_onnx::No
addConstant(node_proto.output(0), layerParams.blobs[0]);
}
void ONNXImporter::parseLSTM(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto_)
void transformBlobs(std::vector<Mat>& blobs)
{
opencv_onnx::NodeProto node_proto = node_proto_;
const std::string output_name = node_proto.output(0);
LayerParams lstmParams = layerParams;
lstmParams.name += "/lstm";
Mat Wx = blobs[0];
Mat Wh = blobs[1];
Mat b = blobs[2];
std::vector<Mat> cudaWorkaround;
cudaWorkaround.push_back(Wx.clone());
cudaWorkaround.push_back(Wh.clone());
cudaWorkaround.push_back(b.clone());
// https://pytorch.org/docs/stable/nn.html#lstm
CV_Assert(node_proto.input_size() >= 7);
Mat Wx = getBlob(node_proto, 1);
Mat Wh = getBlob(node_proto, 2);
Mat b = getBlob(node_proto, 3);
const int numHidden = lstmParams.get<int>("hidden_size");
const int numHidden = Wh.size[2];
const int numDirs = Wx.size[0]; // Is 1 for forward only and 2 for bidirectional LSTM.
const int numFeatures = Wx.size[2];
Mat h0, c0;
if (!node_proto.input(5).empty()) {
h0 = getBlob(node_proto, 5);
h0 = h0.reshape(1, h0.size[0] * h0.size[1]);
} else {
// initial_h attribute can be empty in case of keras2onnx producer. fill it with zeros
h0 = Mat::zeros(numDirs * numFeatures, numHidden, CV_32FC1);
}
if (!node_proto.input(6).empty()) {
c0 = getBlob(node_proto, 6);
c0 = c0.reshape(1, c0.size[0] * c0.size[1]);
} else {
// initial_c attribute can be empty in case of keras2onnx producer. fill it with zeros
c0 = Mat::zeros(numDirs * numFeatures, numHidden, CV_32FC1);
}
Mat h0 = blobs[3];
h0 = h0.reshape(1, h0.size[0] * h0.size[1]);
Mat c0 = blobs[4];
c0 = c0.reshape(1, c0.size[0] * c0.size[1]);
b = b.reshape(1, b.size[0]);
Mat bx = b.colRange(0, b.cols / 2);
@ -1360,31 +1354,245 @@ void ONNXImporter::parseLSTM(LayerParams& layerParams, const opencv_onnx::NodePr
Wx = Wx.reshape(1, Wx.size[0] * Wx.size[1]);
Wh = Wh.reshape(1, Wh.size[0] * Wh.size[1]);
blobs[0] = Wh;
blobs[1] = Wx;
blobs[2] = b.reshape(1, 1);
blobs[3] = h0;
blobs[4] = c0;
lstmParams.blobs.resize(5);
lstmParams.blobs[0] = Wh;
lstmParams.blobs[1] = Wx;
lstmParams.blobs[2] = b;
lstmParams.blobs[3] = h0;
lstmParams.blobs[4] = c0;
if (blobs.size() == 5) {
// so that future patch removing copies can leave all indexing as is
blobs.insert(blobs.begin(), cudaWorkaround.begin(), cudaWorkaround.end());
return;
}
// read direction attribute
lstmParams.set("reverse", lstmParams.get<String>("direction", "") == "reverse");
lstmParams.set("bidirectional", lstmParams.get<String>("direction", "") == "bidirectional");
Mat P = blobs[5];
blobs[5] = P.colRange(0, numHidden);
blobs[5] = blobs[5].clone().reshape(1, blobs[5].total()); // Single column.
blobs[5] = Mat::diag(blobs[5]);
node_proto.set_output(0, lstmParams.name); // set different name so output shapes will be registered on that name
addLayer(lstmParams, node_proto);
blobs.push_back(P.colRange(numHidden, 2 * numHidden));
blobs[6] = blobs[6].clone().reshape(1, blobs[6].total()); // Single column.
blobs[6] = Mat::diag(blobs[6]);
MatShape lstmShape = outShapes[node_proto.output(0)];
blobs.push_back(P.colRange(2 * numHidden, 3 * numHidden));
blobs[7] = blobs[7].clone().reshape(1, blobs[7].total()); // Single column.
blobs[7] = Mat::diag(blobs[7]);
// Add fake 1 as it is done in ONNX
lstmShape.insert(lstmShape.begin() + 1, 1);
// so that future patch removing copies can leave all indexing as is
blobs.insert(blobs.begin(), cudaWorkaround.begin(), cudaWorkaround.end());
}
layerParams.type = "Reshape";
layerParams.set("dim", DictValue::arrayInt(&lstmShape[0], lstmShape.size()));
node_proto.set_input(0, lstmParams.name); // redirect input to LSTM
node_proto.set_output(0, output_name); // keep origin LSTM's name
addLayer(layerParams, node_proto);
void ONNXImporter::lstm_extractConsts(LayerParams& layerParams, const opencv_onnx::NodeProto& lstm_proto, size_t idx, int* blobShape_, int size)
{
MatShape blobShape(blobShape_, blobShape_ + size);
Mat blob;
if (idx < lstm_proto.input_size() && !lstm_proto.input(idx).empty())
{
blob = getBlob(lstm_proto, idx);
CV_Assert(shape(blob) == blobShape);
}
else
{
blob = Mat(blobShape, CV_32FC1, 0.);
}
layerParams.blobs.push_back(blob);
};
void ONNXImporter::lstm_add_reshape(const std::string& input_name, const std::string& output_name, int* layerShape, size_t n)
{
LayerParams reshapeLp;
reshapeLp.name = cv::format("%s/reshape", input_name.c_str());
reshapeLp.type = "Reshape";
CV_Assert(layer_id.find(reshapeLp.name) == layer_id.end());
reshapeLp.set("dim", DictValue::arrayInt(layerShape, n));
opencv_onnx::NodeProto reshape_proto;
reshape_proto.add_input(input_name);
reshape_proto.add_output(output_name);
addLayer(reshapeLp, reshape_proto);
};
std::string ONNXImporter::lstm_add_slice(int index, const std::string& input_name, int* begin, int* end, size_t n)
{
LayerParams sliceLP;
sliceLP.name = cv::format("%s/slice_%d", input_name.c_str(), index);
sliceLP.type = "Slice";
CV_Assert(layer_id.find(sliceLP.name) == layer_id.end());
sliceLP.set("begin", DictValue::arrayInt(begin, n));
sliceLP.set("end", DictValue::arrayInt(end, n));
sliceLP.set("axis", 0);
opencv_onnx::NodeProto slice_proto;
slice_proto.add_input(input_name);
slice_proto.add_output(sliceLP.name);
addLayer(sliceLP, slice_proto);
return slice_proto.output(0);
};
std::string ONNXImporter::lstm_fix_dims(LayerParams& layerParams, const opencv_onnx::NodeProto& lstm_proto,
int batch_size, int num_directions, int hidden_size, bool need_y, const std::string& y_name,
const int index)
{
std::string reshape_output = cv::format("%s/reshape_%d", layerParams.name.c_str(), index);
// reshape from Seq, Batch, Dirs*Hidden to Seq, Batch, Dirs, Hidden
// to not confuse reshape with dynamic first dimension, zero means 'leave unchanged'
int layerShape[] = {0, batch_size, num_directions, hidden_size};
lstm_add_reshape(lstm_proto.output(index), reshape_output, layerShape, sizeof(layerShape) / sizeof(layerShape[0]));
// permute from Seq, Batch, Dirs, Hidden to Seq, Dirs, Batch, Hidden
LayerParams permuteLP;
permuteLP.name = reshape_output + "/permute";
permuteLP.type = "Permute";
CV_Assert(layer_id.find(permuteLP.name) == layer_id.end());
int order[] = {0, 2, 1, 3};
permuteLP.set("order", DictValue::arrayInt(order, 4));
opencv_onnx::NodeProto permute_proto;
permute_proto.add_input(reshape_output);
permute_proto.add_output((need_y && index == 0) ? y_name : static_cast<std::string>(permuteLP.name));
addLayer(permuteLP, permute_proto);
return permute_proto.output(0);
};
void ONNXImporter::lstm_add_transform(int num_directions, int batch_size, int hidden_size,
int index, const std::string& input_name, const std::string& output_name)
{
if (num_directions == 1)
{
// Slice: Yh = Y[-1, :, :, :]
int begin[] = {-1}, end[] = {INT_MAX};
std::string slice_output = lstm_add_slice(index, input_name, begin, end, sizeof(begin) / sizeof(begin[0]));
// Reshape: 1x1xBxH -> 1xBxH
int layerShape[] = {1, batch_size, hidden_size};
lstm_add_reshape(slice_output, output_name, layerShape, sizeof(layerShape) / sizeof(layerShape[0]));
}
else
{
// Slice: SxDxBxH -> last sequence, first direction
int begin0[] = {-1, 0}, end0[] = {INT_MAX, 1};
std::string slice_0 = lstm_add_slice(0, input_name, begin0, end0, sizeof(begin0) / sizeof(begin0[0]));
// Slice: SxDxBxH -> first sequence, last direction
int begin1[] = {0, -1}, end1[] = {1, INT_MAX};
std::string slice_1 = lstm_add_slice(1, input_name, begin1, end1, sizeof(begin1) / sizeof(begin1[0]));
LayerParams concatLP;
concatLP.name = cv::format("%s/concat", input_name.c_str());
concatLP.type = "Concat";
CV_Assert(layer_id.find(concatLP.name) == layer_id.end());
concatLP.set("axis", 1); // 1x1xBxH -> 1x2xBxH
opencv_onnx::NodeProto concat_proto;
concat_proto.add_input(slice_0);
concat_proto.add_input(slice_1);
concat_proto.add_output(concatLP.name);
addLayer(concatLP, concat_proto);
// Reshape: 1x2xBxH -> 2xBxH
int layerShape[] = {2, batch_size, hidden_size};
lstm_add_reshape(concat_proto.output(0), output_name, layerShape, sizeof(layerShape) / sizeof(layerShape[0]));
}
};
void ONNXImporter::parseLSTM(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto_)
{
opencv_onnx::NodeProto lstm_proto = node_proto_;
layerParams.name += "/lstm";
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#LSTM
CV_Assert(lstm_proto.input_size() >= 3);
for (size_t i = 1; i < 3; ++i)
{
const std::string& name = lstm_proto.input(i);
CV_Assert(!name.empty() && constBlobs.count(name) == 1);
}
IterShape_t shapeIt = outShapes.find(lstm_proto.input(0));
CV_Assert(shapeIt != outShapes.end());
const MatShape x_shape = shapeIt->second;
const int seq_length = x_shape[0];
const int batch_size = x_shape[1];
const int input_size = x_shape[2];
const int hidden_size = layerParams.get<int>("hidden_size");
const int num_directions = constBlobs[lstm_proto.input(1)].size[0];
int w_size[] = {num_directions, 4*hidden_size, input_size};
lstm_extractConsts(layerParams, lstm_proto, 1, w_size, sizeof(w_size) / sizeof(w_size[0])); // W
int r_size[] = {num_directions, 4*hidden_size, hidden_size};
lstm_extractConsts(layerParams, lstm_proto, 2, r_size, sizeof(r_size) / sizeof(r_size[0])); // R
int b_size[] = {num_directions, 8*hidden_size};
lstm_extractConsts(layerParams, lstm_proto, 3, b_size, sizeof(b_size) / sizeof(b_size[0])); // B
if (4 < lstm_proto.input_size() && !lstm_proto.input(4).empty())
{
Mat blob = getBlob(lstm_proto, 4);
CV_Assert(blob.total() == batch_size);
for (MatIterator_<int32_t> it = blob.begin<int32_t>(); it != blob.end<int32_t>(); ++it)
{
CV_Assert(*it == seq_length);
}
}
int h_size[] = {num_directions, batch_size, hidden_size};
lstm_extractConsts(layerParams, lstm_proto, 5, h_size, sizeof(h_size) / sizeof(h_size[0])); // initial_h
int c_size[] = {num_directions, batch_size, hidden_size};
lstm_extractConsts(layerParams, lstm_proto, 6, c_size, sizeof(c_size) / sizeof(c_size[0])); // initial_c
if (lstm_proto.input_size() > 7 && !lstm_proto.input(7).empty())
{
layerParams.set("use_peephole", true);
int p_size[] = {num_directions, 3 * hidden_size};
lstm_extractConsts(layerParams, lstm_proto, 7, p_size, sizeof(p_size) / sizeof(p_size[0])); // P
}
transformBlobs(layerParams.blobs);
layerParams.set("is_onnx", true);
layerParams.set("reverse", layerParams.get<String>("direction", "") == "reverse");
layerParams.set("bidirectional", layerParams.get<String>("direction", "") == "bidirectional");
bool need_yc = lstm_proto.output_size() > 2 && !lstm_proto.output(2).empty();
bool need_yh = lstm_proto.output_size() > 1 && !lstm_proto.output(1).empty();
bool need_y = lstm_proto.output_size() > 0 && !lstm_proto.output(0).empty();
const std::string y_name = need_y ? lstm_proto.output(0) : "";
const std::string yh_name = need_yh ? lstm_proto.output(1) : "";
const std::string yc_name = need_yc ? lstm_proto.output(2) : "";
layerParams.set("produce_cell_output", need_yc);
lstm_proto.clear_output();
if (need_y || need_yh)
{
// give random names to LSTMLayer's outputs because every output needs postprocessing
lstm_proto.add_output(cv::format("%s_y", layerParams.name.c_str()));
}
if (need_yc)
{
lstm_proto.add_output(yc_name);
}
addLayer(layerParams, lstm_proto);
std::string y_output = lstm_fix_dims(layerParams, lstm_proto, batch_size, num_directions, hidden_size, need_y,
y_name, 0);
if (need_yh)
{
lstm_add_transform(num_directions, batch_size, hidden_size, 0, y_output, yh_name);
}
}
void ONNXImporter::parseImageScaler(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)

@ -765,12 +765,14 @@ TEST_P(Test_ONNX_layers, LSTM_Activations)
testONNXModels("lstm_cntk_tanh", pb, 0, 0, false, false);
}
TEST_P(Test_ONNX_layers, LSTM)
// disabled due to poor handling of 1-d mats
TEST_P(Test_ONNX_layers, DISABLED_LSTM)
{
testONNXModels("lstm", npy, 0, 0, false, false);
}
TEST_P(Test_ONNX_layers, LSTM_bidirectional)
// disabled due to poor handling of 1-d mats
TEST_P(Test_ONNX_layers, DISABLED_LSTM_bidirectional)
{
testONNXModels("lstm_bidirectional", npy, 0, 0, false, false);
}
@ -785,6 +787,13 @@ TEST_P(Test_ONNX_layers, LSTM_hidden_bidirectional)
testONNXModels("hidden_lstm_bi", npy, 0, 0, false, false);
}
TEST_P(Test_ONNX_layers, LSTM_cell)
{
testONNXModels("lstm_cell_forward", npy, 0, 0, false, false);
testONNXModels("lstm_cell_bidirectional", npy, 0, 0, false, false);
testONNXModels("lstm_cell_with_peepholes", npy, 0, 0, false, false);
}
TEST_P(Test_ONNX_layers, Pad2d_Unfused)
{
testONNXModels("ReflectionPad2d");

Loading…
Cancel
Save