add broadcast where node

pull/23485/head
zihaomu 2 years ago
parent 097891e311
commit 0513741a85
  1. 123
      modules/dnn/src/layers/nary_eltwise_layers.cpp
  2. 8
      modules/dnn/src/onnx/onnx_graph_simplifier.cpp
  3. 1
      modules/dnn/src/onnx/onnx_importer.cpp
  4. 5
      modules/dnn/test/test_onnx_importer.cpp

@ -48,6 +48,7 @@ public:
SUM,
ADD,
DIV,
WHERE,
} op;
NaryEltwiseLayerImpl(const LayerParams& params)
@ -94,6 +95,8 @@ public:
op = OPERATION::OR;
else if (operation == "xor")
op = OPERATION::XOR;
else if (operation == "where")
op = OPERATION::WHERE;
else
CV_Error(cv::Error::StsBadArg, "Unknown operation type \"" + operation + "\"");
}
@ -499,6 +502,120 @@ public:
f, scale, ninputs, max_ndims, shapes[0], inp, out, (const size_t **) steps, ptrs);
}
template <typename T, typename Functor>
void trinary_forward(const Functor& f, const std::vector<Mat>& inputs, std::vector<Mat>& outputs)
{
const Mat& a = inputs[0];
const Mat& b = inputs[1];
const Mat& c = inputs[2];
Mat& out = outputs[0];
// collect info of inputs and output
const int* in_shape[] = {a.size.p, b.size.p, c.size.p};
const size_t* in_step[] = {a.step.p, b.step.p, c.step.p};
const int* out_shape = out.size.p;
const size_t* out_step = out.step.p;
const int in_ndims[] = {a.dims, b.dims, c.dims};
int out_ndims = out.dims;
int max_ndims = std::max(a.dims, std::max(b.dims, std::max(c.dims, out.dims)));
AutoBuffer<size_t> buf(4 * (2 * max_ndims + 6));
int** orig_shapes = (int**)(buf.data());
int** shapes = orig_shapes + 4;
size_t** orig_steps = (size_t**)(shapes + 4);
size_t** steps = orig_steps + 4;
int* shape_buf = (int*)(steps + 4);
size_t* step_buf = (size_t*)(shape_buf + 4 * max_ndims);
int* all_ndims = (int*)(step_buf + 4 * max_ndims);
size_t* all_type_sizes = (size_t*)(all_ndims + 4);
// assign orig_shapes, shapes, orig_steps, steps, all_ndims, all_type_sizes
for (int i = 0; i < 4; i++)
{
orig_shapes[i] = (int*)(i == 0 ? out_shape : in_shape[i-1]);
orig_steps[i] = (size_t*)(i == 0 ? out_step : in_step[i-1]);
shapes[i] = shape_buf + i * max_ndims;
steps[i] = step_buf + i * max_ndims;
all_ndims[i] = i == 0 ? out_ndims : in_ndims[i-1];
all_type_sizes[i] = sizeof(T);
}
if (!prepare_for_broadcast_op(4, max_ndims, all_type_sizes,
all_ndims, (const int**)orig_shapes,
(const size_t**)orig_steps,
shapes, steps))
return;
trinary_forward_impl<T, Functor>(
max_ndims, shapes[0], a.ptr<char>(), steps[1], b.ptr<char>(), steps[2],
c.ptr<char>(), steps[3], out.ptr<char>(), steps[0],
f);
}
template <typename T, typename Functor>
void trinary_forward_impl(
int ndims, const int* shape,
const char* data1, const size_t* step1,
const char* data2, const size_t* step2,
const char* data3, const size_t* step3,
char* data, const size_t* step,
const Functor& op)
{
assert(ndims >= 2);
size_t dp1 = step1[ndims-1]/sizeof(T);
size_t dp2 = step2[ndims-1]/sizeof(T);
size_t dp3 = step3[ndims-1]/sizeof(T);
size_t dp = step[ndims-1]/sizeof(T);
int k, n1 = shape[ndims-1], n2 = shape[ndims-2];
size_t plane_idx, nplanes = 1;
for (k = 0; k < ndims-2; k++) nplanes *= shape[k];
for (plane_idx = 0; plane_idx < nplanes; plane_idx++)
{
const char* ptr1_ = data1;
const char* ptr2_ = data2;
const char* ptr3_ = data3;
char* ptr_ = data;
size_t idx = plane_idx;
for (k = ndims-3; k >= 0; k--)
{
size_t next_idx = idx/shape[k];
int i_k = (int)(idx - next_idx*shape[k]);
ptr1_ += i_k*step1[k];
ptr2_ += i_k*step2[k];
ptr3_ += i_k*step3[k];
ptr_ += i_k*step[k];
idx = next_idx;
}
for (int i2 = 0; i2 < n2; i2++, ptr1_ += step1[ndims-2],
ptr2_ += step2[ndims-2],
ptr3_ += step3[ndims-2],
ptr_ += step[ndims-2])
{
const T* ptr1 = (const T*)ptr1_;
const T* ptr2 = (const T*)ptr2_;
const T* ptr3 = (const T*)ptr3_;
T* ptr = (T*)ptr_;
if (dp1 == 1 && dp2 == 1 && dp3 == 1 && dp == 1)
{
for(int i1 = 0; i1 < n1; i1++)
ptr[i1] = op(ptr1[i1], ptr2[i1], ptr3[i1]);
}
else
{
for(int i1 = 0; i1 < n1; i1++, ptr1 += dp1, ptr2 += dp2, ptr3 += dp3, ptr += dp)
*ptr = op(*ptr1, *ptr2, *ptr3);
}
}
}
}
void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE
{
CV_TRACE_FUNCTION();
@ -637,6 +754,12 @@ public:
binary_forward<T>(op_xor, std::forward<Args>(args)...);
break;
}
case OPERATION::WHERE:
{
auto op_where = [](const T &a, const T &b, const T &c) { return a ? b : c; };
trinary_forward<T>(op_where, std::forward<Args>(args)...);
break;
}
default:
CV_Error(Error::StsBadArg, "Unsupported operation.");
};

@ -1168,8 +1168,12 @@ Mat getMatFromTensor(const opencv_onnx::TensorProto& tensor_proto)
else if (datatype == opencv_onnx::TensorProto_DataType_DOUBLE)
{
const ::google::protobuf::RepeatedField<double> field = tensor_proto.double_data();
CV_Assert(!field.empty());
char* val = (char *)field.data();
char* val = nullptr;
if (!field.empty())
val = (char *)field.data();
else
val = const_cast<char*>(tensor_proto.raw_data().c_str()); // sometime, the double will be stored at raw_data.
#if CV_STRONG_ALIGNMENT
// Aligned pointer is required.
AutoBuffer<double, 16> aligned_val;

@ -4058,6 +4058,7 @@ void ONNXImporter::buildDispatchMap_ONNX_AI(int opset_version)
dispatch["LessOrEqual"] = &ONNXImporter::parseElementWise;
dispatch["Sum"] = dispatch["Min"] = dispatch["Max"] = &ONNXImporter::parseElementWise;
dispatch["Where"] = &ONNXImporter::parseElementWise;
dispatch["Range"] = &ONNXImporter::parseRange;
std::vector<std::string> simpleLayers{"Acos", "Acosh", "Asin", "Asinh", "Atan", "Atanh", "Ceil", "Celu", "Cos",

@ -2492,6 +2492,11 @@ TEST_P(Test_ONNX_layers, OpenAI_CLIP_head)
testONNXModels("clip-vit-base-head");
}
TEST_P(Test_ONNX_layers, where_node)
{
testONNXModels("where_layer");
}
INSTANTIATE_TEST_CASE_P(/**/, Test_ONNX_nets, dnnBackendsAndTargets());
}} // namespace

Loading…
Cancel
Save