Merge pull request #21372 from zihaomu:dnn_quantize_per_tensor

Add per_tensor_quantize to int8 quantize

* add per_tensor_quantize to dnn int8 module.

* change api flag from perTensor to perChannel, and recognize quantize type and onnx importer.

* change the default to hpp
pull/22216/head
Zihao Mu 2 years ago committed by GitHub
parent 16b5fd4bf2
commit a80fcacd90
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 8
      modules/dnn/include/opencv2/dnn/all_layers.hpp
  2. 4
      modules/dnn/include/opencv2/dnn/dnn.hpp
  3. 1
      modules/dnn/src/int8layers/convolution_layer.cpp
  4. 2
      modules/dnn/src/int8layers/fully_connected_layer.cpp
  5. 34
      modules/dnn/src/layers/convolution_layer.cpp
  6. 36
      modules/dnn/src/layers/fully_connected_layer.cpp
  7. 11
      modules/dnn/src/layers/layers_common.cpp
  8. 5
      modules/dnn/src/layers/layers_common.hpp
  9. 4
      modules/dnn/src/net.cpp
  10. 2
      modules/dnn/src/net_impl.hpp
  11. 6
      modules/dnn/src/net_quantization.cpp
  12. 4
      modules/dnn/src/onnx/onnx_importer.cpp
  13. 87
      modules/dnn/test/test_int8_layers.cpp

@ -263,6 +263,10 @@ CV__DNN_INLINE_NS_BEGIN
public:
int input_zp, output_zp;
float input_sc, output_sc;
// quantization type flag. The perChannel default is true, that means it contains the parameters
// of per-Channel quantization. Otherwise, that means this layer contains per-Tensor quantized parameters.
bool per_channel;
static Ptr<BaseConvolutionLayer> create(const LayerParams& params);
};
@ -368,6 +372,10 @@ CV__DNN_INLINE_NS_BEGIN
public:
int input_zp, output_zp;
float input_sc, output_sc;
// quantization type flag. The perChannel default is true, that means it contains the parameters
// of per-Channel quantization. Otherwise, that means this layer contains per-Tensor quantized parameters.
bool per_channel;
static Ptr<InnerProductLayerInt8> create(const LayerParams& params);
};

@ -621,8 +621,10 @@ CV__DNN_INLINE_NS_BEGIN
* @param calibData Calibration data to compute the quantization parameters.
* @param inputsDtype Datatype of quantized net's inputs. Can be CV_32F or CV_8S.
* @param outputsDtype Datatype of quantized net's outputs. Can be CV_32F or CV_8S.
* @param perChannel Quantization granularity of quantized Net. The default is true, that means quantize model
* in per-channel way (channel-wise). Set it false to quantize model in per-tensor way (or tensor-wise).
*/
CV_WRAP Net quantize(InputArrayOfArrays calibData, int inputsDtype, int outputsDtype);
CV_WRAP Net quantize(InputArrayOfArrays calibData, int inputsDtype, int outputsDtype, bool perChannel=true);
/** @brief Returns input scale and zeropoint for a quantized Net.
* @param scales output parameter for returning input scales.

@ -51,6 +51,7 @@ public:
input_zp = params.get<int>("input_zeropoint");
output_zp = params.get<int>("zeropoints");
output_sc = params.get<float>("scales");
per_channel = params.get<bool>("per_channel", true);
if (kernel_size.size() == 2) {
kernel = Size(kernel_size[1], kernel_size[0]);

@ -26,6 +26,8 @@ public:
output_zp = params.get<int>("zeropoints");
output_sc = params.get<float>("scales");
axis = params.get<int>("axis", 1);
per_channel = params.get<bool>("per_channel", true);
if (blobs.size() == 3)
{
// blobs[0] - Weights

@ -2226,26 +2226,36 @@ public:
Mat weightsQuantized(weightsMat.rows, weightsMat.cols, CV_8S);
Mat biasQuantized(1, numOutput, CV_32S);
Mat outputMultiplier(1, numOutput, CV_32F);
double realMin, realMax, weightsScale;
bool perChannel = params.get<bool>("per_channel", true);
for( int i = 0; i < numOutput; i++ )
if (perChannel) // per-Channel quantization.
{
// Quantize weights
cv::minMaxIdx(weightsMat.row(i), &realMin, &realMax);
realMin = std::min(realMin, 0.0);
realMax = std::max(realMax, 0.0);
weightsScale = (realMax == realMin) ? 1.0 : std::max(-realMin, realMax)/127;
weightsMat.row(i).convertTo(weightsQuantized.row(i), CV_8S, 1.f/weightsScale);
for (int i = 0; i < numOutput; i++)
{
double weightsScale = getWeightScale(weightsMat.row(i));
// Quantize biases
weightsMat.row(i).convertTo(weightsQuantized.row(i), CV_8S, 1.f/weightsScale);
float biasScale = inputScale * weightsScale;
biasQuantized.at<int>(i) = cvRound(biasvec[i]/biasScale) - inputZp*(cv::sum(weightsQuantized.row(i))[0]);
outputMultiplier.at<float>(i) = biasScale / outputScale;
}
}
else // per-Tensor quantization.
{
double weightsScale = getWeightScale(weightsMat);
weightsMat.convertTo(weightsQuantized, CV_8S, 1.f/weightsScale);
float biasScale = inputScale * weightsScale;
biasQuantized.at<int>(i) = (int)std::round(biasvec[i]/biasScale) - inputZp*(cv::sum(weightsQuantized.row(i))[0]);
// Store multiplier
outputMultiplier.at<float>(i) = biasScale / outputScale;
for (int i = 0; i < numOutput; i++)
{
biasQuantized.at<int>(i) = cvRound(biasvec[i]/biasScale) - inputZp*(cv::sum(weightsQuantized.row(i))[0]);
outputMultiplier.at<float>(i) = biasScale / outputScale;
}
}
params.blobs.clear();
params.set("per_channel", perChannel);
params.blobs.push_back(weightsQuantized.reshape(1, shape(blobs[0])));
params.blobs.push_back(biasQuantized);
params.blobs.push_back(outputMultiplier);

@ -619,26 +619,36 @@ public:
Mat weightsQuantized(weightsMat.rows, weightsMat.cols, CV_8S);
Mat biasQuantized(1, numOutput, CV_32S);
Mat outputMultiplier(1, numOutput, CV_32F);
bool perChannel = params.get<bool>("per_channel", true);
double realMin, realMax, weightsScale;
for( int i = 0; i < numOutput; i++ )
if (perChannel) // per-Channel quantization.
{
// Quantize weights
cv::minMaxIdx(weightsMat.row(i), &realMin, &realMax);
realMin = std::min(realMin, 0.0);
realMax = std::max(realMax, 0.0);
weightsScale = (realMax == realMin) ? 1.0 : std::max(-realMin, realMax)/127;
weightsMat.row(i).convertTo(weightsQuantized.row(i), CV_8S, 1.f/weightsScale);
// Quantize biases
for (int i = 0; i < numOutput; i++)
{
double weightsScale = getWeightScale(weightsMat.row(i));
weightsMat.row(i).convertTo(weightsQuantized.row(i), CV_8S, 1.f/weightsScale);
float biasScale = inputScale * weightsScale;
biasQuantized.at<int>(i) = cvRound(biasMat.at<float>(i)/biasScale) - inputZp*(cv::sum(weightsQuantized.row(i))[0]);
outputMultiplier.at<float>(i) = biasScale / outputScale;
}
}
else // per-Tensor quantization.
{
double weightsScale = getWeightScale(weightsMat);
weightsMat.convertTo(weightsQuantized, CV_8S, 1.f/weightsScale);
float biasScale = inputScale * weightsScale;
biasQuantized.at<int>(i) = (int)std::round(biasMat.at<float>(i)/biasScale) - inputZp*(cv::sum(weightsQuantized.row(i))[0]);
// Store multiplier
outputMultiplier.at<float>(i) = biasScale / outputScale;
for (int i = 0; i < numOutput; i++)
{
biasQuantized.at<int>(i) = cvRound(biasMat.at<float>(i)/biasScale) - inputZp*(cv::sum(weightsQuantized.row(i))[0]);
outputMultiplier.at<float>(i) = biasScale / outputScale;
}
}
params.blobs.clear();
params.set("per_channel", perChannel);
params.blobs.push_back(weightsQuantized.reshape(1, shape(blobs[0])));
params.blobs.push_back(biasQuantized);
params.blobs.push_back(outputMultiplier);

@ -250,5 +250,16 @@ void getConvPoolPaddings(const std::vector<int>& inp, const std::vector<size_t>&
}
}
double getWeightScale(const Mat& weightsMat)
{
double realMin, realMax;
cv::minMaxIdx(weightsMat, &realMin, &realMax);
realMin = std::min(realMin, 0.0);
realMax = std::max(realMax, 0.0);
return (realMax == realMin) ? 1.0 : std::max(-realMin, realMax)/127;
}
}
}

@ -70,9 +70,12 @@ void getConvPoolOutParams(const std::vector<int>& inp, const std::vector<size_t>
const std::vector<size_t>& stride, const String &padMode,
const std::vector<size_t>& dilation, std::vector<int>& out);
void getConvPoolPaddings(const std::vector<int>& inp, const std::vector<size_t>& kernel,
void getConvPoolPaddings(const std::vector<int>& inp, const std::vector<size_t>& kernel,
const std::vector<size_t>& strides, const String &padMode,
std::vector<size_t>& pads_begin, std::vector<size_t>& pads_end);
// Used in quantized model. It will return the (Max_element - Min_element)/127.
double getWeightScale(const Mat& weightsMat);
}
}

@ -115,12 +115,12 @@ void Net::forward(std::vector<std::vector<Mat>>& outputBlobs,
}
// FIXIT drop from inference API
Net Net::quantize(InputArrayOfArrays calibData, int inputsDtype, int outputsDtype)
Net Net::quantize(InputArrayOfArrays calibData, int inputsDtype, int outputsDtype, bool perChannel)
{
CV_TRACE_FUNCTION();
CV_Assert(impl);
CV_Assert(!empty());
return impl->quantize(calibData, inputsDtype, outputsDtype);
return impl->quantize(calibData, inputsDtype, outputsDtype, perChannel);
}
// FIXIT drop from inference API

@ -258,7 +258,7 @@ struct Net::Impl : public detail::NetImplBase
void dumpNetworkToFile() const;
// FIXIT drop from inference API
Net quantize(InputArrayOfArrays calibData, int inputsDtype, int outputsDtype) /*const*/;
Net quantize(InputArrayOfArrays calibData, int inputsDtype, int outputsDtype, bool perChannel) /*const*/;
void getInputDetails(std::vector<float>& scales, std::vector<int>& zeropoints) /*const*/;
void getOutputDetails(std::vector<float>& scales, std::vector<int>& zeropoints) /*const*/;

@ -33,7 +33,7 @@ void getQuantizationParams(const Mat& src, std::vector<float>& scales, std::vect
}
// FIXIT drop from inference API
Net Net::Impl::quantize(InputArrayOfArrays calibData, int inputsDtype, int outputsDtype)
Net Net::Impl::quantize(InputArrayOfArrays calibData, int inputsDtype, int outputsDtype, bool perChannel)
{
// Net can be quantized only once.
if (netWasQuantized)
@ -192,6 +192,10 @@ Net Net::Impl::quantize(InputArrayOfArrays calibData, int inputsDtype, int outpu
inp_out_sc[1] = scales[ld.id];
inp_out_zp[1] = zeropoints[ld.id];
// Set the quantization type, per-tensor quantize or per-channel quantize.
// Especially for Convolution layer and Fully connection layer.
ld.params.set("per_channel", perChannel);
// Quantize layer
Ptr<Layer> layer = ld.layerInstance;
if (layer->tryQuantize(inp_out_sc, inp_out_zp, ld.params))

@ -3401,6 +3401,7 @@ void ONNXImporter::parseQConv(LayerParams& layerParams, const opencv_onnx::NodeP
int outCn = weights.size[0];
Mat w_scale = getBlob(node_proto, 4);
CV_Assert(w_scale.total() == 1 || w_scale.total() == outCn);
bool per_channel = w_scale.total() == outCn ? true : false;
Mat wt_sc = (w_scale.total() == outCn) ? w_scale : Mat(1, outCn, CV_32F, Scalar(w_scale.at<float>(0)));
Mat out_sc = getBlob(node_proto, 6);
@ -3419,6 +3420,7 @@ void ONNXImporter::parseQConv(LayerParams& layerParams, const opencv_onnx::NodeP
layerParams.set("num_output", outCn);
layerParams.set("input_zeropoint", inp_zp.at<int8_t>(0));
layerParams.set("input_scale",inp_sc.at<float>(0));
layerParams.set("per_channel", per_channel);
layerParams.blobs.push_back(weights);
layerParams.blobs.push_back(biasFused);
layerParams.blobs.push_back(outputMultiplier);
@ -3444,6 +3446,7 @@ void ONNXImporter::parseQMatMul(LayerParams& layerParams, const opencv_onnx::Nod
Mat w_scale = getBlob(node_proto, 4);
CV_Assert(w_scale.total() == 1 || w_scale.total() == outCn);
bool per_channel = w_scale.total() == outCn ? true : false;
Mat wt_sc = (w_scale.total() == outCn) ? w_scale : Mat(1, outCn, CV_32F, Scalar(w_scale.at<float>(0)));
Mat out_sc = getBlob(node_proto, 6);
@ -3460,6 +3463,7 @@ void ONNXImporter::parseQMatMul(LayerParams& layerParams, const opencv_onnx::Nod
layerParams.set("axis", firstInpDims - secondInpDims + 1);
layerParams.set("input_scale", inp_sc.at<float>(0));
layerParams.set("input_zeropoint", inp_zp.at<int8_t>(0));
layerParams.set("per_channel", per_channel);
layerParams.blobs.push_back(weights);
layerParams.blobs.push_back(bias);

@ -29,7 +29,7 @@ class Test_Int8_layers : public DNNTestLayer
public:
void testLayer(const String& basename, const String& importer, double l1, double lInf,
int numInps = 1, int numOuts = 1, bool useCaffeModel = false,
bool useCommonInputBlob = true, bool hasText = false)
bool useCommonInputBlob = true, bool hasText = false, bool perChannel = true)
{
CV_Assert_N(numInps >= 1, numInps <= 10, numOuts >= 1, numOuts <= 10);
std::vector<Mat> inps(numInps), inps_int8(numInps);
@ -75,7 +75,7 @@ public:
for (int i = 0; i < numOuts; i++)
refs[i] = blobFromNPY(outPath + ((numOuts > 1) ? cv::format("_%d.npy", i) : ".npy"));
qnet = net.quantize(inps, CV_8S, CV_8S);
qnet = net.quantize(inps, CV_8S, CV_8S, perChannel);
qnet.getInputDetails(inputScale, inputZp);
qnet.getOutputDetails(outputScale, outputZp);
@ -103,6 +103,12 @@ TEST_P(Test_Int8_layers, Convolution1D)
{
testLayer("conv1d", "ONNX", 0.00302, 0.00909);
testLayer("conv1d_bias", "ONNX", 0.00306, 0.00948);
{
SCOPED_TRACE("Per-tensor quantize");
testLayer("conv1d", "ONNX", 0.00302, 0.00909, 1, 1, false, true, false, false);
testLayer("conv1d_bias", "ONNX", 0.00319, 0.00948, 1, 1, false, true, false, false);
}
}
TEST_P(Test_Int8_layers, Convolution2D)
@ -130,6 +136,18 @@ TEST_P(Test_Int8_layers, Convolution2D)
applyTestTag(CV_TEST_TAG_DNN_SKIP_TIMVX);
testLayer("layer_convolution", "Caffe", 0.0174, 0.0758, 1, 1, true);
testLayer("depthwise_conv2d", "TensorFlow", 0.0388, 0.169);
{
SCOPED_TRACE("Per-tensor quantize");
testLayer("single_conv", "TensorFlow", 0.00413, 0.02301, 1, 1, false, true, false, false);
testLayer("atrous_conv2d_valid", "TensorFlow", 0.027967, 0.07808, 1, 1, false, true, false, false);
testLayer("atrous_conv2d_same", "TensorFlow", 0.01945, 0.1322, 1, 1, false, true, false, false);
testLayer("keras_atrous_conv2d_same", "TensorFlow", 0.005677, 0.03327, 1, 1, false, true, false, false);
testLayer("convolution", "ONNX", 0.00538, 0.01517, 1, 1, false, true, false, false);
testLayer("two_convolution", "ONNX", 0.00295, 0.00926, 1, 1, false, true, false, false);
testLayer("layer_convolution", "Caffe", 0.0175, 0.0759, 1, 1, true, true, false, false);
testLayer("depthwise_conv2d", "TensorFlow", 0.041847, 0.18744, 1, 1, false, true, false, false);
}
}
TEST_P(Test_Int8_layers, Convolution3D)
@ -144,6 +162,13 @@ TEST_P(Test_Int8_layers, Flatten)
testLayer("flatten", "TensorFlow", 0.0036, 0.0069, 1, 1, false, true, true);
testLayer("unfused_flatten", "TensorFlow", 0.0014, 0.0028);
testLayer("unfused_flatten_unknown_batch", "TensorFlow", 0.0043, 0.0051);
{
SCOPED_TRACE("Per-tensor quantize");
testLayer("conv3d", "TensorFlow", 0.00734, 0.02434, 1, 1, false, true, false, false);
testLayer("conv3d", "ONNX", 0.00377, 0.01362, 1, 1, false, true, false, false);
testLayer("conv3d_bias", "ONNX", 0.00201, 0.0039, 1, 1, false, true, false, false);
}
}
TEST_P(Test_Int8_layers, Padding)
@ -349,6 +374,20 @@ TEST_P(Test_Int8_layers, InnerProduct)
testLayer("constant", "ONNX", 0.00021, 0.0006);
testLayer("lin_with_constant", "ONNX", 0.0011, 0.0016);
{
SCOPED_TRACE("Per-tensor quantize");
testLayer("layer_inner_product", "Caffe", 0.0055, 0.02, 1, 1, true, true, false, false);
testLayer("matmul", "TensorFlow", 0.0075, 0.019, 1, 1, false, true, false, false);
testLayer("nhwc_transpose_reshape_matmul", "TensorFlow", 0.0009, 0.0091, 1, 1, false, true, false, false);
testLayer("nhwc_reshape_matmul", "TensorFlow", 0.037, 0.071, 1, 1, false, true, false, false);
testLayer("matmul_layout", "TensorFlow", 0.035, 0.095, 1, 1, false, true, false, false);
testLayer("tf2_dense", "TensorFlow", 0, 0, 1, 1, false, true, false, false);
testLayer("matmul_add", "ONNX", 0.041, 0.082, 1, 1, false, true, false, false);
testLayer("linear", "ONNX", 0.0022, 0.004, 1, 1, false, true, false, false);
testLayer("constant", "ONNX", 0.00038, 0.0012, 1, 1, false, true, false, false);
testLayer("lin_with_constant", "ONNX", 0.0011, 0.0016, 1, 1, false, true, false, false);
}
}
TEST_P(Test_Int8_layers, Reshape)
@ -465,9 +504,9 @@ INSTANTIATE_TEST_CASE_P(/**/, Test_Int8_layers, dnnBackendsAndTargetsInt8());
class Test_Int8_nets : public DNNTestLayer
{
public:
void testClassificationNet(Net baseNet, const Mat& blob, const Mat& ref, double l1, double lInf)
void testClassificationNet(Net baseNet, const Mat& blob, const Mat& ref, double l1, double lInf, bool perChannel = true)
{
Net qnet = baseNet.quantize(blob, CV_32F, CV_32F);
Net qnet = baseNet.quantize(blob, CV_32F, CV_32F, perChannel);
qnet.setPreferableBackend(backend);
qnet.setPreferableTarget(target);
@ -477,9 +516,9 @@ public:
}
void testDetectionNet(Net baseNet, const Mat& blob, const Mat& ref,
double confThreshold, double scoreDiff, double iouDiff)
double confThreshold, double scoreDiff, double iouDiff, bool perChannel = true)
{
Net qnet = baseNet.quantize(blob, CV_32F, CV_32F);
Net qnet = baseNet.quantize(blob, CV_32F, CV_32F, perChannel);
qnet.setPreferableBackend(backend);
qnet.setPreferableTarget(target);
@ -488,14 +527,14 @@ public:
normAssertDetections(ref, out, "", confThreshold, scoreDiff, iouDiff);
}
void testFaster(Net baseNet, const Mat& ref, double confThreshold, double scoreDiff, double iouDiff)
void testFaster(Net baseNet, const Mat& ref, double confThreshold, double scoreDiff, double iouDiff, bool perChannel = true)
{
Mat inp = imread(_tf("dog416.png"));
resize(inp, inp, Size(800, 600));
Mat blob = blobFromImage(inp, 1.0, Size(), Scalar(102.9801, 115.9465, 122.7717), false, false);
Mat imInfo = (Mat_<float>(1, 3) << inp.rows, inp.cols, 1.6f);
Net qnet = baseNet.quantize(std::vector<Mat>{blob, imInfo}, CV_32F, CV_32F);
Net qnet = baseNet.quantize(std::vector<Mat>{blob, imInfo}, CV_32F, CV_32F, perChannel);
qnet.setPreferableBackend(backend);
qnet.setPreferableTarget(target);
@ -505,7 +544,7 @@ public:
normAssertDetections(ref, out, "", confThreshold, scoreDiff, iouDiff);
}
void testONNXNet(const String& basename, double l1, double lInf, bool useSoftmax = false)
void testONNXNet(const String& basename, double l1, double lInf, bool useSoftmax = false, bool perChannel = true)
{
String onnxmodel = findDataFile("dnn/onnx/models/" + basename + ".onnx", false);
@ -515,7 +554,7 @@ public:
baseNet.setPreferableBackend(backend);
baseNet.setPreferableTarget(target);
Net qnet = baseNet.quantize(blob, CV_32F, CV_32F);
Net qnet = baseNet.quantize(blob, CV_32F, CV_32F, perChannel);
qnet.setInput(blob);
Mat out = qnet.forward();
@ -538,7 +577,7 @@ public:
void testDarknetModel(const std::string& cfg, const std::string& weights,
const cv::Mat& ref, double scoreDiff, double iouDiff,
float confThreshold = 0.24, float nmsThreshold = 0.4)
float confThreshold = 0.24, float nmsThreshold = 0.4, bool perChannel = true)
{
CV_Assert(ref.cols == 7);
std::vector<std::vector<int> > refClassIds;
@ -578,7 +617,7 @@ public:
Mat inp = blobFromImages(samples, 1.0/255, Size(416, 416), Scalar(), true, false);
Net baseNet = readNetFromDarknet(findDataFile("dnn/" + cfg), findDataFile("dnn/" + weights, false));
Net qnet = baseNet.quantize(inp, CV_32F, CV_32F);
Net qnet = baseNet.quantize(inp, CV_32F, CV_32F, perChannel);
qnet.setPreferableBackend(backend);
qnet.setPreferableTarget(target);
qnet.setInput(inp);
@ -720,6 +759,11 @@ TEST_P(Test_Int8_nets, ResNet50)
float l1 = 3e-4, lInf = 0.05;
testClassificationNet(net, blob, ref, l1, lInf);
{
SCOPED_TRACE("Per-tensor quantize");
testClassificationNet(net, blob, ref, l1, lInf, false);
}
}
TEST_P(Test_Int8_nets, DenseNet121)
@ -954,6 +998,11 @@ TEST_P(Test_Int8_nets, EfficientDet)
float confThreshold = 0.65, scoreDiff = 0.3, iouDiff = 0.18;
testDetectionNet(net, blob, ref, confThreshold, scoreDiff, iouDiff);
{
SCOPED_TRACE("Per-tensor quantize");
testDetectionNet(net, blob, ref, 0.85, scoreDiff, iouDiff, false);
}
}
TEST_P(Test_Int8_nets, FasterRCNN_resnet50)
@ -1147,11 +1196,20 @@ TEST_P(Test_Int8_nets, TinyYoloVoc)
{
SCOPED_TRACE("batch size 1");
testDarknetModel(config_file, weights_file, ref.rowRange(0, 2), scoreDiff, iouDiff);
{
SCOPED_TRACE("Per-tensor quantize");
testDarknetModel(config_file, weights_file, ref.rowRange(0, 2), 0.1, 0.2, 0.24, 0.6, false);
}
}
{
SCOPED_TRACE("batch size 2");
testDarknetModel(config_file, weights_file, ref, scoreDiff, iouDiff);
{
SCOPED_TRACE("Per-tensor quantize");
testDarknetModel(config_file, weights_file, ref, 0.1, 0.2, 0.24, 0.6, false);
}
}
}
@ -1269,6 +1327,11 @@ TEST_P(Test_Int8_nets, YOLOv4_tiny)
{
SCOPED_TRACE("batch size 1");
testDarknetModel(config_file, weights_file, ref.rowRange(0, N0), scoreDiff, iouDiff, confThreshold);
{
SCOPED_TRACE("Per-tensor quantize");
testDarknetModel(config_file, weights_file, ref.rowRange(0, N0), scoreDiff, 0.16, 0.7, 0.4, false);
}
}
throw SkipTestException("batch2: bad accuracy on second image");

Loading…
Cancel
Save