diff --git a/modules/dnn/src/layers/einsum_layer.cpp b/modules/dnn/src/layers/einsum_layer.cpp index 2cfb36da13..baf4297c0e 100644 --- a/modules/dnn/src/layers/einsum_layer.cpp +++ b/modules/dnn/src/layers/einsum_layer.cpp @@ -32,15 +32,14 @@ static bool IsTransposeReshapeForEinsum(const std::vector& perm, return true; } -Mat batchwiseMatMul( +static Mat batchwiseMatMul( const Mat& input1, const MatShape& input1ShapeOverride, const Mat& input2, const MatShape& input2ShapeOverride) { // Sanity checks before the actual MatMul - //input_1.DataType() == input_2.DataType(), "Data types of the inputs must match for MatMul"); - + CV_CheckType(input1.type(), input2.type(), "Data types of the inputs must match for MatMul"); CV_CheckEQ(input1ShapeOverride.size(), (size_t) 3, "Only 1 batch dimension is allowed for MatMul"); CV_CheckEQ(input2ShapeOverride.size(), (size_t) 3, "Only 1 batch dimension is allowed for MatMul"); CV_CheckEQ((size_t) input1ShapeOverride[0], (size_t) input2ShapeOverride[0], "Batch dimension should match for MatMul;"); @@ -51,8 +50,6 @@ Mat batchwiseMatMul( size_t K = input1ShapeOverride[2]; size_t N = input2ShapeOverride[2]; - //TODO: deal with dynamic shapes - //TODO: deal with reshaping operation (it might not always be needed) std::vector output; if (batches > 1) { @@ -141,26 +138,19 @@ Mat batchwiseMatMul( return output_buffer; }; -Mat Transpose( - const cv::Mat& input, +static Mat Transpose( + const Mat& input, const MatShape& input_shape_override, const std::vector permutation) { int input_rank = input_shape_override.size(); - CV_Assert(input_rank == permutation.size()); - // TODO: ouptimize - bool reshape = false; - if (input.dims != input_shape_override.size()) - { - reshape = true; - } + bool reshape = input.dims != input_rank; Mat input_reshaped; - if(reshape) - { + if(reshape){ input_reshaped = input.reshape(1, input_shape_override.size(), input_shape_override.data()); } @@ -170,13 +160,9 @@ Mat Transpose( outputDims.emplace_back(input_shape_override[dim]); Mat output; - // TODO: ouptimize - MatShape tmp_perm; - tmp_perm.reserve(permutation.size()); - for (int i = 0; i < permutation.size(); i++) - tmp_perm.emplace_back(static_cast(permutation[i])); + MatShape order(permutation.begin(), permutation.end()); - cv::transposeND((reshape ? input_reshaped : input), tmp_perm, output); + cv::transposeND((reshape ? input_reshaped : input), order, output); return output; } @@ -201,12 +187,183 @@ bool IsTransposeRequired(size_t input_rank, const std::vector& permutati return transpose_required; } -Mat Diagonal( - const cv::Mat& input, - int subscriptIndicesToInputIndex, - int dimIndexInIreprocessedInput) + +bool IsTransposeRequiredForDiagonal(int dim1, int dim2, int rank) { + // If the input is 2D, we don't need a transpose + if (rank == 2) + return false; + + // If the two dims are the innermost dims, no transpose is required + if ((dim1 == rank - 1 && dim2 == rank - 2) || + (dim1 == rank - 2 && dim2 == rank - 1)) + return false; + + // Transpose is required + return true; +} + +template +Mat DiagonalDataAssignment(Mat input) { + + int rank = input.dims; + CV_Assert(rank >= 2); + CV_Assert(input.size[rank - 1] == input.size[rank - 2]); + MatShape original_dims = shape(input); + + if (rank > 3){ + //reshape to 3D mat + int collapsed_size = 1; + for (int i = 0; i < rank - 2; ++i) { + collapsed_size *= input.size[i]; + } + std::vector reshaped_dims = {collapsed_size, input.size[rank - 2], input.size[rank - 1]}; + input = input.reshape(1, reshaped_dims); + } + + // Compute total number of higher-dimensional slices + int total_slices = input.size[0]; + + original_dims[rank - 1] = 1; // Set the last dimension to 1, as we have extracted the diagonal + Mat output = Mat(original_dims, input.type()); + + int inner_stride = input.size[input.dims - 1]; + auto inputPtr = input.ptr(); + auto outputPtr = output.ptr(); + for (int slice = 0; slice < total_slices; ++slice) { + for (int j = 0; j < inner_stride; ++j) { + // Direct memory access using raw pointers + outputPtr[slice * inner_stride + j] = inputPtr[slice * inner_stride * inner_stride + j * inner_stride + j]; + } + } + return output; +} + +/* Extract the diagonal elements from the last two dimensions of the tensor. +For instance, given an input_shape of [1, 2, 3, 3]: + +The flexibility in this implementation allows one to choose which of the two +last dimensions retains its value, determined by the `preserve_innermost_dim_val` parameter. + +When preserve_innermost_dim_val == true: + The resulting shape is [1, 2, 1, 3], indicating the diagonal has 3 elements, + and it keeps the dimension value of the innermost dimension. + +When preserve_innermost_dim_val == false: + The resulting shape is [1, 2, 3, 1], indicating the diagonal also has 3 elements, + but it retains the dimension value of the penultimate dimension. */ +Mat DiagonalInnermostDims(const Mat& input, bool preserve_innermost_dim_val) { + const MatShape input_dims = shape(input); + int rank = input_dims.size(); + + // This is an internal method and we already have finished all validations in the calling method. + // We proceed without duplicating all validations again here. + + // We have a minimalistic check here to make sure the innermost dims have the same dim value + // as the calling method may have done a transpose before calling this method + CV_CheckEQ(input.size[rank - 1], input.size[rank - 2], + "innermost dims should have the same dim value to parse the diagonal elements"); + + MatShape output_dims = input_dims; // Copy the original dims + if (preserve_innermost_dim_val) { + output_dims[rank - 2] = 1; + } else { + output_dims[rank - 1] = 1; + } + + // TODO: hande different types + Mat output = DiagonalDataAssignment(input); + + if (output_dims != shape(output)){ + CV_Error(Error::StsError, "Output shape does not match with calculated shape"); + } + return output; +} + +Mat Diagonal(const Mat& input, int dim1, int dim2) { - CV_Error(Error::StsNotImplemented, "Diagonal Not Implemented Yet"); + const MatShape input_dims = shape(input); + int rank = input_dims.size(); + + if (!(rank >= 2 && dim1 != dim2 && input_dims[dim1] == input_dims[dim2])){ + std::string input_dims_str = std::accumulate(std::next(input_dims.begin()), input_dims.end(), std::to_string(input_dims[0]), + [](const std::string& a, int b) { + return a + ' ' + std::to_string(b); + }); + CV_Error(Error::StsError, cv::format("Cannot parse the diagonal elements along dims %d and %d for input shape %s",dim1, dim2, input_dims_str.c_str())); + } + + int first_dim = std::min(dim1, dim2); + int second_dim = std::max(dim1, dim2); + + Mat output; + bool preserve_innermost_dim_val = false; + + bool is_transpose_required = IsTransposeRequiredForDiagonal(dim1, dim2, rank); + if (is_transpose_required) + { + std::vector permutation(rank, 0); + int first_dim_axis = -1; // This is the axis eventually occupied by the first_dim + + // If one of the diagonal dimensions is one of the 2 innermost dims, then leave it as such + // so as to avoid transpose overhead + if (first_dim == rank - 2) { // If rank - 2 is occupied by first_dim, keep it there + permutation[rank - 2] = first_dim; + first_dim_axis = rank - 2; + } else { + if (second_dim != rank - 2) { // If rank - 2 is not occupied by second_dim, then put first_dim there + permutation[rank - 2] = first_dim; + first_dim_axis = rank - 2; + } else { // If rank - 2 is occupied by second_dim, then put first_dim in rank - 1 + permutation[rank - 1] = first_dim; + first_dim_axis = rank - 1; + preserve_innermost_dim_val = true; // We always want to preserve the dim value of the first_dim + } + } + + // Put the second_dim in the dim not occupied by the first_dim + if (first_dim_axis != rank - 1) { + permutation[rank - 1] = second_dim; + } else { + permutation[rank - 2] = second_dim; + } + + size_t iter = 0; + for (int i = 0; i < rank; ++i) { + if (i != first_dim && i != second_dim) { + permutation[iter++] = i; + } + } + + // Permutate the input so that the dims from which we need the diagonal forms the innermost dims + Mat transposed = Transpose(input, input_dims, permutation); + + // Parse the diagonal from the innermost dims + output = DiagonalInnermostDims(transposed, preserve_innermost_dim_val); + + // Swap back the dimensions to the original axes ordering using a "reverse permutation" + // Find the "reverse" permutation + iter = 0; + std::vector reverse_permutation(rank, 0); + for (const auto& perm : permutation) { + reverse_permutation[perm] = iter++; + } + + // Permutate using the reverse permutation to get back the original axes ordering + // (Pass in CPU Transpose function here as this Diagonal method will only be used for CPU based diagonal parsing) + output = Transpose(output, shape(output), reverse_permutation); + } else { + // No transposing required + output = DiagonalInnermostDims(input, preserve_innermost_dim_val); + } + + // Make copy of the output dims + MatShape output_dims = shape(output); + + // Unsqueeze the reduced dim + auto iter = output_dims.begin() + second_dim; + output_dims.erase(iter); + output = output.reshape(1, output_dims); + return output; } /** @@ -299,7 +456,7 @@ public: void parseEquation(String equation); void processEquation(const std::vector& inputs); void processBroadcastedDims(); - void createOutputSubsctipt(); + void validateOutputSubscript(); void calculateOutputShape(); void preProcessInputs(InputArrayOfArrays& inputs); Mat reduceSum(Mat& src, MatShape& reduceAxis); @@ -358,7 +515,7 @@ public: processBroadcastedDims(); // calculate output shape - createOutputSubsctipt(); + validateOutputSubscript(); calculateOutputShape(); } @@ -624,7 +781,7 @@ void LayerEinsumImpl::calculateOutputShape() { // Traverse through each of the subscript labels within the output subscript. bool middleOfEllipsis = false; - // int64_t ellipsisCharCount = 0; + int ellipsisCharCount = 0; subscriptIndicesToOutputIndices.resize(numLetterIndices, -1); @@ -636,7 +793,21 @@ void LayerEinsumImpl::calculateOutputShape() { if(letter == '.') { - CV_Error(Error::StsNotImplemented, "Ellipsis are not supported yet"); + middleOfEllipsis = true; + // Make sure there aren't more than 3 '.'s in the current subscript + if (++ellipsisCharCount > 3) { + CV_Error(Error::StsError, "Found a '.' not part of an ellipsis in the output subscript provided"); + } + + if (ellipsisCharCount == 3) { // Ellipsis is complete. Process it. + middleOfEllipsis = false; + for (size_t i = 0; i < numOfEllipsisDims; ++i) { + einsumOutDims.emplace_back(subscriptIndicesToDimValue[i]); + // The ellipsis is seen in the output and hence the corresponding dims are to not be reduced + subscriptIndicesToLastInput[i] = -1; + subscriptIndicesToOutputIndices[i] = outputDimCounter++; + } + } } else { CV_CheckEQ(middleOfEllipsis, false, "Encountered '.' character that is not part of output subscript"); @@ -666,7 +837,7 @@ void LayerEinsumImpl::calculateOutputShape() } } -void LayerEinsumImpl::createOutputSubsctipt() +void LayerEinsumImpl::validateOutputSubscript() { // The explicit form requires no operation, as the output // would have already been parsed during the input parsing process. @@ -679,8 +850,6 @@ void LayerEinsumImpl::createOutputSubsctipt() { CV_Error(Error::StsError, "Provided output subscript does not include ellipsis while Inputs subscrits constain ellipsis"); - } else { - CV_Error(Error::StsNotImplemented, "Ellipsis are not yet supported"); } } } @@ -689,9 +858,84 @@ void LayerEinsumImpl::createOutputSubsctipt() void LayerEinsumImpl::processBroadcastedDims() { // Only compute this function if ellipsis "..." was found in the equation - if (numOfEllipsisDims > 0){ - // add assert inplace of return bool - CV_Error(Error::StsError, "Ellipsis are not supperted currenly"); + if (numOfEllipsisDims > 0) + { + // extend the number of subscript labels to include each ellipsis dim as + // theoretically each ellipsis dim does correspond to a "virtual" subscript label + numLetterIndices += numOfEllipsisDims; + + // We are going to assign the broadcasted dims outermost subscript indices (i.e.) 0 -> numOfEllipsisDims - 1 + // as most likely bradcasted dims will be batch dimensions (i.e.) outermost dimensions and hence we don't have to pay + // transposing while "homogenizing" the input + + // Hence offset all subscript indices by numOfEllipsisDims + for (size_t i = 0; i < numOfLetters; ++i){ + if (letter2count[i] != -1){ + letter2index[i] += numOfEllipsisDims; + } + } + + std::vector tempIndex2LastInput(numLetterIndices, -1); + for (int i = 0; i < subscriptIndicesToLastInput.size(); ++i){ + tempIndex2LastInput[i + numOfEllipsisDims] = subscriptIndicesToLastInput[i]; + } + subscriptIndicesToLastInput = std::move(tempIndex2LastInput); + + std::vector tempIndexToDimValue(numLetterIndices, -1); + for (int i = 0; i < subscriptIndicesToDimValue.size(); ++i){ + tempIndexToDimValue[i + numOfEllipsisDims] = subscriptIndicesToDimValue[i]; + } + subscriptIndicesToDimValue = std::move(tempIndexToDimValue); + + for (size_t i = 0; i < inputSubscriptIndices.size(); ++i) + { + auto& currentInputDimIndicesToSubscriptIndices = inputSubscriptIndices[i]; + std::vector tempCurrentInputDimIndicesToSubscriptIndices; + tempCurrentInputDimIndicesToSubscriptIndices.reserve(currentInputDimIndicesToSubscriptIndices.size()); + + // make sure it is correct + const auto& dims = einsumInpShapes[i]; + auto rank = dims.size(); + + size_t dimIter = 0; + size_t numBroadcastedIndices = 0; + while (dimIter < currentInputDimIndicesToSubscriptIndices.size()) + { + auto value = currentInputDimIndicesToSubscriptIndices[dimIter]; + if (value == numOfLetters) + { // This is a broadcasted dim + // Shouldn't hit this error - just a sanity check + CV_Assert(numBroadcastedIndices < numOfEllipsisDims); + tempCurrentInputDimIndicesToSubscriptIndices.push_back(static_cast(numBroadcastedIndices)); + subscriptIndicesToLastInput[numBroadcastedIndices] = i; + + // This is the first time we are seeing this broadcasted dim + if (subscriptIndicesToDimValue[numBroadcastedIndices] == -1) + { + subscriptIndicesToDimValue[numBroadcastedIndices] = dims[dimIter]; + } else { // We have seen this broadcasted dim before + // Check if the previous value is equal to the current value + if (subscriptIndicesToDimValue[numBroadcastedIndices] != dims[dimIter]) + { + // If they are not equal, one of them needs to be 1 + if (subscriptIndicesToDimValue[numBroadcastedIndices] == 1) + { + subscriptIndicesToDimValue[numBroadcastedIndices] = dims[dimIter]; + } else { + CV_CheckEQ(dims[dimIter], 1, "The broadcasted dimensions of the inputs are incompatible"); + } + } + } + ++numBroadcastedIndices; + } else { // This is a regular dim - offset it by number of broadcasted dims + tempCurrentInputDimIndicesToSubscriptIndices.push_back(value + static_cast(numOfEllipsisDims)); + } + ++dimIter; + } + // Shouldn't hit this error - just a sanity check + CV_Assert(dimIter == rank); + currentInputDimIndicesToSubscriptIndices = std::move(tempCurrentInputDimIndicesToSubscriptIndices); + } } } @@ -718,18 +962,58 @@ void LayerEinsumImpl::processEquation(const std::vector& inputs) // Variable to deal with "ellipsis" - '...' in the input bool middleOfellipsis = false; + int ellipsisCharCount = 0; for (auto letter : token) { - // Broadcasting based tokens are not implemented yet if (letter == '.') { - CV_Error(Error::StsNotImplemented, - "Broad casting based indices are not supported currently"); - } else - { + middleOfellipsis = true; + + // there should not be more than 3 '.'s in the current subscript + if (++ellipsisCharCount > 3) + { + CV_Error(Error::StsError, cv::format("Found a '.' not part of an ellipsis in input: %d", inputIdx)); + } - if (middleOfellipsis) + // We have seen all 3 '.'s. We can safely process the ellipsis now. + if (ellipsisCharCount == 3) { + middleOfellipsis = false; + + // Example for the following line of code + // Subscript "...ij" for an input of rank 6 + // numOfEllipsisDims = 6 - 5 + 3 = 4 + int currentNumOfEllipsisDims = static_cast(rank) - token.length() + 3; + CV_CheckGE(currentNumOfEllipsisDims, 0, + "Einsum subscripts string contains too many subscript labels when compared to the rank of the input"); + + // Theoretically, currentNumOfEllipsisDims could be 0 + // Example: For an input of rank 2 paired with a subscript "...ij" + if (currentNumOfEllipsisDims != 0) + { + // We have seen a ellipsis before - make sure ranks align as per the ONNX spec - + // "Ellipsis must indicate a fixed number of dimensions." + if (numOfEllipsisDims != 0){ + CV_CheckEQ(numOfEllipsisDims, static_cast(currentNumOfEllipsisDims), + "Ellipsis must indicate a fixed number of dimensions across all inputs"); + } else { + numOfEllipsisDims = static_cast(currentNumOfEllipsisDims); + } + + // We reserve 'numOfLetters' for broadcasted dims as we only allow 'a' - 'z' + // and 'A' - 'Z' (0 - 51) for non-broadcasted dims. + // We will assign appropriate indices (based on number of dimensions the ellipsis corresponds to) + // during broadcasting related post-processing. + for (size_t i = 0; i < numOfEllipsisDims; ++i){ + currTokenIndices.push_back(numOfLetters); + } + + // Offset 'dim_count' by number of dimensions the ellipsis corresponds to + dim_count += numOfEllipsisDims; + } + } + } else { + if (middleOfellipsis){ CV_Error(Error::StsAssert, cv::format( "Encountered '.' character that is not part of an ellipsis in the input: [%d]", @@ -744,8 +1028,7 @@ void LayerEinsumImpl::processEquation(const std::vector& inputs) // The subscript label was not found in the global subscript label array // Therefore, it is added to both the local and global subscript arrays - if(letter2count[letterIdx] == 0) - { + if(letter2count[letterIdx] == 0){ letter2index[letterIdx] = numLetterIndices++; subscriptIndicesToDimValue.push_back(dimValue); subscriptIndicesToLastInput.push_back(inputIdx); @@ -756,20 +1039,12 @@ void LayerEinsumImpl::processEquation(const std::vector& inputs) auto mappedIndx = letter2index[letterIdx]; subscriptIndicesToLastInput[mappedIndx] = inputIdx; - if (subscriptIndicesToDimValue[mappedIndx] != dimValue) - { - if(subscriptIndicesToDimValue[mappedIndx] == 1){ - //TODO: uncomment later on - // subscriptIndicesToDimValue[mappedIndx] == dimValue; - } else - { - if (dimValue != 1) - { - CV_Error(Error::StsError, cv::format("Einsum operands can not be broadcasted." - "Check input shapes/equation passed." - "Input shape of operand [%d]", inputIdx) + - cv::format(" is incompatible in the dimention [%zu].", static_cast(dim_count))); - } + if (subscriptIndicesToDimValue[mappedIndx] != dimValue) { + if (dimValue != 1) { + CV_Error(Error::StsError, cv::format("Einsum operands can not be broadcasted." + "Check input shapes/equation passed." + "Input shape of operand [%d]", inputIdx) + + cv::format(" is incompatible in the dimention [%zu].", static_cast(dim_count))); } } } diff --git a/modules/dnn/test/test_onnx_conformance_layer_parser_denylist.inl.hpp b/modules/dnn/test/test_onnx_conformance_layer_parser_denylist.inl.hpp index 34a77a9a2e..8c461b699f 100644 --- a/modules/dnn/test/test_onnx_conformance_layer_parser_denylist.inl.hpp +++ b/modules/dnn/test/test_onnx_conformance_layer_parser_denylist.inl.hpp @@ -103,7 +103,6 @@ "test_dynamicquantizelinear_min_adjusted", "test_dynamicquantizelinear_min_adjusted_expanded", "test_edge_pad", -"test_einsum_batch_diagonal", "test_einsum_inner_prod", "test_equal", "test_equal_bcast", diff --git a/modules/dnn/test/test_onnx_importer.cpp b/modules/dnn/test/test_onnx_importer.cpp index 7fe9f8ccda..b7e4e73cbc 100644 --- a/modules/dnn/test/test_onnx_importer.cpp +++ b/modules/dnn/test/test_onnx_importer.cpp @@ -1456,6 +1456,11 @@ TEST_P(Test_ONNX_layers, Einsum_2D) testONNXModels("einsum_2d", npy, 0, 0, false, false, 2); } +TEST_P(Test_ONNX_layers, Einsum_2D_Ellipses) +{ + testONNXModels("einsum_2d_ellipses", npy, 0, 0, false, false, 2); +} + TEST_P(Test_ONNX_layers, Einsum_3D) { testONNXModels("einsum_3d", npy, 0, 0, false, false, 2); @@ -1481,7 +1486,7 @@ TEST_P(Test_ONNX_layers, DISABLED_Einsum_HadamardProduct) testONNXModels("einsum_hadamard", npy, 0, 0, false, false, 2); } -TEST_P(Test_ONNX_layers, DISABLED_Einsum_Batch_Diagonal) +TEST_P(Test_ONNX_layers, Einsum_Batch_Diagonal) { testONNXModels("einsum_batch_diagonal", npy, 0, 0, false, false, 1); }