mirror of https://github.com/opencv/opencv.git
Merge pull request #24476 from fengyuentau:attention_layer
dnn: add attention layer #24476 Resolves #24609 Merge with: https://github.com/opencv/opencv_extra/pull/1128. Attention operator spec from onnxruntime: https://github.com/microsoft/onnxruntime/blob/v1.16.1/docs/ContribOperators.md#com.microsoft.Attention. TODO: - [x] benchmark (before this PR vs. with this PR vs. ORT). - [x] Layer fusion: Take care Slice with end=INT64_MAX. - [x] Layer fusion: match more potential attention (VIT) patterns. - [x] Single-head attention is supported. - [x] Test AttentionSubgraph fusion. - [x] Add acc tests for VIT_B_32 and VitTrack - [x] Add perf tests for VIT_B_32 and VitTrack ## Benchmarks Platform: Macbook Air M1. ### Attention Subgraph Input scale: [1, 197, 768]. | | mean (ms) | median (ms) | min (ms) | | ---------------------- | --------- | ----------- | -------- | | w/ Attention (this PR) | 3.75 | 3.68 | 3.22 | | w/o Attention | 9.06 | 9.01 | 8.24 | | ORT (python) | 4.32 | 2.63 | 2.50 | ### ViTs All data in millisecond (ms). | ViTs | With Attention | Without Attention | ORT | | -------- | -------------- | ----------------- | ------ | | vit_b_16 | 302.77 | 365.35 | 109.70 | | vit_b_32 | 89.92 | 116.22 | 30.36 | | vit_l_16 | 1593.32 | 1730.74 | 419.92 | | vit_l_32 | 468.11 | 577.41 | 134.12 | | VitTrack | 3.80 | 3.87 | 2.25 | ### Pull Request Readiness Checklist See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request - [x] I agree to contribute to the project under Apache 2 License. - [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV - [x] The PR is proposed to the proper branch - [x] There is a reference to the original bug report and related work - [x] There is accuracy test, performance test and test data in opencv_extra repository, if applicable Patch to opencv_extra has the same branch name. - [x] The feature is well documented and sample code can be built with the project CMakepull/24728/head
parent
e64c5dc4c6
commit
0521a3a384
13 changed files with 891 additions and 66 deletions
@ -0,0 +1,272 @@ |
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#include "../precomp.hpp" |
||||
#include "cpu_kernels/fast_gemm.hpp" |
||||
#include "cpu_kernels/softmax.hpp" |
||||
|
||||
#include <opencv2/dnn/shape_utils.hpp> |
||||
|
||||
namespace cv { namespace dnn { |
||||
|
||||
static void packWeight(size_t num_heads, size_t head_size, size_t input_hidden_size, |
||||
const float *weight_data, size_t hidden_size, std::vector<float> &packed_weight, const FastGemmOpt &opt) { |
||||
// num_heads * pack(head_size, input_hidden_size)
|
||||
size_t pack_size = fastGemmPackBSize(head_size, input_hidden_size, opt); |
||||
size_t packed_weight_size = num_heads * pack_size; |
||||
packed_weight.resize(packed_weight_size, 0.f); |
||||
auto *packed_weight_data = packed_weight.data(); |
||||
for (size_t i = 0; i < num_heads; i++) { |
||||
fastGemmPackB(false, head_size, input_hidden_size, weight_data, hidden_size, packed_weight_data, opt); |
||||
packed_weight_data += pack_size; |
||||
weight_data += head_size; |
||||
} |
||||
} |
||||
|
||||
// Operator spec: https://github.com/microsoft/onnxruntime/blob/v1.16.1/docs/ContribOperators.md#com.microsoft.Attention
|
||||
class AttentionLayerImpl CV_FINAL : public AttentionLayer { |
||||
public: |
||||
AttentionLayerImpl(const LayerParams ¶ms) { |
||||
setParamsFrom(params); |
||||
|
||||
CV_CheckTrue(params.has("num_heads"), "DNN/Attention: num_heads is required but missing"); |
||||
num_heads = params.get<int>("num_heads"); // required, no default value
|
||||
|
||||
CV_CheckTrue(params.has("qkv_hidden_sizes"), "DNN/Attention: qkv_hidden_sizes is required but missing"); |
||||
auto param_qkv_hidden_sizes = params.get("qkv_hidden_sizes"); |
||||
CV_CheckEQ(param_qkv_hidden_sizes.size(), 3, "DNN/Attention: qkv_hidden_sizes must and only have three elements"); |
||||
|
||||
qkv_hidden_sizes.clear(); |
||||
qkv_hidden_sizes.resize(3); |
||||
qkv_hidden_sizes[0] = static_cast<size_t>(param_qkv_hidden_sizes.get<int>(0)); |
||||
qkv_hidden_sizes[1] = static_cast<size_t>(param_qkv_hidden_sizes.get<int>(1)); |
||||
/* v_hidden_size needs to be initialized in finalize in case v_slice_end=INT_MAX */ |
||||
|
||||
qkv_head_sizes.clear(); |
||||
qkv_head_sizes.resize(3); |
||||
qkv_head_sizes[0] = static_cast<size_t>(qkv_hidden_sizes[0] / num_heads); |
||||
qkv_head_sizes[1] = static_cast<size_t>(qkv_hidden_sizes[1] / num_heads); |
||||
|
||||
scale = 1.f / params.get<float>("scale", sqrt(qkv_head_sizes[0])); |
||||
|
||||
output_ndims = params.get<int>("output_ndims", 3); |
||||
|
||||
is_prepacked = false; |
||||
} |
||||
|
||||
virtual bool supportBackend(int backendId) CV_OVERRIDE { |
||||
return backendId == DNN_BACKEND_OPENCV; |
||||
} |
||||
|
||||
virtual bool getMemoryShapes(const std::vector<MatShape> &inputs, |
||||
const int requiredOutputs, |
||||
std::vector<MatShape> &outputs, |
||||
std::vector<MatShape> &internals) const CV_OVERRIDE { |
||||
CV_CheckEQ(inputs.size(), static_cast<size_t>(3), "DNN/Attention: three inputs are required"); |
||||
const auto &input_shape = inputs[0]; |
||||
const auto &weight_shape = inputs[1]; |
||||
const auto &bias_shape = inputs[2]; |
||||
|
||||
CV_CheckEQ(input_shape.size(), static_cast<size_t>(3), "DNN/Attention: invalid input dimension"); |
||||
CV_CheckEQ(weight_shape.size(), static_cast<size_t>(2), "DNN/Attention: invalid weight dimension"); |
||||
|
||||
CV_CheckEQ(input_shape[2], weight_shape[0], "DNN/Attention: invalid input shape"); |
||||
CV_CheckEQ(weight_shape[1], bias_shape[0], "DNN/Attention: invalid weight or bias shape"); |
||||
|
||||
if (output_ndims == 3) { |
||||
outputs.assign(1, inputs[0]); |
||||
} else if (output_ndims == 2) { |
||||
int batch = input_shape[0], seq_len = input_shape[1], input_hidden_size = input_shape[2]; |
||||
MatShape output_shape{batch * seq_len, input_hidden_size}; |
||||
outputs.assign(1, output_shape); |
||||
} else { |
||||
CV_Error(Error::StsBadArg, format("DNN/Attention: invalid output dimension %zu, valid value is 2 or 3", output_ndims)); |
||||
} |
||||
return false; |
||||
} |
||||
|
||||
virtual void finalize(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr) CV_OVERRIDE { |
||||
opt.init(); |
||||
|
||||
std::vector<Mat> inputs; |
||||
inputs_arr.getMatVector(inputs); |
||||
const auto input_shape = shape(inputs[0]); |
||||
batch_size = static_cast<size_t>(input_shape[0]); |
||||
seq_len = static_cast<size_t>(input_shape[1]); |
||||
input_hidden_size = static_cast<size_t>(input_shape[2]); |
||||
|
||||
const auto weight_shape = shape(inputs[1]); |
||||
hidden_size = weight_shape[1]; |
||||
qkv_hidden_sizes[2] = hidden_size - qkv_hidden_sizes[0] - qkv_hidden_sizes[1]; |
||||
qkv_head_sizes[2] = static_cast<size_t>(qkv_hidden_sizes[2] / num_heads); |
||||
} |
||||
|
||||
void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE { |
||||
CV_TRACE_FUNCTION(); |
||||
CV_TRACE_ARG_VALUE(name, "name", name.c_str()); |
||||
|
||||
if (inputs_arr.depth() == CV_16S) |
||||
{ |
||||
forward_fallback(inputs_arr, outputs_arr, internals_arr); |
||||
return; |
||||
} |
||||
|
||||
std::vector<Mat> inputs, outputs; |
||||
inputs_arr.getMatVector(inputs); |
||||
outputs_arr.getMatVector(outputs); |
||||
|
||||
// prepack weights
|
||||
if (!is_prepacked) { |
||||
// prepack
|
||||
const auto &weight = inputs[1]; |
||||
const auto *weight_data = weight.ptr<const float>(); |
||||
packWeight(num_heads, qkv_head_sizes[0], input_hidden_size, weight_data, hidden_size, packed_weight_q, opt); |
||||
packWeight(num_heads, qkv_head_sizes[1], input_hidden_size, weight_data + qkv_hidden_sizes[0], hidden_size, packed_weight_k, opt); |
||||
packWeight(num_heads, qkv_head_sizes[2], input_hidden_size, weight_data + qkv_hidden_sizes[0] + qkv_hidden_sizes[1], hidden_size, packed_weight_v, opt); |
||||
|
||||
is_prepacked = true; |
||||
} |
||||
|
||||
float *packed_weights[3] = {packed_weight_q.data(), packed_weight_k.data(), packed_weight_v.data()}; |
||||
size_t packed_weights_size[3] = {packed_weight_q.size() / num_heads, packed_weight_k.size() / num_heads, packed_weight_v.size() / num_heads}; |
||||
|
||||
Mat gemm_buffer = Mat::zeros(1, int(batch_size * seq_len * hidden_size), CV_32F); |
||||
auto *Q = gemm_buffer.ptr<float>(); |
||||
auto *K = Q + batch_size * seq_len * qkv_hidden_sizes[0]; |
||||
auto *V = K + batch_size * seq_len * qkv_hidden_sizes[1]; |
||||
float *QKV[3] = {Q, K, V}; // Q, K, V: [B, N, S, H]
|
||||
{ |
||||
const auto &input = inputs[0]; |
||||
const auto &bias = inputs[2]; |
||||
const auto *input_data = input.ptr<const float>(); |
||||
const auto *bias_data = bias.ptr<const float>(); |
||||
|
||||
opt.multi_thread = false; |
||||
auto fn = [&](const Range &r) { |
||||
for (int i = r.start; i < r.end; i++) { |
||||
const int batch_index = static_cast<int>((i / 3) / num_heads); |
||||
const int head_index = static_cast<int>((i / 3) % num_heads); |
||||
const int qkv_index = static_cast<int>(i % 3); |
||||
|
||||
auto *dst = QKV[qkv_index]; |
||||
size_t head_size = qkv_head_sizes[qkv_index]; |
||||
|
||||
int input_offset = batch_index * seq_len * input_hidden_size; |
||||
int bias_offset = qkv_index * qkv_hidden_sizes[0] + head_index * head_size; |
||||
int dst_offset = (batch_index * num_heads + head_index) * (seq_len * head_size); |
||||
|
||||
// broadcast bias ([NH] -> [BN, SH]) and make copy to dst
|
||||
const auto *bias_data_src = bias_data + bias_offset; |
||||
auto *dst_data = dst + dst_offset; |
||||
for (size_t seq_len_idx = 0; seq_len_idx < seq_len; seq_len_idx++) { |
||||
std::memcpy(dst_data, bias_data_src, head_size * sizeof(float)); |
||||
dst_data += head_size; |
||||
} |
||||
|
||||
auto *packed_weight = packed_weights[qkv_index] + packed_weights_size[qkv_index] * head_index; |
||||
// single-thread gemm kernel
|
||||
fastGemm(false, seq_len, head_size, input_hidden_size, |
||||
1.f, input_data + input_offset, input_hidden_size, |
||||
packed_weight, 1.f, dst + dst_offset, head_size, opt); |
||||
} |
||||
}; |
||||
|
||||
size_t loops = 3 * batch_size * num_heads; |
||||
double nstripes = loops * seq_len * qkv_head_sizes[0] * input_hidden_size * (1 / 1024.0); |
||||
parallel_for_(Range(0, loops), fn, nstripes); |
||||
} |
||||
|
||||
// Compute softmax(scale * matmul(Q, K))
|
||||
std::vector<int> attention_prob_shape{int(batch_size * num_heads), int(seq_len), int(seq_len)}; |
||||
Mat attention_prob = Mat::zeros(attention_prob_shape.size(), attention_prob_shape.data(), CV_32F); |
||||
{ |
||||
auto *output = attention_prob.ptr<float>(); |
||||
|
||||
auto loops = batch_size * num_heads; |
||||
auto seq_len_square = seq_len * seq_len; |
||||
auto qk_head_size = qkv_head_sizes[0]; |
||||
auto qk_inner_size = seq_len * qk_head_size; |
||||
|
||||
// Compute scale * matmul(Q, K)
|
||||
opt.multi_thread = false; |
||||
parallel_for_(Range(0, loops), [&] (const Range r) { |
||||
for (int i = r.start; i < r.end; i++) { |
||||
const int output_offset = i * seq_len_square; |
||||
|
||||
const auto *q = Q + qk_inner_size * i, *k = K + qk_inner_size * i; |
||||
fastGemm(false, true, seq_len, qk_head_size, seq_len, qk_head_size, |
||||
scale, q, qk_head_size, 1, |
||||
k, qk_head_size, 1, 0.f, |
||||
output + output_offset, seq_len, opt); |
||||
} |
||||
}, loops * seq_len * qk_head_size * seq_len * (1 / 1024.0)); |
||||
|
||||
// Compute softmax
|
||||
softmax(attention_prob, attention_prob, attention_prob_shape.size() - 1); |
||||
} |
||||
|
||||
// Compute np.matmul(attention_prob, V)
|
||||
Mat output_buffer = Mat::zeros(1, int(batch_size * num_heads * seq_len * qkv_head_sizes[2]), CV_32F); |
||||
{ |
||||
auto *output = outputs[0].ptr<float>(); |
||||
auto *output_buff = output_buffer.ptr<float>(); |
||||
const auto *prob = attention_prob.ptr<const float>(); |
||||
|
||||
auto loops = batch_size * num_heads; |
||||
auto prob_inner_size = seq_len * seq_len; |
||||
auto v_head_size = qkv_head_sizes[2]; |
||||
auto v_inner_size = seq_len * v_head_size; |
||||
|
||||
opt.multi_thread = false; |
||||
parallel_for_(Range(0, loops), [&] (const Range &r) { |
||||
for (int i = r.start; i < r.end; i++) { |
||||
const int output_offset = i * v_inner_size; |
||||
|
||||
const auto *p = prob + i * prob_inner_size, *v = V + i * v_inner_size; |
||||
fastGemm(false, false, seq_len, seq_len, seq_len, v_head_size, |
||||
1.f, p, seq_len, 1, |
||||
v, v_head_size, 1, 0.f, |
||||
output_buff + output_offset, v_head_size, opt); |
||||
|
||||
// tranpose on the fly
|
||||
const int batch_index = static_cast<int>(i / num_heads); |
||||
const int head_index = static_cast<int>(i % num_heads); |
||||
auto *src = output_buff + output_offset; |
||||
auto *dst = output + (batch_index * seq_len * num_heads + head_index) * v_head_size; |
||||
for (int j = 0; j < seq_len; j++) { |
||||
std::memcpy(dst, src, v_head_size * sizeof(float)); |
||||
src += v_head_size; |
||||
dst += qkv_hidden_sizes[2]; |
||||
} |
||||
} |
||||
}, loops * seq_len * seq_len * v_head_size * (1 / 1024.0)); |
||||
} |
||||
} |
||||
|
||||
private: |
||||
size_t num_heads; |
||||
std::vector<size_t> qkv_hidden_sizes; // order: {qk_hidden_size, qk_hidden_size, v_hidden_size}
|
||||
float scale; |
||||
size_t output_ndims; |
||||
|
||||
std::vector<size_t> qkv_head_sizes; // order: {qk_head_size, qk_head_size, v_head_size}
|
||||
|
||||
size_t batch_size; |
||||
size_t seq_len; |
||||
size_t input_hidden_size; |
||||
size_t hidden_size; |
||||
|
||||
bool is_prepacked; |
||||
std::vector<float> packed_weight_q; |
||||
std::vector<float> packed_weight_k; |
||||
std::vector<float> packed_weight_v; |
||||
|
||||
FastGemmOpt opt; |
||||
}; |
||||
|
||||
Ptr<AttentionLayer> AttentionLayer::create(const LayerParams ¶ms) { |
||||
return makePtr<AttentionLayerImpl>(params); |
||||
} |
||||
|
||||
}} // cv::dnn
|
Loading…
Reference in new issue