From 0513741a857daa47cedd2eed8e76b8577cfcd7ec Mon Sep 17 00:00:00 2001 From: zihaomu Date: Fri, 5 May 2023 11:16:19 +0800 Subject: [PATCH] add broadcast where node --- .../dnn/src/layers/nary_eltwise_layers.cpp | 123 ++++++++++++++++++ .../dnn/src/onnx/onnx_graph_simplifier.cpp | 8 +- modules/dnn/src/onnx/onnx_importer.cpp | 1 + modules/dnn/test/test_onnx_importer.cpp | 5 + 4 files changed, 135 insertions(+), 2 deletions(-) diff --git a/modules/dnn/src/layers/nary_eltwise_layers.cpp b/modules/dnn/src/layers/nary_eltwise_layers.cpp index 280920af35..9ad6a61f10 100644 --- a/modules/dnn/src/layers/nary_eltwise_layers.cpp +++ b/modules/dnn/src/layers/nary_eltwise_layers.cpp @@ -48,6 +48,7 @@ public: SUM, ADD, DIV, + WHERE, } op; NaryEltwiseLayerImpl(const LayerParams& params) @@ -94,6 +95,8 @@ public: op = OPERATION::OR; else if (operation == "xor") op = OPERATION::XOR; + else if (operation == "where") + op = OPERATION::WHERE; else CV_Error(cv::Error::StsBadArg, "Unknown operation type \"" + operation + "\""); } @@ -499,6 +502,120 @@ public: f, scale, ninputs, max_ndims, shapes[0], inp, out, (const size_t **) steps, ptrs); } + template + void trinary_forward(const Functor& f, const std::vector& inputs, std::vector& outputs) + { + const Mat& a = inputs[0]; + const Mat& b = inputs[1]; + const Mat& c = inputs[2]; + Mat& out = outputs[0]; + + // collect info of inputs and output + const int* in_shape[] = {a.size.p, b.size.p, c.size.p}; + const size_t* in_step[] = {a.step.p, b.step.p, c.step.p}; + const int* out_shape = out.size.p; + const size_t* out_step = out.step.p; + const int in_ndims[] = {a.dims, b.dims, c.dims}; + int out_ndims = out.dims; + + int max_ndims = std::max(a.dims, std::max(b.dims, std::max(c.dims, out.dims))); + + AutoBuffer buf(4 * (2 * max_ndims + 6)); + + int** orig_shapes = (int**)(buf.data()); + int** shapes = orig_shapes + 4; + size_t** orig_steps = (size_t**)(shapes + 4); + size_t** steps = orig_steps + 4; + + int* shape_buf = (int*)(steps + 4); + size_t* step_buf = (size_t*)(shape_buf + 4 * max_ndims); + + int* all_ndims = (int*)(step_buf + 4 * max_ndims); + size_t* all_type_sizes = (size_t*)(all_ndims + 4); + + // assign orig_shapes, shapes, orig_steps, steps, all_ndims, all_type_sizes + for (int i = 0; i < 4; i++) + { + orig_shapes[i] = (int*)(i == 0 ? out_shape : in_shape[i-1]); + orig_steps[i] = (size_t*)(i == 0 ? out_step : in_step[i-1]); + shapes[i] = shape_buf + i * max_ndims; + steps[i] = step_buf + i * max_ndims; + all_ndims[i] = i == 0 ? out_ndims : in_ndims[i-1]; + all_type_sizes[i] = sizeof(T); + } + + if (!prepare_for_broadcast_op(4, max_ndims, all_type_sizes, + all_ndims, (const int**)orig_shapes, + (const size_t**)orig_steps, + shapes, steps)) + return; + + trinary_forward_impl( + max_ndims, shapes[0], a.ptr(), steps[1], b.ptr(), steps[2], + c.ptr(), steps[3], out.ptr(), steps[0], + f); + } + + template + void trinary_forward_impl( + int ndims, const int* shape, + const char* data1, const size_t* step1, + const char* data2, const size_t* step2, + const char* data3, const size_t* step3, + char* data, const size_t* step, + const Functor& op) + { + assert(ndims >= 2); + size_t dp1 = step1[ndims-1]/sizeof(T); + size_t dp2 = step2[ndims-1]/sizeof(T); + size_t dp3 = step3[ndims-1]/sizeof(T); + size_t dp = step[ndims-1]/sizeof(T); + int k, n1 = shape[ndims-1], n2 = shape[ndims-2]; + size_t plane_idx, nplanes = 1; + for (k = 0; k < ndims-2; k++) nplanes *= shape[k]; + + for (plane_idx = 0; plane_idx < nplanes; plane_idx++) + { + const char* ptr1_ = data1; + const char* ptr2_ = data2; + const char* ptr3_ = data3; + char* ptr_ = data; + size_t idx = plane_idx; + for (k = ndims-3; k >= 0; k--) + { + size_t next_idx = idx/shape[k]; + int i_k = (int)(idx - next_idx*shape[k]); + ptr1_ += i_k*step1[k]; + ptr2_ += i_k*step2[k]; + ptr3_ += i_k*step3[k]; + ptr_ += i_k*step[k]; + idx = next_idx; + } + + for (int i2 = 0; i2 < n2; i2++, ptr1_ += step1[ndims-2], + ptr2_ += step2[ndims-2], + ptr3_ += step3[ndims-2], + ptr_ += step[ndims-2]) + { + const T* ptr1 = (const T*)ptr1_; + const T* ptr2 = (const T*)ptr2_; + const T* ptr3 = (const T*)ptr3_; + T* ptr = (T*)ptr_; + + if (dp1 == 1 && dp2 == 1 && dp3 == 1 && dp == 1) + { + for(int i1 = 0; i1 < n1; i1++) + ptr[i1] = op(ptr1[i1], ptr2[i1], ptr3[i1]); + } + else + { + for(int i1 = 0; i1 < n1; i1++, ptr1 += dp1, ptr2 += dp2, ptr3 += dp3, ptr += dp) + *ptr = op(*ptr1, *ptr2, *ptr3); + } + } + } + } + void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE { CV_TRACE_FUNCTION(); @@ -637,6 +754,12 @@ public: binary_forward(op_xor, std::forward(args)...); break; } + case OPERATION::WHERE: + { + auto op_where = [](const T &a, const T &b, const T &c) { return a ? b : c; }; + trinary_forward(op_where, std::forward(args)...); + break; + } default: CV_Error(Error::StsBadArg, "Unsupported operation."); }; diff --git a/modules/dnn/src/onnx/onnx_graph_simplifier.cpp b/modules/dnn/src/onnx/onnx_graph_simplifier.cpp index d88b630e6f..c17b4fdb09 100644 --- a/modules/dnn/src/onnx/onnx_graph_simplifier.cpp +++ b/modules/dnn/src/onnx/onnx_graph_simplifier.cpp @@ -1168,8 +1168,12 @@ Mat getMatFromTensor(const opencv_onnx::TensorProto& tensor_proto) else if (datatype == opencv_onnx::TensorProto_DataType_DOUBLE) { const ::google::protobuf::RepeatedField field = tensor_proto.double_data(); - CV_Assert(!field.empty()); - char* val = (char *)field.data(); + char* val = nullptr; + if (!field.empty()) + val = (char *)field.data(); + else + val = const_cast(tensor_proto.raw_data().c_str()); // sometime, the double will be stored at raw_data. + #if CV_STRONG_ALIGNMENT // Aligned pointer is required. AutoBuffer aligned_val; diff --git a/modules/dnn/src/onnx/onnx_importer.cpp b/modules/dnn/src/onnx/onnx_importer.cpp index 651d1b1571..4460ca6f4b 100644 --- a/modules/dnn/src/onnx/onnx_importer.cpp +++ b/modules/dnn/src/onnx/onnx_importer.cpp @@ -4058,6 +4058,7 @@ void ONNXImporter::buildDispatchMap_ONNX_AI(int opset_version) dispatch["LessOrEqual"] = &ONNXImporter::parseElementWise; dispatch["Sum"] = dispatch["Min"] = dispatch["Max"] = &ONNXImporter::parseElementWise; + dispatch["Where"] = &ONNXImporter::parseElementWise; dispatch["Range"] = &ONNXImporter::parseRange; std::vector simpleLayers{"Acos", "Acosh", "Asin", "Asinh", "Atan", "Atanh", "Ceil", "Celu", "Cos", diff --git a/modules/dnn/test/test_onnx_importer.cpp b/modules/dnn/test/test_onnx_importer.cpp index e566acd827..0f4ce7d7aa 100644 --- a/modules/dnn/test/test_onnx_importer.cpp +++ b/modules/dnn/test/test_onnx_importer.cpp @@ -2492,6 +2492,11 @@ TEST_P(Test_ONNX_layers, OpenAI_CLIP_head) testONNXModels("clip-vit-base-head"); } +TEST_P(Test_ONNX_layers, where_node) +{ + testONNXModels("where_layer"); +} + INSTANTIATE_TEST_CASE_P(/**/, Test_ONNX_nets, dnnBackendsAndTargets()); }} // namespace