From 55e1dfb778b806272262871c3060ae32a4ff20d0 Mon Sep 17 00:00:00 2001 From: SamFC10 Date: Sun, 20 Jun 2021 13:19:29 +0530 Subject: [PATCH] Fix BatchNorm reinitialization --- modules/dnn/src/layers/batch_norm_layer.cpp | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/modules/dnn/src/layers/batch_norm_layer.cpp b/modules/dnn/src/layers/batch_norm_layer.cpp index 27c3db6c44..42676c7938 100644 --- a/modules/dnn/src/layers/batch_norm_layer.cpp +++ b/modules/dnn/src/layers/batch_norm_layer.cpp @@ -29,6 +29,7 @@ namespace dnn class BatchNormLayerImpl CV_FINAL : public BatchNormLayer { public: + Mat origin_weights, origin_bias; Mat weights_, bias_; UMat umat_weight, umat_bias; mutable int dims; @@ -82,11 +83,11 @@ public: const float* weightsData = hasWeights ? blobs[weightsBlobIndex].ptr() : 0; const float* biasData = hasBias ? blobs[biasBlobIndex].ptr() : 0; - weights_.create(1, (int)n, CV_32F); - bias_.create(1, (int)n, CV_32F); + origin_weights.create(1, (int)n, CV_32F); + origin_bias.create(1, (int)n, CV_32F); - float* dstWeightsData = weights_.ptr(); - float* dstBiasData = bias_.ptr(); + float* dstWeightsData = origin_weights.ptr(); + float* dstBiasData = origin_bias.ptr(); for (size_t i = 0; i < n; ++i) { @@ -94,15 +95,12 @@ public: dstWeightsData[i] = w; dstBiasData[i] = (hasBias ? biasData[i] : 0.0f) - w * meanData[i] * varMeanScale; } - // We will use blobs to store origin weights and bias to restore them in case of reinitialization. - weights_.copyTo(blobs[0].reshape(1, 1)); - bias_.copyTo(blobs[1].reshape(1, 1)); } virtual void finalize(InputArrayOfArrays, OutputArrayOfArrays) CV_OVERRIDE { - blobs[0].reshape(1, 1).copyTo(weights_); - blobs[1].reshape(1, 1).copyTo(bias_); + origin_weights.reshape(1, 1).copyTo(weights_); + origin_bias.reshape(1, 1).copyTo(bias_); } void getScaleShift(Mat& scale, Mat& shift) const CV_OVERRIDE