Merge pull request #9114 from pengli:dnn_rebase
add libdnn acceleration to dnn module (#9114) * import libdnn code Signed-off-by: Li Peng <peng.li@intel.com> * add convolution layer ocl acceleration Signed-off-by: Li Peng <peng.li@intel.com> * add pooling layer ocl acceleration Signed-off-by: Li Peng <peng.li@intel.com> * add softmax layer ocl acceleration Signed-off-by: Li Peng <peng.li@intel.com> * add lrn layer ocl acceleration Signed-off-by: Li Peng <peng.li@intel.com> * add innerproduct layer ocl acceleration Signed-off-by: Li Peng <peng.li@intel.com> * add HAVE_OPENCL macro Signed-off-by: Li Peng <peng.li@intel.com> * fix for convolution ocl Signed-off-by: Li Peng <peng.li@intel.com> * enable getUMat() for multi-dimension Mat Signed-off-by: Li Peng <peng.li@intel.com> * use getUMat for ocl acceleration Signed-off-by: Li Peng <peng.li@intel.com> * use CV_OCL_RUN macro Signed-off-by: Li Peng <peng.li@intel.com> * set OPENCL target when it is available and disable fuseLayer for OCL target for the time being Signed-off-by: Li Peng <peng.li@intel.com> * fix innerproduct accuracy test Signed-off-by: Li Peng <peng.li@intel.com> * remove trailing space Signed-off-by: Li Peng <peng.li@intel.com> * Fixed tensorflow demo bug. Root cause is that tensorflow has different algorithm with libdnn to calculate convolution output dimension. libdnn don't calculate output dimension anymore and just use one passed in by config. * split gemm ocl file split it into gemm_buffer.cl and gemm_image.cl Signed-off-by: Li Peng <peng.li@intel.com> * Fix compile failure Signed-off-by: Li Peng <peng.li@intel.com> * check env flag for auto tuning Signed-off-by: Li Peng <peng.li@intel.com> * switch to new ocl kernels for softmax layer Signed-off-by: Li Peng <peng.li@intel.com> * update softmax layer on some platform subgroup extension may not work well, fallback to non subgroup ocl acceleration. Signed-off-by: Li Peng <peng.li@intel.com> * fallback to cpu path for fc layer with multi output Signed-off-by: Li Peng <peng.li@intel.com> * update output message Signed-off-by: Li Peng <peng.li@intel.com> * update fully connected layer fallback to gemm API if libdnn return false Signed-off-by: Li Peng <peng.li@intel.com> * Add ReLU OCL implementation * disable layer fusion for now Signed-off-by: Li Peng <peng.li@intel.com> * Add OCL implementation for concat layer Signed-off-by: Wu Zhiwen <zhiwen.wu@intel.com> * libdnn: update license and copyrights Also refine libdnn coding style Signed-off-by: Wu Zhiwen <zhiwen.wu@intel.com> Signed-off-by: Li Peng <peng.li@intel.com> * DNN: Don't link OpenCL library explicitly * DNN: Make default preferableTarget to DNN_TARGET_CPU User should set it to DNN_TARGET_OPENCL explicitly if want to use OpenCL acceleration. Also don't fusion when using DNN_TARGET_OPENCL * DNN: refine coding style * Add getOpenCLErrorString * DNN: Use int32_t/uint32_t instread of alias * Use namespace ocl4dnn to include libdnn things * remove extra copyTo in softmax ocl path Signed-off-by: Li Peng <peng.li@intel.com> * update ReLU layer ocl path Signed-off-by: Li Peng <peng.li@intel.com> * Add prefer target property for layer class It is used to indicate the target for layer forwarding, either the default CPU target or OCL target. Signed-off-by: Li Peng <peng.li@intel.com> * Add cl_event based timer for cv::ocl * Rename libdnn to ocl4dnn Signed-off-by: Li Peng <peng.li@intel.com> Signed-off-by: wzw <zhiwen.wu@intel.com> * use UMat for ocl4dnn internal buffer Remove allocateMemory which use clCreateBuffer directly Signed-off-by: Li Peng <peng.li@intel.com> Signed-off-by: wzw <zhiwen.wu@intel.com> * enable buffer gemm in ocl4dnn innerproduct Signed-off-by: Li Peng <peng.li@intel.com> * replace int_tp globally for ocl4dnn kernels. Signed-off-by: wzw <zhiwen.wu@intel.com> Signed-off-by: Li Peng <peng.li@intel.com> * create UMat for layer params Signed-off-by: Li Peng <peng.li@intel.com> * update sign ocl kernel Signed-off-by: Li Peng <peng.li@intel.com> * update image based gemm of inner product layer Signed-off-by: Li Peng <peng.li@intel.com> * remove buffer gemm of inner product layer call cv::gemm API instead Signed-off-by: Li Peng <peng.li@intel.com> * change ocl4dnn forward parameter to UMat Signed-off-by: Li Peng <peng.li@intel.com> * Refine auto-tuning mechanism. - Use OPENCV_OCL4DNN_KERNEL_CONFIG_PATH to set cache directory for fine-tuned kernel configuration. e.g. export OPENCV_OCL4DNN_KERNEL_CONFIG_PATH=/home/tmp, the cache directory will be /home/tmp/spatialkernels/ on Linux. - Define environment OPENCV_OCL4DNN_ENABLE_AUTO_TUNING to enable auto-tuning. - OPENCV_OPENCL_ENABLE_PROFILING is only used to enable profiling for OpenCL command queue. This fix basic kernel get wrong running time, i.e. 0ms. - If creating cache directory failed, disable auto-tuning. * Detect and create cache dir on windows Signed-off-by: Li Peng <peng.li@intel.com> * Refine gemm like convolution kernel. Signed-off-by: Li Peng <peng.li@intel.com> * Fix redundant swizzleWeights calling when use cached kernel config. * Fix "out of resource" bug when auto-tuning too many kernels. * replace cl_mem with UMat in ocl4dnnConvSpatial class * OCL4DNN: reduce the tuning kernel candidate. This patch could reduce 75% of the tuning candidates with less than 2% performance impact for the final result. Signed-off-by: Zhigang Gong <zhigang.gong@intel.com> * replace cl_mem with umat in ocl4dnn convolution Signed-off-by: Li Peng <peng.li@intel.com> * remove weight_image_ of ocl4dnn inner product Actually it is unused in the computation Signed-off-by: Li Peng <peng.li@intel.com> * Various fixes for ocl4dnn 1. OCL_PERFORMANCE_CHECK(ocl::Device::getDefault().isIntel()) 2. Ptr<OCL4DNNInnerProduct<float> > innerProductOp 3. Code comments cleanup 4. ignore check on OCL cpu device Signed-off-by: Li Peng <peng.li@intel.com> * add build option for log softmax Signed-off-by: Li Peng <peng.li@intel.com> * remove unused ocl kernels in ocl4dnn Signed-off-by: Li Peng <peng.li@intel.com> * replace ocl4dnnSet with opencv setTo Signed-off-by: Li Peng <peng.li@intel.com> * replace ALIGN with cv::alignSize Signed-off-by: Li Peng <peng.li@intel.com> * check kernel build options Signed-off-by: Li Peng <peng.li@intel.com> * Handle program compilation fail properly. * Use std::numeric_limits<float>::infinity() for large float number * check ocl4dnn kernel compilation result Signed-off-by: Li Peng <peng.li@intel.com> * remove unused ctx_id Signed-off-by: Li Peng <peng.li@intel.com> * change clEnqueueNDRangeKernel to kernel.run() Signed-off-by: Li Peng <peng.li@intel.com> * change cl_mem to UMat in image based gemm Signed-off-by: Li Peng <peng.li@intel.com> * check intel subgroup support for lrn and pooling layer Signed-off-by: Li Peng <peng.li@intel.com> * Fix convolution bug if group is greater than 1 Signed-off-by: Li Peng <peng.li@intel.com> * Set default layer preferableTarget to be DNN_TARGET_CPU Signed-off-by: Li Peng <peng.li@intel.com> * Add ocl perf test for convolution Signed-off-by: Li Peng <peng.li@intel.com> * Add more ocl accuracy test Signed-off-by: Li Peng <peng.li@intel.com> * replace cl_image with ocl::Image2D Signed-off-by: Li Peng <peng.li@intel.com> * Fix build failure in elementwise layer Signed-off-by: Li Peng <peng.li@intel.com> * use getUMat() to get blob data Signed-off-by: Li Peng <peng.li@intel.com> * replace cl_mem handle with ocl::KernelArg Signed-off-by: Li Peng <peng.li@intel.com> * dnn(build): don't use C++11, OPENCL_LIBRARIES fix * dnn(ocl4dnn): remove unused OpenCL kernels * dnn(ocl4dnn): extract OpenCL code into .cl files * dnn(ocl4dnn): refine auto-tuning Defaultly disable auto-tuning, set OPENCV_OCL4DNN_ENABLE_AUTO_TUNING environment variable to enable it. Use a set of pre-tuned configs as default config if auto-tuning is disabled. These configs are tuned for Intel GPU with 48/72 EUs, and for googlenet, AlexNet, ResNet-50 If default config is not suitable, use the first available kernel config from the candidates. Candidate priority from high to low is gemm like kernel, IDLF kernel, basick kernel. * dnn(ocl4dnn): pooling doesn't use OpenCL subgroups * dnn(ocl4dnn): fix perf test OpenCV has default 3sec time limit for each performance test. Warmup OpenCL backend outside of perf measurement loop. * use ocl::KernelArg as much as possible Signed-off-by: Li Peng <peng.li@intel.com> * dnn(ocl4dnn): fix bias bug for gemm like kernel * dnn(ocl4dnn): wrap cl_mem into UMat Signed-off-by: Li Peng <peng.li@intel.com> * dnn(ocl4dnn): Refine signature of kernel config - Use more readable string as signture of kernel config - Don't count device name and vendor in signature string - Default kernel configurations are tuned for Intel GPU with 24/48/72 EUs, and for googlenet, AlexNet, ResNet-50 net model. * dnn(ocl4dnn): swap width/height in configuration * dnn(ocl4dnn): enable configs for Intel OpenCL runtime only * core: make configuration helper functions accessible from non-core modules * dnn(ocl4dnn): update kernel auto-tuning behavior Avoid unwanted creation of directories * dnn(ocl4dnn): simplify kernel to workaround OpenCL compiler crash * dnn(ocl4dnn): remove redundant code * dnn(ocl4dnn): Add more clear message for simd size dismatch. * dnn(ocl4dnn): add const to const argument Signed-off-by: Li Peng <peng.li@intel.com> * dnn(ocl4dnn): force compiler use a specific SIMD size for IDLF kernel * dnn(ocl4dnn): drop unused tuneLocalSize() * dnn(ocl4dnn): specify OpenCL queue for Timer and convolve() method * dnn(ocl4dnn): sanitize file names used for cache * dnn(perf): enable Network tests with OpenCL * dnn(ocl4dnn/conv): drop computeGlobalSize() * dnn(ocl4dnn/conv): drop unused fields * dnn(ocl4dnn/conv): simplify ctor * dnn(ocl4dnn/conv): refactor kernelConfig localSize=NULL * dnn(ocl4dnn/conv): drop unsupported double / untested half types * dnn(ocl4dnn/conv): drop unused variable * dnn(ocl4dnn/conv): alignSize/divUp * dnn(ocl4dnn/conv): use enum values * dnn(ocl4dnn): drop unused innerproduct variable Signed-off-by: Li Peng <peng.li@intel.com> * dnn(ocl4dnn): add an generic function to check cl option support * dnn(ocl4dnn): run softmax subgroup version kernel first Signed-off-by: Li Peng <peng.li@intel.com>pull/9761/head
parent
f646f61dad
commit
e340ff9c3a
50 changed files with 8788 additions and 63 deletions
@ -0,0 +1,16 @@ |
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#ifndef OPENCV_CONFIGURATION_PRIVATE_HPP |
||||
#define OPENCV_CONFIGURATION_PRIVATE_HPP |
||||
|
||||
namespace cv { namespace utils { |
||||
|
||||
CV_EXPORTS bool getConfigurationParameterBool(const char* name, bool defaultValue); |
||||
CV_EXPORTS size_t getConfigurationParameterSizeT(const char* name, size_t defaultValue); |
||||
CV_EXPORTS cv::String getConfigurationParameterString(const char* name, const char* defaultValue); |
||||
|
||||
}} // namespace
|
||||
|
||||
#endif // OPENCV_CONFIGURATION_PRIVATE_HPP
|
@ -0,0 +1,45 @@ |
||||
/*M/////////////////////////////////////////////////////////////////////////////////////// |
||||
// |
||||
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING. |
||||
// |
||||
// By downloading, copying, installing or using the software you agree to this license. |
||||
// If you do not agree to this license, do not download, install, |
||||
// copy or use the software. |
||||
// |
||||
// |
||||
// License Agreement |
||||
// For Open Source Computer Vision Library |
||||
// |
||||
// Copyright (C) 2017, Intel Corporation, all rights reserved. |
||||
// Copyright (c) 2016-2017 Fabian David Tschopp, all rights reserved. |
||||
// Third party copyrights are property of their respective owners. |
||||
// |
||||
// Redistribution and use in source and binary forms, with or without modification, |
||||
// are permitted provided that the following conditions are met: |
||||
// |
||||
// * Redistribution's of source code must retain the above copyright notice, |
||||
// this list of conditions and the following disclaimer. |
||||
// |
||||
// * Redistribution's in binary form must reproduce the above copyright notice, |
||||
// this list of conditions and the following disclaimer in the documentation |
||||
// and/or other materials provided with the distribution. |
||||
// |
||||
// * The name of the copyright holders may not be used to endorse or promote products |
||||
// derived from this software without specific prior written permission. |
||||
// |
||||
// This software is provided by the copyright holders and contributors "as is" and |
||||
// any express or implied warranties, including, but not limited to, the implied |
||||
// warranties of merchantability and fitness for a particular purpose are disclaimed. |
||||
// In no event shall the Intel Corporation or contributors be liable for any direct, |
||||
// indirect, incidental, special, exemplary, or consequential damages |
||||
// (including, but not limited to, procurement of substitute goods or services; |
||||
// loss of use, data, or profits; or business interruption) however caused |
||||
// and on any theory of liability, whether in contract, strict liability, |
||||
// or tort (including negligence or otherwise) arising in any way out of |
||||
// the use of this software, even if advised of the possibility of such damage. |
||||
// |
||||
//M*/ |
||||
|
||||
__kernel void null_kernel_float(float arg) { |
||||
float out = arg; |
||||
} |
@ -0,0 +1,118 @@ |
||||
#include "../perf_precomp.hpp" |
||||
#include "opencv2/ts/ocl_perf.hpp" |
||||
#include <opencv2/dnn/shape_utils.hpp> |
||||
|
||||
#ifdef HAVE_OPENCL |
||||
|
||||
namespace cvtest |
||||
{ |
||||
namespace ocl |
||||
{ |
||||
|
||||
using std::tr1::tuple; |
||||
using std::tr1::get; |
||||
using std::tr1::make_tuple; |
||||
using std::make_pair; |
||||
using namespace perf; |
||||
using namespace testing; |
||||
using namespace cv; |
||||
using namespace cv::dnn; |
||||
|
||||
enum {STRIDE_OFF = 1, STRIDE_ON = 2}; |
||||
CV_ENUM(StrideSize, STRIDE_OFF, STRIDE_ON); |
||||
|
||||
enum {GROUP_OFF = 1, GROUP_2 = 2}; |
||||
CV_ENUM(GroupSize, GROUP_OFF, GROUP_2); |
||||
|
||||
//Squared Size
|
||||
#define SSZ(n) cv::Size(n, n) |
||||
|
||||
typedef std::pair<MatShape, int> InpShapeNumOut; |
||||
typedef tuple<Size, InpShapeNumOut, GroupSize, StrideSize> ConvParam; //kernel_size, inp shape, groups, stride
|
||||
typedef TestBaseWithParam<ConvParam> ConvolutionPerfTest; |
||||
|
||||
static inline MatShape blobShape(int count, int nplanes, int height, int width) |
||||
{ |
||||
int data[] = {count, nplanes, height, width}; |
||||
return MatShape(data, data+4); |
||||
} |
||||
|
||||
OCL_PERF_TEST_P( ConvolutionPerfTest, perf, Combine( |
||||
Values(Size(1, 1), Size(3, 3), Size(5, 5), Size(11, 11)), |
||||
Values(make_pair(blobShape(1, 4, 224, 224), 64), |
||||
make_pair(blobShape(1, 64, 112, 122), 128), |
||||
make_pair(blobShape(1, 256, 28, 28), 512)), |
||||
GroupSize::all(), |
||||
StrideSize::all()) |
||||
) |
||||
{ |
||||
RNG rng(0); |
||||
|
||||
ConvParam params = GetParam(); |
||||
int ksz = get<0>(params).width; |
||||
MatShape inpShape = get<1>(params).first; |
||||
int outCn = get<1>(params).second; |
||||
int groups = get<2>(params); |
||||
int stride = (ksz >= 11) ? 4 : (int)get<3>(params); |
||||
|
||||
int inpCn = inpShape[1]; |
||||
int wgtSize[] = { outCn, inpCn/groups, ksz, ksz }; |
||||
int biasSize[] = { outCn, 1, 1, 1 }; |
||||
const int wtype = CV_32F; |
||||
Mat wgtBlob(4, wgtSize, wtype), biasBlob(4, biasSize, wtype); |
||||
Mat inpBlob(4, &inpShape[0], wtype); |
||||
rng.fill(biasBlob, RNG::UNIFORM, -1, +1); |
||||
rng.fill(wgtBlob, RNG::UNIFORM, -1, +1); |
||||
rng.fill(inpBlob, RNG::UNIFORM, -1, +1); |
||||
|
||||
LayerParams lp; |
||||
lp.set("num_output", outCn); |
||||
lp.set("group", groups); |
||||
lp.set("stride", stride); |
||||
lp.set("kernel_size", ksz); |
||||
lp.blobs.reserve(2); |
||||
lp.blobs.push_back(wgtBlob); |
||||
lp.blobs.push_back(biasBlob); |
||||
|
||||
std::vector<Mat*> inpBlobs(1, &inpBlob); |
||||
std::vector<Mat> outBlobs, internalBlobs; |
||||
|
||||
cv::setNumThreads(cv::getNumberOfCPUs()); |
||||
|
||||
Ptr<Layer> layer = cv::dnn::LayerFactory::createLayerInstance("Convolution", lp); |
||||
std::vector<MatShape> inputShapes(1, shape(inpBlob)), outShapes, internals; |
||||
layer->getMemoryShapes(inputShapes, 0, outShapes, internals); |
||||
for (int i = 0; i < outShapes.size(); i++) |
||||
{ |
||||
outBlobs.push_back(Mat(outShapes[i], CV_32F)); |
||||
} |
||||
for (int i = 0; i < internals.size(); i++) |
||||
{ |
||||
internalBlobs.push_back(Mat()); |
||||
if (total(internals[i])) |
||||
internalBlobs.back().create(internals[i], CV_32F); |
||||
} |
||||
|
||||
layer->finalize(inpBlobs, outBlobs); |
||||
layer->preferableTarget = DNN_TARGET_OPENCL; |
||||
|
||||
Mat inpBlob2D = inpBlob.reshape(1, outCn); |
||||
Mat wgtBlob2D = wgtBlob.reshape(1, outCn*(inpCn/groups)); |
||||
Mat outBlob2D = outBlobs[0].reshape(1, outBlobs[0].size[0]); |
||||
declare.in(inpBlob2D, wgtBlob2D, WARMUP_RNG).out(outBlob2D).tbb_threads(cv::getNumThreads()); |
||||
|
||||
// warmup
|
||||
layer->forward(inpBlobs, outBlobs, internalBlobs); |
||||
|
||||
TEST_CYCLE() |
||||
{ |
||||
layer->forward(inpBlobs, outBlobs, internalBlobs); |
||||
} |
||||
|
||||
SANITY_CHECK_NOTHING(); |
||||
} |
||||
|
||||
} |
||||
} |
||||
|
||||
#endif |
@ -0,0 +1,62 @@ |
||||
/*M///////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
|
||||
//
|
||||
// By downloading, copying, installing or using the software you agree to this license.
|
||||
// If you do not agree to this license, do not download, install,
|
||||
// copy or use the software.
|
||||
//
|
||||
//
|
||||
// License Agreement
|
||||
// For Open Source Computer Vision Library
|
||||
//
|
||||
// Copyright (c) 2016-2017 Fabian David Tschopp, all rights reserved.
|
||||
// Third party copyrights are property of their respective owners.
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without modification,
|
||||
// are permitted provided that the following conditions are met:
|
||||
//
|
||||
// * Redistribution's of source code must retain the above copyright notice,
|
||||
// this list of conditions and the following disclaimer.
|
||||
//
|
||||
// * Redistribution's in binary form must reproduce the above copyright notice,
|
||||
// this list of conditions and the following disclaimer in the documentation
|
||||
// and/or other materials provided with the distribution.
|
||||
//
|
||||
// * The name of the copyright holders may not be used to endorse or promote products
|
||||
// derived from this software without specific prior written permission.
|
||||
//
|
||||
// This software is provided by the copyright holders and contributors "as is" and
|
||||
// any express or implied warranties, including, but not limited to, the implied
|
||||
// warranties of merchantability and fitness for a particular purpose are disclaimed.
|
||||
// In no event shall the Intel Corporation or contributors be liable for any direct,
|
||||
// indirect, incidental, special, exemplary, or consequential damages
|
||||
// (including, but not limited to, procurement of substitute goods or services;
|
||||
// loss of use, data, or profits; or business interruption) however caused
|
||||
// and on any theory of liability, whether in contract, strict liability,
|
||||
// or tort (including negligence or otherwise) arising in any way out of
|
||||
// the use of this software, even if advised of the possibility of such damage.
|
||||
//
|
||||
//M*/
|
||||
|
||||
#ifndef _OPENCV_LIBDNN_COMMON_HPP_ |
||||
#define _OPENCV_LIBDNN_COMMON_HPP_ |
||||
#include "../../precomp.hpp" |
||||
#include "../../caffe/glog_emulator.hpp" |
||||
#include <opencv2/core/opencl/runtime/opencl_core.hpp> |
||||
|
||||
#ifdef HAVE_OPENCL |
||||
|
||||
// Macro to select the single (_float) or double (_double) precision kernel
|
||||
#define CL_KERNEL_SELECT(kernel) kernel "_float" |
||||
|
||||
#define OCL_CHECK(condition) \ |
||||
do { \
|
||||
cl_int error = (condition); \
|
||||
CHECK_EQ(error, CL_SUCCESS) << " " << cv::ocl::getOpenCLErrorString(error); \
|
||||
} while (0) |
||||
|
||||
bool clOptionSupport(cv::String option); |
||||
|
||||
#endif // HAVE_OPENCL
|
||||
#endif |
@ -0,0 +1,854 @@ |
||||
#ifndef _OPENCV_OCL4DNN_DEFAULT_KERNEL_CONFIG_HPP_ |
||||
#define _OPENCV_OCL4DNN_DEFAULT_KERNEL_CONFIG_HPP_ |
||||
const char *default_kernel_config_intel[] = { |
||||
// Below is the information for OpenCL based on which these configurations tuned
|
||||
/*******************************************************************************
|
||||
Number of platforms 1 |
||||
Platform Name Intel(R) OpenCL |
||||
Platform Vendor Intel(R) Corporation |
||||
Platform Version OpenCL 2.0 |
||||
Platform Profile FULL_PROFILE |
||||
Platform Extensions cl_intel_accelerator cl_intel_advanced_motion_estimation cl_intel_device_side_avc_motion_estimation cl_intel_driver_diagnostics cl_intel_media_block_io cl_intel_motion_estimation cl_intel_planar_yuv cl_intel_packed_yuv cl_intel_required_subgroup_size cl_intel_subgroups cl_intel_subgroups_short cl_intel_va_api_media_sharing cl_khr_3d_image_writes cl_khr_byte_addressable_store cl_khr_depth_images cl_khr_fp16 cl_khr_fp64 cl_khr_global_int32_base_atomics cl_khr_global_int32_extended_atomics cl_khr_icd cl_khr_image2d_from_buffer cl_khr_local_int32_base_atomics cl_khr_local_int32_extended_atomics cl_khr_mipmap_image cl_khr_mipmap_image_writes cl_khr_spir cl_khr_subgroups |
||||
Platform Extensions function suffix INTEL |
||||
|
||||
Platform Name Intel(R) OpenCL |
||||
Number of devices 1 |
||||
Device Name Intel(R) HD Graphics |
||||
Device Vendor Intel(R) Corporation |
||||
Device Vendor ID 0x8086 |
||||
Device Version OpenCL 2.0 |
||||
Driver Version r4.1.61547 |
||||
Device OpenCL C Version OpenCL C 2.0 |
||||
Device Type GPU |
||||
Device Profile FULL_PROFILE |
||||
Max compute units 72 |
||||
Max clock frequency 950MHz |
||||
Device Partition (core) |
||||
Max number of sub-devices 0 |
||||
Supported partition types by <unknown> (0x7FE000000000) |
||||
Max work item dimensions 3 |
||||
Max work item sizes 256x256x256 |
||||
Max work group size 256 |
||||
Preferred work group size multiple 32 |
||||
Preferred / native vector sizes |
||||
char 16 / 16 |
||||
short 8 / 8 |
||||
int 4 / 4 |
||||
long 1 / 1 |
||||
half 8 / 8 (cl_khr_fp16) |
||||
float 1 / 1 |
||||
double 1 / 1 (cl_khr_fp64) |
||||
Half-precision Floating-point support (cl_khr_fp16) |
||||
Denormals Yes |
||||
Infinity and NANs Yes |
||||
Round to nearest Yes |
||||
Round to zero Yes |
||||
Round to infinity Yes |
||||
IEEE754-2008 fused multiply-add Yes |
||||
Support is emulated in software No |
||||
Correctly-rounded divide and sqrt operations No |
||||
Single-precision Floating-point support (core) |
||||
Denormals Yes |
||||
Infinity and NANs Yes |
||||
Round to nearest Yes |
||||
Round to zero Yes |
||||
Round to infinity Yes |
||||
IEEE754-2008 fused multiply-add Yes |
||||
Support is emulated in software No |
||||
Correctly-rounded divide and sqrt operations Yes |
||||
Double-precision Floating-point support (cl_khr_fp64) |
||||
Denormals Yes |
||||
Infinity and NANs Yes |
||||
Round to nearest Yes |
||||
Round to zero Yes |
||||
Round to infinity Yes |
||||
IEEE754-2008 fused multiply-add Yes |
||||
Support is emulated in software No |
||||
Correctly-rounded divide and sqrt operations No |
||||
Address bits 64, Little-Endian |
||||
Global memory size 26887677543 (25.04GiB) |
||||
Error Correction support No |
||||
Max memory allocation 4294959103 (4GiB) |
||||
Unified memory for Host and Device Yes |
||||
Shared Virtual Memory (SVM) capabilities (core) |
||||
Coarse-grained buffer sharing Yes |
||||
Fine-grained buffer sharing No |
||||
Fine-grained system sharing No |
||||
Atomics No |
||||
Minimum alignment for any data type 128 bytes |
||||
Alignment of base address 1024 bits (128 bytes) |
||||
Preferred alignment for atomics |
||||
SVM 64 bytes |
||||
Global 64 bytes |
||||
Local 64 bytes |
||||
Max size for global variable 65536 (64KiB) |
||||
Preferred total size of global vars 4294959103 (4GiB) |
||||
Global Memory cache type Read/Write |
||||
Global Memory cache size 1572864 |
||||
Global Memory cache line 64 bytes |
||||
Image support Yes |
||||
Max number of samplers per kernel 16 |
||||
Max size for 1D images from buffer 268434943 pixels |
||||
Max 1D or 2D image array size 2048 images |
||||
Base address alignment for 2D image buffers 4 bytes |
||||
Pitch alignment for 2D image buffers 4 bytes |
||||
Max 2D image size 16384x16384 pixels |
||||
Max 3D image size 16384x16384x2048 pixels |
||||
Max number of read image args 128 |
||||
Max number of write image args 128 |
||||
Max number of read/write image args 128 |
||||
Max number of pipe args 16 |
||||
Max active pipe reservations 1 |
||||
Max pipe packet size 1024 |
||||
Local memory type Local |
||||
Local memory size 65536 (64KiB) |
||||
Max constant buffer size 4294959103 (4GiB) |
||||
Max number of constant args 8 |
||||
Max size of kernel argument 1024 |
||||
Queue properties (on host) |
||||
Out-of-order execution Yes |
||||
Profiling Yes |
||||
Queue properties (on device) |
||||
Out-of-order execution Yes |
||||
Profiling Yes |
||||
Preferred size 131072 (128KiB) |
||||
Max size 67108864 (64MiB) |
||||
Max queues on device 1 |
||||
Max events on device 1024 |
||||
Prefer user sync for interop Yes |
||||
Profiling timer resolution 83ns |
||||
Execution capabilities |
||||
Run OpenCL kernels Yes |
||||
Run native kernels No |
||||
SPIR versions 1.2 |
||||
printf() buffer size 4194304 (4MiB) |
||||
Built-in kernels block_motion_estimate_intel;block_advanced_motion_estimate_check_intel;block_advanced_motion_estimate_bidirectional_check_intel |
||||
Motion Estimation accelerator version (Intel) 2 |
||||
Device Available Yes |
||||
Compiler Available Yes |
||||
Linker Available Yes |
||||
Device Extensions cl_intel_accelerator cl_intel_advanced_motion_estimation cl_intel_device_side_avc_motion_estimation cl_intel_driver_diagnostics cl_intel_media_block_io cl_intel_motion_estimation cl_intel_planar_yuv cl_intel_packed_yuv cl_intel_required_subgroup_size cl_intel_subgroups cl_intel_subgroups_short cl_intel_va_api_media_sharing cl_khr_3d_image_writes cl_khr_byte_addressable_store cl_khr_depth_images cl_khr_fp16 cl_khr_fp64 cl_khr_global_int32_base_atomics cl_khr_global_int32_extended_atomics cl_khr_icd cl_khr_image2d_from_buffer cl_khr_local_int32_base_atomics cl_khr_local_int32_extended_atomics cl_khr_mipmap_image cl_khr_mipmap_image_writes cl_khr_spir cl_khr_subgroups |
||||
|
||||
NULL platform behavior |
||||
clGetPlatformInfo(NULL, CL_PLATFORM_NAME, ...) No platform |
||||
clGetDeviceIDs(NULL, CL_DEVICE_TYPE_ALL, ...) No platform |
||||
clCreateContext(NULL, ...) [default] No platform |
||||
clCreateContext(NULL, ...) [other] Success [INTEL] |
||||
clCreateContextFromType(NULL, CL_DEVICE_TYPE_CPU) No platform |
||||
clCreateContextFromType(NULL, CL_DEVICE_TYPE_GPU) No platform |
||||
clCreateContextFromType(NULL, CL_DEVICE_TYPE_ACCELERATOR) No platform |
||||
clCreateContextFromType(NULL, CL_DEVICE_TYPE_CUSTOM) No platform |
||||
clCreateContextFromType(NULL, CL_DEVICE_TYPE_ALL) No platform |
||||
********************************************************************************/ |
||||
"EU72_k1x1_cn832_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M128","4 6 8 2 1 1 8 1 0 ", |
||||
"EU72_k1x1_cn256_g1_s1x1_d1x1_b1_in32x32_p0x0_num1_M32","12 2 8 2 1 1 8 1 0 ", |
||||
"EU72_k7x7_cn3_g1_s2x2_d1x1_b1_in224x224_p3x3_num1_M64","1 8 32 5 1 8 1 1 0 ", |
||||
"EU72_k5x5_cn48_g1_s1x1_d1x1_b1_in16x16_p2x2_num1_M128","4 2 16 2 1 1 16 1 0 ", |
||||
"EU72_k1x1_cn128_g1_s1x1_d1x1_b0_in64x64_p0x0_num1_M64","1 8 32 5 1 8 1 1 0 ", |
||||
"EU72_k3x3_cn64_g1_s1x1_d1x1_b1_in64x64_p1x1_num2_M192","2 7 16 2 1 1 16 1 0 ", |
||||
"EU72_k5x5_cn16_g1_s1x1_d1x1_b1_in16x16_p2x2_num1_M48","4 3 16 2 1 1 16 1 0 ", |
||||
"EU72_k1x1_cn832_g1_s1x1_d1x1_b1_in16x16_p0x0_num1_M32","4 6 8 2 1 1 8 1 0 ", |
||||
"EU72_k1x1_cn192_g1_s1x1_d1x1_b1_in32x32_p0x0_num1_M96","1 8 32 5 1 8 1 1 0 ", |
||||
"EU72_k11x7_cn3_g1_s3x4_d1x1_b1_in64x64_p3x2_num1_M64","4 1 16 2 1 1 16 1 0 ", |
||||
"EU72_k1x1_cn512_g1_s1x1_d1x1_b1_in16x16_p0x0_num1_M64","8 3 8 2 1 1 8 1 0 ", |
||||
"EU72_k1x1_cn832_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M32","4 6 8 2 1 1 8 1 0 ", |
||||
"EU72_k3x3_cn4_g1_s1x1_d1x1_b1_in256x256_p1x1_num1_M4","14 1 16 2 1 1 16 1 0 ", |
||||
"EU72_k1x1_cn3_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M4","4 4 8 2 1 1 8 1 0 ", |
||||
"EU72_k5x5_cn32_g1_s1x1_d1x1_b1_in16x16_p2x2_num2_M128","4 2 16 2 1 1 16 1 0 ", |
||||
"EU72_k1x1_cn480_g1_s1x1_d1x1_b1_in16x16_p0x0_num1_M192","1 8 32 5 1 8 1 1 0 ", |
||||
"EU72_k1x1_cn480_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M192","1 8 32 5 1 8 1 1 0 ", |
||||
"EU72_k3x3_cn96_g1_s1x1_d1x1_b1_in16x16_p1x1_num1_M208","2 6 16 2 1 1 16 1 0 ", |
||||
"EU72_k1x1_cn512_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M32","8 3 8 2 1 1 8 1 0 ", |
||||
"EU72_k1x1_cn832_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M384","1 8 32 5 1 8 1 1 0 ", |
||||
"EU72_k3x3_cn160_g1_s1x1_d1x1_b1_in16x16_p1x1_num1_M320","2 5 16 2 1 1 16 1 0 ", |
||||
"EU72_k1x1_cn528_g1_s1x1_d1x1_b1_in16x16_p0x0_num1_M160","8 3 8 2 1 1 8 1 0 ", |
||||
"EU72_k1x1_cn528_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M256","1 8 32 5 1 8 1 1 0 ", |
||||
"EU72_k5x1_cn32_g1_s1x1_d1x1_b0_in64x64_p2x0_num1_M32","4 6 16 2 1 1 16 1 0 ", |
||||
"EU72_k1x1_cn16_g1_s1x1_d1x1_b0_in256x256_p0x0_num1_M4","12 2 16 2 1 1 16 1 0 ", |
||||
"EU72_k1x1_cn64_g1_s1x1_d1x1_b1_in64x64_p0x0_num1_M64","2 8 32 5 1 8 1 1 0 ", |
||||
"EU72_k1x1_cn192_g1_s1x1_d1x1_b1_in32x32_p0x0_num2_M16","8 3 8 2 1 1 8 1 0 ", |
||||
"EU72_k1x1_cn32_g1_s1x1_d1x1_b0_in64x64_p0x0_num1_M128","1 16 32 5 1 16 1 1 0 ", |
||||
"EU72_k3x3_cn32_g1_s1x1_d2x2_b1_in64x64_p2x2_num1_M32","3 6 16 2 1 1 16 1 0 ", |
||||
"EU72_k3x3_cn32_g1_s1x1_d16x16_b1_in64x64_p16x16_num1_M32","1 16 32 5 1 16 1 1 0 ", |
||||
"EU72_k1x1_cn128_g1_s1x1_d1x1_b0_in32x32_p0x0_num1_M512","2 8 32 5 1 8 1 1 0 ", |
||||
"EU72_k3x3_cn192_g1_s1x1_d1x1_b1_in16x16_p1x1_num2_M384","2 7 16 2 1 1 16 1 0 ", |
||||
"EU72_k5x4_cn6_g3_s3x2_d1x1_b1_in128x80_p1x0_num2_M4","1 1 1 4 1 1 1 0 1 ", |
||||
"EU72_k5x5_cn32_g1_s1x1_d1x1_b1_in32x32_p2x2_num2_M96","4 5 16 2 1 1 16 1 0 ", |
||||
"EU72_k3x3_cn64_g1_s1x1_d1x1_b1_in64x64_p1x1_num1_M192","10 2 16 2 1 1 16 1 0 ", |
||||
"EU72_k3x3_cn128_g1_s1x1_d1x1_b1_in32x32_p1x1_num1_M192","6 4 16 2 1 1 16 1 0 ", |
||||
"EU72_k1x1_cn4_g1_s1x1_d1x1_b0_in256x256_p0x0_num1_M16","2 8 32 5 1 8 1 1 0 ", |
||||
"EU72_k1x1_cn480_g1_s1x1_d1x1_b1_in16x16_p0x0_num1_M96","8 3 8 2 1 1 8 1 0 ", |
||||
"EU72_k5x5_cn16_g1_s1x1_d1x1_b1_in32x32_p2x2_num1_M32","8 1 16 2 1 1 16 1 0 ", |
||||
"EU72_k1x1_cn512_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M64","1 8 32 5 1 8 1 1 0 ", |
||||
"EU72_k1x1_cn832_g1_s1x1_d1x1_b1_in16x16_p0x0_num1_M384","4 7 8 2 1 1 8 1 0 ", |
||||
"EU72_k3x3_cn128_g1_s1x1_d1x1_b1_in16x16_p1x1_num1_M256","2 6 16 2 1 1 16 1 0 ", |
||||
"EU72_k3x3_cn96_g1_s1x1_d1x1_b1_in32x32_p1x1_num1_M128","6 4 16 2 1 1 16 1 0 ", |
||||
"EU72_k5x5_cn24_g1_s1x1_d1x1_b1_in16x16_p2x2_num2_M64","4 4 16 2 1 1 16 1 0 ", |
||||
"EU72_k5x5_cn16_g1_s1x1_d1x1_b1_in16x16_p2x2_num2_M48","4 3 16 2 1 1 16 1 0 ", |
||||
"EU72_k3x3_cn3_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M5","2 3 16 2 1 1 16 1 0 ", |
||||
"EU72_k1x1_cn512_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M24","8 2 8 2 1 1 8 1 0 ", |
||||
"EU72_k3x3_cn128_g1_s1x1_d1x1_b0_in32x32_p1x1_num1_M128","1 8 32 5 1 8 1 1 0 ", |
||||
"EU72_k3x3_cn96_g1_s1x1_d1x1_b1_in32x32_p1x1_num2_M128","2 7 16 2 1 1 16 1 0 ", |
||||
"EU72_k1x1_cn128_g1_s1x1_d1x1_b0_in64x64_p0x0_num1_M32","1 16 32 5 1 16 1 1 0 ", |
||||
"EU72_k1x1_cn512_g1_s1x1_d1x1_b1_in16x16_p0x0_num1_M112","8 2 8 2 1 1 8 1 0 ", |
||||
"EU72_k1x1_cn192_g1_s1x1_d1x1_b1_in32x32_p0x0_num2_M64","1 8 32 5 1 8 1 1 0 ", |
||||
"EU72_k1x1_cn528_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M160","1 8 32 5 1 8 1 1 0 ", |
||||
"EU72_k5x5_cn32_g1_s1x1_d1x1_b1_in16x16_p2x2_num1_M128","4 3 16 2 1 1 16 1 0 ", |
||||
"EU72_k1x1_cn64_g1_s1x1_d1x1_b1_in64x64_p0x0_num2_M64","1 16 32 5 1 16 1 1 0 ", |
||||
"EU72_k1x1_cn64_g1_s1x1_d1x1_b0_in128x128_p0x0_num1_M16","2 8 32 5 1 8 1 1 0 ", |
||||
"EU72_k1x1_cn512_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M144","1 8 32 5 1 8 1 1 0 ", |
||||
"EU72_k1x1_cn512_g1_s1x1_d1x1_b1_in16x16_p0x0_num1_M128","8 2 8 2 1 1 8 1 0 ", |
||||
"EU72_k1x1_cn16_g1_s1x1_d1x1_b0_in128x128_p0x0_num1_M64","1 16 32 5 1 16 1 1 0 ", |
||||
"EU72_k3x3_cn112_g1_s1x1_d1x1_b1_in16x16_p1x1_num2_M224","2 7 16 2 1 1 16 1 0 ", |
||||
"EU72_k1x1_cn832_g1_s1x1_d1x1_b1_in16x16_p0x0_num1_M256","4 6 8 2 1 1 8 1 0 ", |
||||
"EU72_k5x5_cn32_g1_s1x1_d1x1_b1_in32x32_p2x2_num1_M96","4 3 16 2 1 1 16 1 0 ", |
||||
"EU72_k1x1_cn512_g1_s2x2_d1x1_b0_in32x32_p0x0_num1_M256","1 8 32 5 1 8 1 1 0 ", |
||||
"EU72_k3x3_cn128_g1_s1x1_d1x1_b1_in32x32_p1x1_num2_M192","10 2 16 2 1 1 16 1 0 ", |
||||
"EU72_k1x1_cn480_g1_s1x1_d1x1_b1_in16x16_p0x0_num1_M64","12 2 8 2 1 1 8 1 0 ", |
||||
"EU72_k3x3_cn384_g2_s1x1_d1x1_b1_in16x16_p1x1_num1_M128","2 5 16 2 1 1 16 1 0 ", |
||||
"EU72_k1x1_cn832_g1_s1x1_d1x1_b1_in16x16_p0x0_num1_M48","4 6 8 2 1 1 8 1 0 ", |
||||
"EU72_k1x1_cn832_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M48","1 8 32 5 1 8 1 1 0 ", |
||||
"EU72_k1x1_cn528_g1_s1x1_d1x1_b1_in16x16_p0x0_num1_M256","8 3 8 2 1 1 8 1 0 ", |
||||
"EU72_k1x1_cn256_g1_s1x1_d1x1_b0_in64x64_p0x0_num1_M64","2 8 32 5 1 8 1 1 0 ", |
||||
"EU72_k3x3_cn144_g1_s1x1_d1x1_b1_in16x16_p1x1_num1_M288","2 5 16 2 1 1 16 1 0 ", |
||||
"EU72_k1x1_cn528_g1_s1x1_d1x1_b1_in16x16_p0x0_num1_M128","1 8 32 5 1 8 1 1 0 ", |
||||
"EU72_k1x1_cn512_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M128","1 8 32 5 1 8 1 1 0 ", |
||||
"EU72_k1x1_cn1024_g1_s1x1_d1x1_b0_in16x16_p0x0_num1_M256","1 8 32 5 1 8 1 1 0 ", |
||||
"EU72_k1x1_cn480_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M96","1 8 32 5 1 8 1 1 0 ", |
||||
"EU72_k1x1_cn512_g1_s2x2_d1x1_b0_in32x32_p0x0_num1_M1024","1 16 32 5 1 16 1 1 0 ", |
||||
"EU72_k1x1_cn2048_g1_s1x1_d1x1_b0_in16x16_p0x0_num1_M512","4 6 8 2 1 1 8 1 0 ", |
||||
"EU72_k3x3_cn512_g1_s1x1_d1x1_b0_in16x16_p1x1_num1_M512","2 5 16 2 1 1 16 1 0 ", |
||||
"EU72_k1x1_cn256_g1_s1x1_d1x1_b1_in32x32_p0x0_num1_M128","1 8 32 5 1 8 1 1 0 ", |
||||
"EU72_k1x1_cn192_g1_s1x1_d1x1_b1_in32x32_p0x0_num1_M16","8 2 8 2 1 1 8 1 0 ", |
||||
"EU72_k5x5_cn32_g1_s1x1_d1x1_b1_in16x16_p2x2_num1_M64","4 2 16 2 1 1 16 1 0 ", |
||||
"EU72_k1x1_cn528_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M128","8 3 8 2 1 1 8 1 0 ", |
||||
"EU72_k3x3_cn144_g1_s1x1_d1x1_b1_in16x16_p1x1_num2_M288","2 7 16 2 1 1 16 1 0 ", |
||||
"EU72_k3x3_cn16_g1_s1x1_d1x1_b1_in128x128_p1x1_num1_M16","2 5 16 2 1 1 16 1 0 ", |
||||
"EU72_k3x3_cn32_g1_s1x1_d8x8_b1_in64x64_p8x8_num1_M32","1 8 32 5 1 8 1 1 0 ", |
||||
"EU72_k1x1_cn64_g1_s1x1_d1x1_b0_in128x128_p0x0_num1_M4","8 3 8 2 1 1 8 1 0 ", |
||||
"EU72_k3x3_cn128_g1_s1x1_d1x1_b1_in16x16_p1x1_num2_M256","2 7 16 2 1 1 16 1 0 ", |
||||
"EU72_k3x3_cn256_g1_s1x1_d1x1_b0_in16x16_p1x1_num1_M256","2 5 16 2 1 1 16 1 0 ", |
||||
"EU72_k3x3_cn112_g1_s1x1_d1x1_b1_in16x16_p1x1_num1_M224","2 5 16 2 1 1 16 1 0 ", |
||||
"EU72_k1x1_cn832_g1_s1x1_d1x1_b1_in16x16_p0x0_num1_M128","1 8 32 5 1 8 1 1 0 ", |
||||
"EU72_k2x2_cn16_g1_s2x2_d1x1_b0_in256x256_p0x0_num1_M16","6 4 16 2 1 1 16 1 0 ", |
||||
"EU72_k1x1_cn832_g1_s1x1_d1x1_b1_in16x16_p0x0_num1_M192","4 6 8 2 1 1 8 1 0 ", |
||||
"EU72_k1x1_cn1024_g1_s2x2_d1x1_b0_in16x16_p0x0_num1_M512","1 8 32 5 1 8 1 1 0 ", |
||||
"EU72_k1x1_cn512_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M160","1 8 32 5 1 8 1 1 0 ", |
||||
"EU72_k3x3_cn384_g2_s1x1_d1x1_b1_in16x16_p1x1_num1_M192","2 5 16 2 1 1 16 1 0 ", |
||||
"EU72_k5x5_cn96_g2_s1x1_d1x1_b1_in32x32_p2x2_num1_M128","4 3 16 2 1 1 16 1 0 ", |
||||
"EU72_k1x1_cn512_g1_s1x1_d1x1_b1_in16x16_p0x0_num1_M32","8 2 8 2 1 1 8 1 0 ", |
||||
"EU72_k2x2_cn64_g1_s2x2_d1x1_b0_in128x128_p0x0_num1_M32","8 3 16 2 1 1 16 1 0 ", |
||||
"EU72_k1x1_cn64_g1_s1x1_d1x1_b0_in64x64_p0x0_num1_M256","1 16 32 5 1 16 1 1 0 ", |
||||
"EU72_k1x1_cn192_g1_s1x1_d1x1_b1_in32x32_p0x0_num1_M32","12 2 8 2 1 1 8 1 0 ", |
||||
"EU72_k5x5_cn16_g1_s1x1_d1x1_b1_in32x32_p2x2_num2_M32","4 2 16 2 1 1 16 1 0 ", |
||||
"EU72_k1x1_cn480_g1_s1x1_d1x1_b1_in16x16_p0x0_num1_M16","12 1 8 2 1 1 8 1 0 ", |
||||
"EU72_k11x11_cn3_g1_s4x4_d1x1_b1_in224x224_p0x0_num1_M96","1 8 32 5 1 8 1 1 0 ", |
||||
"EU72_k1x1_cn256_g1_s1x1_d1x1_b1_in32x32_p0x0_num2_M64","1 8 32 5 1 8 1 1 0 ", |
||||
"EU72_k1x1_cn832_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M256","4 7 8 2 1 1 8 1 0 ", |
||||
"EU72_k3x3_cn192_g1_s1x1_d1x1_b1_in16x16_p1x1_num1_M384","2 5 16 2 1 1 16 1 0 ", |
||||
"EU72_k1x1_cn480_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M16","12 1 8 2 1 1 8 1 0 ", |
||||
"EU72_k1x1_cn512_g1_s1x1_d1x1_b1_in16x16_p0x0_num1_M160","8 3 8 2 1 1 8 1 0 ", |
||||
"EU72_k1x1_cn256_g1_s2x2_d1x1_b0_in64x64_p0x0_num1_M512","1 16 32 5 1 16 1 1 0 ", |
||||
"EU72_k1x1_cn128_g1_s1x1_d1x1_b0_in64x64_p0x0_num1_M16","2 8 32 5 1 8 1 1 0 ", |
||||
"EU72_k1x1_cn832_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M192","4 6 8 2 1 1 8 1 0 ", |
||||
"EU72_k1x1_cn192_g1_s1x1_d1x1_b1_in32x32_p0x0_num2_M96","1 16 32 5 1 16 1 1 0 ", |
||||
"EU72_k1x1_cn528_g1_s1x1_d1x1_b1_in16x16_p0x0_num1_M32","12 1 8 2 1 1 8 1 0 ", |
||||
"EU72_k1x1_cn480_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M64","12 2 8 2 1 1 8 1 0 ", |
||||
"EU72_k3x3_cn256_g1_s1x1_d1x1_b1_in16x16_p1x1_num1_M384","2 7 16 2 1 1 16 1 0 ", |
||||
"EU72_k5x5_cn24_g1_s1x1_d1x1_b1_in16x16_p2x2_num1_M64","4 2 16 2 1 1 16 1 0 ", |
||||
"EU72_k1x1_cn512_g1_s1x1_d1x1_b1_in16x16_p0x0_num1_M24","12 1 8 2 1 1 8 1 0 ", |
||||
"EU72_k5x5_cn48_g1_s1x1_d1x1_b1_in16x16_p2x2_num2_M128","4 2 16 2 1 1 16 1 0 ", |
||||
"EU72_k1x1_cn832_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M160","4 6 8 2 1 1 8 1 0 ", |
||||
"EU72_k1x1_cn512_g1_s1x1_d1x1_b1_in16x16_p0x0_num1_M144","1 8 32 5 1 8 1 1 0 ", |
||||
"EU72_k3x3_cn32_g1_s1x1_d4x4_b1_in64x64_p4x4_num1_M32","1 8 32 5 1 8 1 1 0 ", |
||||
"EU72_k3x3_cn3_g1_s2x2_d1x1_b1_in256x256_p1x1_num1_M13","1 1 1 4 1 1 1 0 1 ", |
||||
"EU72_k3x3_cn32_g1_s1x1_d1x1_b1_in64x64_p1x1_num1_M32","6 4 16 2 1 1 16 1 0 ", |
||||
"EU72_k1x1_cn256_g1_s1x1_d1x1_b1_in32x32_p0x0_num2_M32","1 8 32 5 1 8 1 1 0 ", |
||||
"EU72_k3x3_cn64_g1_s1x1_d1x1_b0_in64x64_p1x1_num1_M64","2 7 16 2 1 1 16 1 0 ", |
||||
"EU72_k1x1_cn256_g1_s1x1_d1x1_b0_in16x16_p0x0_num1_M1024","2 8 32 5 1 8 1 1 0 ", |
||||
"EU72_k3x3_cn160_g1_s1x1_d1x1_b1_in16x16_p1x1_num2_M320","1 8 32 5 1 8 1 1 0 ", |
||||
"EU72_k1x5_cn32_g1_s1x1_d1x1_b1_in64x64_p0x2_num1_M32","4 6 16 2 1 1 16 1 0 ", |
||||
"EU72_k1x1_cn64_g1_s1x1_d1x1_b0_in64x64_p0x0_num1_M64","1 16 32 5 1 16 1 1 0 ", |
||||
"EU72_k1x1_cn832_g1_s1x1_d1x1_b1_in16x16_p0x0_num1_M160","4 6 8 2 1 1 8 1 0 ", |
||||
"EU72_k1x1_cn512_g1_s1x1_d1x1_b0_in32x32_p0x0_num1_M128","1 8 32 5 1 8 1 1 0 ", |
||||
"EU72_k1x1_cn192_g1_s1x1_d1x1_b1_in32x32_p0x0_num1_M64","8 3 8 2 1 1 8 1 0 ", |
||||
"EU72_k1x1_cn256_g1_s1x1_d1x1_b1_in32x32_p0x0_num1_M64","12 2 8 2 1 1 8 1 0 ", |
||||
"EU72_k1x1_cn256_g1_s1x1_d1x1_b1_in32x32_p0x0_num2_M128","2 8 32 5 1 8 1 1 0 ", |
||||
"EU72_k1x1_cn528_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M32","8 3 8 2 1 1 8 1 0 ", |
||||
"EU72_k1x1_cn192_g1_s1x1_d1x1_b1_in32x32_p0x0_num2_M32","1 8 32 5 1 8 1 1 0 ", |
||||
"EU72_k1x1_cn512_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M112","1 8 32 5 1 8 1 1 0 ", |
||||
"EU72_k4x4_cn3_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M2","1 3 16 2 1 1 16 1 0 ", |
||||
"EU72_k1x1_cn1024_g1_s2x2_d1x1_b0_in16x16_p0x0_num1_M2048","1 8 32 5 1 8 1 1 0 ", |
||||
"EU72_k1x1_cn256_g1_s2x2_d1x1_b0_in64x64_p0x0_num1_M128","1 8 32 5 1 8 1 1 0 ", |
||||
"EU72_k7x7_cn3_g1_s2x2_d1x1_b1_in224x224_p3x3_num2_M64","1 8 32 5 1 8 1 1 0 ", |
||||
"EU72_k1x1_cn512_g1_s1x1_d1x1_b0_in16x16_p0x0_num1_M2048","1 8 32 5 1 8 1 1 0 ", |
||||
"EU72_k5x5_cn32_g1_s1x1_d1x1_b1_in16x16_p2x2_num2_M64","8 1 16 2 1 1 16 1 0 ", |
||||
"EU72_k3x3_cn96_g1_s1x1_d1x1_b1_in16x16_p1x1_num2_M208","2 7 16 2 1 1 16 1 0 ", |
||||
// Below is the information for OpenCL based on which these configurations tuned
|
||||
/*******************************************************************************
|
||||
Number of platforms 1 |
||||
Platform Name Intel(R) OpenCL |
||||
Platform Vendor Intel(R) Corporation |
||||
Platform Version OpenCL 2.0 |
||||
Platform Profile FULL_PROFILE |
||||
Platform Extensions cl_intel_accelerator cl_intel_advanced_motion_estimation cl_intel_driver_diagnostics cl_intel_motion_estimation cl_intel_packed_yuv cl_intel_required_subgroup_size cl_intel_subgroups cl_intel_subgroups_short cl_intel_va_api_media_sharing cl_khr_3d_image_writes cl_khr_byte_addressable_store cl_khr_depth_images cl_khr_fp16 cl_khr_fp64 cl_khr_global_int32_base_atomics cl_khr_global_int32_extended_atomics cl_khr_icd cl_khr_image2d_from_buffer cl_khr_local_int32_base_atomics cl_khr_local_int32_extended_atomics cl_khr_mipmap_image cl_khr_mipmap_image_writes cl_khr_spir cl_khr_subgroups |
||||
Platform Extensions function suffix INTEL |
||||
|
||||
Platform Name Intel(R) OpenCL |
||||
Number of devices 1 |
||||
Device Name Intel(R) HD Graphics |
||||
Device Vendor Intel(R) Corporation |
||||
Device Vendor ID 0x8086 |
||||
Device Version OpenCL 2.0 |
||||
Driver Version 16.5.56875 |
||||
Device OpenCL C Version OpenCL C 2.0 ( using IGC ) |
||||
Device Type GPU |
||||
Device Profile FULL_PROFILE |
||||
Max compute units 48 |
||||
Max clock frequency 950MHz |
||||
Device Partition (core) |
||||
Max number of sub-devices 0 |
||||
Supported partition types by <unknown> (0x7F4B00000000) |
||||
Max work item dimensions 3 |
||||
Max work item sizes 256x256x256 |
||||
Max work group size 256 |
||||
Preferred work group size multiple 32 |
||||
Preferred / native vector sizes |
||||
char 16 / 16 |
||||
short 8 / 8 |
||||
int 4 / 4 |
||||
long 1 / 1 |
||||
half 8 / 8 (cl_khr_fp16) |
||||
float 1 / 1 |
||||
double 1 / 1 (cl_khr_fp64) |
||||
Half-precision Floating-point support (cl_khr_fp16) |
||||
Denormals Yes |
||||
Infinity and NANs Yes |
||||
Round to nearest Yes |
||||
Round to zero Yes |
||||
Round to infinity Yes |
||||
IEEE754-2008 fused multiply-add Yes |
||||
Support is emulated in software No |
||||
Correctly-rounded divide and sqrt operations No |
||||
Single-precision Floating-point support (core) |
||||
Denormals Yes |
||||
Infinity and NANs Yes |
||||
Round to nearest Yes |
||||
Round to zero Yes |
||||
Round to infinity Yes |
||||
IEEE754-2008 fused multiply-add Yes |
||||
Support is emulated in software No |
||||
Correctly-rounded divide and sqrt operations Yes |
||||
Double-precision Floating-point support (cl_khr_fp64) |
||||
Denormals Yes |
||||
Infinity and NANs Yes |
||||
Round to nearest Yes |
||||
Round to zero Yes |
||||
Round to infinity Yes |
||||
IEEE754-2008 fused multiply-add Yes |
||||
Support is emulated in software No |
||||
Correctly-rounded divide and sqrt operations No |
||||
Address bits 64, Little-Endian |
||||
Global memory size 13361912218 (12.44GiB) |
||||
Error Correction support No |
||||
Max memory allocation 4294959103 (4GiB) |
||||
Unified memory for Host and Device Yes |
||||
Shared Virtual Memory (SVM) capabilities (core) |
||||
Coarse-grained buffer sharing Yes |
||||
Fine-grained buffer sharing No |
||||
Fine-grained system sharing No |
||||
Atomics No |
||||
Minimum alignment for any data type 128 bytes |
||||
Alignment of base address 1024 bits (128 bytes) |
||||
Preferred alignment for atomics |
||||
SVM 64 bytes |
||||
Global 64 bytes |
||||
Local 64 bytes |
||||
Max size for global variable 65536 (64KiB) |
||||
Preferred total size of global vars 4294959103 (4GiB) |
||||
Global Memory cache type Read/Write |
||||
Global Memory cache size 1048576 |
||||
Global Memory cache line 64 bytes |
||||
Image support Yes |
||||
Max number of samplers per kernel 16 |
||||
Max size for 1D images from buffer 268434943 pixels |
||||
Max 1D or 2D image array size 2048 images |
||||
Base address alignment for 2D image buffers 4 bytes |
||||
Pitch alignment for 2D image buffers 4 bytes |
||||
Max 2D image size 16384x16384 pixels |
||||
Max 3D image size 16384x16384x2048 pixels |
||||
Max number of read image args 128 |
||||
Max number of write image args 128 |
||||
Max number of read/write image args 128 |
||||
Max number of pipe args 16 |
||||
Max active pipe reservations 1 |
||||
Max pipe packet size 1024 |
||||
Local memory type Local |
||||
Local memory size 65536 (64KiB) |
||||
Max constant buffer size 4294959103 (4GiB) |
||||
Max number of constant args 8 |
||||
Max size of kernel argument 1024 |
||||
Queue properties (on host) |
||||
Out-of-order execution Yes |
||||
Profiling Yes |
||||
Queue properties (on device) |
||||
Out-of-order execution Yes |
||||
Profiling Yes |
||||
Preferred size 131072 (128KiB) |
||||
Max size 67108864 (64MiB) |
||||
Max queues on device 1 |
||||
Max events on device 1024 |
||||
Prefer user sync for interop Yes |
||||
Profiling timer resolution 83ns |
||||
Execution capabilities |
||||
Run OpenCL kernels Yes |
||||
Run native kernels No |
||||
SPIR versions 1.2 |
||||
printf() buffer size 4194304 (4MiB) |
||||
Built-in kernels block_motion_estimate_intel;block_advanced_motion_estimate_check_intel;block_advanced_motion_estimate_bidirectional_check_intel |
||||
Motion Estimation accelerator version (Intel) 2 |
||||
Device Available Yes |
||||
Compiler Available Yes |
||||
Linker Available Yes |
||||
Device Extensions cl_intel_accelerator cl_intel_advanced_motion_estimation cl_intel_driver_diagnostics cl_intel_motion_estimation cl_intel_packed_yuv cl_intel_required_subgroup_size cl_intel_subgroups cl_intel_subgroups_short cl_intel_va_api_media_sharing cl_khr_3d_image_writes cl_khr_byte_addressable_store cl_khr_depth_images cl_khr_fp16 cl_khr_fp64 cl_khr_global_int32_base_atomics cl_khr_global_int32_extended_atomics cl_khr_icd cl_khr_image2d_from_buffer cl_khr_local_int32_base_atomics cl_khr_local_int32_extended_atomics cl_khr_mipmap_image cl_khr_mipmap_image_writes cl_khr_spir cl_khr_subgroups |
||||
|
||||
NULL platform behavior |
||||
clGetPlatformInfo(NULL, CL_PLATFORM_NAME, ...) No platform |
||||
clGetDeviceIDs(NULL, CL_DEVICE_TYPE_ALL, ...) No platform |
||||
clCreateContext(NULL, ...) [default] No platform |
||||
clCreateContext(NULL, ...) [other] Success [INTEL] |
||||
clCreateContextFromType(NULL, CL_DEVICE_TYPE_CPU) No platform |
||||
clCreateContextFromType(NULL, CL_DEVICE_TYPE_GPU) No platform |
||||
clCreateContextFromType(NULL, CL_DEVICE_TYPE_ACCELERATOR) No platform |
||||
clCreateContextFromType(NULL, CL_DEVICE_TYPE_CUSTOM) No platform |
||||
clCreateContextFromType(NULL, CL_DEVICE_TYPE_ALL) No platform |
||||
********************************************************************************/ |
||||
"EU48_k1x1_cn512_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M32","8 3 8 2 1 1 8 1 0 ", |
||||
"EU48_k1x1_cn480_g1_s1x1_d1x1_b1_in16x16_p0x0_num1_M64","8 2 8 2 1 1 8 1 0 ", |
||||
"EU48_k1x1_cn32_g1_s1x1_d1x1_b0_in64x64_p0x0_num1_M128","1 16 32 5 1 16 1 1 0 ", |
||||
"EU48_k5x5_cn16_g1_s1x1_d1x1_b1_in32x32_p2x2_num1_M32","8 1 16 2 1 1 16 1 0 ", |
||||
"EU48_k1x1_cn512_g1_s1x1_d1x1_b1_in16x16_p0x0_num1_M144","1 8 32 5 1 8 1 1 0 ", |
||||
"EU48_k1x1_cn192_g1_s1x1_d1x1_b1_in32x32_p0x0_num1_M96","1 16 32 5 1 16 1 1 0 ", |
||||
"EU48_k3x3_cn128_g1_s1x1_d1x1_b0_in32x32_p1x1_num1_M128","6 4 16 2 1 1 16 1 0 ", |
||||
"EU48_k1x1_cn256_g1_s1x1_d1x1_b1_in32x32_p0x0_num1_M128","2 8 32 5 1 8 1 1 0 ", |
||||
"EU48_k5x5_cn32_g1_s1x1_d1x1_b1_in16x16_p2x2_num1_M64","8 1 16 2 1 1 16 1 0 ", |
||||
"EU48_k2x2_cn16_g1_s2x2_d1x1_b0_in256x256_p0x0_num1_M16","2 7 16 2 1 1 16 1 0 ", |
||||
"EU48_k3x3_cn4_g1_s1x1_d1x1_b1_in256x256_p1x1_num1_M4","6 4 8 2 1 1 8 1 0 ", |
||||
"EU48_k1x1_cn128_g1_s1x1_d1x1_b0_in32x32_p0x0_num1_M512","2 8 32 5 1 8 1 1 0 ", |
||||
"EU48_k1x1_cn512_g1_s1x1_d1x1_b1_in16x16_p0x0_num1_M112","8 3 8 2 1 1 8 1 0 ", |
||||
"EU48_k3x3_cn512_g1_s1x1_d1x1_b0_in16x16_p1x1_num1_M512","2 7 8 2 1 1 8 1 0 ", |
||||
"EU48_k1x1_cn192_g1_s1x1_d1x1_b1_in32x32_p0x0_num2_M64","2 8 32 5 1 8 1 1 0 ", |
||||
"EU48_k1x1_cn832_g1_s1x1_d1x1_b1_in16x16_p0x0_num1_M384","4 6 8 2 1 1 8 1 0 ", |
||||
"EU48_k1x1_cn480_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M16","8 2 8 2 1 1 8 1 0 ", |
||||
"EU48_k1x1_cn480_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M96","1 8 32 5 1 8 1 1 0 ", |
||||
"EU48_k1x1_cn256_g1_s1x1_d1x1_b0_in16x16_p0x0_num1_M1024","1 8 32 5 1 8 1 1 0 ", |
||||
"EU48_k1x1_cn832_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M192","4 7 8 2 1 1 8 1 0 ", |
||||
"EU48_k1x1_cn128_g1_s1x1_d1x1_b0_in64x64_p0x0_num1_M64","1 8 32 5 1 8 1 1 0 ", |
||||
"EU48_k3x3_cn160_g1_s1x1_d1x1_b1_in16x16_p1x1_num2_M320","2 7 16 2 1 1 16 1 0 ", |
||||
"EU48_k7x7_cn3_g1_s2x2_d1x1_b1_in224x224_p3x3_num1_M64","1 8 32 5 1 8 1 1 0 ", |
||||
"EU48_k5x5_cn16_g1_s1x1_d1x1_b1_in16x16_p2x2_num1_M48","4 2 16 2 1 1 16 1 0 ", |
||||
"EU48_k1x1_cn528_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M256","2 8 32 5 1 8 1 1 0 ", |
||||
"EU48_k3x3_cn64_g1_s1x1_d1x1_b1_in64x64_p1x1_num1_M192","2 8 16 2 1 1 16 1 0 ", |
||||
"EU48_k11x11_cn3_g1_s4x4_d1x1_b1_in224x224_p0x0_num1_M96","1 8 32 5 1 8 1 1 0 ", |
||||
"EU48_k1x1_cn512_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M112","1 8 32 5 1 8 1 1 0 ", |
||||
"EU48_k1x1_cn512_g1_s1x1_d1x1_b1_in16x16_p0x0_num1_M32","12 1 8 2 1 1 8 1 0 ", |
||||
"EU48_k1x1_cn512_g1_s2x2_d1x1_b0_in32x32_p0x0_num1_M256","1 8 32 5 1 8 1 1 0 ", |
||||
"EU48_k1x1_cn512_g1_s1x1_d1x1_b1_in16x16_p0x0_num1_M128","12 2 8 2 1 1 8 1 0 ", |
||||
"EU48_k1x1_cn512_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M64","8 2 8 2 1 1 8 1 0 ", |
||||
"EU48_k3x3_cn128_g1_s1x1_d1x1_b1_in32x32_p1x1_num2_M192","2 7 16 2 1 1 16 1 0 ", |
||||
"EU48_k3x3_cn128_g1_s1x1_d1x1_b1_in16x16_p1x1_num1_M256","2 5 16 2 1 1 16 1 0 ", |
||||
"EU48_k1x1_cn16_g1_s1x1_d1x1_b0_in256x256_p0x0_num1_M4","8 3 8 2 1 1 8 1 0 ", |
||||
"EU48_k1x5_cn32_g1_s1x1_d1x1_b1_in64x64_p0x2_num1_M32","4 7 16 2 1 1 16 1 0 ", |
||||
"EU48_k1x1_cn832_g1_s1x1_d1x1_b1_in16x16_p0x0_num1_M256","4 7 8 2 1 1 8 1 0 ", |
||||
"EU48_k3x3_cn3_g1_s2x2_d1x1_b1_in256x256_p1x1_num1_M13","1 1 1 4 1 1 1 0 1 ", |
||||
"EU48_k11x7_cn3_g1_s3x4_d1x1_b1_in64x64_p3x2_num1_M64","4 1 16 2 1 1 16 1 0 ", |
||||
"EU48_k1x1_cn480_g1_s1x1_d1x1_b1_in16x16_p0x0_num1_M96","8 3 8 2 1 1 8 1 0 ", |
||||
"EU48_k1x1_cn128_g1_s1x1_d1x1_b0_in64x64_p0x0_num1_M16","1 8 32 5 1 8 1 1 0 ", |
||||
"EU48_k3x3_cn32_g1_s1x1_d2x2_b1_in64x64_p2x2_num1_M32","3 3 16 2 1 1 16 1 0 ", |
||||
"EU48_k3x3_cn32_g1_s1x1_d8x8_b1_in64x64_p8x8_num1_M32","1 8 32 5 1 8 1 1 0 ", |
||||
"EU48_k1x1_cn192_g1_s1x1_d1x1_b1_in32x32_p0x0_num2_M96","2 8 32 5 1 8 1 1 0 ", |
||||
"EU48_k2x2_cn64_g1_s2x2_d1x1_b0_in128x128_p0x0_num1_M32","4 4 16 2 1 1 16 1 0 ", |
||||
"EU48_k1x1_cn528_g1_s1x1_d1x1_b1_in16x16_p0x0_num1_M128","4 3 8 2 1 1 8 1 0 ", |
||||
"EU48_k1x1_cn512_g1_s1x1_d1x1_b0_in32x32_p0x0_num1_M128","2 8 32 5 1 8 1 1 0 ", |
||||
"EU48_k3x3_cn16_g1_s1x1_d1x1_b1_in128x128_p1x1_num1_M16","2 7 16 2 1 1 16 1 0 ", |
||||
"EU48_k1x1_cn4_g1_s1x1_d1x1_b0_in256x256_p0x0_num1_M16","2 8 32 5 1 8 1 1 0 ", |
||||
"EU48_k3x3_cn384_g2_s1x1_d1x1_b1_in16x16_p1x1_num1_M128","6 2 8 2 1 1 8 1 0 ", |
||||
"EU48_k1x1_cn3_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M4","4 2 8 2 1 1 8 1 0 ", |
||||
"EU48_k1x1_cn512_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M144","1 8 32 5 1 8 1 1 0 ", |
||||
"EU48_k1x1_cn512_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M160","1 8 32 5 1 8 1 1 0 ", |
||||
"EU48_k1x1_cn832_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M384","1 8 32 5 1 8 1 1 0 ", |
||||
"EU48_k1x1_cn256_g1_s2x2_d1x1_b0_in64x64_p0x0_num1_M128","1 16 32 5 1 16 1 1 0 ", |
||||
"EU48_k1x1_cn480_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M192","1 8 32 5 1 8 1 1 0 ", |
||||
"EU48_k1x1_cn1024_g1_s2x2_d1x1_b0_in16x16_p0x0_num1_M2048","1 16 32 5 1 16 1 1 0 ", |
||||
"EU48_k3x3_cn192_g1_s1x1_d1x1_b1_in16x16_p1x1_num2_M384","1 8 32 5 1 8 1 1 0 ", |
||||
"EU48_k1x1_cn256_g1_s1x1_d1x1_b1_in32x32_p0x0_num2_M128","1 8 32 5 1 8 1 1 0 ", |
||||
"EU48_k1x1_cn16_g1_s1x1_d1x1_b0_in128x128_p0x0_num1_M64","1 16 32 5 1 16 1 1 0 ", |
||||
"EU48_k1x1_cn832_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M32","4 7 8 2 1 1 8 1 0 ", |
||||
"EU48_k3x3_cn384_g2_s1x1_d1x1_b1_in16x16_p1x1_num1_M192","2 5 16 2 1 1 16 1 0 ", |
||||
"EU48_k3x3_cn96_g1_s1x1_d1x1_b1_in32x32_p1x1_num1_M128","6 4 16 2 1 1 16 1 0 ", |
||||
"EU48_k1x1_cn528_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M32","8 3 8 2 1 1 8 1 0 ", |
||||
"EU48_k1x1_cn480_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M64","12 2 8 2 1 1 8 1 0 ", |
||||
"EU48_k1x1_cn256_g1_s1x1_d1x1_b1_in32x32_p0x0_num2_M64","2 8 32 5 1 8 1 1 0 ", |
||||
"EU48_k1x1_cn2048_g1_s1x1_d1x1_b0_in16x16_p0x0_num1_M512","4 7 8 2 1 1 8 1 0 ", |
||||
"EU48_k1x1_cn256_g1_s1x1_d1x1_b1_in32x32_p0x0_num1_M64","12 2 8 2 1 1 8 1 0 ", |
||||
"EU48_k3x3_cn112_g1_s1x1_d1x1_b1_in16x16_p1x1_num1_M224","1 8 32 5 1 8 1 1 0 ", |
||||
"EU48_k3x3_cn256_g1_s1x1_d1x1_b1_in16x16_p1x1_num1_M384","2 7 16 2 1 1 16 1 0 ", |
||||
"EU48_k3x3_cn32_g1_s1x1_d4x4_b1_in64x64_p4x4_num1_M32","1 8 32 5 1 8 1 1 0 ", |
||||
"EU48_k1x1_cn528_g1_s1x1_d1x1_b1_in16x16_p0x0_num1_M256","1 8 32 5 1 8 1 1 0 ", |
||||
"EU48_k3x3_cn192_g1_s1x1_d1x1_b1_in16x16_p1x1_num1_M384","2 4 16 2 1 1 16 1 0 ", |
||||
"EU48_k3x3_cn144_g1_s1x1_d1x1_b1_in16x16_p1x1_num2_M288","2 4 16 2 1 1 16 1 0 ", |
||||
"EU48_k1x1_cn832_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M48","4 6 8 2 1 1 8 1 0 ", |
||||
"EU48_k5x5_cn32_g1_s1x1_d1x1_b1_in16x16_p2x2_num2_M64","8 1 16 2 1 1 16 1 0 ", |
||||
"EU48_k1x1_cn528_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M160","1 8 32 5 1 8 1 1 0 ", |
||||
"EU48_k1x1_cn528_g1_s1x1_d1x1_b1_in16x16_p0x0_num1_M160","12 2 8 2 1 1 8 1 0 ", |
||||
"EU48_k1x1_cn256_g1_s1x1_d1x1_b0_in64x64_p0x0_num1_M64","2 8 32 5 1 8 1 1 0 ", |
||||
"EU48_k1x1_cn832_g1_s1x1_d1x1_b1_in16x16_p0x0_num1_M192","4 6 8 2 1 1 8 1 0 ", |
||||
"EU48_k5x5_cn96_g2_s1x1_d1x1_b1_in32x32_p2x2_num1_M128","4 5 16 2 1 1 16 1 0 ", |
||||
"EU48_k3x3_cn256_g1_s1x1_d1x1_b0_in16x16_p1x1_num1_M256","2 6 16 2 1 1 16 1 0 ", |
||||
"EU48_k1x1_cn528_g1_s1x1_d1x1_b1_in16x16_p0x0_num1_M32","8 3 8 2 1 1 8 1 0 ", |
||||
"EU48_k5x5_cn16_g1_s1x1_d1x1_b1_in16x16_p2x2_num2_M48","4 2 16 2 1 1 16 1 0 ", |
||||
"EU48_k3x3_cn64_g1_s1x1_d1x1_b0_in64x64_p1x1_num1_M64","10 2 16 2 1 1 16 1 0 ", |
||||
"EU48_k1x1_cn832_g1_s1x1_d1x1_b1_in16x16_p0x0_num1_M160","4 6 8 2 1 1 8 1 0 ", |
||||
"EU48_k1x1_cn832_g1_s1x1_d1x1_b1_in16x16_p0x0_num1_M32","4 5 8 2 1 1 8 1 0 ", |
||||
"EU48_k3x3_cn96_g1_s1x1_d1x1_b1_in16x16_p1x1_num1_M208","2 5 16 2 1 1 16 1 0 ", |
||||
"EU48_k1x1_cn832_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M256","4 6 8 2 1 1 8 1 0 ", |
||||
"EU48_k1x1_cn512_g1_s1x1_d1x1_b0_in16x16_p0x0_num1_M2048","2 8 32 5 1 8 1 1 0 ", |
||||
"EU48_k1x1_cn832_g1_s1x1_d1x1_b1_in16x16_p0x0_num1_M48","4 6 8 2 1 1 8 1 0 ", |
||||
"EU48_k1x1_cn64_g1_s1x1_d1x1_b0_in64x64_p0x0_num1_M256","1 16 32 5 1 16 1 1 0 ", |
||||
"EU48_k3x3_cn112_g1_s1x1_d1x1_b1_in16x16_p1x1_num2_M224","2 7 16 2 1 1 16 1 0 ", |
||||
"EU48_k1x1_cn192_g1_s1x1_d1x1_b1_in32x32_p0x0_num2_M32","1 8 32 5 1 8 1 1 0 ", |
||||
"EU48_k5x1_cn32_g1_s1x1_d1x1_b0_in64x64_p2x0_num1_M32","2 8 32 5 1 8 1 1 0 ", |
||||
"EU48_k1x1_cn64_g1_s1x1_d1x1_b0_in64x64_p0x0_num1_M64","1 8 32 5 1 8 1 1 0 ", |
||||
"EU48_k3x3_cn144_g1_s1x1_d1x1_b1_in16x16_p1x1_num1_M288","2 7 16 2 1 1 16 1 0 ", |
||||
"EU48_k3x3_cn128_g1_s1x1_d1x1_b1_in32x32_p1x1_num1_M192","2 7 16 2 1 1 16 1 0 ", |
||||
"EU48_k5x5_cn16_g1_s1x1_d1x1_b1_in32x32_p2x2_num2_M32","4 3 16 2 1 1 16 1 0 ", |
||||
"EU48_k5x5_cn32_g1_s1x1_d1x1_b1_in32x32_p2x2_num2_M96","4 2 16 2 1 1 16 1 0 ", |
||||
"EU48_k3x3_cn96_g1_s1x1_d1x1_b1_in16x16_p1x1_num2_M208","2 5 16 2 1 1 16 1 0 ", |
||||
"EU48_k5x5_cn32_g1_s1x1_d1x1_b1_in32x32_p2x2_num1_M96","4 2 16 2 1 1 16 1 0 ", |
||||
"EU48_k1x1_cn512_g1_s1x1_d1x1_b1_in16x16_p0x0_num1_M24","12 1 8 2 1 1 8 1 0 ", |
||||
"EU48_k1x1_cn64_g1_s1x1_d1x1_b0_in128x128_p0x0_num1_M16","4 7 16 2 1 1 16 1 0 ", |
||||
"EU48_k1x1_cn256_g1_s2x2_d1x1_b0_in64x64_p0x0_num1_M512","2 8 32 5 1 8 1 1 0 ", |
||||
"EU48_k1x1_cn1024_g1_s1x1_d1x1_b0_in16x16_p0x0_num1_M256","1 8 32 5 1 8 1 1 0 ", |
||||
"EU48_k1x1_cn512_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M128","1 8 32 5 1 8 1 1 0 ", |
||||
"EU48_k1x1_cn256_g1_s1x1_d1x1_b1_in32x32_p0x0_num2_M32","1 8 32 5 1 8 1 1 0 ", |
||||
"EU48_k3x3_cn160_g1_s1x1_d1x1_b1_in16x16_p1x1_num1_M320","2 8 16 2 1 1 16 1 0 ", |
||||
"EU48_k3x3_cn64_g1_s1x1_d1x1_b1_in64x64_p1x1_num2_M192","6 4 16 2 1 1 16 1 0 ", |
||||
"EU48_k5x5_cn32_g1_s1x1_d1x1_b1_in16x16_p2x2_num2_M128","4 3 16 2 1 1 16 1 0 ", |
||||
"EU48_k1x1_cn832_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M160","4 6 8 2 1 1 8 1 0 ", |
||||
"EU48_k1x1_cn528_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M128","1 8 32 5 1 8 1 1 0 ", |
||||
"EU48_k3x3_cn3_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M5","2 3 8 2 1 1 8 1 0 ", |
||||
"EU48_k1x1_cn192_g1_s1x1_d1x1_b1_in32x32_p0x0_num1_M64","1 8 32 5 1 8 1 1 0 ", |
||||
"EU48_k1x1_cn480_g1_s1x1_d1x1_b1_in16x16_p0x0_num1_M192","1 8 32 5 1 8 1 1 0 ", |
||||
"EU48_k1x1_cn192_g1_s1x1_d1x1_b1_in32x32_p0x0_num1_M32","8 3 8 2 1 1 8 1 0 ", |
||||
"EU48_k1x1_cn64_g1_s1x1_d1x1_b1_in64x64_p0x0_num2_M64","1 16 32 5 1 16 1 1 0 ", |
||||
"EU48_k5x5_cn32_g1_s1x1_d1x1_b1_in16x16_p2x2_num1_M128","8 2 16 2 1 1 16 1 0 ", |
||||
"EU48_k7x7_cn3_g1_s2x2_d1x1_b1_in224x224_p3x3_num2_M64","1 8 32 5 1 8 1 1 0 ", |
||||
"EU48_k1x1_cn832_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M128","4 6 8 2 1 1 8 1 0 ", |
||||
"EU48_k1x1_cn128_g1_s1x1_d1x1_b0_in64x64_p0x0_num1_M32","1 16 32 5 1 16 1 1 0 ", |
||||
"EU48_k1x1_cn1024_g1_s2x2_d1x1_b0_in16x16_p0x0_num1_M512","1 8 32 5 1 8 1 1 0 ", |
||||
"EU48_k5x5_cn24_g1_s1x1_d1x1_b1_in16x16_p2x2_num2_M64","4 4 16 2 1 1 16 1 0 ", |
||||
"EU48_k1x1_cn256_g1_s1x1_d1x1_b1_in32x32_p0x0_num1_M32","8 3 8 2 1 1 8 1 0 ", |
||||
"EU48_k1x1_cn512_g1_s1x1_d1x1_b1_in16x16_p0x0_num1_M160","12 2 8 2 1 1 8 1 0 ", |
||||
"EU48_k1x1_cn512_g1_s1x1_d1x1_b1_in16x16_p0x0_num1_M64","8 3 8 2 1 1 8 1 0 ", |
||||
"EU48_k1x1_cn512_g1_s2x2_d1x1_b0_in32x32_p0x0_num1_M1024","1 8 32 5 1 8 1 1 0 ", |
||||
"EU48_k5x4_cn6_g3_s3x2_d1x1_b1_in128x80_p1x0_num2_M4","1 1 1 4 1 1 1 0 1 ", |
||||
"EU48_k3x3_cn128_g1_s1x1_d1x1_b1_in16x16_p1x1_num2_M256","2 7 16 2 1 1 16 1 0 ", |
||||
"EU48_k1x1_cn512_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M24","8 2 8 2 1 1 8 1 0 ", |
||||
"EU48_k1x1_cn480_g1_s1x1_d1x1_b1_in16x16_p0x0_num1_M16","12 1 8 2 1 1 8 1 0 ", |
||||
"EU48_k3x3_cn96_g1_s1x1_d1x1_b1_in32x32_p1x1_num2_M128","10 2 16 2 1 1 16 1 0 ", |
||||
"EU48_k1x1_cn64_g1_s1x1_d1x1_b1_in64x64_p0x0_num1_M64","1 16 32 5 1 16 1 1 0 ", |
||||
"EU48_k3x3_cn32_g1_s1x1_d16x16_b1_in64x64_p16x16_num1_M32","1 16 32 5 1 16 1 1 0 ", |
||||
"EU48_k1x1_cn832_g1_s1x1_d1x1_b1_in16x16_p0x0_num1_M128","4 7 8 2 1 1 8 1 0 ", |
||||
"EU48_k1x1_cn192_g1_s1x1_d1x1_b1_in32x32_p0x0_num1_M16","12 2 8 2 1 1 8 1 0 ", |
||||
"EU48_k4x4_cn3_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M2","1 4 16 2 1 1 16 1 0 ", |
||||
"EU48_k1x1_cn64_g1_s1x1_d1x1_b0_in128x128_p0x0_num1_M4","8 2 8 2 1 1 8 1 0 ", |
||||
"EU48_k5x5_cn24_g1_s1x1_d1x1_b1_in16x16_p2x2_num1_M64","4 2 16 2 1 1 16 1 0 ", |
||||
"EU48_k1x1_cn192_g1_s1x1_d1x1_b1_in32x32_p0x0_num2_M16","1 8 32 5 1 8 1 1 0 ", |
||||
"EU48_k5x5_cn48_g1_s1x1_d1x1_b1_in16x16_p2x2_num2_M128","4 2 16 2 1 1 16 1 0 ", |
||||
"EU48_k3x3_cn32_g1_s1x1_d1x1_b1_in64x64_p1x1_num1_M32","2 8 16 2 1 1 16 1 0 ", |
||||
"EU48_k5x5_cn48_g1_s1x1_d1x1_b1_in16x16_p2x2_num1_M128","4 2 16 2 1 1 16 1 0 ", |
||||
// Below is the information for OpenCL based on which these configurations tuned
|
||||
/*******************************************************************************
|
||||
Number of platforms 1 |
||||
Platform Name Intel(R) OpenCL |
||||
Platform Vendor Intel(R) Corporation |
||||
Platform Version OpenCL 2.0 |
||||
Platform Profile FULL_PROFILE |
||||
Platform Extensions cl_intel_accelerator cl_intel_advanced_motion_estimation cl_intel_device_side_avc_motion_estimation cl_intel_driver_diagnostics cl_intel_media_block_io cl_intel_motion_estimation cl_intel_planar_yuv cl_intel_packed_yuv cl_intel_required_subgroup_size cl_intel_subgroups cl_intel_subgroups_short cl_intel_va_api_media_sharing cl_khr_3d_image_writes cl_khr_byte_addressable_store cl_khr_depth_images cl_khr_fp16 cl_khr_fp64 cl_khr_global_int32_base_atomics cl_khr_global_int32_extended_atomics cl_khr_icd cl_khr_image2d_from_buffer cl_khr_local_int32_base_atomics cl_khr_local_int32_extended_atomics cl_khr_mipmap_image cl_khr_mipmap_image_writes cl_khr_spir cl_khr_subgroups |
||||
Platform Extensions function suffix INTEL |
||||
|
||||
Platform Name Intel(R) OpenCL |
||||
Number of devices 1 |
||||
Device Name Intel(R) HD Graphics |
||||
Device Vendor Intel(R) Corporation |
||||
Device Vendor ID 0x8086 |
||||
Device Version OpenCL 2.0 |
||||
Driver Version 16.5.59288 |
||||
Device OpenCL C Version OpenCL C 2.0 |
||||
Device Type GPU |
||||
Device Profile FULL_PROFILE |
||||
Max compute units 24 |
||||
Max clock frequency 1050MHz |
||||
Device Partition (core) |
||||
Max number of sub-devices 0 |
||||
Supported partition types by <unknown> (0x7F5100000000) |
||||
Max work item dimensions 3 |
||||
Max work item sizes 256x256x256 |
||||
Max work group size 256 |
||||
Preferred work group size multiple 32 |
||||
Preferred / native vector sizes |
||||
char 16 / 16 |
||||
short 8 / 8 |
||||
int 4 / 4 |
||||
long 1 / 1 |
||||
half 8 / 8 (cl_khr_fp16) |
||||
float 1 / 1 |
||||
double 1 / 1 (cl_khr_fp64) |
||||
Half-precision Floating-point support (cl_khr_fp16) |
||||
Denormals Yes |
||||
Infinity and NANs Yes |
||||
Round to nearest Yes |
||||
Round to zero Yes |
||||
Round to infinity Yes |
||||
IEEE754-2008 fused multiply-add Yes |
||||
Support is emulated in software No |
||||
Correctly-rounded divide and sqrt operations No |
||||
Single-precision Floating-point support (core) |
||||
Denormals Yes |
||||
Infinity and NANs Yes |
||||
Round to nearest Yes |
||||
Round to zero Yes |
||||
Round to infinity Yes |
||||
IEEE754-2008 fused multiply-add Yes |
||||
Support is emulated in software No |
||||
Correctly-rounded divide and sqrt operations Yes |
||||
Double-precision Floating-point support (cl_khr_fp64) |
||||
Denormals Yes |
||||
Infinity and NANs Yes |
||||
Round to nearest Yes |
||||
Round to zero Yes |
||||
Round to infinity Yes |
||||
IEEE754-2008 fused multiply-add Yes |
||||
Support is emulated in software No |
||||
Correctly-rounded divide and sqrt operations No |
||||
Address bits 64, Little-Endian |
||||
Global memory size 6588802663 (6.136GiB) |
||||
Error Correction support No |
||||
Max memory allocation 3294401331 (3.068GiB) |
||||
Unified memory for Host and Device Yes |
||||
Shared Virtual Memory (SVM) capabilities (core) |
||||
Coarse-grained buffer sharing Yes |
||||
Fine-grained buffer sharing No |
||||
Fine-grained system sharing No |
||||
Atomics No |
||||
Minimum alignment for any data type 128 bytes |
||||
Alignment of base address 1024 bits (128 bytes) |
||||
Preferred alignment for atomics |
||||
SVM 64 bytes |
||||
Global 64 bytes |
||||
Local 64 bytes |
||||
Max size for global variable 65536 (64KiB) |
||||
Preferred total size of global vars 3294401331 (3.068GiB) |
||||
Global Memory cache type Read/Write |
||||
Global Memory cache size 524288 |
||||
Global Memory cache line 64 bytes |
||||
Image support Yes |
||||
Max number of samplers per kernel 16 |
||||
Max size for 1D images from buffer 205900083 pixels |
||||
Max 1D or 2D image array size 2048 images |
||||
Base address alignment for 2D image buffers 4 bytes |
||||
Pitch alignment for 2D image buffers 4 bytes |
||||
Max 2D image size 16384x16384 pixels |
||||
Max 3D image size 16384x16384x2048 pixels |
||||
Max number of read image args 128 |
||||
Max number of write image args 128 |
||||
Max number of read/write image args 128 |
||||
Max number of pipe args 16 |
||||
Max active pipe reservations 1 |
||||
Max pipe packet size 1024 |
||||
Local memory type Local |
||||
Local memory size 65536 (64KiB) |
||||
Max constant buffer size 3294401331 (3.068GiB) |
||||
Max number of constant args 8 |
||||
Max size of kernel argument 1024 |
||||
Queue properties (on host) |
||||
Out-of-order execution Yes |
||||
Profiling Yes |
||||
Queue properties (on device) |
||||
Out-of-order execution Yes |
||||
Profiling Yes |
||||
Preferred size 131072 (128KiB) |
||||
Max size 67108864 (64MiB) |
||||
Max queues on device 1 |
||||
Max events on device 1024 |
||||
Prefer user sync for interop Yes |
||||
Profiling timer resolution 83ns |
||||
Execution capabilities |
||||
Run OpenCL kernels Yes |
||||
Run native kernels No |
||||
SPIR versions 1.2 |
||||
printf() buffer size 4194304 (4MiB) |
||||
Built-in kernels block_motion_estimate_intel;block_advanced_motion_estimate_check_intel;block_advanced_motion_estimate_bidirectional_check_intel |
||||
Motion Estimation accelerator version (Intel) 2 |
||||
Device Available Yes |
||||
Compiler Available Yes |
||||
Linker Available Yes |
||||
Device Extensions cl_intel_accelerator cl_intel_advanced_motion_estimation cl_intel_device_side_avc_motion_estimation cl_intel_driver_diagnostics cl_intel_media_block_io cl_intel_motion_estimation cl_intel_planar_yuv cl_intel_packed_yuv cl_intel_required_subgroup_size cl_intel_subgroups cl_intel_subgroups_short cl_intel_va_api_media_sharing cl_khr_3d_image_writes cl_khr_byte_addressable_store cl_khr_depth_images cl_khr_fp16 cl_khr_fp64 cl_khr_global_int32_base_atomics cl_khr_global_int32_extended_atomics cl_khr_icd cl_khr_image2d_from_buffer cl_khr_local_int32_base_atomics cl_khr_local_int32_extended_atomics cl_khr_mipmap_image cl_khr_mipmap_image_writes cl_khr_spir cl_khr_subgroups |
||||
|
||||
NULL platform behavior |
||||
clGetPlatformInfo(NULL, CL_PLATFORM_NAME, ...) No platform |
||||
clGetDeviceIDs(NULL, CL_DEVICE_TYPE_ALL, ...) No platform |
||||
clCreateContext(NULL, ...) [default] No platform |
||||
clCreateContext(NULL, ...) [other] Success [INTEL] |
||||
clCreateContextFromType(NULL, CL_DEVICE_TYPE_CPU) No platform |
||||
clCreateContextFromType(NULL, CL_DEVICE_TYPE_GPU) No platform |
||||
clCreateContextFromType(NULL, CL_DEVICE_TYPE_ACCELERATOR) No platform |
||||
clCreateContextFromType(NULL, CL_DEVICE_TYPE_CUSTOM) No platform |
||||
clCreateContextFromType(NULL, CL_DEVICE_TYPE_ALL) No platform |
||||
********************************************************************************/ |
||||
"EU24_k1x1_cn192_g1_s1x1_d1x1_b1_in32x32_p0x0_num1_M32","1 8 32 5 1 8 1 1 0 ", |
||||
"EU24_k1x1_cn256_g1_s1x1_d1x1_b1_in32x32_p0x0_num1_M64","2 8 32 5 1 8 1 1 0 ", |
||||
"EU24_k5x1_cn32_g1_s1x1_d1x1_b0_in64x64_p2x0_num1_M32","4 6 16 2 1 1 16 1 0 ", |
||||
"EU24_k1x1_cn256_g1_s1x1_d1x1_b1_in32x32_p0x0_num1_M128","1 8 32 5 1 8 1 1 0 ", |
||||
"EU24_k5x5_cn48_g1_s1x1_d1x1_b1_in16x16_p2x2_num1_M128","4 2 16 2 1 1 16 1 0 ", |
||||
"EU24_k3x3_cn112_g1_s1x1_d1x1_b1_in16x16_p1x1_num1_M224","2 5 16 2 1 1 16 1 0 ", |
||||
"EU24_k1x1_cn256_g1_s1x1_d1x1_b1_in32x32_p0x0_num1_M32","1 8 32 5 1 8 1 1 0 ", |
||||
"EU24_k2x2_cn16_g1_s2x2_d1x1_b0_in256x256_p0x0_num1_M16","1 8 32 5 1 8 1 1 0 ", |
||||
"EU24_k5x5_cn32_g1_s1x1_d1x1_b1_in16x16_p2x2_num2_M128","4 3 16 2 1 1 16 1 0 ", |
||||
"EU24_k1x1_cn256_g1_s1x1_d1x1_b0_in64x64_p0x0_num1_M64","1 8 32 5 1 8 1 1 0 ", |
||||
"EU24_k1x1_cn528_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M256","1 8 32 5 1 8 1 1 0 ", |
||||
"EU24_k3x3_cn192_g1_s1x1_d1x1_b1_in16x16_p1x1_num2_M384","1 8 32 5 1 8 1 1 0 ", |
||||
"EU24_k3x3_cn256_g1_s1x1_d1x1_b1_in16x16_p1x1_num1_M384","2 7 16 2 1 1 16 1 0 ", |
||||
"EU24_k1x1_cn2048_g1_s1x1_d1x1_b0_in16x16_p0x0_num1_M512","1 8 32 5 1 8 1 1 0 ", |
||||
"EU24_k1x1_cn192_g1_s1x1_d1x1_b1_in32x32_p0x0_num2_M16","2 8 32 5 1 8 1 1 0 ", |
||||
"EU24_k3x3_cn384_g2_s1x1_d1x1_b1_in16x16_p1x1_num1_M128","2 7 16 2 1 1 16 1 0 ", |
||||
"EU24_k1x1_cn832_g1_s1x1_d1x1_b1_in16x16_p0x0_num1_M192","1 8 32 5 1 8 1 1 0 ", |
||||
"EU24_k3x3_cn112_g1_s1x1_d1x1_b1_in16x16_p1x1_num2_M224","2 7 16 2 1 1 16 1 0 ", |
||||
"EU24_k3x3_cn32_g1_s1x1_d8x8_b1_in64x64_p8x8_num1_M32","1 8 32 5 1 8 1 1 0 ", |
||||
"EU24_k3x3_cn96_g1_s1x1_d1x1_b1_in16x16_p1x1_num1_M208","2 7 16 2 1 1 16 1 0 ", |
||||
"EU24_k11x11_cn3_g1_s4x4_d1x1_b1_in224x224_p0x0_num1_M96","1 8 32 5 1 8 1 1 0 ", |
||||
"EU24_k1x1_cn512_g1_s1x1_d1x1_b1_in16x16_p0x0_num1_M160","1 8 32 5 1 8 1 1 0 ", |
||||
"EU24_k7x7_cn3_g1_s2x2_d1x1_b1_in224x224_p3x3_num2_M64","1 8 32 5 1 8 1 1 0 ", |
||||
"EU24_k3x3_cn32_g1_s1x1_d2x2_b1_in64x64_p2x2_num1_M32","3 3 16 2 1 1 16 1 0 ", |
||||
"EU24_k1x1_cn128_g1_s1x1_d1x1_b0_in64x64_p0x0_num1_M64","1 8 32 5 1 8 1 1 0 ", |
||||
"EU24_k1x1_cn512_g1_s1x1_d1x1_b1_in16x16_p0x0_num1_M24","8 3 8 2 1 1 8 1 0 ", |
||||
"EU24_k3x3_cn128_g1_s1x1_d1x1_b0_in32x32_p1x1_num1_M128","6 4 16 2 1 1 16 1 0 ", |
||||
"EU24_k1x1_cn512_g1_s1x1_d1x1_b1_in16x16_p0x0_num1_M144","1 8 32 5 1 8 1 1 0 ", |
||||
"EU24_k1x1_cn1024_g1_s1x1_d1x1_b0_in16x16_p0x0_num1_M256","2 8 32 5 1 8 1 1 0 ", |
||||
"EU24_k3x3_cn96_g1_s1x1_d1x1_b1_in16x16_p1x1_num2_M208","2 7 16 2 1 1 16 1 0 ", |
||||
"EU24_k5x5_cn32_g1_s1x1_d1x1_b1_in16x16_p2x2_num1_M128","4 3 16 2 1 1 16 1 0 ", |
||||
"EU24_k5x5_cn16_g1_s1x1_d1x1_b1_in16x16_p2x2_num2_M48","4 2 16 2 1 1 16 1 0 ", |
||||
"EU24_k1x1_cn512_g1_s1x1_d1x1_b0_in16x16_p0x0_num1_M2048","4 7 16 2 1 1 16 1 0 ", |
||||
"EU24_k1x1_cn128_g1_s1x1_d1x1_b0_in64x64_p0x0_num1_M16","2 8 32 5 1 8 1 1 0 ", |
||||
"EU24_k3x3_cn64_g1_s1x1_d1x1_b1_in64x64_p1x1_num1_M192","6 4 16 2 1 1 16 1 0 ", |
||||
"EU24_k1x1_cn256_g1_s1x1_d1x1_b1_in32x32_p0x0_num2_M128","1 8 32 5 1 8 1 1 0 ", |
||||
"EU24_k1x1_cn256_g1_s1x1_d1x1_b0_in16x16_p0x0_num1_M1024","1 8 32 5 1 8 1 1 0 ", |
||||
"EU24_k1x1_cn32_g1_s1x1_d1x1_b0_in64x64_p0x0_num1_M128","1 16 32 5 1 16 1 1 0 ", |
||||
"EU24_k1x1_cn4_g1_s1x1_d1x1_b0_in256x256_p0x0_num1_M16","1 8 32 5 1 8 1 1 0 ", |
||||
"EU24_k3x3_cn192_g1_s1x1_d1x1_b1_in16x16_p1x1_num1_M384","2 7 16 2 1 1 16 1 0 ", |
||||
"EU24_k1x1_cn832_g1_s1x1_d1x1_b1_in16x16_p0x0_num1_M256","1 8 32 5 1 8 1 1 0 ", |
||||
"EU24_k1x1_cn128_g1_s1x1_d1x1_b0_in64x64_p0x0_num1_M32","1 8 32 5 1 8 1 1 0 ", |
||||
"EU24_k1x1_cn832_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M32","4 6 8 2 1 1 8 1 0 ", |
||||
"EU24_k5x5_cn48_g1_s1x1_d1x1_b1_in16x16_p2x2_num2_M128","4 4 16 2 1 1 16 1 0 ", |
||||
"EU24_k1x1_cn512_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M128","2 8 32 5 1 8 1 1 0 ", |
||||
"EU24_k1x1_cn512_g1_s1x1_d1x1_b1_in16x16_p0x0_num1_M32","8 2 8 2 1 1 8 1 0 ", |
||||
"EU24_k1x1_cn512_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M64","1 8 32 5 1 8 1 1 0 ", |
||||
"EU24_k1x1_cn1024_g1_s2x2_d1x1_b0_in16x16_p0x0_num1_M2048","1 16 32 5 1 16 1 1 0 ", |
||||
"EU24_k5x5_cn32_g1_s1x1_d1x1_b1_in16x16_p2x2_num2_M64","4 3 16 2 1 1 16 1 0 ", |
||||
"EU24_k1x1_cn512_g1_s1x1_d1x1_b1_in16x16_p0x0_num1_M128","1 8 32 5 1 8 1 1 0 ", |
||||
"EU24_k1x1_cn832_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M384","1 8 32 5 1 8 1 1 0 ", |
||||
"EU24_k5x4_cn6_g3_s3x2_d1x1_b1_in128x80_p1x0_num2_M4","1 1 1 4 1 1 1 0 1 ", |
||||
"EU24_k3x3_cn128_g1_s1x1_d1x1_b1_in32x32_p1x1_num2_M192","6 4 16 2 1 1 16 1 0 ", |
||||
"EU24_k3x3_cn256_g1_s1x1_d1x1_b0_in16x16_p1x1_num1_M256","2 7 16 2 1 1 16 1 0 ", |
||||
"EU24_k3x3_cn160_g1_s1x1_d1x1_b1_in16x16_p1x1_num1_M320","2 8 16 2 1 1 16 1 0 ", |
||||
"EU24_k1x1_cn512_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M160","1 8 32 5 1 8 1 1 0 ", |
||||
"EU24_k1x1_cn256_g1_s1x1_d1x1_b1_in32x32_p0x0_num2_M64","1 8 32 5 1 8 1 1 0 ", |
||||
"EU24_k1x1_cn512_g1_s2x2_d1x1_b0_in32x32_p0x0_num1_M256","2 8 32 5 1 8 1 1 0 ", |
||||
"EU24_k1x1_cn64_g1_s1x1_d1x1_b1_in64x64_p0x0_num2_M64","1 8 32 5 1 8 1 1 0 ", |
||||
"EU24_k1x1_cn480_g1_s1x1_d1x1_b1_in16x16_p0x0_num1_M192","1 8 32 5 1 8 1 1 0 ", |
||||
"EU24_k1x1_cn528_g1_s1x1_d1x1_b1_in16x16_p0x0_num1_M256","2 8 32 5 1 8 1 1 0 ", |
||||
"EU24_k3x3_cn128_g1_s1x1_d1x1_b1_in16x16_p1x1_num1_M256","2 5 16 2 1 1 16 1 0 ", |
||||
"EU24_k5x5_cn24_g1_s1x1_d1x1_b1_in16x16_p2x2_num2_M64","4 3 16 2 1 1 16 1 0 ", |
||||
"EU24_k1x1_cn480_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M16","8 3 8 2 1 1 8 1 0 ", |
||||
"EU24_k1x1_cn528_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M128","2 8 32 5 1 8 1 1 0 ", |
||||
"EU24_k1x1_cn512_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M112","2 8 32 5 1 8 1 1 0 ", |
||||
"EU24_k1x1_cn64_g1_s1x1_d1x1_b0_in128x128_p0x0_num1_M16","2 8 32 5 1 8 1 1 0 ", |
||||
"EU24_k1x1_cn480_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M96","1 8 32 5 1 8 1 1 0 ", |
||||
"EU24_k1x1_cn64_g1_s1x1_d1x1_b0_in64x64_p0x0_num1_M256","2 8 32 5 1 8 1 1 0 ", |
||||
"EU24_k1x1_cn832_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M128","1 8 32 5 1 8 1 1 0 ", |
||||
"EU24_k5x5_cn16_g1_s1x1_d1x1_b1_in32x32_p2x2_num2_M32","4 2 16 2 1 1 16 1 0 ", |
||||
"EU24_k1x1_cn480_g1_s1x1_d1x1_b1_in16x16_p0x0_num1_M96","8 3 8 2 1 1 8 1 0 ", |
||||
"EU24_k3x3_cn16_g1_s1x1_d1x1_b1_in128x128_p1x1_num1_M16","6 3 16 2 1 1 16 1 0 ", |
||||
"EU24_k1x1_cn512_g1_s1x1_d1x1_b1_in16x16_p0x0_num1_M112","1 8 32 5 1 8 1 1 0 ", |
||||
"EU24_k5x5_cn32_g1_s1x1_d1x1_b1_in32x32_p2x2_num2_M96","4 3 16 2 1 1 16 1 0 ", |
||||
"EU24_k1x1_cn528_g1_s1x1_d1x1_b1_in16x16_p0x0_num1_M32","8 2 8 2 1 1 8 1 0 ", |
||||
"EU24_k1x1_cn832_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M192","1 8 32 5 1 8 1 1 0 ", |
||||
"EU24_k1x1_cn832_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M160","1 8 32 5 1 8 1 1 0 ", |
||||
"EU24_k3x3_cn144_g1_s1x1_d1x1_b1_in16x16_p1x1_num2_M288","2 8 16 2 1 1 16 1 0 ", |
||||
"EU24_k3x3_cn144_g1_s1x1_d1x1_b1_in16x16_p1x1_num1_M288","2 7 16 2 1 1 16 1 0 ", |
||||
"EU24_k7x7_cn3_g1_s2x2_d1x1_b1_in224x224_p3x3_num1_M64","1 8 32 5 1 8 1 1 0 ", |
||||
"EU24_k1x1_cn512_g1_s1x1_d1x1_b0_in32x32_p0x0_num1_M128","1 8 32 5 1 8 1 1 0 ", |
||||
"EU24_k3x3_cn4_g1_s1x1_d1x1_b1_in256x256_p1x1_num1_M4","10 2 8 2 1 1 8 1 0 ", |
||||
"EU24_k3x3_cn32_g1_s1x1_d16x16_b1_in64x64_p16x16_num1_M32","1 8 32 5 1 8 1 1 0 ", |
||||
"EU24_k1x1_cn480_g1_s1x1_d1x1_b1_in16x16_p0x0_num1_M16","8 2 8 2 1 1 8 1 0 ", |
||||
"EU24_k1x1_cn192_g1_s1x1_d1x1_b1_in32x32_p0x0_num1_M64","1 16 32 5 1 16 1 1 0 ", |
||||
"EU24_k1x5_cn32_g1_s1x1_d1x1_b1_in64x64_p0x2_num1_M32","4 7 16 2 1 1 16 1 0 ", |
||||
"EU24_k1x1_cn528_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M32","8 3 8 2 1 1 8 1 0 ", |
||||
"EU24_k3x3_cn384_g2_s1x1_d1x1_b1_in16x16_p1x1_num1_M192","2 7 16 2 1 1 16 1 0 ", |
||||
"EU24_k1x1_cn832_g1_s1x1_d1x1_b1_in16x16_p0x0_num1_M32","4 6 8 2 1 1 8 1 0 ", |
||||
"EU24_k1x1_cn832_g1_s1x1_d1x1_b1_in16x16_p0x0_num1_M128","4 6 8 2 1 1 8 1 0 ", |
||||
"EU24_k3x3_cn32_g1_s1x1_d4x4_b1_in64x64_p4x4_num1_M32","1 8 32 5 1 8 1 1 0 ", |
||||
"EU24_k2x2_cn64_g1_s2x2_d1x1_b0_in128x128_p0x0_num1_M32","2 8 32 5 1 8 1 1 0 ", |
||||
"EU24_k5x5_cn96_g2_s1x1_d1x1_b1_in32x32_p2x2_num1_M128","4 3 16 2 1 1 16 1 0 ", |
||||
"EU24_k5x5_cn16_g1_s1x1_d1x1_b1_in16x16_p2x2_num1_M48","8 1 16 2 1 1 16 1 0 ", |
||||
"EU24_k1x1_cn16_g1_s1x1_d1x1_b0_in256x256_p0x0_num1_M4","8 3 8 2 1 1 8 1 0 ", |
||||
"EU24_k1x1_cn832_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M256","1 8 32 5 1 8 1 1 0 ", |
||||
"EU24_k1x1_cn512_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M144","2 8 32 5 1 8 1 1 0 ", |
||||
"EU24_k3x3_cn96_g1_s1x1_d1x1_b1_in32x32_p1x1_num1_M128","6 4 16 2 1 1 16 1 0 ", |
||||
"EU24_k1x1_cn256_g1_s1x1_d1x1_b1_in32x32_p0x0_num2_M32","2 8 32 5 1 8 1 1 0 ", |
||||
"EU24_k3x3_cn128_g1_s1x1_d1x1_b1_in32x32_p1x1_num1_M192","2 7 16 2 1 1 16 1 0 ", |
||||
"EU24_k5x5_cn32_g1_s1x1_d1x1_b1_in16x16_p2x2_num1_M64","4 2 16 2 1 1 16 1 0 ", |
||||
"EU24_k1x1_cn528_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M160","1 8 32 5 1 8 1 1 0 ", |
||||
"EU24_k5x5_cn32_g1_s1x1_d1x1_b1_in32x32_p2x2_num1_M96","4 4 16 2 1 1 16 1 0 ", |
||||
"EU24_k1x1_cn832_g1_s1x1_d1x1_b1_in16x16_p0x0_num1_M160","4 6 8 2 1 1 8 1 0 ", |
||||
"EU24_k1x1_cn192_g1_s1x1_d1x1_b1_in32x32_p0x0_num2_M32","2 8 32 5 1 8 1 1 0 ", |
||||
"EU24_k3x3_cn32_g1_s1x1_d1x1_b1_in64x64_p1x1_num1_M32","2 8 16 2 1 1 16 1 0 ", |
||||
"EU24_k3x3_cn96_g1_s1x1_d1x1_b1_in32x32_p1x1_num2_M128","10 2 16 2 1 1 16 1 0 ", |
||||
"EU24_k3x3_cn160_g1_s1x1_d1x1_b1_in16x16_p1x1_num2_M320","2 7 16 2 1 1 16 1 0 ", |
||||
"EU24_k1x1_cn512_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M32","8 3 8 2 1 1 8 1 0 ", |
||||
"EU24_k3x3_cn64_g1_s1x1_d1x1_b0_in64x64_p1x1_num1_M64","2 8 16 2 1 1 16 1 0 ", |
||||
"EU24_k3x3_cn3_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M5","2 3 8 2 1 1 8 1 0 ", |
||||
"EU24_k1x1_cn16_g1_s1x1_d1x1_b0_in128x128_p0x0_num1_M64","1 16 32 5 1 16 1 1 0 ", |
||||
"EU24_k1x1_cn832_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M48","4 6 8 2 1 1 8 1 0 ", |
||||
"EU24_k5x5_cn24_g1_s1x1_d1x1_b1_in16x16_p2x2_num1_M64","4 2 16 2 1 1 16 1 0 ", |
||||
"EU24_k1x1_cn64_g1_s1x1_d1x1_b0_in128x128_p0x0_num1_M4","8 2 8 2 1 1 8 1 0 ", |
||||
"EU24_k1x1_cn512_g1_s1x1_d1x1_b1_in16x16_p0x0_num1_M64","8 2 8 2 1 1 8 1 0 ", |
||||
"EU24_k1x1_cn192_g1_s1x1_d1x1_b1_in32x32_p0x0_num2_M96","1 8 32 5 1 8 1 1 0 ", |
||||
"EU24_k1x1_cn64_g1_s1x1_d1x1_b0_in64x64_p0x0_num1_M64","1 8 32 5 1 8 1 1 0 ", |
||||
"EU24_k1x1_cn480_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M192","1 8 32 5 1 8 1 1 0 ", |
||||
"EU24_k1x1_cn832_g1_s1x1_d1x1_b1_in16x16_p0x0_num1_M48","4 6 8 2 1 1 8 1 0 ", |
||||
"EU24_k3x3_cn128_g1_s1x1_d1x1_b1_in16x16_p1x1_num2_M256","2 7 16 2 1 1 16 1 0 ", |
||||
"EU24_k1x1_cn3_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M4","4 4 16 2 1 1 16 1 0 ", |
||||
"EU24_k4x4_cn3_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M2","1 3 16 2 1 1 16 1 0 ", |
||||
"EU24_k1x1_cn192_g1_s1x1_d1x1_b1_in32x32_p0x0_num1_M96","1 8 32 5 1 8 1 1 0 ", |
||||
"EU24_k3x3_cn512_g1_s1x1_d1x1_b0_in16x16_p1x1_num1_M512","2 7 16 2 1 1 16 1 0 ", |
||||
"EU24_k1x1_cn256_g1_s2x2_d1x1_b0_in64x64_p0x0_num1_M128","1 8 32 5 1 8 1 1 0 ", |
||||
"EU24_k1x1_cn512_g1_s2x2_d1x1_b0_in32x32_p0x0_num1_M1024","1 8 32 5 1 8 1 1 0 ", |
||||
"EU24_k1x1_cn192_g1_s1x1_d1x1_b1_in32x32_p0x0_num2_M64","1 8 32 5 1 8 1 1 0 ", |
||||
"EU24_k11x7_cn3_g1_s3x4_d1x1_b1_in64x64_p3x2_num1_M64","4 1 16 2 1 1 16 1 0 ", |
||||
"EU24_k3x3_cn64_g1_s1x1_d1x1_b1_in64x64_p1x1_num2_M192","6 4 16 2 1 1 16 1 0 ", |
||||
"EU24_k1x1_cn480_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M64","1 8 32 5 1 8 1 1 0 ", |
||||
"EU24_k1x1_cn64_g1_s1x1_d1x1_b1_in64x64_p0x0_num1_M64","1 16 32 5 1 16 1 1 0 ", |
||||
"EU24_k1x1_cn192_g1_s1x1_d1x1_b1_in32x32_p0x0_num1_M16","8 3 8 2 1 1 8 1 0 ", |
||||
"EU24_k1x1_cn128_g1_s1x1_d1x1_b0_in32x32_p0x0_num1_M512","1 8 32 5 1 8 1 1 0 ", |
||||
"EU24_k1x1_cn1024_g1_s2x2_d1x1_b0_in16x16_p0x0_num1_M512","1 8 32 5 1 8 1 1 0 ", |
||||
"EU24_k1x1_cn528_g1_s1x1_d1x1_b1_in16x16_p0x0_num1_M128","1 8 32 5 1 8 1 1 0 ", |
||||
"EU24_k1x1_cn832_g1_s1x1_d1x1_b1_in16x16_p0x0_num1_M384","4 7 8 2 1 1 8 1 0 ", |
||||
"EU24_k1x1_cn528_g1_s1x1_d1x1_b1_in16x16_p0x0_num1_M160","1 8 32 5 1 8 1 1 0 ", |
||||
"EU24_k1x1_cn480_g1_s1x1_d1x1_b1_in16x16_p0x0_num1_M64","8 3 8 2 1 1 8 1 0 ", |
||||
"EU24_k3x3_cn3_g1_s2x2_d1x1_b1_in256x256_p1x1_num1_M13","1 1 1 4 1 1 1 0 1 ", |
||||
"EU24_k1x1_cn256_g1_s2x2_d1x1_b0_in64x64_p0x0_num1_M512","2 8 32 5 1 8 1 1 0 ", |
||||
"EU24_k1x1_cn512_g1_s1x1_d1x1_b1_in16x16_p0x0_num2_M24","8 3 8 2 1 1 8 1 0 ", |
||||
"EU24_k5x5_cn16_g1_s1x1_d1x1_b1_in32x32_p2x2_num1_M32","4 3 16 2 1 1 16 1 0 ", |
||||
}; |
||||
#endif |
@ -0,0 +1,90 @@ |
||||
/*M///////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
|
||||
//
|
||||
// By downloading, copying, installing or using the software you agree to this license.
|
||||
// If you do not agree to this license, do not download, install,
|
||||
// copy or use the software.
|
||||
//
|
||||
//
|
||||
// License Agreement
|
||||
// For Open Source Computer Vision Library
|
||||
//
|
||||
// Copyright (C) 2017, Intel Corporation, all rights reserved.
|
||||
// Copyright (c) 2016-2017 Fabian David Tschopp, all rights reserved.
|
||||
// Third party copyrights are property of their respective owners.
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without modification,
|
||||
// are permitted provided that the following conditions are met:
|
||||
//
|
||||
// * Redistribution's of source code must retain the above copyright notice,
|
||||
// this list of conditions and the following disclaimer.
|
||||
//
|
||||
// * Redistribution's in binary form must reproduce the above copyright notice,
|
||||
// this list of conditions and the following disclaimer in the documentation
|
||||
// and/or other materials provided with the distribution.
|
||||
//
|
||||
// * The name of the copyright holders may not be used to endorse or promote products
|
||||
// derived from this software without specific prior written permission.
|
||||
//
|
||||
// This software is provided by the copyright holders and contributors "as is" and
|
||||
// any express or implied warranties, including, but not limited to, the implied
|
||||
// warranties of merchantability and fitness for a particular purpose are disclaimed.
|
||||
// In no event shall the Intel Corporation or contributors be liable for any direct,
|
||||
// indirect, incidental, special, exemplary, or consequential damages
|
||||
// (including, but not limited to, procurement of substitute goods or services;
|
||||
// loss of use, data, or profits; or business interruption) however caused
|
||||
// and on any theory of liability, whether in contract, strict liability,
|
||||
// or tort (including negligence or otherwise) arising in any way out of
|
||||
// the use of this software, even if advised of the possibility of such damage.
|
||||
//
|
||||
//M*/
|
||||
|
||||
#ifndef _OPENCV_GREENTEA_MATH_FUNCTIONS_HPP_ |
||||
#define _OPENCV_GREENTEA_MATH_FUNCTIONS_HPP_ |
||||
#include "../../precomp.hpp" |
||||
#include "common.hpp" |
||||
|
||||
namespace cv |
||||
{ |
||||
namespace dnn |
||||
{ |
||||
namespace ocl4dnn |
||||
{ |
||||
|
||||
#ifdef HAVE_OPENCL |
||||
enum CBLAS_TRANSPOSE {CblasNoTrans=111, CblasTrans=112, CblasConjTrans=113}; |
||||
|
||||
template<typename Dtype> |
||||
bool ocl4dnnGEMMCommon(const CBLAS_TRANSPOSE TransB, |
||||
const int32_t M, const int32_t N, const int32_t K, |
||||
const UMat A, const UMat B, |
||||
const UMat B_image, UMat C, |
||||
const size_t max_image_size); |
||||
|
||||
template<typename Dtype> |
||||
ocl::Image2D ocl4dnnGEMMCopyBufferToImage(UMat buffer, int offset, |
||||
bool is_matrix_a, bool transpose, |
||||
bool padding, int padded_height, |
||||
int padded_width, int height, |
||||
int width, int ld); |
||||
|
||||
template<typename Dtype> |
||||
bool ocl4dnnGEMV(const CBLAS_TRANSPOSE TransA, |
||||
const int32_t M, const int32_t N, const Dtype alpha, |
||||
const UMat A, const int32_t offA, const UMat x, |
||||
const int32_t offx, const Dtype beta, UMat y, |
||||
const int32_t offy); |
||||
|
||||
template<typename Dtype> |
||||
bool ocl4dnnAXPY(const int32_t N, const Dtype alpha, |
||||
const UMat x, const int32_t offx, UMat y, |
||||
const int32_t offy); |
||||
|
||||
#endif // HAVE_OPENCL
|
||||
|
||||
} // namespace ocl4dnn
|
||||
} // namespace dnn
|
||||
} // namespce cv
|
||||
|
||||
#endif |
@ -0,0 +1,473 @@ |
||||
/*M///////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
|
||||
//
|
||||
// By downloading, copying, installing or using the software you agree to this license.
|
||||
// If you do not agree to this license, do not download, install,
|
||||
// copy or use the software.
|
||||
//
|
||||
//
|
||||
// License Agreement
|
||||
// For Open Source Computer Vision Library
|
||||
//
|
||||
// Copyright (C) 2017, Intel Corporation, all rights reserved.
|
||||
// Copyright (c) 2016-2017 Fabian David Tschopp, all rights reserved.
|
||||
// Third party copyrights are property of their respective owners.
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without modification,
|
||||
// are permitted provided that the following conditions are met:
|
||||
//
|
||||
// * Redistribution's of source code must retain the above copyright notice,
|
||||
// this list of conditions and the following disclaimer.
|
||||
//
|
||||
// * Redistribution's in binary form must reproduce the above copyright notice,
|
||||
// this list of conditions and the following disclaimer in the documentation
|
||||
// and/or other materials provided with the distribution.
|
||||
//
|
||||
// * The name of the copyright holders may not be used to endorse or promote products
|
||||
// derived from this software without specific prior written permission.
|
||||
//
|
||||
// This software is provided by the copyright holders and contributors "as is" and
|
||||
// any express or implied warranties, including, but not limited to, the implied
|
||||
// warranties of merchantability and fitness for a particular purpose are disclaimed.
|
||||
// In no event shall the Intel Corporation or contributors be liable for any direct,
|
||||
// indirect, incidental, special, exemplary, or consequential damages
|
||||
// (including, but not limited to, procurement of substitute goods or services;
|
||||
// loss of use, data, or profits; or business interruption) however caused
|
||||
// and on any theory of liability, whether in contract, strict liability,
|
||||
// or tort (including negligence or otherwise) arising in any way out of
|
||||
// the use of this software, even if advised of the possibility of such damage.
|
||||
//
|
||||
//M*/
|
||||
|
||||
#ifndef _OPENCV_LIBDNN_HPP_ |
||||
#define _OPENCV_LIBDNN_HPP_ |
||||
#include "../../precomp.hpp" |
||||
#include <iomanip> |
||||
#include <map> |
||||
#include <memory> |
||||
#include <string> |
||||
#include <vector> |
||||
#include "common.hpp" |
||||
|
||||
namespace cv { namespace dnn { namespace ocl4dnn { |
||||
#ifdef HAVE_OPENCL |
||||
|
||||
struct OCL4DNNConvConfig |
||||
{ |
||||
OCL4DNNConvConfig() : |
||||
kernel(1, 1), |
||||
pad(0, 0), |
||||
stride(1, 1), |
||||
dilation(1, 1), |
||||
group(1), |
||||
bias_term(false) |
||||
{} |
||||
MatShape in_shape; |
||||
MatShape out_shape; |
||||
Size kernel; |
||||
Size pad; |
||||
Size stride; |
||||
Size dilation; |
||||
int group; // = 1;
|
||||
bool bias_term; // = false;
|
||||
}; |
||||
|
||||
|
||||
template<typename Dtype> |
||||
class OCL4DNNConvSpatial |
||||
{ |
||||
public: |
||||
explicit OCL4DNNConvSpatial(OCL4DNNConvConfig config); |
||||
~OCL4DNNConvSpatial(); |
||||
bool Forward(const UMat& bottom_data, const UMat& weight, |
||||
const UMat& bias, |
||||
UMat& top_data, int32_t batch_size); |
||||
|
||||
private: |
||||
struct kernelConfig |
||||
{ |
||||
std::string kernelName; |
||||
float executionTime; |
||||
size_t local_work_size[3]; |
||||
size_t global_work_size[3]; |
||||
int32_t workItem_output[3]; |
||||
bool verified; |
||||
bool tested; |
||||
bool swizzle_weights; |
||||
bool use_null_local; |
||||
int32_t kernelType; |
||||
|
||||
kernelConfig() |
||||
{} |
||||
|
||||
kernelConfig(const std::string& name, const size_t* global_size, const size_t* local_size, |
||||
const int32_t* workItem, |
||||
bool swizzle, |
||||
int32_t type = 0) |
||||
: executionTime(0) |
||||
{ |
||||
kernelName = name; |
||||
for (int32_t x = 0; x < 3; x++) |
||||
{ |
||||
local_work_size[x] = local_size ? local_size[x] : 1; |
||||
global_work_size[x] = global_size[x]; |
||||
workItem_output[x] = workItem[x]; |
||||
} |
||||
swizzle_weights = swizzle; |
||||
use_null_local = local_size == NULL; |
||||
verified = false; |
||||
tested = false; |
||||
kernelType = type; |
||||
} |
||||
}; |
||||
|
||||
struct tunerParam |
||||
{ |
||||
int kernelType; |
||||
int blockWidth; |
||||
int blockHeight; |
||||
int blockDepth; |
||||
|
||||
tunerParam(int type, int w, int h, int d) |
||||
{ |
||||
kernelType = type; |
||||
blockWidth = w; |
||||
blockHeight= h; |
||||
blockDepth = d; |
||||
} |
||||
}; |
||||
|
||||
inline void addDef(const char* name) |
||||
{ |
||||
options_ << " -D " << name; |
||||
} |
||||
|
||||
inline void addDef(const char* name, const int value) |
||||
{ |
||||
options_ << " -D " << name << "=" << value; |
||||
} |
||||
|
||||
inline void addDef(const char* name, const float value) |
||||
{ |
||||
options_ << " -D " << name << "=(float)" << value; |
||||
} |
||||
|
||||
inline void addDef(const char* name, const double value) |
||||
{ |
||||
options_ << " -D " << name << "=(double)" << value; |
||||
} |
||||
|
||||
inline void addDef(const char* name, const char* value) |
||||
{ |
||||
options_ << " -D " << name << "=" << value; |
||||
} |
||||
|
||||
void useFirstAvailable(const UMat &bottom, |
||||
UMat &top, |
||||
const UMat &weight, |
||||
const UMat &bias, |
||||
int32_t numImages, |
||||
UMat &verifyTop); |
||||
void setupKernel(); |
||||
void collectCommonInformation(); |
||||
void setupKernelDetails(int32_t kernelType, |
||||
int32_t blockM, |
||||
int32_t blockK, |
||||
int32_t blockN); |
||||
|
||||
ocl::Program compileKernel(); |
||||
typedef std::map<std::string, ocl::Program> phash_t; |
||||
phash_t phash; |
||||
void calculateBenchmark(const UMat &bottom, UMat &verifyTop, |
||||
const UMat &weight, const UMat &bias, |
||||
int32_t numImages); |
||||
|
||||
|
||||
void setupConvolution(const UMat &bottom, |
||||
UMat &top, |
||||
const UMat &weight, |
||||
const UMat &bias, |
||||
int32_t numImags, |
||||
UMat &verifyTop); |
||||
bool createConvolutionKernel(int32_t kernelType, |
||||
int32_t blockWidth, |
||||
int32_t blockHeight, |
||||
int32_t blockDepth); |
||||
bool setupIDLF(int32_t blockWidth, |
||||
int32_t blockHeight, |
||||
int32_t blockDepth); |
||||
bool createBasicKernel(int32_t blockWidth, |
||||
int32_t blockHeight, |
||||
int32_t blockDepth); |
||||
bool createGEMMLikeConvKernel(int32_t blockWidth, |
||||
int32_t blockHeight, |
||||
int32_t blockDepth); |
||||
void CreateSubBuffer(const UMat& buffer, UMat& sub_buffer, |
||||
int32_t offset, int32_t size, bool write_only); |
||||
bool convolve(const UMat &bottom, UMat &top, |
||||
const UMat &weight, const UMat &bias, |
||||
int32_t numImages, |
||||
kernelConfig* config, |
||||
const cv::ocl::Queue& queue); |
||||
float timedConvolve(const UMat &bottom, UMat &top, |
||||
const UMat &weight, const UMat &bias, |
||||
int32_t numImages, kernelConfig* config); |
||||
|
||||
bool verifyResult(const UMat &bottom, |
||||
UMat &top, |
||||
const UMat &weight, |
||||
const UMat &bias, |
||||
int32_t numImages, |
||||
kernelConfig* config, |
||||
UMat &verifyTop); |
||||
|
||||
bool swizzleWeight(const UMat &weight, |
||||
int32_t swizzled_factor, |
||||
bool interleave = false); |
||||
|
||||
void generateKey(); |
||||
std::string generateSpecificKey(int32_t type, int32_t blockWidth, |
||||
int32_t blockHeight, |
||||
int32_t blockDepth); |
||||
void cacheTunedConfig(); |
||||
bool loadTunedConfig(); |
||||
|
||||
void saveTunedConfig(); |
||||
bool loadCachedConfig(); |
||||
|
||||
void unloadProgram(const std::string& kernelName); |
||||
void prepareKernel(const UMat &bottom, UMat &top, |
||||
const UMat &weight, const UMat &bias, |
||||
int32_t numImages); |
||||
bool setupKernelByConfig(int x, int y, int z, int type, |
||||
int lx, int ly, int lz, |
||||
bool swizzle, bool nullLocal); |
||||
void generateTunerItems(std::vector< cv::Ptr<tunerParam> > &tunerItems); |
||||
|
||||
int32_t group_; |
||||
bool bias_term_; |
||||
UMat swizzled_weights_umat; |
||||
|
||||
int32_t bottom_index_; |
||||
int32_t output_h_; |
||||
int32_t output_w_; |
||||
int32_t kernel_h_; |
||||
int32_t kernel_w_; |
||||
int32_t height_; |
||||
int32_t width_; |
||||
int32_t pad_h_; |
||||
int32_t pad_w_; |
||||
int32_t stride_h_; |
||||
int32_t stride_w_; |
||||
int32_t dilation_h_; |
||||
int32_t dilation_w_; |
||||
|
||||
/// M_ is the channel dimension of the output for a single group, which is the
|
||||
/// leading dimension of the filter matrix.
|
||||
int32_t M_; |
||||
|
||||
bool tuned_; |
||||
std::string key_, key_sanitized_; |
||||
std::string short_key_; |
||||
std::string kernel_name_; |
||||
std::string cache_path_; |
||||
bool use_cache_path_; // true if cache_path_ directory exists
|
||||
bool force_auto_tuning_; |
||||
int32_t kernel_index_; |
||||
std::vector< cv::Ptr<kernelConfig> > kernelQueue; |
||||
cv::Ptr<kernelConfig> bestKernelConfig; |
||||
|
||||
int32_t bottom_dim_; |
||||
int32_t top_dim_; |
||||
int32_t num_; |
||||
int32_t channels_; |
||||
int32_t num_output_; |
||||
|
||||
int32_t kernelType_; |
||||
int32_t blockM_; |
||||
int32_t blockK_; |
||||
int32_t blockN_; |
||||
std::stringstream options_; |
||||
cv::ocl::ProgramSource src_; |
||||
int32_t prev_kernel_type_; |
||||
}; |
||||
|
||||
typedef enum { |
||||
LIBDNN_POOLING_METHOD_MAX = 0, |
||||
LIBDNN_POOLING_METHOD_AVE = 1, |
||||
LIBDNN_POOLING_METHOD_STO = 2 |
||||
} ocl4dnnPoolingMethod_t; |
||||
|
||||
struct OCL4DNNPoolConfig |
||||
{ |
||||
OCL4DNNPoolConfig() : |
||||
kernel(1, 1), |
||||
pad(0, 0), |
||||
stride(1, 1), |
||||
dilation(1, 1), |
||||
channels(0), |
||||
pool_method(LIBDNN_POOLING_METHOD_MAX), |
||||
global_pooling(false) |
||||
{} |
||||
MatShape in_shape; |
||||
MatShape out_shape; |
||||
Size kernel; |
||||
Size pad; |
||||
Size stride; |
||||
Size dilation; |
||||
|
||||
int channels; |
||||
ocl4dnnPoolingMethod_t pool_method; // = LIBDNN_POOLING_METHOD_MAX;
|
||||
bool global_pooling; // = false;
|
||||
}; |
||||
|
||||
template<typename Dtype> |
||||
class OCL4DNNPool |
||||
{ |
||||
public: |
||||
explicit OCL4DNNPool(OCL4DNNPoolConfig config); |
||||
~OCL4DNNPool(); |
||||
bool Forward(const UMat& bottom_data, |
||||
UMat& top_data, |
||||
UMat& top_mask); |
||||
private: |
||||
UMat mask_idx_; |
||||
|
||||
// Pooling parameters
|
||||
std::vector<int32_t> pad_; |
||||
std::vector<int32_t> stride_; |
||||
std::vector<int32_t> kernel_shape_; |
||||
std::vector<int32_t> im_in_shape_; |
||||
std::vector<int32_t> im_out_shape_; |
||||
|
||||
ocl4dnnPoolingMethod_t pool_method_; |
||||
int32_t count_; |
||||
int32_t batch_size_; |
||||
int32_t channels_; |
||||
int32_t kernel_h_; |
||||
int32_t kernel_w_; |
||||
int32_t stride_h_; |
||||
int32_t stride_w_; |
||||
int32_t pad_h_; |
||||
int32_t pad_w_; |
||||
int32_t height_; |
||||
int32_t width_; |
||||
int32_t pooled_height_; |
||||
int32_t pooled_width_; |
||||
}; |
||||
|
||||
struct OCL4DNNInnerProductConfig |
||||
{ |
||||
OCL4DNNInnerProductConfig() : |
||||
num_output(0), M(0), K(0), |
||||
bias_term(false), transpose(false), phase_test(true) |
||||
{} |
||||
int num_output; |
||||
int M; |
||||
int K; |
||||
bool bias_term; |
||||
bool transpose; // = false;
|
||||
bool phase_test; // = true;
|
||||
}; |
||||
|
||||
template<typename Dtype> |
||||
class OCL4DNNInnerProduct |
||||
{ |
||||
public: |
||||
explicit OCL4DNNInnerProduct(OCL4DNNInnerProductConfig config); |
||||
~OCL4DNNInnerProduct(); |
||||
bool Forward(const UMat& bottom_data, |
||||
const UMat& weight, |
||||
const UMat& bias, |
||||
UMat& top_data); |
||||
private: |
||||
OCL4DNNInnerProductConfig config_; |
||||
int32_t axis_; |
||||
int32_t num_output_; |
||||
int32_t M_; |
||||
int32_t N_; |
||||
int32_t K_; |
||||
bool bias_term_; |
||||
bool transpose_; |
||||
bool image_copied_; |
||||
bool phase_test_; |
||||
}; |
||||
|
||||
typedef enum { |
||||
LRNParameter_NormRegion_ACROSS_CHANNELS = 0, |
||||
LRNParameter_NormRegion_WITHIN_CHANNEL = 1 |
||||
} LRNParameter_NormRegion_WITHIN_CHANNEL_t; |
||||
|
||||
struct OCL4DNNLRNConfig |
||||
{ |
||||
OCL4DNNLRNConfig() : |
||||
phase_test(true) |
||||
{} |
||||
MatShape in_shape; |
||||
LRNParameter_NormRegion_WITHIN_CHANNEL_t lrn_type; |
||||
bool phase_test; // = true;
|
||||
int local_size; |
||||
float alpha; |
||||
float beta; |
||||
float k; |
||||
bool norm_by_size; |
||||
int32_t batch_size; |
||||
int32_t channels; |
||||
int32_t height; |
||||
int32_t width; |
||||
}; |
||||
|
||||
template<typename Dtype> |
||||
class OCL4DNNLRN |
||||
{ |
||||
public: |
||||
explicit OCL4DNNLRN(OCL4DNNLRNConfig config); |
||||
bool Forward(const UMat& bottom_data, UMat& top_data); |
||||
|
||||
private: |
||||
bool crossChannelForward(const UMat& bottom_data, UMat& top_data); |
||||
LRNParameter_NormRegion_WITHIN_CHANNEL_t lrn_type_; |
||||
bool phase_test_; |
||||
int32_t size_; |
||||
Dtype alpha_; |
||||
Dtype beta_; |
||||
Dtype k_; |
||||
int32_t num_; |
||||
int32_t channels_; |
||||
int32_t height_; |
||||
int32_t width_; |
||||
bool norm_by_size_; |
||||
}; |
||||
|
||||
struct OCL4DNNSoftmaxConfig |
||||
{ |
||||
OCL4DNNSoftmaxConfig() |
||||
{} |
||||
MatShape in_shape; |
||||
int axis; |
||||
int channels; |
||||
}; |
||||
|
||||
template<typename Dtype> |
||||
class OCL4DNNSoftmax |
||||
{ |
||||
public: |
||||
explicit OCL4DNNSoftmax(OCL4DNNSoftmaxConfig config); |
||||
~OCL4DNNSoftmax(); |
||||
bool Forward(const UMat& bottom_data, UMat& top_data); |
||||
|
||||
private: |
||||
int32_t softmax_axis_; |
||||
int32_t inner_num_; |
||||
int32_t outer_num_; |
||||
int32_t channels_; |
||||
int32_t count_; |
||||
bool use_slm_; |
||||
UMat scale_data_; |
||||
}; |
||||
#endif // HAVE_OPENCL
|
||||
} // namespace ocl4dnn
|
||||
} // namespace dnn
|
||||
} // namespce cv
|
||||
#endif |
@ -0,0 +1,57 @@ |
||||
/*M///////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
|
||||
//
|
||||
// By downloading, copying, installing or using the software you agree to this license.
|
||||
// If you do not agree to this license, do not download, install,
|
||||
// copy or use the software.
|
||||
//
|
||||
//
|
||||
// License Agreement
|
||||
// For Open Source Computer Vision Library
|
||||
//
|
||||
// Copyright (C) 2017, Intel Corporation, all rights reserved.
|
||||
// Copyright (c) 2016-2017 Fabian David Tschopp, all rights reserved.
|
||||
// Third party copyrights are property of their respective owners.
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without modification,
|
||||
// are permitted provided that the following conditions are met:
|
||||
//
|
||||
// * Redistribution's of source code must retain the above copyright notice,
|
||||
// this list of conditions and the following disclaimer.
|
||||
//
|
||||
// * Redistribution's in binary form must reproduce the above copyright notice,
|
||||
// this list of conditions and the following disclaimer in the documentation
|
||||
// and/or other materials provided with the distribution.
|
||||
//
|
||||
// * The name of the copyright holders may not be used to endorse or promote products
|
||||
// derived from this software without specific prior written permission.
|
||||
//
|
||||
// This software is provided by the copyright holders and contributors "as is" and
|
||||
// any express or implied warranties, including, but not limited to, the implied
|
||||
// warranties of merchantability and fitness for a particular purpose are disclaimed.
|
||||
// In no event shall the Intel Corporation or contributors be liable for any direct,
|
||||
// indirect, incidental, special, exemplary, or consequential damages
|
||||
// (including, but not limited to, procurement of substitute goods or services;
|
||||
// loss of use, data, or profits; or business interruption) however caused
|
||||
// and on any theory of liability, whether in contract, strict liability,
|
||||
// or tort (including negligence or otherwise) arising in any way out of
|
||||
// the use of this software, even if advised of the possibility of such damage.
|
||||
//
|
||||
//M*/
|
||||
|
||||
#include "../../precomp.hpp" |
||||
#include "common.hpp" |
||||
#include "opencl_kernels_dnn.hpp" |
||||
|
||||
using namespace cv; |
||||
|
||||
#ifdef HAVE_OPENCL |
||||
bool clOptionSupport(cv::String option) |
||||
{ |
||||
cv::String errmsg; |
||||
ocl::Program program = ocl::Context::getDefault().getProg(ocl::dnn::dummy_oclsrc, option, errmsg); |
||||
return program.ptr() ? true : false; |
||||
} |
||||
|
||||
#endif // HAVE_OPENCL
|
@ -0,0 +1,538 @@ |
||||
/*M///////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
|
||||
//
|
||||
// By downloading, copying, installing or using the software you agree to this license.
|
||||
// If you do not agree to this license, do not download, install,
|
||||
// copy or use the software.
|
||||
//
|
||||
//
|
||||
// License Agreement
|
||||
// For Open Source Computer Vision Library
|
||||
//
|
||||
// Copyright (C) 2017, Intel Corporation, all rights reserved.
|
||||
// Copyright (c) 2016-2017 Fabian David Tschopp, all rights reserved.
|
||||
// Third party copyrights are property of their respective owners.
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without modification,
|
||||
// are permitted provided that the following conditions are met:
|
||||
//
|
||||
// * Redistribution's of source code must retain the above copyright notice,
|
||||
// this list of conditions and the following disclaimer.
|
||||
//
|
||||
// * Redistribution's in binary form must reproduce the above copyright notice,
|
||||
// this list of conditions and the following disclaimer in the documentation
|
||||
// and/or other materials provided with the distribution.
|
||||
//
|
||||
// * The name of the copyright holders may not be used to endorse or promote products
|
||||
// derived from this software without specific prior written permission.
|
||||
//
|
||||
// This software is provided by the copyright holders and contributors "as is" and
|
||||
// any express or implied warranties, including, but not limited to, the implied
|
||||
// warranties of merchantability and fitness for a particular purpose are disclaimed.
|
||||
// In no event shall the Intel Corporation or contributors be liable for any direct,
|
||||
// indirect, incidental, special, exemplary, or consequential damages
|
||||
// (including, but not limited to, procurement of substitute goods or services;
|
||||
// loss of use, data, or profits; or business interruption) however caused
|
||||
// and on any theory of liability, whether in contract, strict liability,
|
||||
// or tort (including negligence or otherwise) arising in any way out of
|
||||
// the use of this software, even if advised of the possibility of such damage.
|
||||
//
|
||||
//M*/
|
||||
|
||||
#include "../../precomp.hpp" |
||||
#include "common.hpp" |
||||
#include "math_functions.hpp" |
||||
#include <vector> |
||||
#include "opencl_kernels_dnn.hpp" |
||||
|
||||
namespace cv |
||||
{ |
||||
namespace dnn |
||||
{ |
||||
namespace ocl4dnn |
||||
{ |
||||
|
||||
#ifdef HAVE_OPENCL |
||||
// Create and copy buffer to image for GEMM's matrix A and B.
|
||||
// Will return image to caller if the input image is NULL. Otherwise,
|
||||
// will use the image directly. It's caller's responsibility to
|
||||
// release the created image.
|
||||
template<typename Dtype> |
||||
ocl::Image2D ocl4dnnGEMMCopyBufferToImage(UMat buffer, int offset, |
||||
bool is_matrix_a, bool transpose, |
||||
bool padding, int padded_height, |
||||
int padded_width, int height, |
||||
int width, int ld) |
||||
{ |
||||
ocl::Context ctx = ocl::Context::getDefault(); |
||||
ocl::Queue queue = ocl::Queue::getDefault(); |
||||
ocl::Image2D image; |
||||
|
||||
if (!is_matrix_a && transpose) |
||||
{ |
||||
if (ld == width) |
||||
{ |
||||
image = ocl::Image2D(buffer); |
||||
} else { |
||||
// For matrix B with transpose, we need to handle them differently.
|
||||
// As we can't use the sub group block read to get a row easily,
|
||||
// we have to use CL_FLOAT type with read_imagef to get the row.
|
||||
UMat mat(height, width, CV_32FC1); |
||||
image = ocl::Image2D(mat); |
||||
|
||||
ocl::Kernel oclk_gemm_copy("gemm_buffer_copy_image_transpose_float", ocl::dnn::gemm_image_oclsrc); |
||||
|
||||
size_t global_copy[2]; |
||||
global_copy[0] = width; |
||||
global_copy[1] = height; |
||||
oclk_gemm_copy.set(0, ocl::KernelArg::PtrReadOnly(buffer)); |
||||
oclk_gemm_copy.set(1, image); |
||||
oclk_gemm_copy.set(2, offset); |
||||
oclk_gemm_copy.set(3, width); |
||||
oclk_gemm_copy.set(4, height); |
||||
oclk_gemm_copy.set(5, ld); |
||||
oclk_gemm_copy.run(2, global_copy, NULL, false); |
||||
} |
||||
} else { |
||||
if (!padding) |
||||
{ |
||||
// copy without padding.
|
||||
image = ocl::Image2D(buffer); |
||||
} else { |
||||
UMat mat(padded_height, padded_width, CV_8UC4); |
||||
image = ocl::Image2D(mat); |
||||
|
||||
ocl::Kernel oclk_gemm_copy("gemm_buffer_copy_image_no_transpose_float", |
||||
ocl::dnn::gemm_image_oclsrc); |
||||
|
||||
size_t global_copy[2]; |
||||
global_copy[0] = padded_width; |
||||
global_copy[1] = padded_height; |
||||
|
||||
oclk_gemm_copy.set(0, ocl::KernelArg::PtrReadOnly(buffer)); |
||||
oclk_gemm_copy.set(1, image); |
||||
oclk_gemm_copy.set(2, offset); |
||||
oclk_gemm_copy.set(3, width); |
||||
oclk_gemm_copy.set(4, height); |
||||
oclk_gemm_copy.set(5, ld); |
||||
|
||||
oclk_gemm_copy.run(2, global_copy, NULL, false); |
||||
} |
||||
} |
||||
|
||||
return image; |
||||
} |
||||
|
||||
template |
||||
ocl::Image2D ocl4dnnGEMMCopyBufferToImage<float>(UMat buffer, int offset, |
||||
bool is_matrix_a, bool transpose, |
||||
bool padding, int padded_height, |
||||
int padded_width, int height, |
||||
int width, int ld); |
||||
|
||||
enum gemm_type_t |
||||
{ |
||||
GEMM_TYPE_NONE = 0, |
||||
GEMM_TYPE_FAST_IMAGE_32_1, |
||||
GEMM_TYPE_FAST_IMAGE_32_2, |
||||
GEMM_TYPE_FAST_IMAGE_B_IMAGE, |
||||
GEMM_TYPE_MAX |
||||
}; |
||||
|
||||
template<typename Dtype> |
||||
static bool ocl4dnnFastImageGEMM(const CBLAS_TRANSPOSE TransA, |
||||
const CBLAS_TRANSPOSE TransB, const int32_t M, |
||||
const int32_t N, const int32_t K, const Dtype alpha, |
||||
const UMat A, const int32_t offA, const UMat B, |
||||
const int32_t offB, const Dtype beta, UMat C, |
||||
const int32_t offC, bool is_image_a, bool is_image_b, |
||||
enum gemm_type_t gemm_type, |
||||
const size_t max_image_size) |
||||
{ |
||||
CHECK_EQ(gemm_type == GEMM_TYPE_FAST_IMAGE_32_1 || gemm_type == GEMM_TYPE_FAST_IMAGE_32_2 || |
||||
gemm_type == GEMM_TYPE_FAST_IMAGE_B_IMAGE, true) << "Invalid fast image gemm type." << std::endl; |
||||
|
||||
if (is_image_a) |
||||
{ |
||||
CHECK_EQ(offA, 0) << "Invalid input image offset." << std::endl; |
||||
return false; |
||||
} |
||||
|
||||
if (is_image_b) |
||||
{ |
||||
CHECK_EQ(offB, 0) << "Invalid input image offset." << std::endl; |
||||
return false; |
||||
} |
||||
|
||||
int widthA = (TransA == CblasNoTrans) ? K : M; |
||||
int heightA = (TransA == CblasNoTrans) ? M : K; |
||||
int widthB = (TransB == CblasNoTrans) ? N : K; |
||||
int heightB = (TransB == CblasNoTrans) ? K : N; |
||||
|
||||
int ldA = widthA; |
||||
int ldB = widthB; |
||||
int ldC = N; |
||||
|
||||
int A_start_x = 0, A_start_y = 0, B_start_x = 0; |
||||
int B_start_y = 0, C_start_x = 0, C_start_y = 0; |
||||
int blocksize = 1024; |
||||
if (gemm_type == GEMM_TYPE_FAST_IMAGE_B_IMAGE) |
||||
blocksize = max_image_size; |
||||
int blockA_width = blocksize; |
||||
int blockA_height = blocksize; |
||||
int blockB_width = blocksize; |
||||
int blockB_height = blocksize; |
||||
int blockC_width = blocksize; |
||||
int blockC_height = blocksize; |
||||
|
||||
int use_buffer_indicator = 8; |
||||
// To fix the edge problem casued by the sub group block read.
|
||||
// we have to pad the image if it's not multiple of tile.
|
||||
// just padding one line is enough as the sub group block read
|
||||
// will clamp to edge according to the spec.
|
||||
|
||||
ocl::Context ctx = ocl::Context::getDefault(); |
||||
ocl::Queue queue = ocl::Queue::getDefault(); |
||||
|
||||
ocl::Image2D ImA; |
||||
ocl::Image2D ImB; |
||||
|
||||
std::string kernel_name("gemm_"); |
||||
if (gemm_type == GEMM_TYPE_FAST_IMAGE_32_1 || gemm_type == GEMM_TYPE_FAST_IMAGE_B_IMAGE) |
||||
kernel_name += "32_1_"; |
||||
else |
||||
kernel_name += "32_2_"; |
||||
|
||||
if (TransA == CblasNoTrans) |
||||
kernel_name += "N"; |
||||
else |
||||
kernel_name += "T"; |
||||
|
||||
if (TransB == CblasNoTrans) |
||||
{ |
||||
kernel_name += "N_"; |
||||
} else { |
||||
kernel_name += "T_"; |
||||
if (is_image_b || (K % use_buffer_indicator != 0)) |
||||
{ |
||||
kernel_name += "SCALAR_"; |
||||
} else { |
||||
kernel_name += "BUFFER_"; |
||||
} |
||||
} |
||||
|
||||
if (alpha == 1) |
||||
kernel_name += "1_"; |
||||
else |
||||
kernel_name += "0_"; |
||||
|
||||
if (beta == 0) |
||||
kernel_name += "0"; |
||||
else |
||||
kernel_name += "1"; |
||||
|
||||
kernel_name += "_float"; |
||||
|
||||
ocl::Kernel oclk_gemm_float(kernel_name.c_str(), ocl::dnn::gemm_image_oclsrc); |
||||
if (oclk_gemm_float.empty()) |
||||
return false; |
||||
|
||||
while (C_start_y < M) |
||||
{ |
||||
blockC_width = std::min(static_cast<int>(N) - C_start_x, blocksize); |
||||
blockC_height = std::min(static_cast<int>(M) - C_start_y, blocksize); |
||||
|
||||
int isFirstColBlock = 1; |
||||
for (int k = 0; k < K; k += blocksize) |
||||
{ |
||||
blockA_width = std::min(widthA - A_start_x, blocksize); |
||||
blockA_height = std::min(heightA - A_start_y, blocksize); |
||||
blockB_width = std::min(widthB - B_start_x, blocksize); |
||||
blockB_height = std::min(heightB - B_start_y, blocksize); |
||||
int block_Ksize = std::min(static_cast<int>(K) - k, blocksize); |
||||
|
||||
int padded_k = block_Ksize + ((block_Ksize & 7) ? (8 - (block_Ksize & 7)) : 0); |
||||
int imageA_w = (TransA == CblasNoTrans) ? padded_k : blockA_width; |
||||
int imageA_h = (TransA == CblasNoTrans) ? blockA_height : padded_k; |
||||
int imageB_w = (TransB == CblasNoTrans) ? blockB_width : padded_k; |
||||
int imageB_h = (TransB == CblasNoTrans) ? padded_k : blockB_height; |
||||
|
||||
int blockA_offset = offA + A_start_y * ldA + A_start_x; |
||||
int blockB_offset = offB + B_start_y * ldB + B_start_x; |
||||
int blockC_offset = offC + C_start_y * ldC + C_start_x; |
||||
if (TransB == CblasNoTrans) |
||||
{ |
||||
bool padding_A = false; |
||||
bool padding_B = false; |
||||
|
||||
if (!is_image_a && !is_image_b) |
||||
{ |
||||
if (M * K < N * K) |
||||
padding_B = true; |
||||
else |
||||
padding_A = true; |
||||
} |
||||
|
||||
if (!is_image_a) |
||||
{ |
||||
ImA = ocl4dnnGEMMCopyBufferToImage<Dtype>(A, blockA_offset, |
||||
true, TransA != CblasNoTrans, |
||||
padding_A, imageA_h, imageA_w, |
||||
blockA_height, blockA_width, ldA); |
||||
} |
||||
if (!is_image_b) |
||||
{ |
||||
ImB = ocl4dnnGEMMCopyBufferToImage<Dtype>(B, blockB_offset, |
||||
false, false, |
||||
padding_B, imageB_h, imageB_w, |
||||
blockB_height, blockB_width, ldB); |
||||
} |
||||
} else { |
||||
// We will use normal read_imagef to read image B when B has transpose.
|
||||
// thus we don't need to pad image A at all.
|
||||
if (!is_image_a) |
||||
{ |
||||
bool padding; |
||||
padding = !is_image_b; |
||||
ImA = ocl4dnnGEMMCopyBufferToImage<Dtype>(A, blockA_offset, |
||||
true, TransA != CblasNoTrans, |
||||
padding, imageA_h, imageA_w, |
||||
blockA_height, blockA_width, ldA); |
||||
} |
||||
|
||||
if (!is_image_b && (K % use_buffer_indicator != 0)) |
||||
{ |
||||
ImB = ocl4dnnGEMMCopyBufferToImage<Dtype>(B, blockB_offset, |
||||
false, true, false, imageB_h, imageB_w, |
||||
blockB_height, blockB_width, ldB); |
||||
} |
||||
} |
||||
|
||||
size_t global[2]; |
||||
if (gemm_type == GEMM_TYPE_FAST_IMAGE_32_1 || gemm_type == GEMM_TYPE_FAST_IMAGE_B_IMAGE) |
||||
{ |
||||
global[0] = (size_t)( blockC_width + 7 ) & ~7; |
||||
} else { |
||||
global[0] = (size_t)( (blockC_width / 2 ) + 7 ) ^ ~7; |
||||
} |
||||
global[1] = (size_t)(blockC_height + 31) / 32; |
||||
|
||||
size_t local[2]; |
||||
local[0] = 8; |
||||
local[1] = 1; |
||||
|
||||
cl_uint arg_idx = 0; |
||||
if (is_image_a) |
||||
oclk_gemm_float.set(arg_idx++, ocl::KernelArg::PtrReadOnly(A)); |
||||
else |
||||
oclk_gemm_float.set(arg_idx++, ImA); |
||||
|
||||
if (TransB == CblasNoTrans || is_image_b || (K % use_buffer_indicator != 0)) |
||||
{ |
||||
if (is_image_b) |
||||
oclk_gemm_float.set(arg_idx++, ocl::KernelArg::PtrReadOnly(B)); |
||||
else |
||||
oclk_gemm_float.set(arg_idx++, ImB); |
||||
} else { |
||||
oclk_gemm_float.set(arg_idx++, ocl::KernelArg::PtrReadOnly(B)); |
||||
oclk_gemm_float.set(arg_idx++, blockB_offset); |
||||
oclk_gemm_float.set(arg_idx++, ldB); |
||||
} |
||||
oclk_gemm_float.set(arg_idx++, ocl::KernelArg::PtrWriteOnly(C)); |
||||
oclk_gemm_float.set(arg_idx++, blockC_offset); |
||||
oclk_gemm_float.set(arg_idx++, blockC_height); |
||||
oclk_gemm_float.set(arg_idx++, blockC_width); |
||||
oclk_gemm_float.set(arg_idx++, ldC); |
||||
oclk_gemm_float.set(arg_idx++, alpha); |
||||
oclk_gemm_float.set(arg_idx++, beta); |
||||
oclk_gemm_float.set(arg_idx++, padded_k); |
||||
if (TransB != CblasNoTrans) |
||||
oclk_gemm_float.set(arg_idx++, block_Ksize); |
||||
oclk_gemm_float.set(arg_idx++, isFirstColBlock); |
||||
|
||||
if (!oclk_gemm_float.run(2, global, local, false)) |
||||
return false; |
||||
|
||||
if (TransA == CblasNoTrans) |
||||
A_start_x += blockA_width; |
||||
else |
||||
A_start_y += blockA_height; |
||||
|
||||
if (TransB == CblasNoTrans) |
||||
B_start_y += blockB_height; |
||||
else |
||||
B_start_x += blockB_width; |
||||
|
||||
isFirstColBlock = 0; |
||||
} |
||||
|
||||
C_start_x += blockC_width; |
||||
if (TransA == CblasNoTrans) |
||||
A_start_x = 0; |
||||
else |
||||
A_start_y = 0; |
||||
if (TransB == CblasNoTrans) |
||||
{ |
||||
B_start_x += blockB_width; |
||||
B_start_y = 0; |
||||
} else { |
||||
B_start_y += blockB_height; |
||||
B_start_x = 0; |
||||
} |
||||
if (C_start_x >= N) |
||||
{ |
||||
C_start_x = 0; |
||||
B_start_x = 0; |
||||
B_start_y = 0; |
||||
C_start_y += blockC_height; |
||||
if (TransA == CblasNoTrans) |
||||
A_start_y += blockA_height; |
||||
else |
||||
A_start_x += blockA_width; |
||||
} |
||||
} |
||||
|
||||
return true; |
||||
} |
||||
|
||||
template<typename Dtype> |
||||
bool ocl4dnnGEMMCommon(const CBLAS_TRANSPOSE TransB, |
||||
const int32_t M, const int32_t N, const int32_t K, |
||||
const UMat A, const UMat B, |
||||
const UMat B_image, UMat C, |
||||
const size_t max_image_size) |
||||
{ |
||||
gemm_type_t gemm_type = GEMM_TYPE_FAST_IMAGE_32_1; |
||||
|
||||
if (gemm_type == GEMM_TYPE_FAST_IMAGE_32_1 || |
||||
gemm_type == GEMM_TYPE_FAST_IMAGE_32_2) |
||||
{ |
||||
return ocl4dnnFastImageGEMM<Dtype>(CblasNoTrans, TransB, M, N, K, |
||||
(Dtype)1., A, 0, B, 0, (Dtype)0., C, |
||||
0, false, false, gemm_type, max_image_size); |
||||
} |
||||
else if (gemm_type == GEMM_TYPE_FAST_IMAGE_B_IMAGE) |
||||
{ |
||||
return ocl4dnnFastImageGEMM<Dtype>(CblasNoTrans, TransB, M, N, K, |
||||
(Dtype)1., A, 0, B_image, 0, (Dtype)0., C, |
||||
0, false, true, |
||||
GEMM_TYPE_FAST_IMAGE_B_IMAGE, |
||||
max_image_size); |
||||
} |
||||
return false; |
||||
} |
||||
|
||||
template bool ocl4dnnGEMMCommon<float>(const CBLAS_TRANSPOSE TransB, |
||||
const int32_t M, const int32_t N, const int32_t K, |
||||
const UMat A, const UMat B, |
||||
const UMat B_image, UMat C, |
||||
const size_t max_image_size); |
||||
|
||||
template<typename Dtype> |
||||
bool ocl4dnnGEMV(const CBLAS_TRANSPOSE TransA, |
||||
const int32_t M, const int32_t N, const Dtype alpha, |
||||
const UMat A, const int32_t offA, const UMat x, |
||||
const int32_t offx, const Dtype beta, UMat y, |
||||
const int32_t offy) |
||||
{ |
||||
return false; |
||||
} |
||||
|
||||
template<> |
||||
bool ocl4dnnGEMV<float>(const CBLAS_TRANSPOSE TransA, |
||||
const int32_t M, const int32_t N, const float alpha, |
||||
const UMat A, const int32_t offA, const UMat x, |
||||
const int32_t offx, const float beta, UMat y, |
||||
const int32_t offy) |
||||
{ |
||||
ocl::Queue queue = ocl::Queue::getDefault(); |
||||
bool ret = false; |
||||
|
||||
if (TransA == CblasNoTrans) |
||||
{ |
||||
ocl::Kernel k(CL_KERNEL_SELECT("matvec_mul4"), cv::ocl::dnn::matvec_mul_oclsrc); |
||||
if (k.empty()) |
||||
return false; |
||||
|
||||
uint row_size = M; |
||||
uint col_size = N; |
||||
size_t localsize[] = { 128 }; |
||||
size_t globalsize[] = { row_size / 4 * localsize[0] }; |
||||
|
||||
uint argId = 0; |
||||
k.set(argId++, ocl::KernelArg::PtrReadOnly(A)); |
||||
k.set(argId++, offA); |
||||
k.set(argId++, cl_uint(col_size)); |
||||
k.set(argId++, cl_uint(col_size%4)); |
||||
k.set(argId++, ocl::KernelArg::PtrReadOnly(x)); |
||||
k.set(argId++, offx); |
||||
k.set(argId++, alpha); |
||||
k.set(argId++, beta); |
||||
k.set(argId++, ocl::KernelArg::PtrWriteOnly(y)); |
||||
k.set(argId++, offy); |
||||
k.set(argId++, NULL, localsize[0] * sizeof(cl_float4)); |
||||
|
||||
ret = k.run(1, globalsize, localsize, false); |
||||
|
||||
if ((row_size % 4) != 0 && ret) |
||||
{ |
||||
ocl::Kernel k_1(CL_KERNEL_SELECT("matvec_mul1"), cv::ocl::dnn::matvec_mul_oclsrc); |
||||
size_t localsize[] = { 128 }; |
||||
size_t globalsize[] = { row_size % 4 * localsize[0] }; |
||||
uint row_offset = row_size - (row_size % 4); |
||||
|
||||
uint argId = 0; |
||||
k_1.set(argId++, ocl::KernelArg::PtrReadOnly(A)); |
||||
k_1.set(argId++, offA); |
||||
k_1.set(argId++, cl_uint(col_size)); |
||||
k_1.set(argId++, cl_uint(row_offset)); |
||||
k_1.set(argId++, cl_uint(col_size%4)); |
||||
k_1.set(argId++, ocl::KernelArg::PtrReadOnly(x)); |
||||
k_1.set(argId++, offx); |
||||
k_1.set(argId++, alpha); |
||||
k_1.set(argId++, beta); |
||||
k_1.set(argId++, ocl::KernelArg::PtrWriteOnly(y)); |
||||
k_1.set(argId++, offy); |
||||
k_1.set(argId++, NULL, localsize[0] * sizeof(cl_float)); |
||||
|
||||
ret = k_1.run(1, globalsize, localsize, false); |
||||
} |
||||
} |
||||
return ret; |
||||
} |
||||
|
||||
template<typename Dtype> |
||||
bool ocl4dnnAXPY(const int32_t N, const Dtype alpha, |
||||
const UMat X, const int32_t offX, UMat Y, |
||||
const int32_t offY) |
||||
{ |
||||
ocl::Context ctx = ocl::Context::getDefault(); |
||||
|
||||
ocl::Kernel oclk_axpy(CL_KERNEL_SELECT("axpy"), cv::ocl::dnn::math_oclsrc); |
||||
if (oclk_axpy.empty()) |
||||
return false; |
||||
|
||||
size_t global[] = { 128 * 128 }; |
||||
size_t local[] = { 128 }; |
||||
|
||||
cl_uint argIdx = 0; |
||||
oclk_axpy.set(argIdx++, N); |
||||
oclk_axpy.set(argIdx++, alpha); |
||||
oclk_axpy.set(argIdx++, ocl::KernelArg::PtrReadOnly(X)); |
||||
oclk_axpy.set(argIdx++, offX); |
||||
oclk_axpy.set(argIdx++, ocl::KernelArg::PtrWriteOnly(Y)); |
||||
oclk_axpy.set(argIdx++, offY); |
||||
|
||||
return oclk_axpy.run(1, global, local, false); |
||||
} |
||||
|
||||
template bool ocl4dnnAXPY<float>(const int32_t N, const float alpha, |
||||
const UMat X, const int32_t offX, |
||||
UMat Y, const int32_t offY); |
||||
|
||||
#endif // HAVE_OPENCL
|
||||
|
||||
} // namespace ocl4dnn
|
||||
} // namespace dnn
|
||||
} // namespce cv
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,108 @@ |
||||
/*M///////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
|
||||
//
|
||||
// By downloading, copying, installing or using the software you agree to this license.
|
||||
// If you do not agree to this license, do not download, install,
|
||||
// copy or use the software.
|
||||
//
|
||||
//
|
||||
// License Agreement
|
||||
// For Open Source Computer Vision Library
|
||||
//
|
||||
// Copyright (C) 2017, Intel Corporation, all rights reserved.
|
||||
// Copyright (c) 2016-2017 Fabian David Tschopp, all rights reserved.
|
||||
// Third party copyrights are property of their respective owners.
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without modification,
|
||||
// are permitted provided that the following conditions are met:
|
||||
//
|
||||
// * Redistribution's of source code must retain the above copyright notice,
|
||||
// this list of conditions and the following disclaimer.
|
||||
//
|
||||
// * Redistribution's in binary form must reproduce the above copyright notice,
|
||||
// this list of conditions and the following disclaimer in the documentation
|
||||
// and/or other materials provided with the distribution.
|
||||
//
|
||||
// * The name of the copyright holders may not be used to endorse or promote products
|
||||
// derived from this software without specific prior written permission.
|
||||
//
|
||||
// This software is provided by the copyright holders and contributors "as is" and
|
||||
// any express or implied warranties, including, but not limited to, the implied
|
||||
// warranties of merchantability and fitness for a particular purpose are disclaimed.
|
||||
// In no event shall the Intel Corporation or contributors be liable for any direct,
|
||||
// indirect, incidental, special, exemplary, or consequential damages
|
||||
// (including, but not limited to, procurement of substitute goods or services;
|
||||
// loss of use, data, or profits; or business interruption) however caused
|
||||
// and on any theory of liability, whether in contract, strict liability,
|
||||
// or tort (including negligence or otherwise) arising in any way out of
|
||||
// the use of this software, even if advised of the possibility of such damage.
|
||||
//
|
||||
//M*/
|
||||
|
||||
#include "../../precomp.hpp" |
||||
#include "common.hpp" |
||||
#include "ocl4dnn.hpp" |
||||
#include "math_functions.hpp" |
||||
|
||||
#ifdef HAVE_OPENCL |
||||
namespace cv { namespace dnn { namespace ocl4dnn { |
||||
template<typename Dtype> |
||||
OCL4DNNInnerProduct<Dtype>::OCL4DNNInnerProduct(OCL4DNNInnerProductConfig config) |
||||
{ |
||||
bias_term_ = config.bias_term; |
||||
transpose_ = config.transpose; |
||||
N_ = num_output_ = config.num_output; |
||||
M_ = config.M; |
||||
K_ = config.K; |
||||
phase_test_ = config.phase_test; |
||||
image_copied_ = false; |
||||
} |
||||
|
||||
template<typename Dtype> |
||||
OCL4DNNInnerProduct<Dtype>::~OCL4DNNInnerProduct() |
||||
{ |
||||
} |
||||
|
||||
template<typename Dtype> |
||||
bool OCL4DNNInnerProduct<Dtype>::Forward(const UMat& bottom, |
||||
const UMat& weight, |
||||
const UMat& bias, |
||||
UMat& top) |
||||
{ |
||||
bool ret; |
||||
|
||||
if (M_ == 1) |
||||
{ |
||||
ret = ocl4dnnGEMV<Dtype>(CblasNoTrans, N_, K_, (Dtype) 1., |
||||
weight, 0, bottom, 0, (Dtype) 0., top, 0); |
||||
|
||||
if (bias_term_ && ret) |
||||
ret = ocl4dnnAXPY<Dtype>(N_, 1, bias, 0, top, 0); |
||||
|
||||
return ret; |
||||
} |
||||
else |
||||
{ |
||||
ret = false; |
||||
size_t max_image_size = std::min(ocl::Device::getDefault().image2DMaxWidth(), |
||||
ocl::Device::getDefault().image2DMaxHeight()); |
||||
if (M_ <= max_image_size && |
||||
N_ <= max_image_size && |
||||
K_ <= max_image_size && |
||||
cv::traits::Depth<Dtype>::value == CV_32F && |
||||
ocl::Device::getDefault().intelSubgroupsSupport()) |
||||
{ |
||||
ret = ocl4dnnGEMMCommon<Dtype>(transpose_ ? CblasNoTrans : CblasTrans, |
||||
M_, N_, K_, bottom, weight, UMat(), top, |
||||
max_image_size); |
||||
} |
||||
return ret; |
||||
} |
||||
} |
||||
|
||||
template class OCL4DNNInnerProduct<float>; |
||||
} // namespace ocl4dnn
|
||||
} |
||||
} |
||||
#endif // HAVE_OPENCL
|
@ -0,0 +1,126 @@ |
||||
/*M///////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
|
||||
//
|
||||
// By downloading, copying, installing or using the software you agree to this license.
|
||||
// If you do not agree to this license, do not download, install,
|
||||
// copy or use the software.
|
||||
//
|
||||
//
|
||||
// License Agreement
|
||||
// For Open Source Computer Vision Library
|
||||
//
|
||||
// Copyright (C) 2017, Intel Corporation, all rights reserved.
|
||||
// Copyright (c) 2016-2017 Fabian David Tschopp, all rights reserved.
|
||||
// Third party copyrights are property of their respective owners.
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without modification,
|
||||
// are permitted provided that the following conditions are met:
|
||||
//
|
||||
// * Redistribution's of source code must retain the above copyright notice,
|
||||
// this list of conditions and the following disclaimer.
|
||||
//
|
||||
// * Redistribution's in binary form must reproduce the above copyright notice,
|
||||
// this list of conditions and the following disclaimer in the documentation
|
||||
// and/or other materials provided with the distribution.
|
||||
//
|
||||
// * The name of the copyright holders may not be used to endorse or promote products
|
||||
// derived from this software without specific prior written permission.
|
||||
//
|
||||
// This software is provided by the copyright holders and contributors "as is" and
|
||||
// any express or implied warranties, including, but not limited to, the implied
|
||||
// warranties of merchantability and fitness for a particular purpose are disclaimed.
|
||||
// In no event shall the Intel Corporation or contributors be liable for any direct,
|
||||
// indirect, incidental, special, exemplary, or consequential damages
|
||||
// (including, but not limited to, procurement of substitute goods or services;
|
||||
// loss of use, data, or profits; or business interruption) however caused
|
||||
// and on any theory of liability, whether in contract, strict liability,
|
||||
// or tort (including negligence or otherwise) arising in any way out of
|
||||
// the use of this software, even if advised of the possibility of such damage.
|
||||
//
|
||||
//M*/
|
||||
|
||||
#include "../../precomp.hpp" |
||||
#include "common.hpp" |
||||
#include "ocl4dnn.hpp" |
||||
#include "opencl_kernels_dnn.hpp" |
||||
|
||||
#ifdef HAVE_OPENCL |
||||
namespace cv { namespace dnn { namespace ocl4dnn { |
||||
template<typename Dtype> |
||||
OCL4DNNLRN<Dtype>::OCL4DNNLRN(OCL4DNNLRNConfig config) |
||||
{ |
||||
lrn_type_ = config.lrn_type; |
||||
phase_test_ = config.phase_test; |
||||
size_ = config.local_size; |
||||
CHECK_EQ(size_ % 2, 1)<< "LRN only supports odd values for local_size"; |
||||
alpha_ = config.alpha; |
||||
beta_ = config.beta; |
||||
k_ = config.k; |
||||
norm_by_size_ = config.norm_by_size; |
||||
num_ = config.batch_size; |
||||
channels_ = config.channels; |
||||
height_ = config.height; |
||||
width_ = config.width; |
||||
} |
||||
|
||||
template<typename Dtype> |
||||
bool OCL4DNNLRN<Dtype>::Forward(const UMat& bottom, UMat& top) |
||||
{ |
||||
bool ret = true; |
||||
|
||||
if (!ocl::Device::getDefault().intelSubgroupsSupport()) |
||||
return false; |
||||
|
||||
switch (lrn_type_) |
||||
{ |
||||
case LRNParameter_NormRegion_ACROSS_CHANNELS: |
||||
ret = crossChannelForward(bottom, top); |
||||
break; |
||||
case LRNParameter_NormRegion_WITHIN_CHANNEL: |
||||
//TODO
|
||||
//WithinChannelForward(bottom_data, top_data);
|
||||
ret = false; |
||||
break; |
||||
default: |
||||
ret = false; |
||||
LOG(FATAL)<< "Unknown normalization region."; |
||||
} |
||||
return ret; |
||||
} |
||||
|
||||
template<typename Dtype> |
||||
bool OCL4DNNLRN<Dtype>::crossChannelForward(const UMat& bottom, UMat& top) |
||||
{ |
||||
ocl::Queue queue = ocl::Queue::getDefault(); |
||||
CHECK_EQ(phase_test_, true) << "Only support forward inference."; |
||||
|
||||
cl_uint argIdx = 0; |
||||
int32_t n_threads = num_ * height_ * width_; |
||||
size_t global_work_size_[1] = {(size_t)n_threads}; |
||||
String opts = clOptionSupport("-cl-no-subgroup-ifp") ? " -cl-no-subgroup-ifp " : ""; |
||||
ocl::Kernel oclk_lrn_fill; |
||||
if (!oclk_lrn_fill.create(CL_KERNEL_SELECT("lrn_full_no_scale"), ocl::dnn::ocl4dnn_lrn_oclsrc, opts)) |
||||
return false; |
||||
|
||||
oclk_lrn_fill.set(argIdx++, n_threads); |
||||
oclk_lrn_fill.set(argIdx++, ocl::KernelArg::PtrReadOnly(bottom)); |
||||
oclk_lrn_fill.set(argIdx++, num_); |
||||
oclk_lrn_fill.set(argIdx++, channels_); |
||||
oclk_lrn_fill.set(argIdx++, height_); |
||||
oclk_lrn_fill.set(argIdx++, width_); |
||||
oclk_lrn_fill.set(argIdx++, size_); |
||||
int size_norm_factor = norm_by_size_ ? size_ : 1; |
||||
oclk_lrn_fill.set(argIdx++, alpha_ / size_norm_factor); |
||||
oclk_lrn_fill.set(argIdx++, k_); |
||||
oclk_lrn_fill.set(argIdx++, ocl::KernelArg::PtrWriteOnly(top)); |
||||
oclk_lrn_fill.set(argIdx++, -beta_); |
||||
|
||||
return oclk_lrn_fill.run(1, global_work_size_, NULL, false); |
||||
} |
||||
|
||||
template class OCL4DNNLRN<float>; |
||||
} // namespace ocl4dnn
|
||||
} |
||||
} |
||||
#endif // HAVE_OPENCL
|
@ -0,0 +1,213 @@ |
||||
/*M///////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
|
||||
//
|
||||
// By downloading, copying, installing or using the software you agree to this license.
|
||||
// If you do not agree to this license, do not download, install,
|
||||
// copy or use the software.
|
||||
//
|
||||
//
|
||||
// License Agreement
|
||||
// For Open Source Computer Vision Library
|
||||
//
|
||||
// Copyright (c) 2016-2017 Fabian David Tschopp, all rights reserved.
|
||||
// Third party copyrights are property of their respective owners.
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without modification,
|
||||
// are permitted provided that the following conditions are met:
|
||||
//
|
||||
// * Redistribution's of source code must retain the above copyright notice,
|
||||
// this list of conditions and the following disclaimer.
|
||||
//
|
||||
// * Redistribution's in binary form must reproduce the above copyright notice,
|
||||
// this list of conditions and the following disclaimer in the documentation
|
||||
// and/or other materials provided with the distribution.
|
||||
//
|
||||
// * The name of the copyright holders may not be used to endorse or promote products
|
||||
// derived from this software without specific prior written permission.
|
||||
//
|
||||
// This software is provided by the copyright holders and contributors "as is" and
|
||||
// any express or implied warranties, including, but not limited to, the implied
|
||||
// warranties of merchantability and fitness for a particular purpose are disclaimed.
|
||||
// In no event shall the Intel Corporation or contributors be liable for any direct,
|
||||
// indirect, incidental, special, exemplary, or consequential damages
|
||||
// (including, but not limited to, procurement of substitute goods or services;
|
||||
// loss of use, data, or profits; or business interruption) however caused
|
||||
// and on any theory of liability, whether in contract, strict liability,
|
||||
// or tort (including negligence or otherwise) arising in any way out of
|
||||
// the use of this software, even if advised of the possibility of such damage.
|
||||
//
|
||||
//M*/
|
||||
|
||||
#include "../../precomp.hpp" |
||||
#include <string> |
||||
#include <vector> |
||||
#include "common.hpp" |
||||
#include "ocl4dnn.hpp" |
||||
#include "opencl_kernels_dnn.hpp" |
||||
|
||||
#ifdef HAVE_OPENCL |
||||
namespace cv { namespace dnn { namespace ocl4dnn { |
||||
template<typename Dtype> |
||||
OCL4DNNPool<Dtype>::OCL4DNNPool(OCL4DNNPoolConfig config) |
||||
{ |
||||
int dims = config.in_shape.size(); |
||||
int spatial_dims = 2; |
||||
|
||||
batch_size_ = config.in_shape[0]; |
||||
channels_ = config.channels; |
||||
pool_method_ = config.pool_method; |
||||
|
||||
for (int i = 0; i < spatial_dims; ++i) |
||||
{ |
||||
kernel_shape_.push_back(i == 0 ? config.kernel.height : config.kernel.width); |
||||
pad_.push_back(i == 0 ? config.pad.height : config.pad.width); |
||||
stride_.push_back(i == 0 ? config.stride.height : config.stride.width); |
||||
im_in_shape_.push_back(config.in_shape[dims - spatial_dims + i]); |
||||
im_out_shape_.push_back(config.out_shape[dims - spatial_dims + i]); |
||||
} |
||||
|
||||
kernel_h_ = kernel_shape_[0]; |
||||
kernel_w_ = kernel_shape_[1]; |
||||
stride_h_ = stride_[0]; |
||||
stride_w_ = stride_[1]; |
||||
pad_h_ = pad_[0]; |
||||
pad_w_ = pad_[1]; |
||||
height_ = im_in_shape_[0]; |
||||
width_ = im_in_shape_[1]; |
||||
pooled_height_ = im_out_shape_[0]; |
||||
pooled_width_ = im_out_shape_[1]; |
||||
|
||||
count_ = 1; |
||||
for (int i = 0; i < config.out_shape.size(); ++i) |
||||
{ |
||||
count_ *= config.out_shape[i]; |
||||
} |
||||
} |
||||
|
||||
template<typename Dtype> |
||||
OCL4DNNPool<Dtype>::~OCL4DNNPool() |
||||
{ |
||||
mask_idx_.release(); |
||||
} |
||||
|
||||
template<typename Dtype> |
||||
bool OCL4DNNPool<Dtype>::Forward(const UMat& bottom, |
||||
UMat& top, |
||||
UMat& top_mask) |
||||
{ |
||||
bool ret = true; |
||||
ocl::Queue queue = ocl::Queue::getDefault(); |
||||
size_t global[] = { 128 * 128 }; |
||||
size_t local[] = { 128 }; |
||||
cl_uint argIdx = 0; |
||||
|
||||
// support 2D case
|
||||
switch (pool_method_) |
||||
{ |
||||
case LIBDNN_POOLING_METHOD_MAX: |
||||
{ |
||||
if (top_mask.empty() && mask_idx_.empty()) |
||||
{ |
||||
mask_idx_.create(1, count_, CV_32FC1); |
||||
} |
||||
ocl::Kernel oclk_max_pool_forward(CL_KERNEL_SELECT("max_pool_forward"), |
||||
cv::ocl::dnn::ocl4dnn_pooling_oclsrc); |
||||
|
||||
if (oclk_max_pool_forward.empty()) |
||||
return false; |
||||
|
||||
argIdx = 0; |
||||
oclk_max_pool_forward.set(argIdx++, count_); |
||||
oclk_max_pool_forward.set(argIdx++, ocl::KernelArg::PtrReadOnly(bottom)); |
||||
oclk_max_pool_forward.set(argIdx++, batch_size_); |
||||
oclk_max_pool_forward.set(argIdx++, channels_); |
||||
oclk_max_pool_forward.set(argIdx++, height_); |
||||
oclk_max_pool_forward.set(argIdx++, width_); |
||||
oclk_max_pool_forward.set(argIdx++, pooled_height_); |
||||
oclk_max_pool_forward.set(argIdx++, pooled_width_); |
||||
oclk_max_pool_forward.set(argIdx++, kernel_h_); |
||||
oclk_max_pool_forward.set(argIdx++, kernel_w_); |
||||
oclk_max_pool_forward.set(argIdx++, stride_h_); |
||||
oclk_max_pool_forward.set(argIdx++, stride_w_); |
||||
oclk_max_pool_forward.set(argIdx++, pad_h_); |
||||
oclk_max_pool_forward.set(argIdx++, pad_w_); |
||||
oclk_max_pool_forward.set(argIdx++, ocl::KernelArg::PtrWriteOnly(top)); |
||||
oclk_max_pool_forward.set(argIdx++, mask_idx_.empty() ? 0 : 1); |
||||
if (mask_idx_.empty()) |
||||
oclk_max_pool_forward.set(argIdx++, (void *)NULL); |
||||
else |
||||
oclk_max_pool_forward.set(argIdx++, ocl::KernelArg::PtrWriteOnly(mask_idx_)); |
||||
oclk_max_pool_forward.set(argIdx++, ocl::KernelArg::PtrWriteOnly(top_mask)); |
||||
|
||||
ret = oclk_max_pool_forward.run(1, global, local, false); |
||||
} |
||||
break; |
||||
case LIBDNN_POOLING_METHOD_AVE: |
||||
{ |
||||
ocl::Kernel oclk_ave_pool_forward(CL_KERNEL_SELECT("ave_pool_forward"), |
||||
cv::ocl::dnn::ocl4dnn_pooling_oclsrc); |
||||
|
||||
if (oclk_ave_pool_forward.empty()) |
||||
return false; |
||||
|
||||
argIdx = 0; |
||||
oclk_ave_pool_forward.set(argIdx++, count_); |
||||
oclk_ave_pool_forward.set(argIdx++, ocl::KernelArg::PtrReadOnly(bottom)); |
||||
oclk_ave_pool_forward.set(argIdx++, batch_size_); |
||||
oclk_ave_pool_forward.set(argIdx++, channels_); |
||||
oclk_ave_pool_forward.set(argIdx++, height_); |
||||
oclk_ave_pool_forward.set(argIdx++, width_); |
||||
oclk_ave_pool_forward.set(argIdx++, pooled_height_); |
||||
oclk_ave_pool_forward.set(argIdx++, pooled_width_); |
||||
oclk_ave_pool_forward.set(argIdx++, kernel_h_); |
||||
oclk_ave_pool_forward.set(argIdx++, kernel_w_); |
||||
oclk_ave_pool_forward.set(argIdx++, stride_h_); |
||||
oclk_ave_pool_forward.set(argIdx++, stride_w_); |
||||
oclk_ave_pool_forward.set(argIdx++, pad_h_); |
||||
oclk_ave_pool_forward.set(argIdx++, pad_w_); |
||||
oclk_ave_pool_forward.set(argIdx++, ocl::KernelArg::PtrWriteOnly(top)); |
||||
|
||||
ret = oclk_ave_pool_forward.run(1, global, local, false); |
||||
} |
||||
break; |
||||
case LIBDNN_POOLING_METHOD_STO: |
||||
{ |
||||
ocl::Kernel oclk_sto_pool_forward(CL_KERNEL_SELECT("sto_pool_forward_test"), |
||||
cv::ocl::dnn::ocl4dnn_pooling_oclsrc); |
||||
|
||||
if (oclk_sto_pool_forward.empty()) |
||||
return false; |
||||
|
||||
argIdx = 0; |
||||
oclk_sto_pool_forward.set(argIdx++, count_); |
||||
oclk_sto_pool_forward.set(argIdx++, ocl::KernelArg::PtrReadOnly(bottom)); |
||||
oclk_sto_pool_forward.set(argIdx++, batch_size_); |
||||
oclk_sto_pool_forward.set(argIdx++, channels_); |
||||
oclk_sto_pool_forward.set(argIdx++, height_); |
||||
oclk_sto_pool_forward.set(argIdx++, width_); |
||||
oclk_sto_pool_forward.set(argIdx++, pooled_height_); |
||||
oclk_sto_pool_forward.set(argIdx++, pooled_width_); |
||||
oclk_sto_pool_forward.set(argIdx++, kernel_h_); |
||||
oclk_sto_pool_forward.set(argIdx++, kernel_w_); |
||||
oclk_sto_pool_forward.set(argIdx++, stride_h_); |
||||
oclk_sto_pool_forward.set(argIdx++, stride_w_); |
||||
oclk_sto_pool_forward.set(argIdx++, ocl::KernelArg::PtrWriteOnly(top)); |
||||
|
||||
ret = oclk_sto_pool_forward.run(1, global, local, false); |
||||
} |
||||
break; |
||||
default: |
||||
{ |
||||
ret = false; |
||||
LOG(FATAL)<< "Unknown pooling method."; |
||||
} |
||||
} |
||||
return ret; |
||||
} |
||||
|
||||
template class OCL4DNNPool<float>; |
||||
} // namespace ocl4dnn
|
||||
} |
||||
} |
||||
#endif // HAVE_OPENCL
|
@ -0,0 +1,135 @@ |
||||
/*M///////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
|
||||
//
|
||||
// By downloading, copying, installing or using the software you agree to this license.
|
||||
// If you do not agree to this license, do not download, install,
|
||||
// copy or use the software.
|
||||
//
|
||||
//
|
||||
// License Agreement
|
||||
// For Open Source Computer Vision Library
|
||||
//
|
||||
// Copyright (C) 2017, Intel Corporation, all rights reserved.
|
||||
// Third party copyrights are property of their respective owners.
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without modification,
|
||||
// are permitted provided that the following conditions are met:
|
||||
//
|
||||
// * Redistribution's of source code must retain the above copyright notice,
|
||||
// this list of conditions and the following disclaimer.
|
||||
//
|
||||
// * Redistribution's in binary form must reproduce the above copyright notice,
|
||||
// this list of conditions and the following disclaimer in the documentation
|
||||
// and/or other materials provided with the distribution.
|
||||
//
|
||||
// * The name of the copyright holders may not be used to endorse or promote products
|
||||
// derived from this software without specific prior written permission.
|
||||
//
|
||||
// This software is provided by the copyright holders and contributors "as is" and
|
||||
// any express or implied warranties, including, but not limited to, the implied
|
||||
// warranties of merchantability and fitness for a particular purpose are disclaimed.
|
||||
// In no event shall the Intel Corporation or contributors be liable for any direct,
|
||||
// indirect, incidental, special, exemplary, or consequential damages
|
||||
// (including, but not limited to, procurement of substitute goods or services;
|
||||
// loss of use, data, or profits; or business interruption) however caused
|
||||
// and on any theory of liability, whether in contract, strict liability,
|
||||
// or tort (including negligence or otherwise) arising in any way out of
|
||||
// the use of this software, even if advised of the possibility of such damage.
|
||||
//
|
||||
//M*/
|
||||
|
||||
#include "../../precomp.hpp" |
||||
#include <vector> |
||||
#include "common.hpp" |
||||
#include "ocl4dnn.hpp" |
||||
#include "opencl_kernels_dnn.hpp" |
||||
|
||||
#ifdef HAVE_OPENCL |
||||
namespace cv { namespace dnn { namespace ocl4dnn { |
||||
template<typename Dtype> |
||||
OCL4DNNSoftmax<Dtype>::OCL4DNNSoftmax(OCL4DNNSoftmaxConfig config) |
||||
{ |
||||
softmax_axis_ = config.axis; |
||||
channels_ = config.channels; |
||||
|
||||
inner_num_ = 1; |
||||
outer_num_ = 1; |
||||
count_ = 1; |
||||
int32_t scale_sz = 1; |
||||
for (int32_t i = softmax_axis_ + 1; i < config.in_shape.size(); i++) |
||||
inner_num_ *= config.in_shape[i]; |
||||
use_slm_ = (config.in_shape[softmax_axis_] * inner_num_ + inner_num_ * 17) <= 8192; |
||||
for (int32_t i = 0; i < softmax_axis_; i++) |
||||
outer_num_ *= config.in_shape[i]; |
||||
count_ = inner_num_ + outer_num_; |
||||
|
||||
std::vector<int32_t> scale_dims = config.in_shape; |
||||
scale_dims[softmax_axis_] = use_slm_ ? 1 : 17; |
||||
for (int32_t i = 0; i < scale_dims.size(); i++) |
||||
scale_sz *= scale_dims[i]; |
||||
|
||||
scale_data_.create(1, scale_sz, CV_32FC1); |
||||
} |
||||
|
||||
template<typename Dtype> |
||||
OCL4DNNSoftmax<Dtype>::~OCL4DNNSoftmax() |
||||
{ |
||||
scale_data_.release(); |
||||
} |
||||
|
||||
template<typename Dtype> |
||||
bool OCL4DNNSoftmax<Dtype>::Forward(const UMat& bottom, UMat& top) |
||||
{ |
||||
bool ret = false; |
||||
ocl::Queue queue = ocl::Queue::getDefault(); |
||||
bool intel_subgroup = ocl::Device::getDefault().intelSubgroupsSupport(); |
||||
if (intel_subgroup && inner_num_ < 128) |
||||
{ |
||||
String opts = clOptionSupport("-cl-no-subgroup-ifp") ? " -cl-no-subgroup-ifp " : ""; |
||||
String kname; |
||||
ocl::Kernel oclk_softmax_forward_kernel; |
||||
|
||||
if (use_slm_) |
||||
kname = CL_KERNEL_SELECT("softmax_forward_slm"); |
||||
else |
||||
kname = CL_KERNEL_SELECT("softmax_forward"); |
||||
|
||||
if (!oclk_softmax_forward_kernel.create(kname.c_str(), ocl::dnn::softmax_loss_oclsrc, opts)) |
||||
return false; |
||||
|
||||
size_t global_size[] = { 256, (size_t)outer_num_, 1 }; |
||||
size_t local_size[] = { 256, 1, 1 }; |
||||
cl_uint argIdx = 0; |
||||
|
||||
if (use_slm_) |
||||
{ |
||||
oclk_softmax_forward_kernel.set(argIdx++, outer_num_); |
||||
oclk_softmax_forward_kernel.set(argIdx++, channels_); |
||||
oclk_softmax_forward_kernel.set(argIdx++, inner_num_); |
||||
oclk_softmax_forward_kernel.set(argIdx++, ocl::KernelArg::PtrWriteOnly(scale_data_)); |
||||
oclk_softmax_forward_kernel.set(argIdx++, ocl::KernelArg::PtrReadOnly(bottom)); |
||||
oclk_softmax_forward_kernel.set(argIdx++, ocl::KernelArg::PtrWriteOnly(top)); |
||||
oclk_softmax_forward_kernel.set(argIdx++, NULL, channels_ * inner_num_* sizeof(Dtype)); |
||||
oclk_softmax_forward_kernel.set(argIdx++, NULL, inner_num_* sizeof(Dtype)); |
||||
oclk_softmax_forward_kernel.set(argIdx++, NULL, 16 * inner_num_* sizeof(Dtype)); |
||||
} |
||||
else |
||||
{ |
||||
oclk_softmax_forward_kernel.set(argIdx++, outer_num_); |
||||
oclk_softmax_forward_kernel.set(argIdx++, channels_); |
||||
oclk_softmax_forward_kernel.set(argIdx++, inner_num_); |
||||
oclk_softmax_forward_kernel.set(argIdx++, ocl::KernelArg::PtrWriteOnly(scale_data_)); |
||||
oclk_softmax_forward_kernel.set(argIdx++, ocl::KernelArg::PtrReadOnly(bottom)); |
||||
oclk_softmax_forward_kernel.set(argIdx++, ocl::KernelArg::PtrWriteOnly(top)); |
||||
} |
||||
ret = oclk_softmax_forward_kernel.run(3, global_size, local_size, false); |
||||
} |
||||
return ret; |
||||
} |
||||
|
||||
template class OCL4DNNSoftmax<float>; |
||||
} // namespace ocl4dnn
|
||||
} |
||||
} |
||||
#endif // HAVE_OPENCL
|
@ -0,0 +1,26 @@ |
||||
|
||||
__kernel void batchnorm(__global const T *src, int src_offset, |
||||
__global const float *meanMat, |
||||
float varMeanScale, |
||||
__global const float *invStdMat, |
||||
__global const float *weight, |
||||
__global const float *bias, |
||||
int hasWeight, int hasBias, |
||||
int width, int height, int channel, |
||||
__global T *dst, int dst_offset) |
||||
{ |
||||
int x = get_global_id(0); |
||||
int y = get_global_id(1); |
||||
int c = get_global_id(2); |
||||
|
||||
if (x >= width || y >= height || c >= channel) |
||||
return; |
||||
|
||||
float mean = meanMat[c] * varMeanScale; |
||||
float invstd = invStdMat[c]; |
||||
float w = hasWeight ? weight[c] : 1; |
||||
float b = hasBias ? bias[c] : 0; |
||||
int index = y * width + x + c * width * height; |
||||
T val = (src[index + src_offset] - mean) * w * invstd + b; |
||||
dst[index + dst_offset] = val; |
||||
} |
@ -0,0 +1,45 @@ |
||||
/*M/////////////////////////////////////////////////////////////////////////////////////// |
||||
// |
||||
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING. |
||||
// |
||||
// By downloading, copying, installing or using the software you agree to this license. |
||||
// If you do not agree to this license, do not download, install, |
||||
// copy or use the software. |
||||
// |
||||
// |
||||
// License Agreement |
||||
// For Open Source Computer Vision Library |
||||
// |
||||
// Copyright (C) 2017, Intel Corporation, all rights reserved. |
||||
// Copyright (c) 2016-2017 Fabian David Tschopp, all rights reserved. |
||||
// Third party copyrights are property of their respective owners. |
||||
// |
||||
// Redistribution and use in source and binary forms, with or without modification, |
||||
// are permitted provided that the following conditions are met: |
||||
// |
||||
// * Redistribution's of source code must retain the above copyright notice, |
||||
// this list of conditions and the following disclaimer. |
||||
// |
||||
// * Redistribution's in binary form must reproduce the above copyright notice, |
||||
// this list of conditions and the following disclaimer in the documentation |
||||
// and/or other materials provided with the distribution. |
||||
// |
||||
// * The name of the copyright holders may not be used to endorse or promote products |
||||
// derived from this software without specific prior written permission. |
||||
// |
||||
// This software is provided by the copyright holders and contributors "as is" and |
||||
// any express or implied warranties, including, but not limited to, the implied |
||||
// warranties of merchantability and fitness for a particular purpose are disclaimed. |
||||
// In no event shall the Intel Corporation or contributors be liable for any direct, |
||||
// indirect, incidental, special, exemplary, or consequential damages |
||||
// (including, but not limited to, procurement of substitute goods or services; |
||||
// loss of use, data, or profits; or business interruption) however caused |
||||
// and on any theory of liability, whether in contract, strict liability, |
||||
// or tort (including negligence or otherwise) arising in any way out of |
||||
// the use of this software, even if advised of the possibility of such damage. |
||||
// |
||||
//M*/ |
||||
|
||||
__kernel void null_kernel_float(float arg) { |
||||
float out = arg; |
||||
} |
@ -0,0 +1,60 @@ |
||||
/*M/////////////////////////////////////////////////////////////////////////////////////// |
||||
// |
||||
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING. |
||||
// |
||||
// By downloading, copying, installing or using the software you agree to this license. |
||||
// If you do not agree to this license, do not download, install, |
||||
// copy or use the software. |
||||
// |
||||
// |
||||
// License Agreement |
||||
// For Open Source Computer Vision Library |
||||
// |
||||
// Copyright (c) 2016-2017 Fabian David Tschopp, all rights reserved. |
||||
// Third party copyrights are property of their respective owners. |
||||
// |
||||
// Redistribution and use in source and binary forms, with or without modification, |
||||
// are permitted provided that the following conditions are met: |
||||
// |
||||
// * Redistribution's of source code must retain the above copyright notice, |
||||
// this list of conditions and the following disclaimer. |
||||
// |
||||
// * Redistribution's in binary form must reproduce the above copyright notice, |
||||
// this list of conditions and the following disclaimer in the documentation |
||||
// and/or other materials provided with the distribution. |
||||
// |
||||
// * The name of the copyright holders may not be used to endorse or promote products |
||||
// derived from this software without specific prior written permission. |
||||
// |
||||
// This software is provided by the copyright holders and contributors "as is" and |
||||
// any express or implied warranties, including, but not limited to, the implied |
||||
// warranties of merchantability and fitness for a particular purpose are disclaimed. |
||||
// In no event shall the Intel Corporation or contributors be liable for any direct, |
||||
// indirect, incidental, special, exemplary, or consequential damages |
||||
// (including, but not limited to, procurement of substitute goods or services; |
||||
// loss of use, data, or profits; or business interruption) however caused |
||||
// and on any theory of liability, whether in contract, strict liability, |
||||
// or tort (including negligence or otherwise) arising in any way out of |
||||
// the use of this software, even if advised of the possibility of such damage. |
||||
// |
||||
//M*/ |
||||
|
||||
__kernel void concat(const int nthreads, |
||||
__global const Dtype* in_data, |
||||
const int num_concats, |
||||
const int concat_size, |
||||
const int top_concat_axis, |
||||
const int bottom_concat_axis, |
||||
const int offset_concat_axis, |
||||
__global Dtype* out_data) { |
||||
|
||||
for (int index = get_global_id(0); index < nthreads; |
||||
index += get_global_size(0)) { |
||||
const int total_concat_size = concat_size * bottom_concat_axis; |
||||
const int concat_num = index / total_concat_size; |
||||
const int concat_index = index % total_concat_size; |
||||
const int top_index = concat_index |
||||
+ (concat_num * top_concat_axis + offset_concat_axis) * concat_size; |
||||
out_data[top_index] = in_data[index]; |
||||
} |
||||
} |
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,73 @@ |
||||
/*M/////////////////////////////////////////////////////////////////////////////////////// |
||||
// |
||||
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING. |
||||
// |
||||
// By downloading, copying, installing or using the software you agree to this license. |
||||
// If you do not agree to this license, do not download, install, |
||||
// copy or use the software. |
||||
// |
||||
// |
||||
// License Agreement |
||||
// For Open Source Computer Vision Library |
||||
// |
||||
// Copyright (C) 2017, Intel Corporation, all rights reserved. |
||||
// Third party copyrights are property of their respective owners. |
||||
// |
||||
// Redistribution and use in source and binary forms, with or without modification, |
||||
// are permitted provided that the following conditions are met: |
||||
// |
||||
// * Redistribution's of source code must retain the above copyright notice, |
||||
// this list of conditions and the following disclaimer. |
||||
// |
||||
// * Redistribution's in binary form must reproduce the above copyright notice, |
||||
// this list of conditions and the following disclaimer in the documentation |
||||
// and/or other materials provided with the distribution. |
||||
// |
||||
// * The name of the copyright holders may not be used to endorse or promote products |
||||
// derived from this software without specific prior written permission. |
||||
// |
||||
// This software is provided by the copyright holders and contributors "as is" and |
||||
// any express or implied warranties, including, but not limited to, the implied |
||||
// warranties of merchantability and fitness for a particular purpose are disclaimed. |
||||
// In no event shall the Intel Corporation or contributors be liable for any direct, |
||||
// indirect, incidental, special, exemplary, or consequential damages |
||||
// (including, but not limited to, procurement of substitute goods or services; |
||||
// loss of use, data, or profits; or business interruption) however caused |
||||
// and on any theory of liability, whether in contract, strict liability, |
||||
// or tort (including negligence or otherwise) arising in any way out of |
||||
// the use of this software, even if advised of the possibility of such damage. |
||||
// |
||||
//M*/ |
||||
|
||||
#define CONCAT(A,B) A##_##B |
||||
#define TEMPLATE(name,type) CONCAT(name,type) |
||||
#define Dtype float |
||||
|
||||
__kernel void TEMPLATE(copyWeightsSwizzled, Dtype) |
||||
(__global Dtype* weightIn, |
||||
__global Dtype* weightOut, |
||||
const int kernel_w, |
||||
const int kernel_h, |
||||
const int channels, |
||||
const int outputs, |
||||
const int swizzleFactor) { |
||||
|
||||
unsigned int sX = get_global_id(0); |
||||
|
||||
//Original location |
||||
|
||||
//Output location |
||||
int outputSublayer = channels / swizzleFactor; |
||||
int outputSublayerIndex = channels % swizzleFactor; |
||||
|
||||
int filter = sX / (kernel_w*kernel_h*channels); |
||||
int kernel_X = sX % kernel_w; |
||||
int kernel_Y = (sX / kernel_w) % kernel_h; |
||||
int kernel_C = (sX / (kernel_w * kernel_h)) % channels; |
||||
|
||||
int FP = filter / swizzleFactor; |
||||
int F1 = filter % swizzleFactor; |
||||
|
||||
weightOut[FP*(kernel_w*kernel_h*channels*swizzleFactor) + kernel_C*(kernel_w*kernel_h*swizzleFactor) + kernel_Y*(kernel_w*swizzleFactor) + kernel_X*swizzleFactor + F1] |
||||
= weightIn[filter*(kernel_w*kernel_h*channels) + kernel_C*(kernel_w*kernel_h) + kernel_Y*kernel_w + kernel_X]; |
||||
} |
@ -0,0 +1,43 @@ |
||||
/*M/////////////////////////////////////////////////////////////////////////////////////// |
||||
// |
||||
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING. |
||||
// |
||||
// By downloading, copying, installing or using the software you agree to this license. |
||||
// If you do not agree to this license, do not download, install, |
||||
// copy or use the software. |
||||
// |
||||
// |
||||
// License Agreement |
||||
// For Open Source Computer Vision Library |
||||
// |
||||
// Third party copyrights are property of their respective owners. |
||||
// |
||||
// Redistribution and use in source and binary forms, with or without modification, |
||||
// are permitted provided that the following conditions are met: |
||||
// |
||||
// * Redistribution's of source code must retain the above copyright notice, |
||||
// this list of conditions and the following disclaimer. |
||||
// |
||||
// * Redistribution's in binary form must reproduce the above copyright notice, |
||||
// this list of conditions and the following disclaimer in the documentation |
||||
// and/or other materials provided with the distribution. |
||||
// |
||||
// * The name of the copyright holders may not be used to endorse or promote products |
||||
// derived from this software without specific prior written permission. |
||||
// |
||||
// This software is provided by the copyright holders and contributors "as is" and |
||||
// any express or implied warranties, including, but not limited to, the implied |
||||
// warranties of merchantability and fitness for a particular purpose are disclaimed. |
||||
// In no event shall the Intel Corporation or contributors be liable for any direct, |
||||
// indirect, incidental, special, exemplary, or consequential damages |
||||
// (including, but not limited to, procurement of substitute goods or services; |
||||
// loss of use, data, or profits; or business interruption) however caused |
||||
// and on any theory of liability, whether in contract, strict liability, |
||||
// or tort (including negligence or otherwise) arising in any way out of |
||||
// the use of this software, even if advised of the possibility of such damage. |
||||
// |
||||
//M*/ |
||||
|
||||
__kernel void dummy_kernel() |
||||
{ |
||||
} |
@ -0,0 +1,635 @@ |
||||
/*M/////////////////////////////////////////////////////////////////////////////////////// |
||||
// |
||||
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING. |
||||
// |
||||
// By downloading, copying, installing or using the software you agree to this license. |
||||
// If you do not agree to this license, do not download, install, |
||||
// copy or use the software. |
||||
// |
||||
// |
||||
// License Agreement |
||||
// For Open Source Computer Vision Library |
||||
// |
||||
// Copyright (C) 2017, Intel Corporation, all rights reserved. |
||||
// Third party copyrights are property of their respective owners. |
||||
// |
||||
// Redistribution and use in source and binary forms, with or without modification, |
||||
// are permitted provided that the following conditions are met: |
||||
// |
||||
// * Redistribution's of source code must retain the above copyright notice, |
||||
// this list of conditions and the following disclaimer. |
||||
// |
||||
// * Redistribution's in binary form must reproduce the above copyright notice, |
||||
// this list of conditions and the following disclaimer in the documentation |
||||
// and/or other materials provided with the distribution. |
||||
// |
||||
// * The name of the copyright holders may not be used to endorse or promote products |
||||
// derived from this software without specific prior written permission. |
||||
// |
||||
// This software is provided by the copyright holders and contributors "as is" and |
||||
// any express or implied warranties, including, but not limited to, the implied |
||||
// warranties of merchantability and fitness for a particular purpose are disclaimed. |
||||
// In no event shall the Intel Corporation or contributors be liable for any direct, |
||||
// indirect, incidental, special, exemplary, or consequential damages |
||||
// (including, but not limited to, procurement of substitute goods or services; |
||||
// loss of use, data, or profits; or business interruption) however caused |
||||
// and on any theory of liability, whether in contract, strict liability, |
||||
// or tort (including negligence or otherwise) arising in any way out of |
||||
// the use of this software, even if advised of the possibility of such damage. |
||||
// |
||||
//M*/ |
||||
|
||||
#define CONCAT(A,B) A##_##B |
||||
#define TEMPLATE(name,type) CONCAT(name,type) |
||||
|
||||
// Types used for parameters, offset computations and so on |
||||
#define int_tp int |
||||
#define uint_tp unsigned int |
||||
|
||||
#define Dtype float |
||||
#define Dtype2 float2 |
||||
#define Dtype4 float4 |
||||
#define Dtype8 float8 |
||||
|
||||
#define as_Dtype as_float |
||||
#define as_Dtype2 as_float2 |
||||
#define as_Dtype4 as_float4 |
||||
#define as_Dtype8 as_float8 |
||||
|
||||
#define KERNEL_ARG_DTYPE float |
||||
|
||||
#if defined(cl_intel_subgroups) |
||||
#pragma OPENCL EXTENSION cl_intel_subgroups : enable |
||||
#endif |
||||
|
||||
#define TILE_M 32 |
||||
#define TILE_K 8 |
||||
|
||||
// common block to calculate (alpha * AxB + beta * C) and output to destination image. |
||||
|
||||
#define SUBGROUP_BLOCK_READ8( __image, __coord ) intel_sub_group_block_read8( __image, __coord ) |
||||
#define SHUFFLE_TYPE2(val) val |
||||
#define SHUFFLE_TYPE8(val) val |
||||
#define READ_IMAGE(__image, __coord) read_imagef(__image, sampler, __coord) |
||||
#define SIZE_OF_ELEMENT sizeof(uint) |
||||
#define SIMD_SIZE_GEMM 8 |
||||
#define TILE_N 8 |
||||
|
||||
//#define USE_IMAGE_C |
||||
#ifdef USE_IMAGE_C |
||||
#define BLOCKC_READ8( _C, _coordC ) as_Dtype8( intel_sub_group_block_read8( _C, _coordC ) ) |
||||
#define BLOCKC_WRITE8( _C, _coordC, _val ) intel_sub_group_block_write8( _C, _coordC, as_uint8( _val ) ) |
||||
#define MATC_PARAMETER __read_only image2d_t C, __write_only image2d_t dst |
||||
#define GEMM_OUTPUT(ALPHA1, BETA_NOT0) GEMM_OUTPUT_EXT(ALPHA1, BETA_NOT0, C, dst, sizeof(uint)) |
||||
#else |
||||
#define BLOCKC_READ8( _C, _coordC ) \ |
||||
(Dtype8) ( (_coordC.x + get_local_id(0) < N && _coordC.y < M) ? _C[ _coordC.y * ldc + _coordC.x + get_local_id(0) ] : 0, \ |
||||
(_coordC.x + get_local_id(0) < N && _coordC.y + 1 < M) ? _C[ ( _coordC.y + 1 ) * ldc + _coordC.x + get_local_id(0) ] : 0, \ |
||||
(_coordC.x + get_local_id(0) < N && _coordC.y + 2 < M) ? _C[ ( _coordC.y + 2 ) * ldc + _coordC.x + get_local_id(0) ] : 0, \ |
||||
(_coordC.x + get_local_id(0) < N && _coordC.y + 3 < M) ? _C[ ( _coordC.y + 3 ) * ldc + _coordC.x + get_local_id(0) ] : 0, \ |
||||
(_coordC.x + get_local_id(0) < N && _coordC.y + 4 < M) ? _C[ ( _coordC.y + 4 ) * ldc + _coordC.x + get_local_id(0) ] : 0, \ |
||||
(_coordC.x + get_local_id(0) < N && _coordC.y + 5 < M) ? _C[ ( _coordC.y + 5 ) * ldc + _coordC.x + get_local_id(0) ] : 0, \ |
||||
(_coordC.x + get_local_id(0) < N && _coordC.y + 6 < M) ? _C[ ( _coordC.y + 6 ) * ldc + _coordC.x + get_local_id(0) ] : 0, \ |
||||
(_coordC.x + get_local_id(0) < N && _coordC.y + 7 < M) ? _C[ ( _coordC.y + 7 ) * ldc + _coordC.x + get_local_id(0) ] : 0) |
||||
|
||||
#define BLOCKC_WRITE8( _C, _coordC, _val) do {\ |
||||
if (_coordC.x + get_local_id(0) < N) { \ |
||||
if (_coordC.y < M) \ |
||||
_C[ _coordC.y * ldc + _coordC.x + get_local_id(0) ] = _val.s0; \ |
||||
if (_coordC.y + 1 < M) \ |
||||
_C[ ( _coordC.y + 1 )* ldc + _coordC.x + get_local_id(0) ] = _val.s1; \ |
||||
if (_coordC.y + 2 < M) \ |
||||
_C[ ( _coordC.y + 2 )* ldc + _coordC.x + get_local_id(0) ] = _val.s2; \ |
||||
if (_coordC.y + 3 < M) \ |
||||
_C[ ( _coordC.y + 3 )* ldc + _coordC.x + get_local_id(0) ] = _val.s3; \ |
||||
if (_coordC.y + 4 < M) \ |
||||
_C[ ( _coordC.y + 4 )* ldc + _coordC.x + get_local_id(0) ] = _val.s4; \ |
||||
if (_coordC.y + 5 < M) \ |
||||
_C[ ( _coordC.y + 5 )* ldc + _coordC.x + get_local_id(0) ] = _val.s5; \ |
||||
if (_coordC.y + 6 < M) \ |
||||
_C[ ( _coordC.y + 6 )* ldc + _coordC.x + get_local_id(0) ] = _val.s6; \ |
||||
if (_coordC.y + 7 < M) \ |
||||
_C[ ( _coordC.y + 7 )* ldc + _coordC.x + get_local_id(0) ] = _val.s7; \ |
||||
}} while(0) |
||||
#define MATC_PARAMETER __global Dtype * C, const int offC, const int M, const int N, const int ldc |
||||
#define GEMM_OUTPUT(ALPHA1, BETA_NOT0) GEMM_OUTPUT_EXT(ALPHA1, BETA_NOT0, (C + offC), (C + offC), 1) |
||||
#endif |
||||
|
||||
#define GEMM_OUTPUT_EXT(ALPHA1, BETA_NOT0, _C, _dst, _C_step) \ |
||||
int2 coordDst = (int2)( ( group_x * TILE_N ) * _C_step, ( group_y * TILE_M ) ); \ |
||||
int2 coordC = coordDst; \ |
||||
Dtype8 blockC00; \ |
||||
Dtype8 blockC01; \ |
||||
Dtype8 blockC02; \ |
||||
Dtype8 blockC03; \ |
||||
if (BETA_NOT0) { \ |
||||
blockC00 = isFirstColBlock ? BLOCKC_READ8( _C, coordC ) * beta : BLOCKC_READ8( _C, coordC ); coordC.y += 8; \ |
||||
blockC01 = isFirstColBlock ? BLOCKC_READ8( _C, coordC ) * beta : BLOCKC_READ8( _C, coordC ); coordC.y += 8; \ |
||||
blockC02 = isFirstColBlock ? BLOCKC_READ8( _C, coordC ) * beta : BLOCKC_READ8( _C, coordC ); coordC.y += 8; \ |
||||
blockC03 = isFirstColBlock ? BLOCKC_READ8( _C, coordC ) * beta : BLOCKC_READ8( _C, coordC ); \ |
||||
if (!ALPHA1) { \ |
||||
blockC00 = mad(blockAxB00, (Dtype8)alpha, blockC00); \ |
||||
blockC01 = mad(blockAxB01, (Dtype8)alpha, blockC01); \ |
||||
blockC02 = mad(blockAxB02, (Dtype8)alpha, blockC02); \ |
||||
blockC03 = mad(blockAxB03, (Dtype8)alpha, blockC03); \ |
||||
} else { \ |
||||
blockC00 += blockAxB00; \ |
||||
blockC01 += blockAxB01; \ |
||||
blockC02 += blockAxB02; \ |
||||
blockC03 += blockAxB03; \ |
||||
} \ |
||||
} else { \ |
||||
blockC00 = isFirstColBlock ? BLOCKC_READ8( _C, coordC ) * beta : BLOCKC_READ8( _C, coordC ); coordC.y += 8; \ |
||||
blockC01 = isFirstColBlock ? BLOCKC_READ8( _C, coordC ) * beta : BLOCKC_READ8( _C, coordC ); coordC.y += 8; \ |
||||
blockC02 = isFirstColBlock ? BLOCKC_READ8( _C, coordC ) * beta : BLOCKC_READ8( _C, coordC ); coordC.y += 8; \ |
||||
blockC03 = isFirstColBlock ? BLOCKC_READ8( _C, coordC ) * beta : BLOCKC_READ8( _C, coordC ); \ |
||||
if (!ALPHA1) { \ |
||||
blockC00 = mad(blockAxB00, (Dtype8)alpha, blockC00); \ |
||||
blockC01 = mad(blockAxB01, (Dtype8)alpha, blockC01); \ |
||||
blockC02 = mad(blockAxB02, (Dtype8)alpha, blockC02); \ |
||||
blockC03 = mad(blockAxB03, (Dtype8)alpha, blockC03); \ |
||||
} else { \ |
||||
blockC00 += blockAxB00; \ |
||||
blockC01 += blockAxB01; \ |
||||
blockC02 += blockAxB02; \ |
||||
blockC03 += blockAxB03; \ |
||||
} \ |
||||
} \ |
||||
BLOCKC_WRITE8( _dst, coordDst, blockC00 ); coordDst.y += 8; \ |
||||
BLOCKC_WRITE8( _dst, coordDst, blockC01 ); coordDst.y += 8; \ |
||||
BLOCKC_WRITE8( _dst, coordDst, blockC02 ); coordDst.y += 8; \ |
||||
BLOCKC_WRITE8( _dst, coordDst, blockC03 ); |
||||
|
||||
// Get the specified column of the block of the block |
||||
#define TRANSPOSE_BLOCK_8( _block, _col ) \ |
||||
(Dtype8)( intel_sub_group_shuffle( _block.s0, _col ), \ |
||||
intel_sub_group_shuffle( _block.s1, _col ), \ |
||||
intel_sub_group_shuffle( _block.s2, _col ), \ |
||||
intel_sub_group_shuffle( _block.s3, _col ), \ |
||||
intel_sub_group_shuffle( _block.s4, _col ), \ |
||||
intel_sub_group_shuffle( _block.s5, _col ), \ |
||||
intel_sub_group_shuffle( _block.s6, _col ), \ |
||||
intel_sub_group_shuffle( _block.s7, _col ) ); |
||||
|
||||
// A's column block multiply B 's row block. |
||||
#define MULTIPLY_BLOCKS_8x8( _result, _blockA, _blockB ) \ |
||||
{ \ |
||||
const Dtype8 acol0 = TRANSPOSE_BLOCK_8( _blockA, 0 ); \ |
||||
const Dtype8 acol1 = TRANSPOSE_BLOCK_8( _blockA, 1 ); \ |
||||
const Dtype8 acol2 = TRANSPOSE_BLOCK_8( _blockA, 2 ); \ |
||||
const Dtype8 acol3 = TRANSPOSE_BLOCK_8( _blockA, 3 ); \ |
||||
const Dtype8 acol4 = TRANSPOSE_BLOCK_8( _blockA, 4 ); \ |
||||
const Dtype8 acol5 = TRANSPOSE_BLOCK_8( _blockA, 5 ); \ |
||||
const Dtype8 acol6 = TRANSPOSE_BLOCK_8( _blockA, 6 ); \ |
||||
const Dtype8 acol7 = TRANSPOSE_BLOCK_8( _blockA, 7 ); \ |
||||
_result = mad( (Dtype8)(_blockB.s0), acol0, _result ); \ |
||||
_result = mad( (Dtype8)(_blockB.s1), acol1, _result ); \ |
||||
_result = mad( (Dtype8)(_blockB.s2), acol2, _result ); \ |
||||
_result = mad( (Dtype8)(_blockB.s3), acol3, _result ); \ |
||||
_result = mad( (Dtype8)(_blockB.s4), acol4, _result ); \ |
||||
_result = mad( (Dtype8)(_blockB.s5), acol5, _result ); \ |
||||
_result = mad( (Dtype8)(_blockB.s6), acol6, _result ); \ |
||||
_result = mad( (Dtype8)(_blockB.s7), acol7, _result ); \ |
||||
} |
||||
|
||||
#define GEMM_NN(ALPHA1, BETA_NOT0) \ |
||||
__attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) \ |
||||
__attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, 1, 1))) \ |
||||
__kernel void TEMPLATE(gemm_32_1_NN_ ##ALPHA1 ##_ ##BETA_NOT0, Dtype)( \ |
||||
__read_only image2d_t A, \ |
||||
__read_only image2d_t B, \ |
||||
MATC_PARAMETER, \ |
||||
KERNEL_ARG_DTYPE alpha_in, \ |
||||
KERNEL_ARG_DTYPE beta_in, \ |
||||
int width0, \ |
||||
int isFirstColBlock) \ |
||||
{ \ |
||||
const Dtype alpha = (Dtype)alpha_in; \ |
||||
const Dtype beta = (Dtype)beta_in; \ |
||||
const int group_x = get_group_id(0); \ |
||||
const int group_y = get_group_id(1); \ |
||||
Dtype8 blockAxB00 = 0.0f; \ |
||||
Dtype8 blockAxB01 = 0.0f; \ |
||||
Dtype8 blockAxB02 = 0.0f; \ |
||||
Dtype8 blockAxB03 = 0.0f; \ |
||||
int2 coordA = (int2)( 0, group_y * TILE_M ); \ |
||||
int2 coordB = (int2)( ( group_x * TILE_N ) * SIZE_OF_ELEMENT, 0 ); \ |
||||
do \ |
||||
{ \ |
||||
int2 coordBTemp = coordB; \ |
||||
Dtype8 blockB00 = as_Dtype8( SUBGROUP_BLOCK_READ8( B, coordBTemp ) ); coordB.y += TILE_K; \ |
||||
int2 coordATemp = coordA; \ |
||||
Dtype8 blockA00 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \ |
||||
Dtype8 blockA01 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \ |
||||
Dtype8 blockA02 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \ |
||||
Dtype8 blockA03 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.x += TILE_K * SIZE_OF_ELEMENT; \ |
||||
MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00 ); \ |
||||
MULTIPLY_BLOCKS_8x8( blockAxB01, blockA01, blockB00 ); \ |
||||
MULTIPLY_BLOCKS_8x8( blockAxB02, blockA02, blockB00 ); \ |
||||
MULTIPLY_BLOCKS_8x8( blockAxB03, blockA03, blockB00 ); \ |
||||
} \ |
||||
while( coordB.y < width0 ); \ |
||||
GEMM_OUTPUT(ALPHA1, BETA_NOT0); \ |
||||
} |
||||
|
||||
GEMM_NN(1, 0) // ALPHA == 1, BETA == 0 |
||||
GEMM_NN(1, 1) // ALPHA == 1, BETA != 0 |
||||
GEMM_NN(0, 0) // ALPHA != 1, BETA == 0 |
||||
GEMM_NN(0, 1) // ALPHA != 1, BETA != 0 |
||||
|
||||
#undef TRANSPOSE_BLOCK_8 |
||||
#undef MULTIPLY_BLOCKS_8x8 |
||||
#undef GEMM_NN |
||||
|
||||
// replicate the first row to column block. |
||||
#define TRANSPOSE_BLOCK_8(_vec, _col) \ |
||||
(Dtype8)( intel_sub_group_shuffle(_vec, _col + 0), \ |
||||
intel_sub_group_shuffle(_vec, _col + 1), \ |
||||
intel_sub_group_shuffle(_vec, _col + 2), \ |
||||
intel_sub_group_shuffle(_vec, _col + 3), \ |
||||
intel_sub_group_shuffle(_vec, _col + 4), \ |
||||
intel_sub_group_shuffle(_vec, _col + 5), \ |
||||
intel_sub_group_shuffle(_vec, _col + 6), \ |
||||
intel_sub_group_shuffle(_vec, _col + 7) ) |
||||
|
||||
#define MULTIPLY_BLOCKS_8x8( _result, _blockA, _blockB, _col ) \ |
||||
{ \ |
||||
_result = mad( (Dtype8)(_blockB.s0), TRANSPOSE_BLOCK_8(_blockA.s0, _col), _result ); \ |
||||
_result = mad( (Dtype8)(_blockB.s1), TRANSPOSE_BLOCK_8(_blockA.s1, _col), _result ); \ |
||||
_result = mad( (Dtype8)(_blockB.s2), TRANSPOSE_BLOCK_8(_blockA.s2, _col), _result ); \ |
||||
_result = mad( (Dtype8)(_blockB.s3), TRANSPOSE_BLOCK_8(_blockA.s3, _col), _result ); \ |
||||
_result = mad( (Dtype8)(_blockB.s4), TRANSPOSE_BLOCK_8(_blockA.s4, _col), _result ); \ |
||||
_result = mad( (Dtype8)(_blockB.s5), TRANSPOSE_BLOCK_8(_blockA.s5, _col), _result ); \ |
||||
_result = mad( (Dtype8)(_blockB.s6), TRANSPOSE_BLOCK_8(_blockA.s6, _col), _result ); \ |
||||
_result = mad( (Dtype8)(_blockB.s7), TRANSPOSE_BLOCK_8(_blockA.s7, _col), _result ); \ |
||||
} |
||||
|
||||
#define GEMM_TN(ALPHA1, BETA_NOT0) \ |
||||
__attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) \ |
||||
__attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, 1, 1))) \ |
||||
__kernel void TEMPLATE(gemm_32_1_TN_ ##ALPHA1 ##_ ##BETA_NOT0,Dtype)( \ |
||||
__read_only image2d_t A, \ |
||||
__read_only image2d_t B, \ |
||||
MATC_PARAMETER, \ |
||||
KERNEL_ARG_DTYPE alpha_in, \ |
||||
KERNEL_ARG_DTYPE beta_in, \ |
||||
int width0, \ |
||||
int isFirstColBlock) \ |
||||
{ \ |
||||
const Dtype alpha = (Dtype)alpha_in; \ |
||||
const Dtype beta = (Dtype)beta_in; \ |
||||
const int group_x = get_group_id(0);\ |
||||
const int group_y = get_group_id(1);\ |
||||
Dtype8 blockAxB00 = 0.0f;\ |
||||
Dtype8 blockAxB01 = 0.0f;\ |
||||
Dtype8 blockAxB02 = 0.0f;\ |
||||
Dtype8 blockAxB03 = 0.0f;\ |
||||
int2 coordA = (int2)( group_y * TILE_M * SIZE_OF_ELEMENT, 0 );\ |
||||
int2 coordB = (int2)( ( group_x * TILE_N ) * SIZE_OF_ELEMENT, 0 );\ |
||||
do\ |
||||
{\ |
||||
int2 coordBTemp = coordB;\ |
||||
Dtype8 blockB00 = as_Dtype8( SUBGROUP_BLOCK_READ8( B, coordBTemp ) ); coordB.y += TILE_K;\ |
||||
int2 coordATemp = coordA;\ |
||||
Dtype8 blockA00 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 8 * SIZE_OF_ELEMENT;\ |
||||
Dtype8 blockA01 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 8 * SIZE_OF_ELEMENT;\ |
||||
Dtype8 blockA02 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 8 * SIZE_OF_ELEMENT;\ |
||||
Dtype8 blockA03 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.y += TILE_K;\ |
||||
MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00, 0 ); \ |
||||
MULTIPLY_BLOCKS_8x8( blockAxB01, blockA01, blockB00, 0 ); \ |
||||
MULTIPLY_BLOCKS_8x8( blockAxB02, blockA02, blockB00, 0 ); \ |
||||
MULTIPLY_BLOCKS_8x8( blockAxB03, blockA03, blockB00, 0 ); \ |
||||
} \ |
||||
while( coordB.y < width0 ); \ |
||||
GEMM_OUTPUT(ALPHA1, BETA_NOT0); \ |
||||
} |
||||
|
||||
GEMM_TN(1, 0) // ALPHA == 1, BETA == 0 |
||||
GEMM_TN(1, 1) // ALPHA == 1, BETA != 0 |
||||
GEMM_TN(0, 0) // ALPHA != 1, BETA == 0 |
||||
GEMM_TN(0, 1) // ALPHA != 1, BETA != 0 |
||||
|
||||
#undef MULTIPLY_BLOCKS_8x8 |
||||
#undef TRANSPOSE_BLOCK_8 |
||||
#undef GEMM_TN |
||||
|
||||
// The same as GEMM_NN |
||||
#define TRANSPOSE_BLOCK_8( _block, _col ) \ |
||||
(Dtype8)( intel_sub_group_shuffle( _block.s0, _col), \ |
||||
intel_sub_group_shuffle( _block.s1, _col), \ |
||||
intel_sub_group_shuffle( _block.s2, _col), \ |
||||
intel_sub_group_shuffle( _block.s3, _col), \ |
||||
intel_sub_group_shuffle( _block.s4, _col), \ |
||||
intel_sub_group_shuffle( _block.s5, _col), \ |
||||
intel_sub_group_shuffle( _block.s6, _col), \ |
||||
intel_sub_group_shuffle( _block.s7, _col) ) |
||||
|
||||
#define MULTIPLY_BLOCKS_8x8( _result, _blockA, _blockB ) \ |
||||
{ \ |
||||
const Dtype8 acol0 = TRANSPOSE_BLOCK_8( _blockA, 0 ); \ |
||||
const Dtype8 acol1 = TRANSPOSE_BLOCK_8( _blockA, 1 ); \ |
||||
const Dtype8 acol2 = TRANSPOSE_BLOCK_8( _blockA, 2 ); \ |
||||
const Dtype8 acol3 = TRANSPOSE_BLOCK_8( _blockA, 3 ); \ |
||||
const Dtype8 acol4 = TRANSPOSE_BLOCK_8( _blockA, 4 ); \ |
||||
const Dtype8 acol5 = TRANSPOSE_BLOCK_8( _blockA, 5 ); \ |
||||
const Dtype8 acol6 = TRANSPOSE_BLOCK_8( _blockA, 6 ); \ |
||||
const Dtype8 acol7 = TRANSPOSE_BLOCK_8( _blockA, 7 ); \ |
||||
_result = mad( (Dtype8)_blockB.s0, acol0, _result ); \ |
||||
_result = mad( (Dtype8)_blockB.s1, acol1, _result ); \ |
||||
_result = mad( (Dtype8)_blockB.s2, acol2, _result ); \ |
||||
_result = mad( (Dtype8)_blockB.s3, acol3, _result ); \ |
||||
_result = mad( (Dtype8)_blockB.s4, acol4, _result ); \ |
||||
_result = mad( (Dtype8)_blockB.s5, acol5, _result ); \ |
||||
_result = mad( (Dtype8)_blockB.s6, acol6, _result ); \ |
||||
_result = mad( (Dtype8)_blockB.s7, acol7, _result ); \ |
||||
} |
||||
|
||||
#define GEMM_NT(ALPHA1, BETA_NOT0, VECSCALAR, VECSIZE) \ |
||||
__attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) \ |
||||
__attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, 1, 1))) \ |
||||
__kernel void TEMPLATE(gemm_32_1_NT_ ##VECSCALAR ##_ ##ALPHA1 ##_ ##BETA_NOT0,Dtype)( \ |
||||
__read_only image2d_t A, \ |
||||
MATB_PARAMETER, \ |
||||
MATC_PARAMETER, \ |
||||
KERNEL_ARG_DTYPE alpha_in, \ |
||||
KERNEL_ARG_DTYPE beta_in, \ |
||||
int padded_k, \ |
||||
int k, \ |
||||
int isFirstColBlock) \ |
||||
{ \ |
||||
const Dtype alpha = (Dtype)alpha_in; \ |
||||
const Dtype beta = (Dtype)beta_in; \ |
||||
const int group_x = get_group_id(0); \ |
||||
const int group_y = get_group_id(1); \ |
||||
Dtype8 blockAxB00 = 0.0f; \ |
||||
Dtype8 blockAxB01 = 0.0f; \ |
||||
Dtype8 blockAxB02 = 0.0f; \ |
||||
Dtype8 blockAxB03 = 0.0f; \ |
||||
int2 coordA = (int2)( 0, group_y * TILE_M ); \ |
||||
int2 coordB = (int2)( 0, ( group_x * TILE_N )); \ |
||||
const sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; \ |
||||
do \ |
||||
{ \ |
||||
Dtype8 blockB00; \ |
||||
BLOCKB_READ8(blockB00, B, coordB); \ |
||||
int2 coordATemp = coordA; \ |
||||
Dtype8 blockA00 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \ |
||||
Dtype8 blockA01 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \ |
||||
Dtype8 blockA02 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \ |
||||
Dtype8 blockA03 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.x += TILE_K * SIZE_OF_ELEMENT; \ |
||||
MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00 ); \ |
||||
MULTIPLY_BLOCKS_8x8( blockAxB01, blockA01, blockB00 ); \ |
||||
MULTIPLY_BLOCKS_8x8( blockAxB02, blockA02, blockB00 ); \ |
||||
MULTIPLY_BLOCKS_8x8( blockAxB03, blockA03, blockB00 ); \ |
||||
} \ |
||||
while( coordB.x < padded_k / VECSIZE ); \ |
||||
GEMM_OUTPUT(ALPHA1, BETA_NOT0); \ |
||||
} |
||||
|
||||
#define BLOCKB_READ8(_blockb, _B, _coordB) \ |
||||
int2 _coordBTemp = _coordB; \ |
||||
_coordBTemp.y += get_local_id(0); \ |
||||
_blockb.s0123 = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \ |
||||
_blockb.s4567 = READ_IMAGE(_B, _coordBTemp); _coordB.x += 2; |
||||
|
||||
#define MATB_PARAMETER __read_only image2d_t B |
||||
|
||||
GEMM_NT(1, 0, VEC4, 4) // ALPHA == 1, BETA == 0 |
||||
GEMM_NT(1, 1, VEC4, 4) // ALPHA == 1, BETA != 0 |
||||
GEMM_NT(0, 0, VEC4, 4) // ALPHA != 1, BETA == 0 |
||||
GEMM_NT(0, 1, VEC4, 4) // ALPHA != 1, BETA != 0 |
||||
#undef BLOCKB_READ8 |
||||
#undef MATB_PARAMETER |
||||
|
||||
#define BLOCKB_READ8(_blockb, _B, _coordB) \ |
||||
int2 _coordBTemp = _coordB; \ |
||||
_coordBTemp.y += get_local_id(0); \ |
||||
const __global Dtype *B_read = (__global Dtype *)(_B + (_coordBTemp.y * ldb) + _coordBTemp.x + offB); \ |
||||
_blockb = vload8(0, B_read); \ |
||||
_coordB.x += TILE_K; |
||||
|
||||
#define MATB_PARAMETER __global Dtype *B, int offB, int ldb |
||||
|
||||
GEMM_NT(1, 0, BUFFER, 1) // ALPHA == 1, BETA == 0 |
||||
GEMM_NT(1, 1, BUFFER, 1) // ALPHA == 1, BETA != 0 |
||||
GEMM_NT(0, 0, BUFFER, 1) // ALPHA != 1, BETA == 0 |
||||
GEMM_NT(0, 1, BUFFER, 1) // ALPHA != 1, BETA != 0 |
||||
#undef BLOCKB_READ8 |
||||
#undef MATB_PARAMETER |
||||
|
||||
#define BLOCKB_READ8(_blockb, _B, _coordB) \ |
||||
int2 _coordBTemp = _coordB; \ |
||||
_coordBTemp.y += get_local_id(0); \ |
||||
Dtype4 temp; \ |
||||
temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \ |
||||
_blockb.s0 = temp.s0; \ |
||||
temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \ |
||||
_blockb.s1 = temp.s0; \ |
||||
temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \ |
||||
_blockb.s2 = temp.s0; \ |
||||
temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \ |
||||
_blockb.s3 = temp.s0; \ |
||||
temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \ |
||||
_blockb.s4 = temp.s0; \ |
||||
temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \ |
||||
_blockb.s5 = temp.s0; \ |
||||
temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \ |
||||
_blockb.s6 = temp.s0; \ |
||||
temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \ |
||||
_blockb.s7 = temp.s0; \ |
||||
_coordB.x += 8; |
||||
|
||||
#define MATB_PARAMETER __read_only image2d_t B |
||||
|
||||
GEMM_NT(1, 0, SCALAR, 1) // ALPHA == 1, BETA == 0 |
||||
GEMM_NT(1, 1, SCALAR, 1) // ALPHA == 1, BETA != 0 |
||||
GEMM_NT(0, 0, SCALAR, 1) // ALPHA != 1, BETA == 0 |
||||
GEMM_NT(0, 1, SCALAR, 1) // ALPHA != 1, BETA != 0 |
||||
#undef BLOCKB_READ8 |
||||
#undef MATB_PARAMETER |
||||
|
||||
#undef MULTIPLY_BLOCKS_8x8 |
||||
#undef TRANSPOSE_BLOCK_8 |
||||
#undef GEMM_NT |
||||
|
||||
//The same as GEMM_TN. |
||||
#define TRANSPOSE_BLOCK_8(_vec, _col) \ |
||||
(Dtype8)( intel_sub_group_shuffle(_vec, _col + 0), \ |
||||
intel_sub_group_shuffle(_vec, _col + 1), \ |
||||
intel_sub_group_shuffle(_vec, _col + 2), \ |
||||
intel_sub_group_shuffle(_vec, _col + 3), \ |
||||
intel_sub_group_shuffle(_vec, _col + 4), \ |
||||
intel_sub_group_shuffle(_vec, _col + 5), \ |
||||
intel_sub_group_shuffle(_vec, _col + 6), \ |
||||
intel_sub_group_shuffle(_vec, _col + 7) ); |
||||
|
||||
#define MULTIPLY_BLOCKS_8x8( _result, _blockA, _blockB, _col ) \ |
||||
{ \ |
||||
const Dtype8 acol0 = TRANSPOSE_BLOCK_8( _blockA.s0, _col ); \ |
||||
const Dtype8 acol1 = TRANSPOSE_BLOCK_8( _blockA.s1, _col ); \ |
||||
const Dtype8 acol2 = TRANSPOSE_BLOCK_8( _blockA.s2, _col ); \ |
||||
const Dtype8 acol3 = TRANSPOSE_BLOCK_8( _blockA.s3, _col ); \ |
||||
const Dtype8 acol4 = TRANSPOSE_BLOCK_8( _blockA.s4, _col ); \ |
||||
const Dtype8 acol5 = TRANSPOSE_BLOCK_8( _blockA.s5, _col ); \ |
||||
const Dtype8 acol6 = TRANSPOSE_BLOCK_8( _blockA.s6, _col ); \ |
||||
const Dtype8 acol7 = TRANSPOSE_BLOCK_8( _blockA.s7, _col ); \ |
||||
_result = mad( (Dtype8)_blockB.s0, acol0, _result ); \ |
||||
_result = mad( (Dtype8)_blockB.s1, acol1, _result ); \ |
||||
_result = mad( (Dtype8)_blockB.s2, acol2, _result ); \ |
||||
_result = mad( (Dtype8)_blockB.s3, acol3, _result ); \ |
||||
_result = mad( (Dtype8)_blockB.s4, acol4, _result ); \ |
||||
_result = mad( (Dtype8)_blockB.s5, acol5, _result ); \ |
||||
_result = mad( (Dtype8)_blockB.s6, acol6, _result ); \ |
||||
_result = mad( (Dtype8)_blockB.s7, acol7, _result ); \ |
||||
} |
||||
|
||||
#define GEMM_TT(ALPHA1, BETA_NOT0, VECSCALAR, VECSIZE) \ |
||||
__attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) \ |
||||
__attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, 1, 1))) \ |
||||
__kernel void TEMPLATE(gemm_32_1_TT_ ##VECSCALAR ##_ ##ALPHA1 ##_ ##BETA_NOT0, Dtype)( \ |
||||
__read_only image2d_t A, \ |
||||
MATB_PARAMETER, \ |
||||
MATC_PARAMETER, \ |
||||
KERNEL_ARG_DTYPE alpha_in, \ |
||||
KERNEL_ARG_DTYPE beta_in, \ |
||||
int padded_k, \ |
||||
int k, \ |
||||
int isFirstColBlock) \ |
||||
{ \ |
||||
const Dtype alpha = (Dtype)alpha_in; \ |
||||
const Dtype beta = (Dtype)beta_in; \ |
||||
const int group_x = get_group_id(0); \ |
||||
const int group_y = get_group_id(1); \ |
||||
Dtype8 blockAxB00 = 0.0f; \ |
||||
Dtype8 blockAxB01 = 0.0f; \ |
||||
Dtype8 blockAxB02 = 0.0f; \ |
||||
Dtype8 blockAxB03 = 0.0f; \ |
||||
int2 coordA = (int2)( group_y * TILE_M * SIZE_OF_ELEMENT, 0 ); \ |
||||
int2 coordB = (int2)( 0, ( group_x * TILE_N )); \ |
||||
const sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; \ |
||||
do \ |
||||
{ \ |
||||
Dtype8 blockB00; \ |
||||
BLOCKB_READ8(blockB00, B, coordB); \ |
||||
int2 coordATemp = coordA; \ |
||||
Dtype8 blockA00 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 8 * SIZE_OF_ELEMENT; \ |
||||
Dtype8 blockA01 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 8 * SIZE_OF_ELEMENT; \ |
||||
Dtype8 blockA02 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 8 * SIZE_OF_ELEMENT; \ |
||||
Dtype8 blockA03 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.y += TILE_K; \ |
||||
MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00 , blockB00, 0 ); \ |
||||
MULTIPLY_BLOCKS_8x8( blockAxB01, blockA01 , blockB00, 0 ); \ |
||||
MULTIPLY_BLOCKS_8x8( blockAxB02, blockA02 , blockB00, 0 ); \ |
||||
MULTIPLY_BLOCKS_8x8( blockAxB03, blockA03 , blockB00, 0 ); \ |
||||
} \ |
||||
while( coordB.x < padded_k / VECSIZE ); \ |
||||
GEMM_OUTPUT(ALPHA1, BETA_NOT0);\ |
||||
} |
||||
|
||||
#define BLOCKB_READ8(_blockb, _B, _coordB) \ |
||||
int2 _coordBTemp = _coordB; \ |
||||
_coordBTemp.y += get_local_id(0); \ |
||||
_blockb.s0123 = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \ |
||||
_blockb.s4567 = READ_IMAGE(_B, _coordBTemp); _coordB.x += 2; |
||||
|
||||
#define MATB_PARAMETER __read_only image2d_t B |
||||
|
||||
GEMM_TT(1, 0, VEC4, 4) // ALPHA == 1, BETA == 0 |
||||
GEMM_TT(1, 1, VEC4, 4) // ALPHA == 1, BETA != 0 |
||||
GEMM_TT(0, 0, VEC4, 4) // ALPHA != 1, BETA == 0 |
||||
GEMM_TT(0, 1, VEC4, 4) // ALPHA != 1, BETA != 0 |
||||
#undef BLOCKB_READ8 |
||||
#undef MATB_PARAMETER |
||||
|
||||
#define BLOCKB_READ8(_blockb, _B, _coordB) \ |
||||
int2 _coordBTemp = _coordB; \ |
||||
_coordBTemp.y += get_local_id(0); \ |
||||
const __global Dtype *B_read = (__global Dtype *)(_B + (_coordBTemp.y * k) + _coordBTemp.x + offB); \ |
||||
_blockb = vload8(0, B_read); \ |
||||
_coordB.x += TILE_K; |
||||
|
||||
#define MATB_PARAMETER __global Dtype *B, int offB, int ldb |
||||
|
||||
GEMM_TT(1, 0, BUFFER, 1) // ALPHA == 1, BETA == 0 |
||||
GEMM_TT(1, 1, BUFFER, 1) // ALPHA == 1, BETA != 0 |
||||
GEMM_TT(0, 0, BUFFER, 1) // ALPHA != 1, BETA == 0 |
||||
GEMM_TT(0, 1, BUFFER, 1) // ALPHA != 1, BETA != 0 |
||||
#undef BLOCKB_READ8 |
||||
#undef MATB_PARAMETER |
||||
|
||||
#define BLOCKB_READ8(_blockb, _B, _coordB) \ |
||||
int2 _coordBTemp = _coordB; \ |
||||
_coordBTemp.y += get_local_id(0); \ |
||||
Dtype4 temp; \ |
||||
temp = READ_IMAGE(B, _coordBTemp); _coordBTemp.x += 1; \ |
||||
_blockb.s0 = temp.s0; \ |
||||
temp = READ_IMAGE(B, _coordBTemp); _coordBTemp.x += 1; \ |
||||
_blockb.s1 = temp.s0; \ |
||||
temp = READ_IMAGE(B, _coordBTemp); _coordBTemp.x += 1; \ |
||||
_blockb.s2 = temp.s0; \ |
||||
temp = READ_IMAGE(B, _coordBTemp); _coordBTemp.x += 1; \ |
||||
_blockb.s3 = temp.s0; \ |
||||
temp = READ_IMAGE(B, _coordBTemp); _coordBTemp.x += 1; \ |
||||
_blockb.s4 = temp.s0; \ |
||||
temp = READ_IMAGE(B, _coordBTemp); _coordBTemp.x += 1; \ |
||||
_blockb.s5 = temp.s0; \ |
||||
temp = READ_IMAGE(B, _coordBTemp); _coordBTemp.x += 1; \ |
||||
_blockb.s6 = temp.s0; \ |
||||
temp = READ_IMAGE(B, _coordBTemp); _coordBTemp.x += 1; \ |
||||
_blockb.s7 = temp.s0; \ |
||||
_coordB.x += 8; |
||||
|
||||
#define MATB_PARAMETER __read_only image2d_t B |
||||
|
||||
GEMM_TT(1, 0, SCALAR, 1) // ALPHA == 1, BETA == 0 |
||||
GEMM_TT(1, 1, SCALAR, 1) // ALPHA == 1, BETA != 0 |
||||
GEMM_TT(0, 0, SCALAR, 1) // ALPHA != 1, BETA == 0 |
||||
GEMM_TT(0, 1, SCALAR, 1) // ALPHA != 1, BETA != 0 |
||||
#undef BLOCKB_READ8 |
||||
#undef MATB_PARAMETER |
||||
|
||||
#undef MULTIPLY_BLOCKS_8x8 |
||||
#undef TRANSPOSE_BLOCK_8 |
||||
#undef GEMM_TT |
||||
|
||||
#undef TILE_M |
||||
#undef TILE_K |
||||
#undef TILE_N |
||||
#undef SUBGROUP_BLOCK_READ8 |
||||
#undef READ_IMAGE |
||||
#undef SIZE_OF_ELEMENT |
||||
|
||||
__kernel void TEMPLATE(gemm_buffer_copy_image_transpose,Dtype)( |
||||
__global Dtype* A, |
||||
__write_only image2d_t ImA, |
||||
int offA, |
||||
int width, |
||||
int height, |
||||
int ldA) |
||||
{ |
||||
const int gidx = get_global_id(0); |
||||
const int gidy = get_global_id(1); |
||||
int2 coord_dst = (int2)(gidx, gidy); |
||||
__global Dtype* A_off = A + offA; |
||||
Dtype srcA = A_off[gidy * ldA + gidx]; |
||||
write_imagef(ImA, coord_dst, (Dtype4)srcA); |
||||
} |
||||
|
||||
__kernel void TEMPLATE(gemm_buffer_copy_image_no_transpose,Dtype)( |
||||
__global Dtype* A, |
||||
__write_only image2d_t ImA, |
||||
int offA, |
||||
int width, |
||||
int height, |
||||
int ldA) |
||||
{ |
||||
const int gidx = get_global_id(0); |
||||
const int gidy = get_global_id(1); |
||||
int2 coord_dst = (int2)(gidx, gidy); |
||||
if (gidx >= width || gidy >= height) { |
||||
write_imageui(ImA, coord_dst, (uint4)0); |
||||
return; |
||||
} |
||||
__global Dtype* A_off = A + offA; |
||||
uint4 srcA = convert_uint4(as_uchar4(A_off[gidy * ldA + gidx])); |
||||
write_imageui(ImA, coord_dst, srcA); |
||||
} |
@ -0,0 +1,55 @@ |
||||
/*M/////////////////////////////////////////////////////////////////////////////////////// |
||||
// |
||||
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING. |
||||
// |
||||
// By downloading, copying, installing or using the software you agree to this license. |
||||
// If you do not agree to this license, do not download, install, |
||||
// copy or use the software. |
||||
// |
||||
// |
||||
// License Agreement |
||||
// For Open Source Computer Vision Library |
||||
// |
||||
// Copyright (C) 2017, Intel Corporation, all rights reserved. |
||||
// Copyright (c) 2016-2017 Fabian David Tschopp, all rights reserved. |
||||
// Third party copyrights are property of their respective owners. |
||||
// |
||||
// Redistribution and use in source and binary forms, with or without modification, |
||||
// are permitted provided that the following conditions are met: |
||||
// |
||||
// * Redistribution's of source code must retain the above copyright notice, |
||||
// this list of conditions and the following disclaimer. |
||||
// |
||||
// * Redistribution's in binary form must reproduce the above copyright notice, |
||||
// this list of conditions and the following disclaimer in the documentation |
||||
// and/or other materials provided with the distribution. |
||||
// |
||||
// * The name of the copyright holders may not be used to endorse or promote products |
||||
// derived from this software without specific prior written permission. |
||||
// |
||||
// This software is provided by the copyright holders and contributors "as is" and |
||||
// any express or implied warranties, including, but not limited to, the implied |
||||
// warranties of merchantability and fitness for a particular purpose are disclaimed. |
||||
// In no event shall the Intel Corporation or contributors be liable for any direct, |
||||
// indirect, incidental, special, exemplary, or consequential damages |
||||
// (including, but not limited to, procurement of substitute goods or services; |
||||
// loss of use, data, or profits; or business interruption) however caused |
||||
// and on any theory of liability, whether in contract, strict liability, |
||||
// or tort (including negligence or otherwise) arising in any way out of |
||||
// the use of this software, even if advised of the possibility of such damage. |
||||
// |
||||
//M*/ |
||||
|
||||
#define CONCAT(A,B) A##_##B |
||||
#define TEMPLATE(name,type) CONCAT(name,type) |
||||
#define Dtype float |
||||
|
||||
__kernel void TEMPLATE(axpy,Dtype)(const int n, const Dtype alpha, __global const Dtype* x, |
||||
const int offx, __global Dtype* y, |
||||
const int offy) { |
||||
for (int index = get_global_id(0); index < n; index += get_global_size(0)) { |
||||
Dtype src = x[offx + index]; |
||||
Dtype dst = y[offy + index]; |
||||
y[offy + index] = alpha * src + dst; |
||||
} |
||||
} |
@ -0,0 +1,191 @@ |
||||
/*M/////////////////////////////////////////////////////////////////////////////////////// |
||||
// |
||||
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING. |
||||
// |
||||
// By downloading, copying, installing or using the software you agree to this license. |
||||
// If you do not agree to this license, do not download, install, |
||||
// copy or use the software. |
||||
// |
||||
// |
||||
// License Agreement |
||||
// For Open Source Computer Vision Library |
||||
// |
||||
// Copyright (C) 2017, Intel Corporation, all rights reserved. |
||||
// Third party copyrights are property of their respective owners. |
||||
// |
||||
// Redistribution and use in source and binary forms, with or without modification, |
||||
// are permitted provided that the following conditions are met: |
||||
// |
||||
// * Redistribution's of source code must retain the above copyright notice, |
||||
// this list of conditions and the following disclaimer. |
||||
// |
||||
// * Redistribution's in binary form must reproduce the above copyright notice, |
||||
// this list of conditions and the following disclaimer in the documentation |
||||
// and/or other materials provided with the distribution. |
||||
// |
||||
// * The name of the copyright holders may not be used to endorse or promote products |
||||
// derived from this software without specific prior written permission. |
||||
// |
||||
// This software is provided by the copyright holders and contributors "as is" and |
||||
// any express or implied warranties, including, but not limited to, the implied |
||||
// warranties of merchantability and fitness for a particular purpose are disclaimed. |
||||
// In no event shall the Intel Corporation or contributors be liable for any direct, |
||||
// indirect, incidental, special, exemplary, or consequential damages |
||||
// (including, but not limited to, procurement of substitute goods or services; |
||||
// loss of use, data, or profits; or business interruption) however caused |
||||
// and on any theory of liability, whether in contract, strict liability, |
||||
// or tort (including negligence or otherwise) arising in any way out of |
||||
// the use of this software, even if advised of the possibility of such damage. |
||||
// |
||||
//M*/ |
||||
|
||||
#define CONCAT(A,B) A##_##B |
||||
#define TEMPLATE(name,type) CONCAT(name,type) |
||||
#define Dtype float |
||||
|
||||
__kernel void TEMPLATE(matvec_mul4,Dtype)( |
||||
__global const float * A, |
||||
int offA, |
||||
unsigned int A_col_size, |
||||
unsigned int trail_item, |
||||
__global const float * v, |
||||
int offv, |
||||
float alpha, |
||||
float beta, |
||||
__global float4 * result, |
||||
int offr, |
||||
__local float4 * work) |
||||
{ |
||||
unsigned int row_gid = get_group_id(0); |
||||
unsigned int lid = get_local_id(0); |
||||
const __global float *src0_read = A + row_gid * 4 * A_col_size + offA; |
||||
const __global float *src1_read = v + offv; |
||||
result = (__global float4*)((__global float*)result + offr); |
||||
float4 dot0 = (float4)(0.f); |
||||
float4 dot1 = (float4)(0.f); |
||||
float4 dot2 = (float4)(0.f); |
||||
float4 dot3 = (float4)(0.f); |
||||
|
||||
unsigned int i = lid; |
||||
while( i < A_col_size / 4) { |
||||
const float4 a0 = vload4(i, src0_read); |
||||
const float4 a1 = vload4(i, src0_read + A_col_size); |
||||
const float4 a2 = vload4(i, src0_read + 2 * A_col_size); |
||||
const float4 a3 = vload4(i, src0_read + 3 * A_col_size); |
||||
|
||||
const float4 b0 = vload4(i, src1_read); |
||||
|
||||
dot0 += a0 * b0; |
||||
dot1 += a1 * b0; |
||||
dot2 += a2 * b0; |
||||
dot3 += a3 * b0; |
||||
|
||||
i += get_local_size(0); |
||||
} |
||||
|
||||
work[lid].s0 = dot0.x + dot0.y + dot0.z + dot0.w; |
||||
work[lid].s1 = dot1.x + dot1.y + dot1.z + dot1.w; |
||||
work[lid].s2 = dot2.x + dot2.y + dot2.z + dot2.w; |
||||
work[lid].s3 = dot3.x + dot3.y + dot3.z + dot3.w; |
||||
|
||||
if(i == A_col_size / 4) |
||||
{ |
||||
if(trail_item != 0) |
||||
{ |
||||
const __global float *src0_trail = src0_read + i * 4; |
||||
const __global float *src1_trail = src1_read + i * 4; |
||||
for(unsigned int i = 0; i < trail_item; ++i) { |
||||
const float at0 = src0_trail[i]; |
||||
const float at1 = src0_trail[i + A_col_size]; |
||||
const float at2 = src0_trail[i + 2 * A_col_size]; |
||||
const float at3 = src0_trail[i + 3 * A_col_size]; |
||||
|
||||
const float bt = src1_trail[i]; |
||||
|
||||
work[lid].s0 += at0 * bt; |
||||
work[lid].s1 += at1 * bt; |
||||
work[lid].s2 += at2 * bt; |
||||
work[lid].s3 += at3 * bt; |
||||
} |
||||
} |
||||
|
||||
} |
||||
|
||||
for(unsigned int stride=get_local_size(0)/2 ; stride>0 ; stride>>=1) { |
||||
barrier(CLK_LOCAL_MEM_FENCE); |
||||
if(lid < stride) |
||||
work[lid] += work[lid+stride]; |
||||
} |
||||
if(lid == 0) { |
||||
if(beta == (Dtype)0) |
||||
result[row_gid] = alpha * work[0]; |
||||
else |
||||
result[row_gid] = alpha * work[0] + beta * result[row_gid]; |
||||
} |
||||
} |
||||
|
||||
/* This kernel used for the trailing rows when row_of_A %4 !=0 */ |
||||
__kernel void TEMPLATE(matvec_mul1,Dtype)( |
||||
__global const float * A, |
||||
int offA, |
||||
unsigned int A_col_size, |
||||
unsigned int row_offset, |
||||
unsigned int trail_item, |
||||
__global const float * v, |
||||
int offv, |
||||
float alpha, |
||||
float beta, |
||||
__global float * result, |
||||
int offr, |
||||
__local float * work) |
||||
{ |
||||
unsigned int row_gid = get_group_id(0); |
||||
unsigned int lid = get_local_id(0); |
||||
|
||||
const __global float *src0_read = A + (row_offset + row_gid) * A_col_size + offA; |
||||
const __global float *src1_read = v + + offv; |
||||
result = result + offr; |
||||
float4 dot0 = (float4)(0.f); |
||||
|
||||
unsigned int i = lid; |
||||
while( i < A_col_size / 4) |
||||
{ |
||||
const float4 a0 = vload4(i, src0_read); |
||||
const float4 b0 = vload4(i, src1_read); |
||||
|
||||
dot0 += a0 * b0; |
||||
i += get_local_size(0); |
||||
} |
||||
|
||||
work[lid] = dot0.x + dot0.y + dot0.z + dot0.w; |
||||
|
||||
if(i == A_col_size / 4) |
||||
{ |
||||
if(trail_item != 0) |
||||
{ |
||||
const __global float *src0_trail = src0_read + i * 4; |
||||
const __global float *src1_trail = src1_read + i * 4; |
||||
for(unsigned int i = 0; i < trail_item; ++i) { |
||||
const float at0 = src0_trail[i]; |
||||
const float bt = src1_trail[i]; |
||||
|
||||
work[lid] += at0 * bt; |
||||
} |
||||
} |
||||
|
||||
} |
||||
for(unsigned int stride=get_local_size(0)/2 ; stride>0 ; stride>>=1) { |
||||
barrier(CLK_LOCAL_MEM_FENCE); |
||||
if(lid < stride) |
||||
work[lid] += work[lid+stride]; |
||||
} |
||||
|
||||
if(lid == 0) { |
||||
if(beta == (Dtype)0) { |
||||
result[row_gid+row_offset] = alpha * work[0]; |
||||
} else { |
||||
result[row_gid+row_offset] *= beta; |
||||
result[row_gid+row_offset] += alpha * work[0]; |
||||
} |
||||
} |
||||
} |
@ -0,0 +1,96 @@ |
||||
/*M/////////////////////////////////////////////////////////////////////////////////////// |
||||
// |
||||
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING. |
||||
// |
||||
// By downloading, copying, installing or using the software you agree to this license. |
||||
// If you do not agree to this license, do not download, install, |
||||
// copy or use the software. |
||||
// |
||||
// |
||||
// License Agreement |
||||
// For Open Source Computer Vision Library |
||||
// |
||||
// Copyright (C) 2017, Intel Corporation, all rights reserved. |
||||
// Copyright (c) 2016-2017 Fabian David Tschopp, all rights reserved. |
||||
// Third party copyrights are property of their respective owners. |
||||
// |
||||
// Redistribution and use in source and binary forms, with or without modification, |
||||
// are permitted provided that the following conditions are met: |
||||
// |
||||
// * Redistribution's of source code must retain the above copyright notice, |
||||
// this list of conditions and the following disclaimer. |
||||
// |
||||
// * Redistribution's in binary form must reproduce the above copyright notice, |
||||
// this list of conditions and the following disclaimer in the documentation |
||||
// and/or other materials provided with the distribution. |
||||
// |
||||
// * The name of the copyright holders may not be used to endorse or promote products |
||||
// derived from this software without specific prior written permission. |
||||
// |
||||
// This software is provided by the copyright holders and contributors "as is" and |
||||
// any express or implied warranties, including, but not limited to, the implied |
||||
// warranties of merchantability and fitness for a particular purpose are disclaimed. |
||||
// In no event shall the Intel Corporation or contributors be liable for any direct, |
||||
// indirect, incidental, special, exemplary, or consequential damages |
||||
// (including, but not limited to, procurement of substitute goods or services; |
||||
// loss of use, data, or profits; or business interruption) however caused |
||||
// and on any theory of liability, whether in contract, strict liability, |
||||
// or tort (including negligence or otherwise) arising in any way out of |
||||
// the use of this software, even if advised of the possibility of such damage. |
||||
// |
||||
//M*/ |
||||
|
||||
#define CONCAT(A,B) A##_##B |
||||
#define TEMPLATE(name,type) CONCAT(name,type) |
||||
#define Dtype float |
||||
|
||||
__kernel void TEMPLATE(lrn_full_no_scale,Dtype)(const int nthreads, __global const Dtype* in, |
||||
const int num, const int channels, |
||||
const int height, const int width, const int size, |
||||
const Dtype alpha_over_size, const Dtype k, |
||||
__global Dtype* const out, |
||||
const Dtype negative_beta) { |
||||
for (int index = get_global_id(0); index < nthreads; |
||||
index += get_global_size(0)) { |
||||
// find out the local offset |
||||
const int w = index % width; |
||||
const int h = (index / width) % height; |
||||
const int n = index / width / height; |
||||
const int offset = (n * channels * height + h) * width + w; |
||||
const int step = height * width; |
||||
__global const Dtype* in_off = in + offset; |
||||
__global Dtype* out_off = out + offset; |
||||
Dtype scale_val; |
||||
int head = 0; |
||||
const int pre_pad = (size - 1) / 2; |
||||
const int post_pad = size - pre_pad - 1; |
||||
Dtype accum_scale = 0; |
||||
// fill the scale at [n, :, h, w] |
||||
// accumulate values |
||||
while (head < post_pad && head < channels) { |
||||
accum_scale += in_off[head * step] * in_off[head * step]; |
||||
++head; |
||||
} |
||||
// both add and subtract |
||||
while (head < channels) { |
||||
accum_scale += in_off[head * step] * in_off[head * step]; |
||||
if (head - size >= 0) { |
||||
accum_scale -= in_off[(head - size) * step] |
||||
* in_off[(head - size) * step]; |
||||
} |
||||
scale_val = k + accum_scale * alpha_over_size; |
||||
out_off[(head - post_pad) * step] = in_off[(head - post_pad) * step] * (Dtype)native_powr((float)scale_val, (float)negative_beta); |
||||
++head; |
||||
} |
||||
// subtract only |
||||
while (head < channels + post_pad) { |
||||
if (head - size >= 0) { |
||||
accum_scale -= in_off[(head - size) * step] |
||||
* in_off[(head - size) * step]; |
||||
} |
||||
scale_val = k + accum_scale * alpha_over_size; |
||||
out_off[(head - post_pad) * step] = in_off[(head - post_pad) * step] * (Dtype)native_powr((float)scale_val, (float)negative_beta); |
||||
++head; |
||||
} |
||||
} |
||||
} |
@ -0,0 +1,177 @@ |
||||
/*M/////////////////////////////////////////////////////////////////////////////////////// |
||||
// |
||||
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING. |
||||
// |
||||
// By downloading, copying, installing or using the software you agree to this license. |
||||
// If you do not agree to this license, do not download, install, |
||||
// copy or use the software. |
||||
// |
||||
// |
||||
// License Agreement |
||||
// For Open Source Computer Vision Library |
||||
// |
||||
// Copyright (C) 2017, Intel Corporation, all rights reserved. |
||||
// Copyright (c) 2016-2017 Fabian David Tschopp, all rights reserved. |
||||
// Third party copyrights are property of their respective owners. |
||||
// |
||||
// Redistribution and use in source and binary forms, with or without modification, |
||||
// are permitted provided that the following conditions are met: |
||||
// |
||||
// * Redistribution's of source code must retain the above copyright notice, |
||||
// this list of conditions and the following disclaimer. |
||||
// |
||||
// * Redistribution's in binary form must reproduce the above copyright notice, |
||||
// this list of conditions and the following disclaimer in the documentation |
||||
// and/or other materials provided with the distribution. |
||||
// |
||||
// * The name of the copyright holders may not be used to endorse or promote products |
||||
// derived from this software without specific prior written permission. |
||||
// |
||||
// This software is provided by the copyright holders and contributors "as is" and |
||||
// any express or implied warranties, including, but not limited to, the implied |
||||
// warranties of merchantability and fitness for a particular purpose are disclaimed. |
||||
// In no event shall the Intel Corporation or contributors be liable for any direct, |
||||
// indirect, incidental, special, exemplary, or consequential damages |
||||
// (including, but not limited to, procurement of substitute goods or services; |
||||
// loss of use, data, or profits; or business interruption) however caused |
||||
// and on any theory of liability, whether in contract, strict liability, |
||||
// or tort (including negligence or otherwise) arising in any way out of |
||||
// the use of this software, even if advised of the possibility of such damage. |
||||
// |
||||
//M*/ |
||||
|
||||
#define CONCAT(A,B) A##_##B |
||||
#define TEMPLATE(name,type) CONCAT(name,type) |
||||
#define Dtype float |
||||
|
||||
void TEMPLATE(max_pool_forward_impl, Dtype)( |
||||
const int nthreads, __global const Dtype* bottom_data, const int num, |
||||
const int channels, const int height, const int width, |
||||
const int pooled_height, const int pooled_width, const int kernel_h, |
||||
const int kernel_w, const int stride_h, const int stride_w, const int pad_h, |
||||
const int pad_w, |
||||
__global Dtype* top_data, |
||||
const int use_mask, __global int* mask, __global Dtype* top_mask, bool no_mask) |
||||
{ |
||||
for (int index = get_global_id(0); index < nthreads; |
||||
index += get_global_size(0)) |
||||
{ |
||||
const int pw = index % pooled_width; |
||||
const int ph = (index / pooled_width) % pooled_height; |
||||
const int c = (index / pooled_width / pooled_height) % channels; |
||||
const int n = index / pooled_width / pooled_height / channels; |
||||
int hstart = ph * stride_h - pad_h; |
||||
int wstart = pw * stride_w - pad_w; |
||||
const int hend = min(hstart + kernel_h, height); |
||||
const int wend = min(wstart + kernel_w, width); |
||||
hstart = max(hstart, (int)0); |
||||
wstart = max(wstart, (int)0); |
||||
Dtype maxval = -FLT_MAX; |
||||
int maxidx = -1; |
||||
__global const Dtype* bottom_slice = bottom_data |
||||
+ (n * channels + c) * height * width; |
||||
for (int h = hstart; h < hend; ++h) { |
||||
for (int w = wstart; w < wend; ++w) { |
||||
if (bottom_slice[h * width + w] > maxval) { |
||||
maxidx = h * width + w; |
||||
maxval = bottom_slice[maxidx]; |
||||
} |
||||
} |
||||
} |
||||
top_data[index] = maxval; |
||||
if (!no_mask) { |
||||
if (use_mask == 1) { |
||||
mask[index] = maxidx; |
||||
} else { |
||||
top_mask[index] = maxidx; |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
__kernel void TEMPLATE(max_pool_forward, Dtype)( |
||||
const int nthreads, __global const Dtype* bottom_data, const int num, |
||||
const int channels, const int height, const int width, |
||||
const int pooled_height, const int pooled_width, const int kernel_h, |
||||
const int kernel_w, const int stride_h, const int stride_w, const int pad_h, |
||||
const int pad_w, |
||||
__global Dtype* top_data, |
||||
const int use_mask, __global int* mask, __global Dtype* top_mask) |
||||
{ |
||||
TEMPLATE(max_pool_forward_impl, Dtype)( |
||||
nthreads, bottom_data, num, channels, height, width, |
||||
pooled_height, pooled_width, kernel_h, |
||||
kernel_w, stride_h, stride_w, pad_h, pad_w, top_data, use_mask, mask, top_mask, false |
||||
); |
||||
} |
||||
|
||||
__kernel void TEMPLATE(ave_pool_forward, Dtype)( |
||||
const int nthreads, __global const Dtype* const bottom_data, const int num, |
||||
const int channels, const int height, const int width, |
||||
const int pooled_height, const int pooled_width, const int kernel_h, |
||||
const int kernel_w, const int stride_h, const int stride_w, const int pad_h, |
||||
const int pad_w, __global Dtype* top_data) |
||||
{ |
||||
for (int index = get_global_id(0); index < nthreads; |
||||
index += get_global_size(0)) |
||||
{ |
||||
{ |
||||
const int pw = index % pooled_width; |
||||
const int ph = (index / pooled_width) % pooled_height; |
||||
const int c = (index / pooled_width / pooled_height) % channels; |
||||
const int n = index / pooled_width / pooled_height / channels; |
||||
int hstart = ph * stride_h - pad_h; |
||||
int wstart = pw * stride_w - pad_w; |
||||
int hend = min(hstart + kernel_h, height + pad_h); |
||||
int wend = min(wstart + kernel_w, width + pad_w); |
||||
const int pool_size = (hend - hstart) * (wend - wstart); |
||||
hstart = max(hstart, (int)0); |
||||
wstart = max(wstart, (int)0); |
||||
hend = min(hend, height); |
||||
wend = min(wend, width); |
||||
Dtype aveval = 0; |
||||
__global const Dtype* bottom_slice = bottom_data |
||||
+ (n * channels + c) * height * width; |
||||
for (int h = hstart; h < hend; ++h) { |
||||
for (int w = wstart; w < wend; ++w) { |
||||
aveval += bottom_slice[h * width + w]; |
||||
} |
||||
} |
||||
top_data[index] = aveval / pool_size; |
||||
} |
||||
} |
||||
} |
||||
|
||||
__kernel void TEMPLATE(sto_pool_forward_test,Dtype)( |
||||
const int nthreads, __global const Dtype* const bottom_data, const int num, |
||||
const int channels, const int height, const int width, |
||||
const int pooled_height, const int pooled_width, const int kernel_h, |
||||
const int kernel_w, const int stride_h, const int stride_w, |
||||
__global Dtype* top_data) |
||||
{ |
||||
for (int index = get_global_id(0); index < nthreads; |
||||
index += get_global_size(0)) |
||||
{ |
||||
const int pw = index % pooled_width; |
||||
const int ph = (index / pooled_width) % pooled_height; |
||||
const int c = (index / pooled_width / pooled_height) % channels; |
||||
const int n = index / pooled_width / pooled_height / channels; |
||||
const int hstart = ph * stride_h; |
||||
const int hend = min(hstart + kernel_h, height); |
||||
const int wstart = pw * stride_w; |
||||
const int wend = min(wstart + kernel_w, width); |
||||
// We set cumsum to be 0 to avoid divide-by-zero problems |
||||
Dtype cumsum = FLT_MIN; |
||||
Dtype cumvalues = 0.; |
||||
__global const Dtype* bottom_slice = bottom_data |
||||
+ (n * channels + c) * height * width; |
||||
// First pass: get sum |
||||
for (int h = hstart; h < hend; ++h) { |
||||
for (int w = wstart; w < wend; ++w) { |
||||
cumsum += bottom_slice[h * width + w]; |
||||
cumvalues += bottom_slice[h * width + w] * bottom_slice[h * width + w]; |
||||
} |
||||
} |
||||
top_data[index] = cumvalues / cumsum; |
||||
} |
||||
} |
@ -0,0 +1,182 @@ |
||||
/*M/////////////////////////////////////////////////////////////////////////////////////// |
||||
// |
||||
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING. |
||||
// |
||||
// By downloading, copying, installing or using the software you agree to this license. |
||||
// If you do not agree to this license, do not download, install, |
||||
// copy or use the software. |
||||
// |
||||
// |
||||
// License Agreement |
||||
// For Open Source Computer Vision Library |
||||
// |
||||
// Copyright (C) 2017, Intel Corporation, all rights reserved. |
||||
// Copyright (c) 2016-2017 Fabian David Tschopp, all rights reserved. |
||||
// Third party copyrights are property of their respective owners. |
||||
// |
||||
// Redistribution and use in source and binary forms, with or without modification, |
||||
// are permitted provided that the following conditions are met: |
||||
// |
||||
// * Redistribution's of source code must retain the above copyright notice, |
||||
// this list of conditions and the following disclaimer. |
||||
// |
||||
// * Redistribution's in binary form must reproduce the above copyright notice, |
||||
// this list of conditions and the following disclaimer in the documentation |
||||
// and/or other materials provided with the distribution. |
||||
// |
||||
// * The name of the copyright holders may not be used to endorse or promote products |
||||
// derived from this software without specific prior written permission. |
||||
// |
||||
// This software is provided by the copyright holders and contributors "as is" and |
||||
// any express or implied warranties, including, but not limited to, the implied |
||||
// warranties of merchantability and fitness for a particular purpose are disclaimed. |
||||
// In no event shall the Intel Corporation or contributors be liable for any direct, |
||||
// indirect, incidental, special, exemplary, or consequential damages |
||||
// (including, but not limited to, procurement of substitute goods or services; |
||||
// loss of use, data, or profits; or business interruption) however caused |
||||
// and on any theory of liability, whether in contract, strict liability, |
||||
// or tort (including negligence or otherwise) arising in any way out of |
||||
// the use of this software, even if advised of the possibility of such damage. |
||||
// |
||||
//M*/ |
||||
|
||||
#define CONCAT(A,B) A##_##B |
||||
#define TEMPLATE(name,type) CONCAT(name,type) |
||||
#define Dtype float |
||||
|
||||
#if defined(cl_intel_subgroups) |
||||
#pragma OPENCL EXTENSION cl_intel_subgroups : enable |
||||
#endif |
||||
|
||||
__kernel void TEMPLATE(softmax_forward_slm,Dtype)(const int num, const int channels, |
||||
const int spatial_dim, |
||||
__global Dtype* scale, |
||||
__global const Dtype* data, |
||||
__global Dtype* out, |
||||
__local Dtype *out_tmp, |
||||
__local Dtype *scale_tmp, |
||||
__local Dtype *group_tmp) { |
||||
|
||||
int n = get_global_id(1); |
||||
for (int index = get_global_id(0), s = 0; index < spatial_dim * get_local_size(0); index += |
||||
get_global_size(0), ++s) { |
||||
float maxval = -FLT_MAX; |
||||
for (int c = get_global_id(0); c < channels; c += get_global_size(0)) { |
||||
Dtype tmp = data[(n * channels + c) * spatial_dim + s]; |
||||
maxval = max((Dtype)tmp, (Dtype)maxval); |
||||
} |
||||
maxval = sub_group_reduce_max(maxval * 100000); |
||||
//if (get_sub_group_local_id() == 0) |
||||
group_tmp[get_sub_group_id() * spatial_dim + s] = maxval; |
||||
} |
||||
|
||||
barrier(CLK_LOCAL_MEM_FENCE); |
||||
|
||||
for (int index = get_global_id(0); index < spatial_dim * get_max_sub_group_size(); index += |
||||
get_global_size(0)) { |
||||
int s = index / get_max_sub_group_size(); |
||||
Dtype maxval = sub_group_reduce_max(group_tmp[get_sub_group_local_id() * spatial_dim + s]); |
||||
//if (get_sub_group_local_id() == 0) |
||||
scale_tmp[s] = maxval / 100000; |
||||
} |
||||
|
||||
barrier(CLK_LOCAL_MEM_FENCE); |
||||
|
||||
for (int index = get_global_id(0); index < channels * spatial_dim; |
||||
index += get_global_size(0)) { |
||||
int s = index % spatial_dim; |
||||
out_tmp[index] = exp(data[n * channels * spatial_dim + index] - scale_tmp[s]); |
||||
} |
||||
barrier(CLK_LOCAL_MEM_FENCE); |
||||
|
||||
for (int index = get_global_id(0), s = 0; index < spatial_dim * get_local_size(0); index += |
||||
get_global_size(0), ++s) { |
||||
Dtype sum = 0; |
||||
for (int c = get_global_id(0); c < channels; c += get_global_size(0)) { |
||||
sum += out_tmp[c * spatial_dim + s]; |
||||
} |
||||
sum = sub_group_reduce_add(sum * 100000); |
||||
group_tmp[get_sub_group_id() * spatial_dim + s] = sum; |
||||
} |
||||
barrier(CLK_LOCAL_MEM_FENCE); |
||||
|
||||
for (int index = get_global_id(0); index < spatial_dim * get_max_sub_group_size(); index += |
||||
get_global_size(0)) { |
||||
int s = index / get_max_sub_group_size(); |
||||
Dtype sum = sub_group_reduce_add(group_tmp[get_sub_group_local_id() * spatial_dim + s]); |
||||
//if (get_sub_group_local_id() == 0) |
||||
scale_tmp[s] = sum / 100000; |
||||
} |
||||
barrier(CLK_LOCAL_MEM_FENCE); |
||||
|
||||
for (int index = get_global_id(0); index < channels * spatial_dim; |
||||
index += get_global_size(0)) { |
||||
int s = index % spatial_dim; |
||||
out[n * channels * spatial_dim + index] = out_tmp[index] / scale_tmp[s]; |
||||
} |
||||
} |
||||
|
||||
__kernel void TEMPLATE(softmax_forward,Dtype)(const int num, const int channels, |
||||
const int spatial_dim, |
||||
__global Dtype* scale, |
||||
__global const Dtype* data, |
||||
__global Dtype* out) { |
||||
|
||||
int n = get_global_id(1); |
||||
__global Dtype *group_tmp = scale + spatial_dim * num + n * get_max_sub_group_size() * spatial_dim; |
||||
for (int index = get_global_id(0), s = 0; index < spatial_dim * get_local_size(0); index += |
||||
get_global_size(0), ++s) { |
||||
float maxval = -FLT_MAX; |
||||
for (int c = get_global_id(0); c < channels; c += get_global_size(0)) { |
||||
Dtype tmp = data[(n * channels + c) * spatial_dim + s]; |
||||
maxval = max((Dtype)tmp, (Dtype)maxval); |
||||
} |
||||
maxval = sub_group_reduce_max(maxval * 100000); |
||||
//if (get_sub_group_local_id() == 0) |
||||
group_tmp[get_sub_group_id() * spatial_dim + s] = maxval; |
||||
} |
||||
barrier(CLK_GLOBAL_MEM_FENCE); |
||||
|
||||
for (int index = get_global_id(0); index < spatial_dim * get_max_sub_group_size(); index += |
||||
get_global_size(0)) { |
||||
int s = index / get_max_sub_group_size(); |
||||
Dtype maxval = sub_group_reduce_max(group_tmp[get_sub_group_local_id() * spatial_dim + s]); |
||||
//if (get_sub_group_local_id() == 0) |
||||
scale[n * spatial_dim + s] = maxval / 100000; |
||||
} |
||||
|
||||
barrier(CLK_GLOBAL_MEM_FENCE); |
||||
|
||||
for (int index = get_global_id(0); index < channels * spatial_dim; |
||||
index += get_global_size(0)) { |
||||
int s = index % spatial_dim; |
||||
out[n * channels * spatial_dim + index] = exp(data[n * channels * spatial_dim + index] - scale[n * spatial_dim + s]); |
||||
} |
||||
barrier(CLK_GLOBAL_MEM_FENCE); |
||||
|
||||
for (int index = get_global_id(0), s = 0; index < spatial_dim * get_local_size(0); index += |
||||
get_global_size(0), ++s) { |
||||
Dtype sum = 0; |
||||
for (int c = get_global_id(0); c < channels; c += get_global_size(0)) { |
||||
sum += out[n * channels * spatial_dim + c * spatial_dim + s]; |
||||
} |
||||
sum = sub_group_reduce_add(sum * 100000); |
||||
group_tmp[get_sub_group_id() * spatial_dim + s] = sum; |
||||
} |
||||
barrier(CLK_GLOBAL_MEM_FENCE); |
||||
|
||||
for (int index = get_global_id(0); index < spatial_dim * get_max_sub_group_size(); index += |
||||
get_global_size(0)) { |
||||
int s = index / get_max_sub_group_size(); |
||||
Dtype sum = sub_group_reduce_add(group_tmp[get_sub_group_local_id() * spatial_dim + s]); |
||||
//if (get_sub_group_local_id() == 0) |
||||
scale[n * spatial_dim + s] = sum / 100000; |
||||
} |
||||
barrier(CLK_GLOBAL_MEM_FENCE); |
||||
|
||||
for (int index = get_global_id(0); index < channels * spatial_dim; |
||||
index += get_global_size(0)) { |
||||
int s = index % spatial_dim; |
||||
out[n * channels * spatial_dim + index] /= scale[n * spatial_dim + s]; |
||||
} |
||||
} |
Loading…
Reference in new issue