@ -244,6 +244,10 @@ static DictValue parse(const ::google::protobuf::RepeatedField< ::google::protob |
return DictValue::arrayInt(&dst[0], src.size()); |
} |
static DictValue parseStr(const ::google::protobuf::RepeatedPtrField< ::std::string>& src) { |
return DictValue::arrayString(src.begin(), static_cast<int>(src.size())); |
} |
LayerParams ONNXImporter::getLayerParams(const opencv_onnx::NodeProto& node_proto) |
{ |
LayerParams lp; |
@ -301,6 +305,10 @@ LayerParams ONNXImporter::getLayerParams(const opencv_onnx::NodeProto& node_prot |
CV_Assert(attribute_proto.ints_size() == 1 || attribute_proto.ints_size() == 2 || attribute_proto.ints_size() == 3); |
lp.set("dilation", parse(attribute_proto.ints())); |
} |
else if(attribute_name == "activations" && node_proto.op_type() == "LSTM") |
{ |
lp.set(attribute_name, parseStr(attribute_proto.strings())); |
} |
else if (attribute_proto.has_i()) |
{ |
::google::protobuf::int64 src = attribute_proto.i(); |
@ -997,18 +1005,32 @@ void ONNXImporter::parseLSTM(LayerParams& layerParams, const opencv_onnx::NodePr |
lstmParams.name += "/lstm"; |
// https://pytorch.org/docs/stable/nn.html#lstm
CV_Assert(node_proto.input_size() == 7); |
CV_Assert(node_proto.input_size() >= 7); |
Mat Wx = getBlob(node_proto, 1); |
Mat Wh = getBlob(node_proto, 2); |
Mat b = getBlob(node_proto, 3); |
Mat h0 = getBlob(node_proto, 5); |
Mat c0 = getBlob(node_proto, 6); |
b = b.reshape(1, b.size[0]); |
const int numHidden = lstmParams.get<int>("hidden_size"); |
const int numDirs = Wx.size[0]; // Is 1 for forward only and 2 for bidirectional LSTM.
const int numFeatures = Wx.size[2]; |
Mat h0, c0; |
if (!node_proto.input(5).empty()) { |
h0 = getBlob(node_proto, 5); |
h0 = h0.reshape(1, h0.size[0] * h0.size[1]); |
} else { |
// initial_h attribute can be empty in case of keras2onnx producer. fill it with zeros
h0 = Mat::zeros(numDirs * numFeatures, numHidden, CV_32FC1); |
} |
if (!node_proto.input(6).empty()) { |
c0 = getBlob(node_proto, 6); |
c0 = c0.reshape(1, c0.size[0] * c0.size[1]); |
} else { |
// initial_c attribute can be empty in case of keras2onnx producer. fill it with zeros
c0 = Mat::zeros(numDirs * numFeatures, numHidden, CV_32FC1); |
} |
b = b.reshape(1, b.size[0]); |
Mat bx = b.colRange(0, b.cols / 2); |
Mat bh = b.colRange(b.cols / 2, b.cols); |
b = bx + bh; |
@ -1036,8 +1058,7 @@ void ONNXImporter::parseLSTM(LayerParams& layerParams, const opencv_onnx::NodePr |
} |
Wx = Wx.reshape(1, Wx.size[0] * Wx.size[1]); |
Wh = Wh.reshape(1, Wh.size[0] * Wh.size[1]); |
h0 = h0.reshape(1, h0.size[0] * h0.size[1]); |
c0 = c0.reshape(1, c0.size[0] * c0.size[1]); |
lstmParams.blobs.resize(5); |
lstmParams.blobs[0] = Wh; |
@ -1045,6 +1066,9 @@ void ONNXImporter::parseLSTM(LayerParams& layerParams, const opencv_onnx::NodePr |
lstmParams.blobs[2] = b; |
lstmParams.blobs[3] = h0; |
lstmParams.blobs[4] = c0; |
// read direction attribute
lstmParams.set("reverse", lstmParams.get<String>("direction", "") == "reverse"); |
lstmParams.set("bidirectional", lstmParams.get<String>("direction", "") == "bidirectional"); |
node_proto.set_output(0, lstmParams.name); // set different name so output shapes will be registered on that name