diff --git a/modules/dnn/include/opencv2/dnn/all_layers.hpp b/modules/dnn/include/opencv2/dnn/all_layers.hpp index 37af0ddea5..ae1d909660 100644 --- a/modules/dnn/include/opencv2/dnn/all_layers.hpp +++ b/modules/dnn/include/opencv2/dnn/all_layers.hpp @@ -1085,6 +1085,16 @@ CV__DNN_INLINE_NS_BEGIN static Ptr create(const LayerParams& params); }; + class CV_EXPORTS LayerNormLayer : public Layer + { + public: + bool hasBias; + int axis; + float epsilon; + + static Ptr create(const LayerParams& params); + }; + //! @} //! @} CV__DNN_INLINE_NS_END diff --git a/modules/dnn/perf/perf_layer.cpp b/modules/dnn/perf/perf_layer.cpp index ffe0240a18..38e35f1258 100644 --- a/modules/dnn/perf/perf_layer.cpp +++ b/modules/dnn/perf/perf_layer.cpp @@ -417,6 +417,212 @@ PERF_TEST_P_(Layer_ScatterND, DISABLED_ScatterND_add) test_layer({N, C, H , W}, "add"); } +struct Layer_LayerNorm : public TestBaseWithParam > +{ + void test_layer(const std::vector& x_shape) + { + int backendId = get<0>(GetParam()); + int targetId = get<1>(GetParam()); + + Mat x(x_shape, CV_32FC1); + Mat scale(x_shape.back(), 1, CV_32FC1); + Mat b(x_shape.back(), 1, CV_32FC1); + + randu(x, 0.f, 1.f); + randu(scale, 0.f, 1.f); + randu(b, 0.f, 1.f); + + + Net net; + LayerParams lp; + lp.type = "LayerNormalization"; + lp.name = "testLayer"; + lp.set("axis", 2); + lp.set("hasBias", true); + int id = net.addLayerToPrev(lp.name, lp.type, lp); + net.connect(0, 0, id, 0); + net.connect(0, 1, id, 1); + net.connect(0, 2, id, 2); + + // warmup + { + std::vector inpNames(3); + inpNames[0] = "x"; + inpNames[1] = "scale"; + inpNames[2] = "b"; + net.setInputsNames(inpNames); + net.setInput(x, inpNames[0]); + net.setInput(scale, inpNames[1]); + net.setInput(b, inpNames[2]); + + net.setPreferableBackend(backendId); + net.setPreferableTarget(targetId); + Mat out = net.forward(); + } + + TEST_CYCLE() + { + Mat res = net.forward(); + } + + SANITY_CHECK_NOTHING(); + } + + int N = 1; + int H = 50; + int W = 768; +}; + +PERF_TEST_P_(Layer_LayerNorm, LayerNorm) +{ + test_layer({N, H ,W}); +} + +struct Layer_LayerNormExpanded : public TestBaseWithParam > +{ + void test_layer(const std::vector& x_shape) + { + int backendId = get<0>(GetParam()); + int targetId = get<1>(GetParam()); + + Mat x(x_shape, CV_32FC1); + Mat scale(1, x_shape.back(), CV_32FC1); // transpose to pass shape check + Mat b(1, x_shape.back(), CV_32FC1); // transpose to pass shape check + + randu(x, 0.f, 1.f); + randu(scale, 0.f, 1.f); + randu(b, 0.f, 1.f); + + // sub graph structure: + // -> ReduceMean -> -> Pow(2) -> ReduceMean -> Add(epsilon) -> Sqrt -> + // x Sub Div -> Mul(scale) -> Add(bias) + // ---------------> -------------------------------------------------> + + Net net; + + LayerParams lp_rm; + lp_rm.type = "Reduce"; + lp_rm.name = "reducemean1"; + lp_rm.set("reduce", "AVE"); + std::vector deleteDims(1, x_shape.back()); + lp_rm.set("deleted_dims", DictValue::arrayInt(&deleteDims[0], deleteDims.size())); + std::vector targetDims(x_shape.begin(), x_shape.end()); + targetDims[x_shape.size() - 1] = 1; + lp_rm.set("target_dims", DictValue::arrayInt(&targetDims[0], targetDims.size())); + int id_rm = net.addLayerToPrev(lp_rm.name, lp_rm.type, lp_rm); + net.connect(0, 0, id_rm, 0); + + LayerParams lp_sub; + lp_sub.type = "NaryEltwise"; + lp_sub.name = "sub1"; + lp_sub.set("operation", "sub"); + int id_sub = net.addLayer(lp_sub.name, lp_sub.type, lp_sub); + net.connect(0, 0, id_sub, 0); + net.connect(id_rm, 0, id_sub, 1); + + Mat pow_const(1, 1, CV_32FC1); + pow_const.at(0) = 2.f; + LayerParams lp_pow_const; + lp_pow_const.type = "Const"; + lp_pow_const.name = "const1"; + lp_pow_const.blobs.push_back(pow_const); + int id_pow_const = net.addLayer(lp_pow_const.name, lp_pow_const.type, lp_pow_const); + LayerParams lp_pow; + lp_pow.type = "NaryEltwise"; + lp_pow.name = "pow1"; + lp_pow.set("operation", "pow"); + int id_pow = net.addLayer(lp_pow.name, lp_pow.type, lp_pow); + net.connect(id_sub, 0, id_pow, 0); + net.connect(id_pow_const, 0, id_pow, 1); + + LayerParams lp_rm1; + lp_rm1.type = "Reduce"; + lp_rm1.name = "reducemean2"; + lp_rm1.set("reduce", "AVE"); + lp_rm1.set("deleted_dims", DictValue::arrayInt(&deleteDims[0], deleteDims.size())); + lp_rm1.set("target_dims", DictValue::arrayInt(&targetDims[0], targetDims.size())); + int id_rm1 = net.addLayer(lp_rm1.name, lp_rm1.type, lp_rm1); + net.connect(id_pow, 0, id_rm1, 0); + + Mat add_const(1, 1, CV_32F); + add_const.at(0) = 1e-5; + LayerParams lp_add_const; + lp_add_const.type = "Const"; + lp_add_const.name = "const2"; + lp_add_const.blobs.push_back(add_const); + int id_add_const = net.addLayer(lp_add_const.name, lp_add_const.type, lp_add_const); + LayerParams lp_add; + lp_add.type = "NaryEltwise"; + lp_add.name = "add1"; + lp_add.set("operation", "add"); + int id_add = net.addLayer(lp_add.name, lp_add.type, lp_add); + net.connect(id_rm1, 0, id_add, 0); + net.connect(id_add_const, 0, id_add, 1); + + LayerParams lp_sqrt; + lp_sqrt.type = "Sqrt"; + lp_sqrt.name = "sqrt1"; + int id_sqrt = net.addLayer(lp_sqrt.name, lp_sqrt.type, lp_sqrt); + net.connect(id_add, 0, id_sqrt, 0); + + LayerParams lp_div; + lp_div.type = "NaryEltwise"; + lp_div.name = "div1"; + lp_div.set("operation", "div"); + int id_div = net.addLayer(lp_div.name, lp_div.type, lp_div); + net.connect(id_sub, 0, id_div, 0); + net.connect(id_sqrt, 0, id_div, 1); + + LayerParams lp_mul; + lp_mul.type = "NaryEltwise"; + lp_mul.name = "mul1"; + lp_mul.set("operation", "mul"); + int id_mul = net.addLayer(lp_mul.name, lp_mul.type, lp_mul); + net.connect(id_div, 0, id_mul, 0); + net.connect(0, 1, id_mul, 1); + + LayerParams lp_add1; + lp_add1.type = "NaryEltwise"; + lp_add1.name = "add2"; + lp_add1.set("operation", "add"); + int id_add1 = net.addLayer(lp_add1.name, lp_add1.type, lp_add1); + net.connect(id_mul, 0, id_add1, 0); + net.connect(0, 2, id_add1, 1); + + // warmup + { + std::vector inpNames(3); + inpNames[0] = "x"; + inpNames[1] = "scale"; + inpNames[2] = "b"; + net.setInputsNames(inpNames); + net.setInput(x, inpNames[0]); + net.setInput(scale, inpNames[1]); + net.setInput(b, inpNames[2]); + + net.setPreferableBackend(backendId); + net.setPreferableTarget(targetId); + Mat out = net.forward(); + } + + TEST_CYCLE() + { + Mat res = net.forward(); + } + + SANITY_CHECK_NOTHING(); + } + + int N = 1; + int H = 50; + int W = 768; +}; + +PERF_TEST_P_(Layer_LayerNormExpanded, DISABLED_LayerNormExpanded) +{ + test_layer({N, H ,W}); +} + INSTANTIATE_TEST_CASE_P(/**/, Layer_Slice, dnnBackendsAndTargets(false, false)); INSTANTIATE_TEST_CASE_P(/**/, Layer_NaryEltwise, testing::Values(std::make_tuple(DNN_BACKEND_OPENCV, DNN_TARGET_CPU))); #ifdef HAVE_CUDA @@ -424,5 +630,7 @@ INSTANTIATE_TEST_CASE_P(CUDA, Layer_NaryEltwise, testing::Values(std::make_tuple #endif INSTANTIATE_TEST_CASE_P(/**/, Layer_Scatter, testing::Values(std::make_tuple(DNN_BACKEND_OPENCV, DNN_TARGET_CPU))); INSTANTIATE_TEST_CASE_P(/**/, Layer_ScatterND, testing::Values(std::make_tuple(DNN_BACKEND_OPENCV, DNN_TARGET_CPU))); +INSTANTIATE_TEST_CASE_P(/**/, Layer_LayerNorm, testing::Values(std::make_tuple(DNN_BACKEND_OPENCV, DNN_TARGET_CPU))); +INSTANTIATE_TEST_CASE_P(/**/, Layer_LayerNormExpanded, testing::Values(std::make_tuple(DNN_BACKEND_OPENCV, DNN_TARGET_CPU))); } // namespace diff --git a/modules/dnn/src/init.cpp b/modules/dnn/src/init.cpp index d20f9dff8d..72eca9ed4e 100644 --- a/modules/dnn/src/init.cpp +++ b/modules/dnn/src/init.cpp @@ -154,6 +154,7 @@ void initializeLayerFactory() CV_DNN_REGISTER_LAYER_CLASS(Arg, ArgLayer); CV_DNN_REGISTER_LAYER_CLASS(Reciprocal, ReciprocalLayer); CV_DNN_REGISTER_LAYER_CLASS(Gather, GatherLayer); + CV_DNN_REGISTER_LAYER_CLASS(LayerNormalization, LayerNormLayer); CV_DNN_REGISTER_LAYER_CLASS(Crop, CropLayer); CV_DNN_REGISTER_LAYER_CLASS(Eltwise, EltwiseLayer); diff --git a/modules/dnn/src/layers/layer_norm.cpp b/modules/dnn/src/layers/layer_norm.cpp new file mode 100644 index 0000000000..a760766a3f --- /dev/null +++ b/modules/dnn/src/layers/layer_norm.cpp @@ -0,0 +1,176 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#include "../precomp.hpp" +#include "layers_common.hpp" + +namespace cv { namespace dnn { + +class LayerNormLayerImpl CV_FINAL : public LayerNormLayer +{ +public: + LayerNormLayerImpl(const LayerParams& params) + { + setParamsFrom(params); + + // standard attr + axis = params.get("axis", 0); + epsilon = params.get("epsilon", 1e-5); + + // opencv attr + hasBias = params.get("hasBias", false); + } + + virtual bool supportBackend(int backendId) CV_OVERRIDE + { + return backendId == DNN_BACKEND_OPENCV; + } + + virtual bool getMemoryShapes(const std::vector &inputs, + const int requiredOutputs, + std::vector &outputs, + std::vector &internals) const CV_OVERRIDE + { + // check shapes of weight and bias if existed + // inputs >= 2 (X and Weight are requested, bias is optional) + CV_Check(inputs.size(), inputs.size() >= 2 && inputs.size() <= 3, "LayerNorm: require two (x, weight) or three (x, weight, bias) inputs"); + + auto x_shape = inputs[0]; + int x_ndims = static_cast(x_shape.size()); + + auto w_shape = inputs[1]; + // if axis == last_dim, scale and b are both 1d tensor (represented as 2d mat nx1) + int w_ndims = static_cast(w_shape.size()); + w_ndims = (axis == x_ndims - 1 && w_ndims == 2) ? w_ndims - 1 : w_ndims; + CV_CheckEQ(x_ndims - axis, w_ndims, "LayerNorm: shape of weight does not match with given axis and shape of input"); + for (int i = 0; i < w_ndims; ++i) + CV_CheckEQ(x_shape[axis+i], w_shape[i], "LayerNorm: weight dimensions does not match with input dimensions"); + if (hasBias) + { + CV_CheckEQ(inputs.size(), (size_t)3, ""); + auto b_shape = inputs[2]; + CV_CheckEQ(w_shape.size(), b_shape.size(), "LayerNorm: shape of weight does not match with shape of bias"); + for (size_t i = 0; i < w_shape.size(); ++i) + CV_CheckEQ(w_shape[i], b_shape[i], "LayerNorm: bias dimensions does not match with weight dimensions"); + } + + // only one output is needed; Mean & InvStdDev are not needed + // in inference and should beomitted in onnx importer + outputs.assign(1, inputs[0]); + return false; + } + + template + class LayerNormInvoker : public ParallelLoopBody + { + public: + const Mat& src; + const float* scaleData; + const float* biasData; + Mat& dst; + + float epsilon; + + int total; + int normSize; + float invNormSize; + + LayerNormInvoker(const Mat& src_, const Mat& scale, const Mat* b, Mat& dst_, int axis, float epsilon_) + : src(src_), scaleData(scale.ptr()), biasData(nullptr), dst(dst_), epsilon(epsilon_) + { + if (hasBias) + { + CV_Assert(b != nullptr); + CV_Assert(b->isContinuous()); + biasData = (const float*)b->ptr(); + } + + auto dstShape = shape(dst); + total = std::accumulate(dstShape.begin(), dstShape.begin() + axis, 1, std::multiplies()); + normSize = std::accumulate(dstShape.begin() + axis, dstShape.end(), 1, std::multiplies()); + invNormSize = 1.0f / normSize; + } + + static void run(const Mat& src, const Mat& scale, const Mat* b, Mat& dst, int axis, float epsilon) + { + CV_Assert(src.isContinuous()); + CV_Assert(dst.isContinuous()); + CV_CheckTypeEQ(src.type(), CV_32F, "DNN/LayerNorm: only support float32"); + CV_CheckTypeEQ(src.type(), dst.type(), ""); + CV_Assert(scale.isContinuous()); + + CV_CheckGE(epsilon, 0.0f, ""); + + LayerNormInvoker p(src, scale, b, dst, axis, epsilon); + + double nstripes = ((size_t)p.total * p.normSize) * (1 / 1024.0); + // double nstripes = ((size_t)p.total) * (1 / 1024.0); + parallel_for_(Range(0, p.total), p, nstripes); + } + + void operator()(const Range& r) const CV_OVERRIDE + { + int stripeStart = r.start; + int stripeEnd = r.end; + + const float* srcData = src.ptr(); + float* dstData = dst.ptr(); + + for (int ofs = stripeStart; ofs < stripeEnd; ++ofs) + { + const float* first = srcData + ofs * normSize; + float* dstFirst = dstData + ofs * normSize; + + float mean = 0; + float meanSquare = 0; + for (int h = 0; h < normSize; ++h) + { + float v = first[h]; + mean += v; + meanSquare += v * v; + } + mean *= invNormSize; + meanSquare = std::sqrt(std::max(0.f, meanSquare * invNormSize - mean * mean) + epsilon); + float invMeanSquare = 1.0f / meanSquare; + for (int h = 0; h < normSize; ++h) + { + float v = (first[h] - mean) * invMeanSquare * scaleData[h]; + if (hasBias) { + v = v + biasData[h]; + } + dstFirst[h] = v; + } + } + } + }; + + void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE + { + CV_TRACE_FUNCTION(); + CV_TRACE_ARG_VALUE(name, "name", name.c_str()); + + if (inputs_arr.depth() == CV_16S) + { + forward_fallback(inputs_arr, outputs_arr, internals_arr); + return; + } + + std::vector inputs, outputs; + inputs_arr.getMatVector(inputs); + outputs_arr.getMatVector(outputs); + + if (hasBias) { + LayerNormInvoker::run(inputs[0], inputs[1], &inputs[2], outputs[0], axis, epsilon); + } else { + LayerNormInvoker::run(inputs[0], inputs[1], nullptr, outputs[0], axis, epsilon); + } + } +}; + +Ptr LayerNormLayer::create(const LayerParams& params) +{ + return makePtr(params); +} + +}} // cv::dnn diff --git a/modules/dnn/src/onnx/onnx_graph_simplifier.cpp b/modules/dnn/src/onnx/onnx_graph_simplifier.cpp index 730c08b25c..c977a4761d 100644 --- a/modules/dnn/src/onnx/onnx_graph_simplifier.cpp +++ b/modules/dnn/src/onnx/onnx_graph_simplifier.cpp @@ -75,6 +75,28 @@ public: return makePtr(node); } + int getInputInitializerId(int node_id, int node_input_id) + { + auto node = getNode(node_id); + std::string node_input_name = node->getInputName(node_input_id); + for (int i = 0; i < numInitializers; ++i) + if (net.initializer(i).name() == node_input_name) + return i; + CV_Error(Error::StsParseError, "Initializer with name " + node_input_name + " not found"); + } + + Mat getMatFromInitializer(int idx) + { + const opencv_onnx::TensorProto& tensor_proto = net.initializer(idx); + return getMatFromTensor(tensor_proto); + } + + std::string getNameOfInitializer(int idx) const + { + const opencv_onnx::TensorProto& tensor_proto = net.initializer(idx); + return tensor_proto.name(); + } + virtual int getNumNodes() const CV_OVERRIDE { return numInputs + numInitializers + net.node_size(); @@ -110,6 +132,142 @@ private: opencv_onnx::GraphProto& net; }; +class LayerNormSubGraph : public Subgraph +{ +public: + LayerNormSubGraph() : axis(-1), epsilon(1e-5) + { + // -> ReduceMean -> -> Pow(2) -> ReduceMean -> Add(epsilon) -> Sqrt -> + // x Sub Div -> Mul(scale) -> Add(bias) + // ---------------> -------------------------------------------------> + // NOTE: Pow(2), Add(epsilon), Mul(scale), add(bias) can have constants as op_type Constant or Initializer + int input = addNodeToMatch(""); + int mean = addNodeToMatch("ReduceMean", input); + + int sub = addNodeToMatch("Sub", input, mean); + + int pow = addNodeToMatch("Pow", sub, addNodeToMatch("")); + int mean1 = addNodeToMatch("ReduceMean", pow); + int add = addNodeToMatch("Add", mean1, addNodeToMatch("")); + int sqrt = addNodeToMatch("Sqrt", add); + + int div = addNodeToMatch("Div", sub, sqrt); + int mul = addNodeToMatch("Mul", div, addNodeToMatch("")); + addNodeToMatch("Add", mul, addNodeToMatch("")); + + setFusedNode("LayerNormalization", input); + } + + static bool isWithInitializer(const std::vector& matchedNodesIds) + { + // if node.getType() is Constant, Constant nodes are placed between other nodes + if (matchedNodesIds[2] - matchedNodesIds[1] != 1) + return false; + // if Initializer, there is no nodes for constant between other nodes + return true; + } + + static float extractConstant(const Ptr& net, int node_id, int input_id, bool withInitializer) + { + if (withInitializer) + { + auto onnx_net = net.dynamicCast(); + int initializer_id = onnx_net->getInputInitializerId(node_id, input_id); + Mat const_mat = onnx_net->getMatFromInitializer(initializer_id); + return *const_mat.ptr(); + } else { + const Ptr node = net->getNode(node_id); + int constant_id = getInputNodeId(net, node, input_id); + Ptr constant_ptr = net->getNode(constant_id); + opencv_onnx::NodeProto* constant_node = constant_ptr.dynamicCast()->node; + opencv_onnx::TensorProto constant_proto = constant_node->attribute(0).t(); + Mat constant_mat = getMatFromTensor(constant_proto); + return *constant_mat.ptr(); + } + } + + static float extractAxis(const Ptr& net, int node_id) + { + Ptr mean_ptr = net->getNode(node_id); + opencv_onnx::NodeProto* mean_node = mean_ptr.dynamicCast()->node; + int axis_ = -1; + for (int i = 0; i < mean_node->attribute_size(); i++) + { + opencv_onnx::AttributeProto attr = mean_node->attribute(i); + if (attr.name() != "axes") + continue; + axis_ = static_cast(attr.ints(0)); + } + return axis_; + } + + static std::string getInputName(const Ptr& net, int node_id, int input_id, bool withInitializer) + { + if (withInitializer) + { + auto onnx_net = net.dynamicCast(); + int initializer_id = onnx_net->getInputInitializerId(node_id, input_id); + return onnx_net->getNameOfInitializer(initializer_id); + } else { + const auto node = net->getNode(node_id); + return node->getInputName(input_id); + } + } + + virtual bool match(const Ptr& net, int nodeId, + std::vector& matchedNodesIds, + std::vector& targetNodesIds) CV_OVERRIDE + { + if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds)) + { + withInitializer = isWithInitializer(matchedNodesIds); + + float pow_exp = extractConstant(net, matchedNodesIds[2], 1, withInitializer); + if (pow_exp - 2 > 1e-5) // not pow(2) + return false; + + int axis_mean1 = extractAxis(net, matchedNodesIds[0]); + int axis_mean2 = extractAxis(net, matchedNodesIds[3]); + if (axis_mean1 != axis_mean2) + return false; + axis = axis_mean1; + + epsilon = extractConstant(net, matchedNodesIds[4], 1, withInitializer); + + weight_name = getInputName(net, matchedNodesIds[7], 1, withInitializer); + bias_name = getInputName(net, matchedNodesIds[8], 1, withInitializer); + + return true; + } + return false; + } + + virtual void finalize(const Ptr&, + const Ptr& fusedNode, + std::vector >&) CV_OVERRIDE + { + opencv_onnx::NodeProto* node = fusedNode.dynamicCast()->node; + // axis + opencv_onnx::AttributeProto* attr_axis = node->add_attribute(); + attr_axis->set_name("axis"); + attr_axis->set_i(axis); + // epsilon + opencv_onnx::AttributeProto* attr_epsilon = node->add_attribute(); + attr_epsilon->set_name("epsilon"); + attr_epsilon->set_f(epsilon); + // add input + node->add_input(weight_name); + node->add_input(bias_name); + } + +protected: + int axis; + float epsilon; + bool withInitializer; + std::string weight_name; + std::string bias_name; +}; + class SoftMaxSubgraphBase : public Subgraph { public: @@ -746,6 +904,7 @@ public: void simplifySubgraphs(opencv_onnx::GraphProto& net) { std::vector > subgraphs; + subgraphs.push_back(makePtr()); subgraphs.push_back(makePtr()); subgraphs.push_back(makePtr()); subgraphs.push_back(makePtr()); diff --git a/modules/dnn/src/onnx/onnx_importer.cpp b/modules/dnn/src/onnx/onnx_importer.cpp index fe4d4660f3..6fd40d0d16 100644 --- a/modules/dnn/src/onnx/onnx_importer.cpp +++ b/modules/dnn/src/onnx/onnx_importer.cpp @@ -190,6 +190,7 @@ private: void parseRange (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto); void parseScatter (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto); void parseTile (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto); + void parseLayerNorm (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto); void parseSimpleLayers (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto); // Domain: com.microsoft @@ -3285,6 +3286,56 @@ void ONNXImporter::parseTile(LayerParams& layerParams, const opencv_onnx::NodePr } } +void ONNXImporter::parseLayerNorm(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto) +{ + // validate axis and convert it if negative + auto inputDims = static_cast(outShapes[node_proto.input(0)].size()); + int axis = layerParams.get("axis", -1); + // axis: [-dims, dims) + CV_CheckGE(axis, -inputDims, "DNN/ONNXImporter: axis of LayerNormalization is out of range"); + CV_CheckLT(axis, inputDims, "DNN/ONNXImporter: axis of LayerNormalization is out of range"); + axis = (axis + inputDims) % inputDims; + layerParams.set("axis", axis); + + // check if bias existed + bool hasBias = false; + if (node_proto.input_size() > 2) + hasBias = true; + layerParams.set("hasBias", hasBias); + + // constants as constant inputs + for (size_t i = 1; i < node_proto.input_size(); i++) + { + if (layer_id.find(node_proto.input(i)) == layer_id.end()) + { + Mat blob = getBlob(node_proto, i); + + LayerParams constParams; + constParams.name = node_proto.input(i); + constParams.type = "Const"; + constParams.blobs.push_back(blob); + + opencv_onnx::NodeProto proto; + proto.add_output(constParams.name); + addLayer(constParams, proto); + } + } + + // Remove additional outputs (Mean, InvStdDev) + if (node_proto.output_size() > 1) + { + auto outputName = node_proto.output(0); + opencv_onnx::NodeProto node_proto_ = node_proto; + node_proto_.clear_output(); + node_proto_.add_output(outputName); + addLayer(layerParams, node_proto_); + } + else + { + addLayer(layerParams, node_proto); + } +} + void ONNXImporter::parseSimpleLayers(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto) { bool is_all_input_const = true; @@ -3987,6 +4038,7 @@ void ONNXImporter::buildDispatchMap_ONNX_AI(int opset_version) dispatch["SpaceToDepth"] = dispatch["DepthToSpace"] = &ONNXImporter::parseDepthToSpace; dispatch["ScatterElements"] = dispatch["Scatter"] = dispatch["ScatterND"] = &ONNXImporter::parseScatter; dispatch["Tile"] = &ONNXImporter::parseTile; + dispatch["LayerNormalization"] = &ONNXImporter::parseLayerNorm; dispatch["Equal"] = dispatch["Greater"] = dispatch["Less"] = dispatch["Pow"] = dispatch["Add"] = dispatch["Sub"] = dispatch["Mul"] = dispatch["Div"] = dispatch["GreaterOrEqual"] = diff --git a/modules/dnn/test/test_onnx_importer.cpp b/modules/dnn/test/test_onnx_importer.cpp index 0ffa252c71..ad2a3849d6 100644 --- a/modules/dnn/test/test_onnx_importer.cpp +++ b/modules/dnn/test/test_onnx_importer.cpp @@ -2416,6 +2416,36 @@ TEST_P(Test_ONNX_layers, Tile) testONNXModels("tile", pb); } +TEST_P(Test_ONNX_layers, LayerNorm) +{ + testONNXModels("test_layer_normalization_2d_axis0", pb, 0, 0, false, true, 3); + testONNXModels("test_layer_normalization_2d_axis1", pb, 0, 0, false, true, 3); + testONNXModels("test_layer_normalization_2d_axis_negative_1", pb, 0, 0, false, true, 3); + testONNXModels("test_layer_normalization_2d_axis_negative_2", pb, 0, 0, false, true, 3); + testONNXModels("test_layer_normalization_3d_axis0_epsilon", pb, 0, 0, false, true, 3); + testONNXModels("test_layer_normalization_3d_axis1_epsilon", pb, 0, 0, false, true, 3); + testONNXModels("test_layer_normalization_3d_axis2_epsilon", pb, 0, 0, false, true, 3); + testONNXModels("test_layer_normalization_3d_axis_negative_1_epsilon", pb, 0, 0, false, true, 3); + testONNXModels("test_layer_normalization_3d_axis_negative_2_epsilon", pb, 0, 0, false, true, 3); + testONNXModels("test_layer_normalization_3d_axis_negative_3_epsilon", pb, 0, 0, false, true, 3); + testONNXModels("test_layer_normalization_4d_axis0", pb, 0, 0, false, true, 3); + testONNXModels("test_layer_normalization_4d_axis1", pb, 0, 0, false, true, 3); + testONNXModels("test_layer_normalization_4d_axis2", pb, 0, 0, false, true, 3); + testONNXModels("test_layer_normalization_4d_axis3", pb, 0, 0, false, true, 3); + testONNXModels("test_layer_normalization_4d_axis_negative_1", pb, 0, 0, false, true, 3); + testONNXModels("test_layer_normalization_4d_axis_negative_2", pb, 0, 0, false, true, 3); + testONNXModels("test_layer_normalization_4d_axis_negative_3", pb, 0, 0, false, true, 3); + testONNXModels("test_layer_normalization_4d_axis_negative_4", pb, 0, 0, false, true, 3); + testONNXModels("test_layer_normalization_default_axis", pb, 0, 0, false, true, 3); +} + +// for testing graph simplification +TEST_P(Test_ONNX_layers, LayerNormExpanded) +{ + testONNXModels("layer_norm_expanded"); + testONNXModels("layer_norm_expanded_with_initializers"); +} + INSTANTIATE_TEST_CASE_P(/**/, Test_ONNX_nets, dnnBackendsAndTargets()); }} // namespace