|
|
|
@ -244,6 +244,10 @@ static DictValue parse(const ::google::protobuf::RepeatedField< ::google::protob |
|
|
|
|
return DictValue::arrayInt(&dst[0], src.size()); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
static DictValue parseStr(const ::google::protobuf::RepeatedPtrField< ::std::string>& src) { |
|
|
|
|
return DictValue::arrayString(src.begin(), static_cast<int>(src.size())); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
LayerParams ONNXImporter::getLayerParams(const opencv_onnx::NodeProto& node_proto) |
|
|
|
|
{ |
|
|
|
|
LayerParams lp; |
|
|
|
@ -301,6 +305,10 @@ LayerParams ONNXImporter::getLayerParams(const opencv_onnx::NodeProto& node_prot |
|
|
|
|
CV_Assert(attribute_proto.ints_size() == 1 || attribute_proto.ints_size() == 2 || attribute_proto.ints_size() == 3); |
|
|
|
|
lp.set("dilation", parse(attribute_proto.ints())); |
|
|
|
|
} |
|
|
|
|
else if(attribute_name == "activations" && node_proto.op_type() == "LSTM") |
|
|
|
|
{ |
|
|
|
|
lp.set(attribute_name, parseStr(attribute_proto.strings())); |
|
|
|
|
} |
|
|
|
|
else if (attribute_proto.has_i()) |
|
|
|
|
{ |
|
|
|
|
::google::protobuf::int64 src = attribute_proto.i(); |
|
|
|
@ -997,18 +1005,32 @@ void ONNXImporter::parseLSTM(LayerParams& layerParams, const opencv_onnx::NodePr |
|
|
|
|
lstmParams.name += "/lstm"; |
|
|
|
|
|
|
|
|
|
// https://pytorch.org/docs/stable/nn.html#lstm
|
|
|
|
|
CV_Assert(node_proto.input_size() == 7); |
|
|
|
|
CV_Assert(node_proto.input_size() >= 7); |
|
|
|
|
Mat Wx = getBlob(node_proto, 1); |
|
|
|
|
Mat Wh = getBlob(node_proto, 2); |
|
|
|
|
Mat b = getBlob(node_proto, 3); |
|
|
|
|
Mat h0 = getBlob(node_proto, 5); |
|
|
|
|
Mat c0 = getBlob(node_proto, 6); |
|
|
|
|
|
|
|
|
|
b = b.reshape(1, b.size[0]); |
|
|
|
|
|
|
|
|
|
const int numHidden = lstmParams.get<int>("hidden_size"); |
|
|
|
|
const int numDirs = Wx.size[0]; // Is 1 for forward only and 2 for bidirectional LSTM.
|
|
|
|
|
const int numFeatures = Wx.size[2]; |
|
|
|
|
|
|
|
|
|
Mat h0, c0; |
|
|
|
|
if (!node_proto.input(5).empty()) { |
|
|
|
|
h0 = getBlob(node_proto, 5); |
|
|
|
|
h0 = h0.reshape(1, h0.size[0] * h0.size[1]); |
|
|
|
|
} else { |
|
|
|
|
// initial_h attribute can be empty in case of keras2onnx producer. fill it with zeros
|
|
|
|
|
h0 = Mat::zeros(numDirs * numFeatures, numHidden, CV_32FC1); |
|
|
|
|
} |
|
|
|
|
if (!node_proto.input(6).empty()) { |
|
|
|
|
c0 = getBlob(node_proto, 6); |
|
|
|
|
c0 = c0.reshape(1, c0.size[0] * c0.size[1]); |
|
|
|
|
} else { |
|
|
|
|
// initial_c attribute can be empty in case of keras2onnx producer. fill it with zeros
|
|
|
|
|
c0 = Mat::zeros(numDirs * numFeatures, numHidden, CV_32FC1); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
b = b.reshape(1, b.size[0]); |
|
|
|
|
Mat bx = b.colRange(0, b.cols / 2); |
|
|
|
|
Mat bh = b.colRange(b.cols / 2, b.cols); |
|
|
|
|
b = bx + bh; |
|
|
|
@ -1036,8 +1058,7 @@ void ONNXImporter::parseLSTM(LayerParams& layerParams, const opencv_onnx::NodePr |
|
|
|
|
} |
|
|
|
|
Wx = Wx.reshape(1, Wx.size[0] * Wx.size[1]); |
|
|
|
|
Wh = Wh.reshape(1, Wh.size[0] * Wh.size[1]); |
|
|
|
|
h0 = h0.reshape(1, h0.size[0] * h0.size[1]); |
|
|
|
|
c0 = c0.reshape(1, c0.size[0] * c0.size[1]); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lstmParams.blobs.resize(5); |
|
|
|
|
lstmParams.blobs[0] = Wh; |
|
|
|
@ -1045,6 +1066,9 @@ void ONNXImporter::parseLSTM(LayerParams& layerParams, const opencv_onnx::NodePr |
|
|
|
|
lstmParams.blobs[2] = b; |
|
|
|
|
lstmParams.blobs[3] = h0; |
|
|
|
|
lstmParams.blobs[4] = c0; |
|
|
|
|
|
|
|
|
|
// read direction attribute
|
|
|
|
|
lstmParams.set("reverse", lstmParams.get<String>("direction", "") == "reverse"); |
|
|
|
|
lstmParams.set("bidirectional", lstmParams.get<String>("direction", "") == "bidirectional"); |
|
|
|
|
|
|
|
|
|
node_proto.set_output(0, lstmParams.name); // set different name so output shapes will be registered on that name
|
|
|
|
|