Merge pull request #22828 from WanliZhong:improve_matmul

DNN: make MatMul support 3D or 4D with broadcast
pull/22880/head
Alexander Smorkalov 2 years ago committed by GitHub
commit ac6fb17784
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 43
      modules/dnn/src/cuda4dnn/primitives/matmul.hpp
  2. 68
      modules/dnn/src/layers/fully_connected_layer.cpp
  3. 33
      modules/dnn/src/onnx/onnx_importer.cpp
  4. 1
      modules/dnn/test/test_onnx_importer.cpp

@ -23,9 +23,14 @@ namespace cv { namespace dnn { namespace cuda4dnn {
public:
using wrapper_type = GetCUDABackendWrapperType<T>;
MatMulOp(csl::Stream stream_, csl::cublas::Handle handle)
MatMulOp(csl::Stream stream_, csl::cublas::Handle handle, const Mat& constInp)
: stream(std::move(stream_)), cublasHandle(std::move(handle))
{
if (!constInp.empty())
{
constTensor = csl::makeTensorHeader<T>(constInp);
csl::copyMatToTensor<T>(constInp, constTensor, stream);
}
}
void forward(
@ -33,13 +38,20 @@ namespace cv { namespace dnn { namespace cuda4dnn {
const std::vector<cv::Ptr<BackendWrapper>>& outputs,
csl::Workspace& workspace) override
{
CV_Assert(inputs.size() == 2 && outputs.size() == 1);
CV_Assert((inputs.size() == 2 && constTensor.empty() ||
inputs.size() == 1 && !constTensor.empty()) && outputs.size() == 1);
auto input1_wrapper = inputs[0].dynamicCast<wrapper_type>();
auto input1 = input1_wrapper->getView();
auto input2_wrapper = inputs[1].dynamicCast<wrapper_type>();
auto input2 = input2_wrapper->getView();
csl::TensorView<T> input2;
if (constTensor.empty())
{
auto input2_wrapper = inputs[1].dynamicCast<wrapper_type>();
input2 = input2_wrapper->getView();
}
else
input2 = csl::TensorView<T>(constTensor);
auto output_wrapper = outputs[0].dynamicCast<wrapper_type>();
auto output = output_wrapper->getSpan();
@ -59,9 +71,18 @@ namespace cv { namespace dnn { namespace cuda4dnn {
auto m = input1.get_axis_size(-2);
auto n = input1.get_axis_size(-1);
auto k = input2.get_axis_size(-1);
auto b = input1.size() / m / n;
CV_Assert(input2.get_axis_size(-2) == n);
int k;
if (constTensor.empty())
{
k = input2.get_axis_size(-1);
CV_Assert(input2.get_axis_size(-2) == n);
}
else
{
k = input2.get_axis_size(-2);
CV_Assert(input2.get_axis_size(-1) == n);
}
CV_Assert(output.get_axis_size(-2) == m);
CV_Assert(output.get_axis_size(-1) == k);
@ -70,24 +91,28 @@ namespace cv { namespace dnn { namespace cuda4dnn {
CV_Assert(b == 1);
CV_Assert(get_effective_rank(input1) <= 2);
CV_Assert(get_effective_rank(input2) <= 2);
csl::tensor_ops::gemm<T>(cublasHandle, 0.0, output, 1.0, false, input1, false, input2);
csl::tensor_ops::gemm<T>(cublasHandle, 0.0, output, 1.0, false, input1, !constTensor.empty(), input2);
}
else
{
CV_Assert(rank >= 3);
input1.reshape(b, m, n);
input2.reshape(b, n, k);
if (constTensor.empty())
input2.reshape(b, n, k);
else
input2.reshape(b, k, n);
output.reshape(b, m, k);
input1.squeeze_to(3);
input2.squeeze_to(3);
output.squeeze_to(3);
csl::tensor_ops::gemmStridedBatched<T>(cublasHandle, 0.0, output, 1.0, false, input1, false, input2);
csl::tensor_ops::gemmStridedBatched<T>(cublasHandle, 0.0, output, 1.0, false, input1, !constTensor.empty(), input2);
}
}
private:
csl::Stream stream;
csl::cublas::Handle cublasHandle;
csl::Tensor<T> constTensor;
};
}}} /* namespace cv::dnn::cuda4dnn */

@ -85,6 +85,7 @@ public:
bias = params.get<bool>("bias_term", true);
axis = params.get<int>("axis", 1);
isMatMul = params.get<bool>("is_matmul", false);
if (!blobs.empty())
{
CV_Assert(1 <= blobs.size() && blobs.size() <= 2);
@ -94,6 +95,7 @@ public:
CV_Assert(blobs[0].dims >= 2 && (size_t)(innerSize * numOutput) == blobs[0].total());
CV_Assert(!bias || (blobs.size() == 2 && (size_t)numOutput == blobs[1].total()));
blobs[0].copyTo(oriMat);
weightsMat = blobs[0] = blobs[0].reshape(1, numOutput);
int vecsize = weightsMat.cols;
if (vecsize % VEC_ALIGN != 0)
@ -108,6 +110,8 @@ public:
if (bias)
biasMat = blobs[1] = blobs[1].reshape(1, 1);
else if(isMatMul)
biasMat = Mat::zeros(1, oriMat.size[oriMat.dims - 2], weightsMat.type());
else
biasMat = Mat::zeros(1, numOutput, weightsMat.type());
}
@ -153,7 +157,10 @@ public:
CV_Assert(!transA && !transB);
CV_CheckEQ(inputsTmp.size(), (size_t)1, "");
CV_CheckEQ(blobs[0].dims, 2, "");
numOutput = blobs[0].size[0];
if(isMatMul)
numOutput = oriMat.size[oriMat.dims - 2];
else
numOutput = blobs[0].size[0];
CV_Assert(!bias || (size_t)numOutput == blobs[1].total());
cAxis = normalize_axis(axis, inputsTmp[0]);
}
@ -519,16 +526,40 @@ public:
if (!blobs.empty())
{
CV_Assert(!transA && !transB);
int axisCan = normalize_axis(axis, input[0].dims);
int outerSize = input[0].total(0, axisCan);
int inp1Dim = input[0].dims;
if (isMatMul)
{
int matNum = input[0].total(0, inp1Dim - 2);
int rowMatMul = oriMat.size[oriMat.dims - 2];
Mat srcMatTmp = input[0].reshape(1, matNum);
Mat dstMatTmp = output[0].reshape(1, matNum);
int outerSize = input[0].size[inp1Dim - 2];
int rowStart = -rowMatMul;
for (int n = 0; n < matNum; ++n)
{
Mat srcMat = srcMatTmp.row(n).reshape(1, outerSize);
Mat dstMat = dstMatTmp.row(n).reshape(1, outerSize);
rowStart = (rowStart + rowMatMul) % weightsMat.rows;
Mat weiMat = weightsMat.rowRange(rowStart, rowStart + rowMatMul);
for (size_t i = 0; i < input.size(); i++)
const int nstripes = getNumThreads();
FullyConnected::run(srcMat, weiMat, biasMat, dstMat, activ.get(), nstripes);
}
}
else
{
Mat srcMat = input[i].reshape(1, outerSize);
Mat dstMat = output[i].reshape(1, outerSize);
int axisCan = normalize_axis(axis, inp1Dim);
int outerSize = input[0].total(0, axisCan);
for (size_t i = 0; i < input.size(); i++)
{
Mat srcMat = input[i].reshape(1, outerSize);
Mat dstMat = output[i].reshape(1, outerSize);
const int nstripes = getNumThreads();
FullyConnected::run(srcMat, weightsMat, biasMat, dstMat, activ.get(), nstripes);
const int nstripes = getNumThreads();
FullyConnected::run(srcMat, weightsMat, biasMat, dstMat, activ.get(), nstripes);
}
}
}
else
@ -579,14 +610,26 @@ public:
) override
{
auto context = reinterpret_cast<csl::CSLContext*>(context_);
auto input_wrapper = inputs[0].dynamicCast<CUDABackendWrapper>();
if (weightsMat.empty())
if (weightsMat.empty() || isMatMul)
{
CV_Assert(!bias);
return make_cuda_node<cuda4dnn::MatMulOp>(preferableTarget, std::move(context->stream), std::move(context->cublas_handle));
int inp2Dim;
// broadcast is not supported with CUDA
if(weightsMat.empty())
{
auto input_wrapper2 = inputs[1].dynamicCast<CUDABackendWrapper>();
inp2Dim = input_wrapper2->getRank();
}else
inp2Dim = oriMat.dims;
if(input_wrapper->getRank() == inp2Dim)
return make_cuda_node<cuda4dnn::MatMulOp>(preferableTarget, std::move(context->stream), std::move(context->cublas_handle), oriMat);
else
return Ptr<BackendNode>();
}
auto input_wrapper = inputs[0].dynamicCast<CUDABackendWrapper>();
auto flatten_start_axis = normalize_axis(axis, input_wrapper->getRank());
auto biasMat_ = bias ? biasMat : Mat();
return make_cuda_node<cuda4dnn::InnerProductOp>(preferableTarget, std::move(context->stream), std::move(context->cublas_handle), flatten_start_axis, weightsMat, biasMat_);
@ -752,8 +795,9 @@ public:
}
bool bias;
Mat weightsMat, biasMat;
Mat weightsMat, biasMat, oriMat;
bool transA, transB;
bool isMatMul = false;
Ptr<ActivationLayer> activ;
};

@ -2088,30 +2088,21 @@ void ONNXImporter::parseMatMul(LayerParams& layerParams, const opencv_onnx::Node
if (constBlobs.find(node_proto.input(1)) != constBlobs.end())
{
Mat blob = getBlob(node_proto, 1);
Mat transBlob;
secondInpDims = blob.dims;
if (secondInpDims == 2)
{
layerParams.blobs.push_back(blob.t());
layerParams.set("num_output", layerParams.blobs[0].size[0]);
}
else
{
LayerParams constParams;
constParams.name = layerParams.name + "/const_1";
constParams.type = "Const";
constParams.blobs.push_back(blob);
opencv_onnx::NodeProto tmpProto;
tmpProto.add_output(constParams.name);
addLayer(constParams, tmpProto);
node_proto.set_input(1, constParams.name);
}
}
else
// create order transposing last 2 dimensions
std::vector<int> order(secondInpDims);
std::iota(order.begin(), order.end(), 0);
std::swap(order[secondInpDims - 2], order[secondInpDims - 1]);
transposeND(blob, order, transBlob);
layerParams.blobs.push_back(transBlob);
int numOutput = layerParams.blobs[0].total(0, secondInpDims - 1);
layerParams.set("num_output", numOutput);
layerParams.set("is_matmul", true);
} else
secondInpDims = outShapes[node_proto.input(1)].size();
layerParams.set("axis", firstInpDims - secondInpDims + 1);
layerParams.set("axis", firstInpDims - 1);
addLayer(layerParams, node_proto);
}

@ -921,6 +921,7 @@ TEST_P(Test_ONNX_layers, MatMul_init)
testONNXModels("matmul_4d_init");
testONNXModels("matmul_init_2");
testONNXModels("matmul_init_bcast");
}
TEST_P(Test_ONNX_layers, MatMulAdd)

Loading…
Cancel
Save