Merge pull request #24552 from fengyuentau:layernorm_backends

dnn: add openvino, opencl and cuda backends for layer normalization layer #24552

Merge after https://github.com/opencv/opencv/pull/24544.

Todo:

- [x] openvino
- [x] opencl
- [x] cuda

### Pull Request Readiness Checklist

See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request

- [x] I agree to contribute to the project under Apache 2 License.
- [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV
- [x] The PR is proposed to the proper branch
- [x] There is a reference to the original bug report and related work
- [x] There is accuracy test, performance test and test data in opencv_extra repository, if applicable
      Patch to opencv_extra has the same branch name.
- [x] The feature is well documented and sample code can be built with the project CMake
pull/24426/merge
Yuantao Feng 12 months ago committed by GitHub
parent fba3c947ef
commit d05fb709f9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 69
      modules/dnn/src/cuda/mvn.cu
  2. 8
      modules/dnn/src/cuda4dnn/kernels/mvn.hpp
  3. 93
      modules/dnn/src/cuda4dnn/primitives/layer_norm.hpp
  4. 172
      modules/dnn/src/layers/layer_norm.cpp
  5. 6
      modules/dnn/src/opencl/mvn.cl
  6. 76
      modules/dnn/test/test_onnx_conformance_layer_filter__openvino.inl.hpp

@ -68,15 +68,36 @@ namespace raw {
}
template <class T>
__global__ void normalize_mean_variance_channelwise(Span<T> output, View<T> input, View<T> scale, View<T> bias, View<float> means, View<float> stdev, size_type inner_size, size_type C) {
__global__ void normalize_mean_variance_channelwise(Span<T> output, View<T> input, View<T> scale, View<T> bias, View<float> means, View<float> inv_stddev, size_type inner_size, size_type C) {
for (auto idx : grid_stride_range(output.size())) {
const index_type outer_idx = idx / inner_size;
const index_type c = outer_idx % C;
auto s = static_cast<float>(scale[c]) * stdev[outer_idx];
auto s = static_cast<float>(scale[c]) * inv_stddev[outer_idx];
auto b = static_cast<float>(bias[c]);
output[idx] = (static_cast<float>(input[idx]) - means[outer_idx]) * s + b;
}
}
template <class T>
__global__ void normalize_mean_variance_layernorm(Span<T> output, View<T> input, View<T> scale, View<float> means, View<float> inv_stddev, size_type inner_size) {
for (auto idx : grid_stride_range(output.size())) {
const index_type outer_idx = idx / inner_size;
const index_type inner_idx = idx % inner_size;
auto s = static_cast<float>(scale[inner_idx]) * inv_stddev[outer_idx];
output[idx] = (static_cast<float>(input[idx]) - means[outer_idx]) * s;
}
}
template <class T>
__global__ void normalize_mean_variance_layernorm_with_bias(Span<T> output, View<T> input, View<T> scale, View<T> bias, View<float> means, View<float> inv_stddev, size_type inner_size) {
for (auto idx : grid_stride_range(output.size())) {
const index_type outer_idx = idx / inner_size;
const index_type inner_idx = idx % inner_size;
auto s = static_cast<float>(scale[inner_idx]) * inv_stddev[outer_idx];
auto b = static_cast<float>(bias[inner_idx]);
output[idx] = (static_cast<float>(input[idx]) - means[outer_idx]) * s + b;
}
}
}
template <class T>
@ -154,20 +175,54 @@ template void normalize_mean_variance(const Stream&, Span<__half>, View<__half>,
template void normalize_mean_variance(const Stream&, Span<float>, View<float>, View<float>, View<float>, std::size_t);
template <class T>
void normalize_mean_variance_channelwise(const Stream& stream, Span<T> output, View<T> input, View<T> scale, View<T> bias, View<float> means, View<float> stdev, std::size_t inner_size, std::size_t C)
void normalize_mean_variance_channelwise(const Stream& stream, Span<T> output, View<T> input, View<T> scale, View<T> bias, View<float> means, View<float> inv_stddev, std::size_t inner_size, std::size_t C)
{
CV_Assert(input.size() == output.size());
CV_Assert(input.size() / inner_size == means.size());
CV_Assert(means.size() == stdev.size());
CV_Assert(means.size() == inv_stddev.size());
auto kernel = raw::normalize_mean_variance_channelwise<T>;
auto policy = make_policy(kernel, output.size(), 0, stream);
launch_kernel(kernel, policy, output, input, scale, bias, means, stdev, inner_size, C);
launch_kernel(kernel, policy, output, input, scale, bias, means, inv_stddev, inner_size, C);
}
#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
template void normalize_mean_variance_channelwise(const Stream&, Span<__half> /*output*/, View<__half> /*input*/, View<__half> /*scale*/, View<__half> /*bias*/, View<float> /*means*/, View<float> /*inv_stddev*/, std::size_t, std::size_t);
#endif
template void normalize_mean_variance_channelwise(const Stream&, Span<float> /*output*/, View<float> /*input*/, View<float> /*scale*/, View<float> /*bias*/, View<float> /*means*/, View<float> /*inv_stddev*/, std::size_t, std::size_t);
template <class T>
void normalize_mean_variance_layernorm(const Stream& stream, Span<T> output, View<T> input, View<T> scale, View<float> means, View<float> inv_stddev, std::size_t inner_size)
{
CV_Assert(input.size() == output.size());
CV_Assert(input.size() / inner_size == means.size());
CV_Assert(means.size() == inv_stddev.size());
auto kernel = raw::normalize_mean_variance_layernorm<T>;
auto policy = make_policy(kernel, output.size(), 0, stream);
launch_kernel(kernel, policy, output, input, scale, means, inv_stddev, inner_size);
}
#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
template void normalize_mean_variance_layernorm(const Stream&, Span<__half> /*output*/, View<__half> /*input*/, View<__half> /*scale*/, View<float> /*means*/, View<float> /*inv_stddev*/, std::size_t);
#endif
template void normalize_mean_variance_layernorm(const Stream&, Span<float> /*output*/, View<float> /*input*/, View<float> /*scale*/, View<float> /*means*/, View<float> /*inv_stddev*/, std::size_t);
template <class T>
void normalize_mean_variance_layernorm(const Stream& stream, Span<T> output, View<T> input, View<T> scale, View<T> bias, View<float> means, View<float> inv_stddev, std::size_t inner_size)
{
CV_Assert(input.size() == output.size());
CV_Assert(input.size() / inner_size == means.size());
CV_Assert(means.size() == inv_stddev.size());
auto kernel = raw::normalize_mean_variance_layernorm_with_bias<T>;
auto policy = make_policy(kernel, output.size(), 0, stream);
launch_kernel(kernel, policy, output, input, scale, bias, means, inv_stddev, inner_size);
}
#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
template void normalize_mean_variance_channelwise(const Stream&, Span<__half> /*output*/, View<__half> /*input*/, View<__half> /*scale*/, View<__half> /*bias*/, View<float> /*means*/, View<float> /*stdev*/, std::size_t, std::size_t);
template void normalize_mean_variance_layernorm(const Stream&, Span<__half> /*output*/, View<__half> /*input*/, View<__half> /*scale*/, View<__half> /*bias*/, View<float> /*means*/, View<float> /*inv_stddev*/, std::size_t);
#endif
template void normalize_mean_variance_channelwise(const Stream&, Span<float> /*output*/, View<float> /*input*/, View<float> /*scale*/, View<float> /*bias*/, View<float> /*means*/, View<float> /*stdev*/, std::size_t, std::size_t);
template void normalize_mean_variance_layernorm(const Stream&, Span<float> /*output*/, View<float> /*input*/, View<float> /*scale*/, View<float> /*bias*/, View<float> /*means*/, View<float> /*inv_stddev*/, std::size_t);
}}}} /* namespace cv::dnn::cuda4dnn::kernels */

@ -27,7 +27,13 @@ template <class T>
void normalize_mean_variance(const csl::Stream& stream, csl::Span<T> output, csl::View<T> input, csl::View<float> means, csl::View<float> scale, std::size_t inner_size);
template <class T>
void normalize_mean_variance_channelwise(const csl::Stream &stream, csl::Span<T> output, csl::View<T> input, csl::View<T> scale, csl::View<T> bias, csl::View<float> means, csl::View<float> stdev, std::size_t inner_size, std::size_t C);
void normalize_mean_variance_channelwise(const csl::Stream &stream, csl::Span<T> output, csl::View<T> input, csl::View<T> scale, csl::View<T> bias, csl::View<float> means, csl::View<float> inv_stddev, std::size_t inner_size, std::size_t C);
template <class T>
void normalize_mean_variance_layernorm(const csl::Stream &stream, csl::Span<T> output, csl::View<T> input, csl::View<T> scale, csl::View<float> means, csl::View<float> inv_stddev, std::size_t inner_size);
template <class T>
void normalize_mean_variance_layernorm(const csl::Stream &stream, csl::Span<T> output, csl::View<T> input, csl::View<T> scale, csl::View<T> bias, csl::View<float> means, csl::View<float> inv_stddev, std::size_t inner_size);
}}}} /* namespace cv::dnn::cuda4dnn::kernels */

@ -0,0 +1,93 @@
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#ifndef OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_LAYER_NORM_HPP
#define OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_LAYER_NORM_HPP
#include "../../op_cuda.hpp"
#include "../csl/stream.hpp"
#include "../csl/span.hpp"
#include "../csl/tensor.hpp"
#include "../csl/workspace.hpp"
#include "../kernels/fill_copy.hpp"
#include "../kernels/mvn.hpp"
#include <opencv2/core.hpp>
#include <cstddef>
#include <vector>
#include <utility>
namespace cv { namespace dnn { namespace cuda4dnn {
template <class T>
class LayerNormOp final : public CUDABackendNode {
public:
using wrapper_type = GetCUDABackendWrapperType<T>;
LayerNormOp(csl::Stream stream_, int normalized_axis, float epsilon_, size_t loops)
: stream(std::move(stream_)), epsilon(epsilon_) {
CV_CheckGE(normalized_axis, 0, "LayerNorm/CUDA: axis needs to be normalized");
axis = static_cast<size_t>(normalized_axis);
csl::WorkspaceBuilder builder;
builder.require<float>(loops);
builder.require<float>(loops);
scratch_mem_in_bytes = builder.required_workspace_size();
}
void forward(const std::vector<cv::Ptr<BackendWrapper>>& inputs,
const std::vector<cv::Ptr<BackendWrapper>>& outputs,
csl::Workspace& workspace) override {
auto input_wrapper = inputs[0].dynamicCast<wrapper_type>();
auto scale_wrapper = inputs[1].dynamicCast<wrapper_type>();
auto input = input_wrapper->getView();
auto scale = scale_wrapper->getView();
auto output_wrapper = outputs[0].dynamicCast<wrapper_type>();
auto output = output_wrapper->getSpan();
auto loops = input.size_range(0, axis);
auto norm_size = input.size_range(axis, input.rank());
if (norm_size == 1) {
kernels::fill<T>(stream, output, 0.f);
return;
} else {
auto ws_allocator = csl::WorkspaceAllocator(workspace);
auto mean = ws_allocator.get_span<float>(loops);
kernels::fill<float>(stream, mean, 0.f);
auto inv_stddev = ws_allocator.get_span<float>(loops);
kernels::fill<float>(stream, inv_stddev, 0.f);
kernels::reduce_mean_sqr_sum<T>(stream, mean, inv_stddev, input, norm_size);
kernels::compute_normalization_scale(stream, inv_stddev, mean, inv_stddev, norm_size, epsilon);
if (inputs.size() == 3) {
auto bias_wrapper = inputs[2].dynamicCast<wrapper_type>();
auto bias = bias_wrapper->getView();
kernels::normalize_mean_variance_layernorm<T>(stream, output, input, scale, bias, mean, inv_stddev, norm_size);
} else {
kernels::normalize_mean_variance_layernorm<T>(stream, output, input, scale, mean, inv_stddev, norm_size);
}
}
}
std::size_t get_workspace_memory_in_bytes() const noexcept override { return scratch_mem_in_bytes; }
private:
csl::Stream stream;
float epsilon;
size_t axis;
std::size_t scratch_mem_in_bytes;
};
}}} // cv::dnn::cuda4dnn
#endif // OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_LAYER_NORM_HPP

@ -9,8 +9,26 @@
// CANN backend
#include "../op_cann.hpp"
// OpenVINO backend
#include "../op_inf_engine.hpp"
#include "../ie_ngraph.hpp"
// CUDA backend
#include "../op_cuda.hpp"
#ifdef HAVE_CUDA
#include "../cuda4dnn/primitives/layer_norm.hpp"
using namespace cv::dnn::cuda4dnn;
#endif
// OpenCL backend
#ifdef HAVE_OPENCL
#include "../ocl4dnn/include/math_functions.hpp"
#include "opencl_kernels_dnn.hpp"
#endif
namespace cv { namespace dnn {
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#LayerNormalization
class LayerNormLayerImpl CV_FINAL : public LayerNormLayer
{
public:
@ -25,7 +43,12 @@ public:
virtual bool supportBackend(int backendId) CV_OVERRIDE
{
#ifdef HAVE_INF_ENGINE
if (backendId == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH)
return true;
#endif
return backendId == DNN_BACKEND_OPENCV ||
backendId == DNN_BACKEND_CUDA ||
(backendId == DNN_BACKEND_CANN && axis != -1); // axis=-1 not supported due to 1d mat shape problem
}
@ -73,6 +96,9 @@ public:
CV_TRACE_FUNCTION();
CV_TRACE_ARG_VALUE(name, "name", name.c_str());
CV_OCL_RUN(IS_DNN_OPENCL_TARGET(preferableTarget),
forward_ocl(inputs_arr, outputs_arr, internals_arr))
if (inputs_arr.depth() == CV_16S)
{
forward_fallback(inputs_arr, outputs_arr, internals_arr);
@ -95,6 +121,91 @@ public:
}
}
#ifdef HAVE_OPENCL
bool forward_ocl(InputArrayOfArrays inputs_, OutputArrayOfArrays outputs_, OutputArrayOfArrays internals_) {
std::vector<UMat> inputs;
std::vector<UMat> outputs;
inputs_.getUMatVector(inputs);
outputs_.getUMatVector(outputs);
const auto &input = inputs[0], &scale = inputs[1]; // &bias = inputs[2]; // bias is optional
auto &output = outputs[0];
const auto input_shape = shape(input);
size_t loops = static_cast<size_t>(total(input_shape, 0, axis)),
norm_size = static_cast<size_t>(total(input_shape, axis));
float inv_norm_size = 1.f / norm_size;
const auto &bias = inputs.size() == 3 ? inputs[2] : UMat::zeros(norm_size, 1, CV_32F);
// no fp16 support
if (input.depth() == CV_16S) {
return false;
}
String base_opts = format(" -DT=float -DT4=float4 -Dconvert_T=convert_float4");
// Calculate mean
UMat one = UMat::ones(norm_size, 1, CV_32F);
UMat mean = UMat(loops, 1, CV_32F);
UMat mean_square = UMat(loops, 1, CV_32F);
UMat tmp = UMat(loops, norm_size, CV_32F);
bool ret = ocl4dnn::ocl4dnnGEMV<float>(ocl4dnn::CblasNoTrans, loops, norm_size, inv_norm_size,
input, 0, one, 0, 0.f, mean, 0);
if (!ret) {
return false;
}
// Calculate mean_square
int num_vector = (norm_size % 8 == 0) ? 8 : ((norm_size % 4 == 0) ? 4 : 1);
size_t global[] = {loops, static_cast<size_t>(norm_size / num_vector)};
String build_opt = format(" -DNUM=%d", num_vector) + base_opts;
String mean_square_kernel_name = format("calc_mean%d", num_vector);
ocl::Kernel mean_square_kernel(mean_square_kernel_name.c_str(), ocl::dnn::mvn_oclsrc, build_opt + " -DKERNEL_MEAN");
if (mean_square_kernel.empty()) {
return false;
}
mean_square_kernel.set(0, ocl::KernelArg::PtrReadOnly(input));
mean_square_kernel.set(1, (int)loops);
mean_square_kernel.set(2, (int)norm_size);
mean_square_kernel.set(3, ocl::KernelArg::PtrReadOnly(mean));
mean_square_kernel.set(4, ocl::KernelArg::PtrWriteOnly(tmp));
ret = mean_square_kernel.run(2, global, NULL, false);
if (!ret) {
return false;
}
ret = ocl4dnn::ocl4dnnGEMV<float>(ocl4dnn::CblasNoTrans, loops, norm_size, inv_norm_size,
tmp, 0, one, 0, 0.f, mean_square, 0);
if (!ret) {
return false;
}
// Calculate instance norm: output = scale * (x - mean) / sqrt(var + eps) + bias
String mvn_kernel_name = format("mvn%d", num_vector);
build_opt += " -DNORM_VARIANCE -DLAYER_NORM -DKERNEL_MVN";
ocl::Kernel mvn_kernel(mvn_kernel_name.c_str(), ocl::dnn::mvn_oclsrc, build_opt);
if (mvn_kernel.empty()) {
return false;
}
mvn_kernel.set(0, ocl::KernelArg::PtrReadOnly(input));
mvn_kernel.set(1, (int)loops);
mvn_kernel.set(2, (int)norm_size);
mvn_kernel.set(3, (float)epsilon);
mvn_kernel.set(4, ocl::KernelArg::PtrReadOnly(mean));
mvn_kernel.set(5, ocl::KernelArg::PtrReadOnly(mean_square));
mvn_kernel.set(6, ocl::KernelArg::PtrReadOnly(scale));
mvn_kernel.set(7, ocl::KernelArg::PtrReadOnly(bias));
mvn_kernel.set(8, (int)1);
mvn_kernel.set(9, (float)0.f);
mvn_kernel.set(10, ocl::KernelArg::PtrWriteOnly(output));
ret = mvn_kernel.run(2, global, NULL, false);
if (!ret) {
return false;
}
return true;
}
#endif
#ifdef HAVE_CANN
virtual Ptr<BackendNode> initCann(const std::vector<Ptr<BackendWrapper> > &inputs,
const std::vector<Ptr<BackendWrapper> > &outputs,
@ -147,6 +258,67 @@ public:
}
#endif // HAVE_CANN
#ifdef HAVE_DNN_NGRAPH
virtual Ptr<BackendNode> initNgraph(const std::vector<Ptr<BackendWrapper> >& inputs,
const std::vector<Ptr<BackendNode> >& nodes) CV_OVERRIDE {
auto ieInpNode = nodes[0].dynamicCast<InfEngineNgraphNode>()->node;
const auto &input_shape = ieInpNode.get_shape();
std::shared_ptr<ngraph::Node> mvn, result;
// mvn
#if INF_ENGINE_VER_MAJOR_LE(INF_ENGINE_RELEASE_2021_2)
// https://docs.openvino.ai/2021.4/api/ngraph_python_api/_autosummary/ngraph.opset3.mvn.html?highlight=mvn#ngraph.opset3.mvn
bool across_channels = false;
bool normalize_variance = true;
mvn = std::make_shared<ngraph::op::MVN>(ieInpNode, across_channels, normalize_variance, epsilon);
#else
// https://docs.openvino.ai/2023.1/openvino_docs_ops_normalization_MVN_6.html
std::vector<int64_t> axes_v(input_shape.size() - axis);
std::iota(axes_v.begin(), axes_v.end(), axis);
auto axes = std::make_shared<ngraph::op::Constant>(ngraph::element::i64, ngraph::Shape{axes_v.size()}, axes_v.data());
bool normalize_variance = true;
mvn = std::make_shared<ngraph::op::v6::MVN>(ieInpNode, axes, normalize_variance, epsilon, ngraph::op::MVNEpsMode::INSIDE_SQRT);
#endif
// layer norm = scale * mvn + bias
auto scale = nodes[1].dynamicCast<InfEngineNgraphNode>()->node;
ngraph::Output<ngraph::Node> bias;
if (nodes.size() == 3) {
bias = nodes[2].dynamicCast<InfEngineNgraphNode>()->node;
}
if (axis == -1 || axis == input_shape.size() - 1) { // special case for 1D tensor (2D mat)
std::vector<int64_t> shared_shape_v(input_shape.size(), 1);
shared_shape_v.back() = -1;
auto shared_shape = std::make_shared<ngraph::op::Constant>(ngraph::element::i64, ngraph::Shape{shared_shape_v.size()}, shared_shape_v.data());
scale = std::make_shared<ngraph::op::v1::Reshape>(scale, shared_shape, true);
if (nodes.size() == 3) {
bias = std::make_shared<ngraph::op::v1::Reshape>(bias, shared_shape, true);
}
}
result = std::make_shared<ngraph::op::v1::Multiply>(mvn, scale);
if (nodes.size() == 3) {
result = std::make_shared<ngraph::op::v1::Add>(result, bias);
}
return Ptr<BackendNode>(new InfEngineNgraphNode(result));
}
#endif // HAVE_DNN_NGRAPH
#ifdef HAVE_CUDA
Ptr<BackendNode> initCUDA(void *context_,
const std::vector<Ptr<BackendWrapper>>& inputs,
const std::vector<Ptr<BackendWrapper>>& outputs) override {
auto context = reinterpret_cast<csl::CSLContext*>(context_);
auto input_wrapper = inputs[0].dynamicCast<CUDABackendWrapper>();
auto input_shape = input_wrapper->getShape();
size_t loops = static_cast<size_t>(total(input_shape, 0, axis));
return make_cuda_node<cuda4dnn::LayerNormOp>(preferableTarget, std::move(context->stream), axis, epsilon, loops);
}
#endif // HAVE_CUDA
};
Ptr<LayerNormLayer> LayerNormLayer::create(const LayerParams& params)

@ -126,12 +126,18 @@ __kernel void MVN(__global const Dtype* src,
alpha = 1;
#endif
#ifdef LAYER_NORM
vec_type w = load(bnorm_weight, y), b = load(bnorm_bias, y);
#else
Dtype w = 1.f, b = 0.f;
#ifdef FUSE_BATCH_NORM
w = bnorm_weight[x % channels];
b = bnorm_bias[x % channels];
#endif
#endif // LAYER_NORM
vec_type src_vec = load(src, index) - (vec_type)mean_val;
vec_type dst_vec = src_vec * alpha;
dst_vec = dst_vec * w + (vec_type)b;

@ -793,81 +793,43 @@ CASE(test_isinf_positive)
CASE(test_isnan)
// no filter
CASE(test_layer_normalization_2d_axis0)
#if SKIP_SET_1
SKIP_NON_CPU;
#endif
// no filter
CASE(test_layer_normalization_2d_axis1)
#if SKIP_SET_1
SKIP_NON_CPU;
#endif
// no filter
CASE(test_layer_normalization_2d_axis_negative_1)
#if SKIP_SET_1
SKIP_NON_CPU;
#endif
// no filter
CASE(test_layer_normalization_2d_axis_negative_2)
#if SKIP_SET_1
SKIP_NON_CPU;
#endif
// no filter
CASE(test_layer_normalization_3d_axis0_epsilon)
#if SKIP_SET_1
SKIP_NON_CPU;
#endif
// no filter
CASE(test_layer_normalization_3d_axis1_epsilon)
#if SKIP_SET_1
SKIP_NON_CPU;
#endif
// no filter
CASE(test_layer_normalization_3d_axis2_epsilon)
#if SKIP_SET_1
SKIP_NON_CPU;
#endif
// no filter
CASE(test_layer_normalization_3d_axis_negative_1_epsilon)
#if SKIP_SET_1
SKIP_NON_CPU;
#endif
// no filter
CASE(test_layer_normalization_3d_axis_negative_2_epsilon)
#if SKIP_SET_1
SKIP_NON_CPU;
#endif
// no filter
CASE(test_layer_normalization_3d_axis_negative_3_epsilon)
#if SKIP_SET_1
SKIP_NON_CPU;
#endif
// no filter
CASE(test_layer_normalization_4d_axis0)
#if SKIP_SET_1
SKIP_NON_CPU;
#endif
// no filter
CASE(test_layer_normalization_4d_axis1)
#if SKIP_SET_1
SKIP_NON_CPU;
#endif
// no filter
CASE(test_layer_normalization_4d_axis2)
#if SKIP_SET_1
SKIP_NON_CPU;
#endif
// no filter
CASE(test_layer_normalization_4d_axis3)
#if SKIP_SET_1
SKIP_NON_CPU;
#endif
// no filter
CASE(test_layer_normalization_4d_axis_negative_1)
#if SKIP_SET_1
SKIP_NON_CPU;
#endif
// no filter
CASE(test_layer_normalization_4d_axis_negative_2)
#if SKIP_SET_1
SKIP_NON_CPU;
#endif
// no filter
CASE(test_layer_normalization_4d_axis_negative_3)
#if SKIP_SET_1
SKIP_NON_CPU;
#endif
// no filter
CASE(test_layer_normalization_4d_axis_negative_4)
#if SKIP_SET_1
SKIP_NON_CPU;
#endif
// no filter
CASE(test_layer_normalization_default_axis)
#if SKIP_SET_1
SKIP_NON_CPU;
#endif
// no filter
CASE(test_leakyrelu)
// no filter
CASE(test_leakyrelu_default)

Loading…
Cancel
Save