diff --git a/modules/dnn/src/layers/batch_norm_layer.cpp b/modules/dnn/src/layers/batch_norm_layer.cpp index d42face4ec..3b472328c8 100644 --- a/modules/dnn/src/layers/batch_norm_layer.cpp +++ b/modules/dnn/src/layers/batch_norm_layer.cpp @@ -96,6 +96,46 @@ public: shift = bias_; } + virtual bool tryFuse(Ptr& top) CV_OVERRIDE + { + Mat w, b; + top->getScaleShift(w, b); + if (w.empty() && b.empty()) + return false; + + const int numChannels = weights_.total(); + const int numFusedWeights = w.total(); + const int numFusedBias = b.total(); + + if ((numFusedWeights != numChannels && numFusedWeights != 1 && !w.empty()) || + (numFusedBias != numChannels && numFusedBias != 1 && !b.empty())) + return false; + + if (!w.empty()) + { + w = w.reshape(1, 1); + if (numFusedWeights == 1) + { + multiply(weights_, w.at(0), weights_); + multiply(bias_, w.at(0), bias_); + } + else + { + multiply(weights_, w, weights_); + multiply(bias_, w, bias_); + } + } + if (!b.empty()) + { + b = b.reshape(1, 1); + if (numFusedBias == 1) + add(bias_, b.at(0), bias_); + else + add(bias_, b.reshape(1, 1), bias_); + } + return true; + } + bool getMemoryShapes(const std::vector &inputs, const int requiredOutputs, std::vector &outputs,