From 4d0f13544db8ea0fd1e439992a7ffd215ce15114 Mon Sep 17 00:00:00 2001 From: Alexander Alekhin Date: Mon, 2 Mar 2020 17:13:02 +0300 Subject: [PATCH] Merge pull request #16700 from alalek:fix_core_matexpr_size_gemm core: fix MatExpr::size() for gemm() * core(test): MatExpr::size() test for gemm() * core: fix MatExpr::size() for gemm() --- modules/core/src/matrix_expressions.cpp | 13 ++++++++++--- modules/core/test/test_mat.cpp | 23 +++++++++++++++++++++++ 2 files changed, 33 insertions(+), 3 deletions(-) diff --git a/modules/core/src/matrix_expressions.cpp b/modules/core/src/matrix_expressions.cpp index 5ac1fafbd6..d7fb6e4228 100644 --- a/modules/core/src/matrix_expressions.cpp +++ b/modules/core/src/matrix_expressions.cpp @@ -117,8 +117,17 @@ public: void transpose(const MatExpr& expr, MatExpr& res) const CV_OVERRIDE; + Size size(const MatExpr& expr) const CV_OVERRIDE + { + return Size( + (expr.flags & GEMM_2_T) ? expr.b.rows : expr.b.cols, + (expr.flags & GEMM_1_T) ? expr.a.cols : expr.a.rows + ); + } + static void makeExpr(MatExpr& res, int flags, const Mat& a, const Mat& b, double alpha=1, const Mat& c=Mat(), double beta=1); + }; static MatOp_GEMM g_MatOp_GEMM; @@ -199,7 +208,7 @@ static inline bool isReciprocal(const MatExpr& e) { return isBin(e,'/') && (!e.b static inline bool isT(const MatExpr& e) { return e.op == &g_MatOp_T; } static inline bool isInv(const MatExpr& e) { return e.op == &g_MatOp_Invert; } static inline bool isSolve(const MatExpr& e) { return e.op == &g_MatOp_Solve; } -static inline bool isGEMM(const MatExpr& e) { return e.op == &g_MatOp_GEMM; } +//static inline bool isGEMM(const MatExpr& e) { return e.op == &g_MatOp_GEMM; } static inline bool isMatProd(const MatExpr& e) { return e.op == &g_MatOp_GEMM && (!e.c.data || e.beta == 0); } static inline bool isInitializer(const MatExpr& e) { return e.op == getGlobalMatOpInitializer(); } @@ -1240,8 +1249,6 @@ Size MatExpr::size() const { if( isT(*this) || isInv(*this) ) return Size(a.rows, a.cols); - if( isGEMM(*this) ) - return Size(b.cols, a.rows); if( isSolve(*this) ) return Size(b.cols, a.cols); if( isInitializer(*this) ) diff --git a/modules/core/test/test_mat.cpp b/modules/core/test/test_mat.cpp index 3fa8442d69..f4f3597034 100644 --- a/modules/core/test/test_mat.cpp +++ b/modules/core/test/test_mat.cpp @@ -2029,6 +2029,29 @@ TEST(Core_MatExpr, issue_16655) << "Mat: CV_8UC3 != " << typeToString(ab_mat.type()); } +TEST(Core_MatExpr, issue_16689) +{ + Mat a(Size(10, 5), CV_32FC1, 5); + Mat b(Size(10, 5), CV_32FC1, 2); + Mat bt(Size(5, 10), CV_32FC1, 3); + { + MatExpr r = a * bt; // gemm + EXPECT_EQ(Mat(r).size(), r.size()) << "[10x5] x [5x10] => [5x5]"; + } + { + MatExpr r = a * b.t(); // gemm + EXPECT_EQ(Mat(r).size(), r.size()) << "[10x5] x [10x5].t() => [5x5]"; + } + { + MatExpr r = a.t() * b; // gemm + EXPECT_EQ(Mat(r).size(), r.size()) << "[10x5].t() x [10x5] => [10x10]"; + } + { + MatExpr r = a.t() * bt.t(); // gemm + EXPECT_EQ(Mat(r).size(), r.size()) << "[10x5].t() x [5x10].t() => [10x10]"; + } +} + #ifdef HAVE_EIGEN TEST(Core_Eigen, eigen2cv_check_Mat_type) {