|
|
|
@ -32,6 +32,16 @@ public: |
|
|
|
|
return exclusive_raw == 0; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
virtual void getTypes(const std::vector<MatType>& inputs, |
|
|
|
|
const int requiredOutputs, |
|
|
|
|
const int requiredInternals, |
|
|
|
|
std::vector<MatType>& outputs, |
|
|
|
|
std::vector<MatType>& internals) const CV_OVERRIDE |
|
|
|
|
{ |
|
|
|
|
CV_CheckType(inputs[0], inputs[0] == CV_32F || inputs[0] == CV_32S || inputs[0] == CV_64S || inputs[0] == CV_16F, ""); |
|
|
|
|
outputs.assign(1, inputs[0]); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE |
|
|
|
|
{ |
|
|
|
|
CV_TRACE_FUNCTION(); |
|
|
|
@ -47,9 +57,30 @@ public: |
|
|
|
|
inputs_arr.getMatVector(inputs); |
|
|
|
|
outputs_arr.getMatVector(outputs); |
|
|
|
|
|
|
|
|
|
CV_CheckTypeEQ(inputs[0].depth(), outputs[0].depth(), ""); |
|
|
|
|
|
|
|
|
|
switch(inputs[0].depth()) |
|
|
|
|
{ |
|
|
|
|
case CV_32F: |
|
|
|
|
forwardImpl<float>(inputs, outputs); |
|
|
|
|
break; |
|
|
|
|
case CV_32S: |
|
|
|
|
forwardImpl<int32_t>(inputs, outputs); |
|
|
|
|
break; |
|
|
|
|
case CV_64S: |
|
|
|
|
forwardImpl<int64_t>(inputs, outputs); |
|
|
|
|
break; |
|
|
|
|
default: |
|
|
|
|
CV_Error(Error::BadDepth, ""); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
template <typename T> |
|
|
|
|
void forwardImpl(const std::vector<Mat>& inputs, std::vector<Mat>& outputs) |
|
|
|
|
{ |
|
|
|
|
// Get input tensor.
|
|
|
|
|
const auto& src_mat = inputs[0]; |
|
|
|
|
const auto* src_ptr = src_mat.ptr<float>(); |
|
|
|
|
const T* src_ptr = src_mat.ptr<T>(); |
|
|
|
|
|
|
|
|
|
// Get target axis.
|
|
|
|
|
int axis = inputs.size() > 1 ? parseAxis(inputs[1]) : axis_raw; |
|
|
|
@ -58,7 +89,7 @@ public: |
|
|
|
|
|
|
|
|
|
// Get output tensor.
|
|
|
|
|
auto& dst_mat = outputs[0]; |
|
|
|
|
auto* dst_ptr = dst_mat.ptr<float>(); |
|
|
|
|
T* dst_ptr = dst_mat.ptr<T>(); |
|
|
|
|
|
|
|
|
|
// Get flags.
|
|
|
|
|
const auto exclusive = exclusive_raw == 1; |
|
|
|
@ -89,7 +120,7 @@ public: |
|
|
|
|
size_t first_inner_offset = target_offset + target_start * inner_size; |
|
|
|
|
if (exclusive) |
|
|
|
|
for (size_t inner_idx = 0; inner_idx < inner_size; inner_idx++) |
|
|
|
|
dst_ptr[first_inner_offset + inner_idx] = 0.0f; |
|
|
|
|
dst_ptr[first_inner_offset + inner_idx] = 0; |
|
|
|
|
else |
|
|
|
|
for (size_t inner_idx = 0; inner_idx < inner_size; inner_idx++) |
|
|
|
|
dst_ptr[first_inner_offset + inner_idx] = src_ptr[first_inner_offset + inner_idx]; |
|
|
|
|