Merge pull request #24812 from Abdurrahheem:ash/einsum_bachedGemm

Replace interactive batched Matrix Multiply. #24812

This PR replaces iterative batch matrix multiplication which `FastGemmBatch` in Einsum layer.

### Pull Request Readiness Checklist

See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request

- [x] I agree to contribute to the project under Apache 2 License.
- [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV
- [x] The PR is proposed to the proper branch
- [x] There is a reference to the original bug report and related work
- [x] There is accuracy test, performance test and test data in opencv_extra repository, if applicable
      Patch to opencv_extra has the same branch name.
- [x] The feature is well documented and sample code can be built with the project CMake
pull/24854/head
Abduragim Shtanchaev 10 months ago committed by GitHub
parent 1e190b3094
commit c923c59833
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      modules/dnn/src/layers/cpu_kernels/fast_gemm.cpp
  2. 78
      modules/dnn/src/layers/einsum_layer.cpp

@ -385,7 +385,7 @@ void fastGemmBatch(bool trans_a, bool trans_b,
const auto shape_b = shape(B); const auto shape_b = shape(B);
const auto shape_c = shape(C); const auto shape_c = shape(C);
CV_CheckGE(shape_a.size(), static_cast<size_t>(2), "DNN/fastGemmBatch: A must be n-dimensional (n >= 2)"); CV_CheckGE(shape_a.size(), static_cast<size_t>(2), "DNN/fastGemmBatch: A must be n-dimensional (n >= 2)");
CV_CheckEQ(shape_b.size(), static_cast<size_t>(2), "DNN/fastGemmBatch: B must be n-dimensional (n >= 2)"); CV_CheckGE(shape_b.size(), static_cast<size_t>(2), "DNN/fastGemmBatch: B must be n-dimensional (n >= 2)");
const float *a = A.ptr<const float>(); const float *a = A.ptr<const float>();
const float *b = B.ptr<const float>(); const float *b = B.ptr<const float>();

@ -1299,7 +1299,6 @@ Mat LayerEinsumImpl::batchwiseMatMul(
const Mat& input2, const Mat& input2,
const MatShape& input2ShapeOverride) const MatShape& input2ShapeOverride)
{ {
// Sanity checks before the actual MatMul // Sanity checks before the actual MatMul
CV_CheckType(input1.type(), input2.type(), "Data types of the inputs must match for MatMul"); CV_CheckType(input1.type(), input2.type(), "Data types of the inputs must match for MatMul");
CV_CheckEQ(input1ShapeOverride.size(), (size_t) 3, "Only 1 batch dimension is allowed for MatMul"); CV_CheckEQ(input1ShapeOverride.size(), (size_t) 3, "Only 1 batch dimension is allowed for MatMul");
@ -1312,61 +1311,22 @@ Mat LayerEinsumImpl::batchwiseMatMul(
int K = input1ShapeOverride[2]; int K = input1ShapeOverride[2];
int N = input2ShapeOverride[2]; int N = input2ShapeOverride[2];
std::vector<Mat> output; Mat reshapedInput1 = input1;
Mat reshapedInput2 = input2;
Mat output;
if (batches > 1) if (batches > 1)
{ {
Mat reshapedInput1 = input1; // create tmpout with type like input1
Mat reshapedInput2 = input2; output = Mat({batches, M, N}, input1.type());
// input1 should of size MxK
// check if input1 needs reshape, if need reshape
if (input1.size[0] != M || input1.size[1] != K)
{
int shape[] = {batches, M, K};
reshapedInput1 = input1.reshape(1, 3, shape);
}
// input2 should be of size KxN
// check if input2 needs reshape, if needs reshape
if (input2.size[0] != K || input2.size[1] != N)
{
int shape[] = {batches, K, N};
reshapedInput2 = input2.reshape(1, 3, shape);
}
for (size_t i=0; i < batches; i++)
{
std::vector<Range> ranges1 = {cv::Range(i, i+1)};
for (int j = 1; j < reshapedInput1.dims; j++)
ranges1.emplace_back(cv::Range::all());
Mat part1 = reshapedInput1(ranges1); reshapedInput2 = reshapedInput2.reshape(1, input2ShapeOverride);
int shape[] = {M, K}; reshapedInput1 = reshapedInput1.reshape(1, input1ShapeOverride);
part1 = part1.reshape(1, sizeof(shape)/sizeof(shape[0]), shape);
std::vector<Range> ranges2 = {cv::Range(i, i+1)};
for (int j = 1; j < reshapedInput2.dims; j++)
ranges2.emplace_back(cv::Range::all());
Mat part2 = reshapedInput2(ranges2);
int shape2[] = {K, N};
part2 = part2.reshape(1, sizeof(shape2)/sizeof(shape2[0]), shape2);
Mat tmp_output(M, N, part1.type());
fastGemm(false, false, 1.0, part1, part2, 0.0, tmp_output, opt);
int newShape[] = {1, M, N};
tmp_output = tmp_output.reshape(1, sizeof(newShape)/sizeof(newShape[0]), newShape);
output.emplace_back(tmp_output);
}
fastGemmBatch(false, false, 1.0, reshapedInput1, reshapedInput2, 0.0, output, opt);
} else { } else {
Mat reshapedInput1 = input1;
Mat reshapedInput2 = input2;
// input1 should of size MxK // input1 should of size MxK
// check if input1 needs reshape, if need reshape
if (input1.dims > 2 || input1.size[0] != M || input1.size[1] != K) if (input1.dims > 2 || input1.size[0] != M || input1.size[1] != K)
{ {
int shape[] = {M, K}; int shape[] = {M, K};
@ -1374,30 +1334,18 @@ Mat LayerEinsumImpl::batchwiseMatMul(
} }
// input2 should be of size KxN // input2 should be of size KxN
// check if input2 needs reshape, if needs reshape
if (input2.dims > 2 || input2.size[0] != K || input2.size[1] != N) if (input2.dims > 2 || input2.size[0] != K || input2.size[1] != N)
{ {
int shape2[] = {K, N}; int shape2[] = {K, N};
reshapedInput2 = input2.reshape(1, 2, shape2); reshapedInput2 = input2.reshape(1, 2, shape2);
} }
Mat tmp_output(M, N, reshapedInput1.type()); output = Mat(M, N, reshapedInput1.type());
fastGemm(false, false, 1.0, reshapedInput1, reshapedInput2, 0.0, tmp_output, opt); fastGemm(false, false, 1.0, reshapedInput1, reshapedInput2, 0.0, output, opt);
int newShape[] = {1, M, N};
tmp_output = tmp_output.reshape(1, sizeof(newShape)/sizeof(newShape[0]), newShape);
output.emplace_back(tmp_output);
}
int outputDim[] = {static_cast<int>(output.size()), M, N};
Mat output_buffer = Mat::zeros(3, outputDim, CV_32F);
for (size_t i = 0; i < output.size(); i++) { output = output.reshape(1, {1, M, N});
Mat output_slice = output_buffer.row(i);
output[i].copyTo(output_slice);
} }
return output_buffer; return output;
}; };
Ptr<EinsumLayer> EinsumLayer::create(const LayerParams& params) Ptr<EinsumLayer> EinsumLayer::create(const LayerParams& params)
{ {

Loading…
Cancel
Save