Merge pull request #3614 from hipudding:ascendc

Enable AscendC kernel operator
pull/3608/head^2
Alexander Smorkalov 1 year ago committed by GitHub
commit bbce2ef9d6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 6
      modules/cannops/CMakeLists.txt
  2. 17
      modules/cannops/ascendc_kernels/CMakeLists.txt
  3. 22
      modules/cannops/ascendc_kernels/kernel_tiling_types.h
  4. 379
      modules/cannops/ascendc_kernels/threshold_opencv_kernel.cpp
  5. 77
      modules/cannops/ascendc_kernels/vector_tiling.h
  6. 7
      modules/cannops/include/opencv2/ascendc_kernels.hpp
  7. 28
      modules/cannops/include/opencv2/cann_call.hpp
  8. 20
      modules/cannops/perf/perf_element_operations.cpp
  9. 6
      modules/cannops/src/ascend_mat.cpp
  10. 12
      modules/cannops/src/cann_call.cpp
  11. 128
      modules/cannops/src/element_operations.cpp
  12. 2
      modules/cannops/src/precomp.hpp
  13. 33
      modules/cannops/test/test_element_operations.cpp
  14. 51
      modules/cannops/test/test_kernel.cpp
  15. 1
      modules/cannops/test/test_precomp.hpp

@ -15,3 +15,9 @@ ocv_include_directories(${CMAKE_SOURCE_DIR}/modules/ts/include)
ocv_add_accuracy_tests(DEPENDS_ON opencv_cannops)
ocv_add_perf_tests(DEPENDS_ON opencv_cannops)
ocv_add_samples(opencv_cannops)
# compile ascnedc kernels.
add_subdirectory(ascendc_kernels)
ocv_include_directories(${CMAKE_BINARY_DIR}/include/ascendc_kernels)
ocv_target_link_libraries(opencv_cannops PRIVATE ascendc_kernels)
ocv_target_link_libraries(opencv_test_cannops PRIVATE ascendc_kernels)

@ -0,0 +1,17 @@
set(SOC_VERSION "ascend310p3" CACHE STRING "system on chip type")
set(ASCEND_CANN_PACKAGE_PATH "/usr/local/Ascend/ascend-toolkit/latest" CACHE PATH "ASCEND CANN package installation directory")
set(RUN_MODE "npu" CACHE STRING "run mode: npu/sim/cpu")
if(EXISTS ${ASCEND_CANN_PACKAGE_PATH}/compiler/tikcpp/ascendc_kernel_cmake)
set(ASCENDC_CMAKE_DIR ${ASCEND_CANN_PACKAGE_PATH}/compiler/tikcpp/ascendc_kernel_cmake)
elseif(EXISTS ${ASCEND_CANN_PACKAGE_PATH}/ascendc_devkit/tikcpp/samples/cmake)
set(ASCENDC_CMAKE_DIR ${ASCEND_CANN_PACKAGE_PATH}/ascendc_devkit/tikcpp/samples/cmake)
else()
message(FATAL_ERROR "ascendc_kernel_cmake does not exist, please check whether the compiler package is installed.")
endif()
include(${ASCENDC_CMAKE_DIR}/ascendc.cmake)
ascendc_library(ascendc_kernels STATIC
threshold_opencv_kernel.cpp
)

@ -0,0 +1,22 @@
#ifndef KERNEL_TILING_H
#define KERNEL_TILING_H
/*
* threshType:
* THRESH_BINARY = 0,
* THRESH_BINARY_INV = 1,
* THRESH_TRUNC = 2,
* THRESH_TOZERO = 3,
* THRESH_TOZERO_INV = 4,
*/
#pragma pack(push, 8)
struct ThresholdOpencvTilingData
{
float maxVal;
float thresh;
uint32_t totalLength;
uint8_t threshType;
uint8_t dtype;
};
#pragma pack(pop)
#endif // KERNEL_TILING_H

@ -0,0 +1,379 @@
#include "kernel_operator.h"
#include "vector_tiling.h"
#include "kernel_tiling_types.h"
using namespace AscendC;
// Make compiler happy. These two function will never be called.
__aicore__ static inline void Cast(const LocalTensor<half>& dstLocal,
const LocalTensor<half>& srcLocal, const RoundMode& round_mode,
const uint32_t calCount){};
__aicore__ static inline void Cast(const LocalTensor<float>& dstLocal,
const LocalTensor<float>& srcLocal, const RoundMode& round_mode,
const uint32_t calCount){};
/**
* T: input data type.
* C: data type for calculate.
* if T != C, data should cast from T to C.
*/
template <typename T, typename C>
class KernelThreshold
{
public:
__aicore__ inline KernelThreshold() {}
__aicore__ inline void Init(ThresholdOpencvTilingData* tiling, GM_ADDR x, GM_ADDR y)
{
tilingData = tiling;
/**
* Calculate memory use per element.
* 1. InputQueue: sizeof(T) * BUFFER_NUM
* 2. OutputQueue: sizeof(T) * BUFFER_NUM
* 3. maskBuffer: 1 byte at most.
*/
uint64_t bytesPerElem = sizeof(T) * BUFFER_NUM * 2 + sizeof(uint8_t) * 1;
/**
* If need cast, should init two more cast buffers.
* Memory use per element:
* 1. InputCastBuffer: sizeof(C)
* 2. OutputCastBuffer: sizeof(C)
*/
if (!std::is_same<T, C>::value)
{
bytesPerElem += sizeof(C) * 2;
}
// Most of AscendC APIs need align to 32 Bytes, but Compare and Select need
// align to 256 Bytes, 256/sizeof(C) means how many element can be process
// in one loop.
vecTiling.calculate(tilingData->totalLength, GetBlockNum(), GetBlockIdx(), bytesPerElem,
256 / sizeof(C));
xGM.SetGlobalBuffer((__gm__ T*)x + vecTiling.blockOffset, vecTiling.blockLength);
yGM.SetGlobalBuffer((__gm__ T*)y + vecTiling.blockOffset, vecTiling.blockLength);
// Cast buffer.
if (!std::is_same<T, C>::value)
{
pipe.InitBuffer(InputCastBuffer, vecTiling.loopLength * sizeof(C));
pipe.InitBuffer(outputCastBuffer, vecTiling.loopLength * sizeof(C));
}
pipe.InitBuffer(inputQueue, BUFFER_NUM, vecTiling.loopLength * sizeof(T));
pipe.InitBuffer(outputQueue, BUFFER_NUM, vecTiling.loopLength * sizeof(T));
pipe.InitBuffer(maskBuffer, vecTiling.loopLength * sizeof(uint8_t));
}
__aicore__ inline void Run()
{
for (uint32_t loop = 0; loop < vecTiling.loopCount; loop++)
{
uint32_t offset = loop * vecTiling.loopLength;
Compute(offset, vecTiling.loopLength);
}
if (vecTiling.loopTailLength != 0)
{
uint32_t offset = vecTiling.loopCount * vecTiling.loopLength;
Compute(offset, vecTiling.loopTailLength);
}
}
private:
__aicore__ inline void Compute(uint32_t offset, uint32_t len)
{
CopyIn(offset, len);
// Get local Tensor, if case is need, local tensors come from
// cast buffer. otherwise, local tensors come from input/output queue.
LocalTensor<C> xLocal = CastInput(inputQueue, InputCastBuffer, len);
LocalTensor<C> yLocal = GetOutput(outputQueue, outputCastBuffer);
Threshold(xLocal, yLocal, len);
// Free local input tensor if tensor is not from cast buffer.
FreeInput(inputQueue, xLocal);
// Cast output tensor to output queue if output tensor is from cast buffer.
CastOutput(outputQueue, yLocal, len);
CopyOut(offset, len);
}
/**
* If need cast:
* 1. Get data from input queue, this data can't be calculate directly.
* 2. Get buffer with type C, which satisfied AscendC APIs.
* 3. Cast data from T to C.
*
* If not need cast:
* 1. Only need get data from queue.
*/
__aicore__ inline LocalTensor<C> CastInput(TQue<QuePosition::VECIN, BUFFER_NUM>& queue,
TBuf<TPosition::VECCALC>& buffer, uint32_t len)
{
LocalTensor<C> xLocal;
if (std::is_same<T, C>::value)
{
xLocal = queue.DeQue<C>();
}
else
{
xLocal = buffer.Get<C>();
LocalTensor<T> xCast = queue.DeQue<T>();
Cast(xLocal, xCast, RoundMode::CAST_NONE, len);
queue.FreeTensor(xCast);
}
return xLocal;
}
/**
* If need cast:
* 1. Get local tensor from cast buffer.
*
* If not need cast:
* 1. Alloc local tensor from output queue.
*/
__aicore__ inline LocalTensor<C> GetOutput(TQue<QuePosition::VECOUT, BUFFER_NUM>& queue,
TBuf<TPosition::VECCALC>& buffer)
{
if (std::is_same<T, C>::value)
{
return queue.AllocTensor<C>();
}
else
{
return buffer.Get<C>();
}
}
/**
* If need cast:
* 1. Input local tensor are get from cast buffer, which do not need free.
*
* If not need cast:
* 1. Input local tensor are alloced from input queue, which need free.
*/
__aicore__ inline void FreeInput(TQue<QuePosition::VECIN, BUFFER_NUM>& queue,
LocalTensor<C>& xLocal)
{
if (std::is_same<T, C>::value)
{
queue.FreeTensor(xLocal);
}
}
/**
* If need cast:
* 1. Alloc local tensor from output queue.
* 2. Cast from C to T.
* 3. Put casted local tensor in queue.
*
* If not need cast:
* 1. Only put local tensor in queue.
*
*/
__aicore__ inline void CastOutput(TQue<QuePosition::VECOUT, BUFFER_NUM>& queue,
LocalTensor<C>& yLocal, uint32_t len)
{
if (std::is_same<T, C>::value)
{
queue.EnQue(yLocal);
}
else
{
LocalTensor<T> yCast = queue.AllocTensor<T>();
RoundMode roundMode = RoundMode::CAST_NONE;
// Ref to AscendC cast API.
if (std::is_same<T, int16_t>::value)
{
roundMode = RoundMode::CAST_RINT;
}
else if (std::is_same<T, int32_t>::value)
{
roundMode = RoundMode::CAST_ROUND;
}
Cast(yCast, yLocal, roundMode, len);
queue.EnQue(yCast);
}
}
__aicore__ inline void CopyIn(uint32_t offset, uint32_t len)
{
LocalTensor<T> xLocal = inputQueue.AllocTensor<T>();
DataCopy(xLocal, xGM[offset], len);
inputQueue.EnQue(xLocal);
}
__aicore__ inline void CopyOut(uint32_t offset, uint32_t len)
{
LocalTensor<T> yLocal = outputQueue.DeQue<T>();
DataCopy(yGM[offset], yLocal, len);
outputQueue.FreeTensor(yLocal);
}
/**
* AscendC API Compare Warpper.
* AscendC Compare level2 API need input length align to 256, process
* tail data by level0 API.
*/
__aicore__ inline void CompareWrap(const LocalTensor<uint8_t>& dstLocal,
const LocalTensor<C>& src0Local,
const LocalTensor<C>& src1Local, CMPMODE cmpMode,
uint32_t calCount)
{
// Elements total count for on loop inside Compare.
uint32_t batchCount = 256 / sizeof(C);
// Tail elements count.
uint32_t tailCount = calCount % batchCount;
// Level2 API, calCount should align to 256.
Compare(dstLocal, src0Local, src1Local, cmpMode, calCount - tailCount);
// Data blocks are already cut align to 256, tail count will be 0 for
// all process loops except last one.
if (tailCount != 0)
{
BinaryRepeatParams repeatParams = {1, 1, 1, 8, 8, 8};
uint32_t tailIdx = calCount - tailCount;
uint32_t maskIdx = tailIdx / sizeof(uint8_t);
Compare(dstLocal[maskIdx], src0Local[tailIdx], src1Local[tailIdx], cmpMode, tailCount,
1, repeatParams);
}
}
/**
* AscendC API Select Warpper.
* AscendC Select level2 API need input length align to 256, process
* tail data by level0 API.
*/
__aicore__ inline void SelectWrap(const LocalTensor<C>& dstLocal,
const LocalTensor<uint8_t>& selMask,
const LocalTensor<C>& src0Local, C src1Local, SELMODE selMode,
uint32_t calCount)
{
uint32_t batchCount = 256 / sizeof(C);
uint32_t tailCount = calCount % batchCount;
Select(dstLocal, selMask, src0Local, src1Local, selMode, calCount - tailCount);
if (tailCount != 0)
{
BinaryRepeatParams repeatParams = {1, 1, 1, 8, 8, 8};
uint32_t tailIdx = calCount - tailCount;
uint32_t maskIdx = tailIdx / sizeof(uint8_t);
Select(dstLocal[tailIdx], selMask[maskIdx], src0Local[tailIdx], src1Local, selMode,
tailCount, 1, repeatParams);
}
}
__aicore__ inline void Threshold(LocalTensor<C>& xLocal, LocalTensor<C>& yLocal, uint32_t len)
{
LocalTensor<uint8_t> mask = maskBuffer.Get<uint8_t>();
Duplicate(yLocal, static_cast<C>(tilingData->thresh), len);
switch (tilingData->threshType)
{
case 0:
CompareWrap(mask, xLocal, yLocal, CMPMODE::LE, len);
Duplicate(yLocal, static_cast<C>(0), len);
SelectWrap(yLocal, mask, yLocal, static_cast<C>(tilingData->maxVal),
SELMODE::VSEL_TENSOR_SCALAR_MODE, len);
break;
case 1:
CompareWrap(mask, xLocal, yLocal, CMPMODE::GT, len);
Duplicate(yLocal, static_cast<C>(0), len);
SelectWrap(yLocal, mask, yLocal, static_cast<C>(tilingData->maxVal),
SELMODE::VSEL_TENSOR_SCALAR_MODE, len);
break;
case 2:
CompareWrap(mask, xLocal, yLocal, CMPMODE::LE, len);
SelectWrap(yLocal, mask, xLocal, static_cast<C>(tilingData->thresh),
SELMODE::VSEL_TENSOR_SCALAR_MODE, len);
break;
case 3:
CompareWrap(mask, xLocal, yLocal, CMPMODE::GT, len);
SelectWrap(yLocal, mask, xLocal, static_cast<C>(0),
SELMODE::VSEL_TENSOR_SCALAR_MODE, len);
break;
case 4:
CompareWrap(mask, xLocal, yLocal, CMPMODE::LE, len);
SelectWrap(yLocal, mask, xLocal, static_cast<C>(0),
SELMODE::VSEL_TENSOR_SCALAR_MODE, len);
break;
default:
break;
}
}
TPipe pipe;
TQue<QuePosition::VECIN, BUFFER_NUM> inputQueue;
TQue<QuePosition::VECOUT, BUFFER_NUM> outputQueue;
TBuf<TPosition::VECCALC> InputCastBuffer, outputCastBuffer, maskBuffer;
GlobalTensor<T> xGM, yGM;
VectorTiling vecTiling;
ThresholdOpencvTilingData* tilingData;
};
#define LAUNCH_THRESHOLD_KERNEL(NAME, T, C) \
__aicore__ inline void launch_threshold_kernel_##NAME(ThresholdOpencvTilingData* tilingData, \
GM_ADDR x, GM_ADDR y) \
{ \
KernelThreshold<T, C> op; \
op.Init(tilingData, x, y); \
op.Run(); \
}
LAUNCH_THRESHOLD_KERNEL(CV_8U, uint8_t, half) // CV_8U
LAUNCH_THRESHOLD_KERNEL(CV_8S, int8_t, half) // CV_8S
// CV_16U
LAUNCH_THRESHOLD_KERNEL(CV_16S, int16_t, half) // CV_16S
LAUNCH_THRESHOLD_KERNEL(CV_32S, int32_t, float) // CV_32S
LAUNCH_THRESHOLD_KERNEL(CV_32F, float, float) // CV_32F
// CV_64F
LAUNCH_THRESHOLD_KERNEL(CV_16F, half, half) // CV_16F
#undef LAUNCH_THRESHOLD_KERNEL
#define CALL_THRESHOLD_KERNEL(NAME) launch_threshold_kernel_##NAME
extern "C" __global__ __aicore__ void threshold_opencv(GM_ADDR tilingGM, GM_ADDR x, GM_ADDR y)
{
ThresholdOpencvTilingData tilingData;
auto tempTilingGM = (__gm__ uint8_t*)tilingGM;
auto tempTiling = (uint8_t*)&tilingData;
for (int32_t i = 0; i < sizeof(ThresholdOpencvTilingData) / sizeof(uint8_t);
++i, ++tempTilingGM, ++tempTiling)
{
*tempTiling = *tempTilingGM;
}
// AscendC can only call inline functions, function pointer can't be used here.
// Use Macro and switch case instead.
switch (tilingData.dtype)
{
case 0:
CALL_THRESHOLD_KERNEL(CV_8U)(&tilingData, x, y);
break;
case 1:
CALL_THRESHOLD_KERNEL(CV_8S)(&tilingData, x, y);
break;
case 3:
CALL_THRESHOLD_KERNEL(CV_16S)(&tilingData, x, y);
break;
case 4:
CALL_THRESHOLD_KERNEL(CV_32S)(&tilingData, x, y);
break;
case 5:
CALL_THRESHOLD_KERNEL(CV_32F)(&tilingData, x, y);
break;
case 7:
CALL_THRESHOLD_KERNEL(CV_16F)(&tilingData, x, y);
break;
case 2: case 6: default: // CV_16U, CV_64F
break;
}
// Clear tiling GM cache manually. (cce compiler bug)
dcci(tilingGM, 1);
}

@ -0,0 +1,77 @@
#ifndef TILING_KERNEL_H
#define TILING_KERNEL_H
#ifdef __CCE_KT_TEST__
#define __aicore__
#else
#define __aicore__ [aicore]
#endif
inline __aicore__ int32_t AlignNCeil(int32_t n, int32_t align) { return ((n + align) & ~(align-1)); }
inline __aicore__ int32_t AlignNFloor(int32_t n, int32_t align) { return (n & ~(align-1)); }
constexpr int32_t BUFFER_NUM = 2;
constexpr int32_t UB_BUF_LEN = 248 * 1024;
struct VectorTiling {
__aicore__ inline void calculate(uint64_t _totalLength, uint64_t _blockNum,
uint64_t _blockIdx, uint64_t _variableBytesPerElem, uint32_t _align) {
totalLength = _totalLength;
blockNum = _blockNum;
blockIdx = _blockIdx;
variableBytesPerElem = _variableBytesPerElem;
blockLength = 0;
blockOffset = 0;
align = _align;
GetBlockLengthAndOffset();
GetLoopLengthAndCount();
#ifdef __CCE_KT_TEST__
std::cout << "Block(" << blockIdx << "): BlockLength = " << blockLength
<< ", BlockOffset = " << blockOffset
<< ", LoopLength = " << loopLength
<< ", LoopCount = " << loopCount
<< ", LoopTailLength = " << loopTailLength << std::endl;
#endif
}
__aicore__ inline void GetBlockLengthAndOffset() {
// Data should Align by 32B.
uint32_t fullBlockLength = AlignNCeil(totalLength / blockNum, 32);
// Some core may get no data after Align32 Ceil.
uint32_t fullBlockNum = totalLength / fullBlockLength;
uint32_t blockTailLength = totalLength % fullBlockLength;
if (blockIdx < fullBlockNum) {
blockLength = fullBlockLength;
blockOffset = blockIdx * blockLength;
// Last block must less than full block num.
} else if (blockTailLength != 0 && blockIdx == fullBlockNum) {
blockLength = blockTailLength;
blockOffset = blockIdx * fullBlockLength;
}
}
/**
* @brief Get length for one loop and loop count.
* Use as much UB buf as possible.
*/
__aicore__ inline void GetLoopLengthAndCount() {
loopLength = AlignNFloor(UB_BUF_LEN / variableBytesPerElem, align);
loopCount = blockLength / loopLength;
loopTailLength = blockLength - (loopLength * loopCount);
}
uint64_t totalLength;
uint64_t blockNum;
uint64_t blockIdx;
uint64_t variableBytesPerElem;
uint32_t blockLength;
uint32_t blockOffset;
uint32_t loopLength;
uint32_t loopCount;
uint32_t loopTailLength;
uint32_t align;
};
#endif // TILING_KERNEL_H

@ -0,0 +1,7 @@
#ifndef ASCENDC_KERNELS_H
#define ASCENDC_KERNELS_H
#include "../../ascendc_kernels/kernel_tiling_types.h"
#include "aclrtlaunch_threshold_opencv.h"
#endif //ASCENDC_KERNELS_H

@ -9,7 +9,9 @@
#include <set>
#include <string>
#include <acl/acl_base.h>
#include "opencv2/cann.hpp"
#include "cann.hpp"
#include "stream_accessor.hpp"
#include "ascendc_kernels.hpp"
class aclopAttr;
@ -17,6 +19,15 @@ namespace cv
{
namespace cann
{
CV_EXPORTS void checkAclError(aclError err, const char* file, const int line, const char* func);
void checkAclPtr(void* ptr, const char* file, const int line, const char* func);
#define CV_ACL_SAFE_CALL(expr) checkAclError((expr), __FILE__, __LINE__, CV_Func)
#define CV_ACL_SAFE_CALL_PTR(expr) \
({ \
auto ptr = (expr); \
checkAclPtr(ptr, __FILE__, __LINE__, CV_Func); \
ptr; \
})
// Warpper for functions in CANN, callers should not call CANN's api directly, but should call the
// function provided in cann_call.
void aclrtMallocWarpper(void** data, size_t size);
@ -39,7 +50,7 @@ void aclrtMemsetWarpper(std::shared_ptr<uchar>& ptr, int32_t value, size_t count
//! Type mapping between opencv and cann.
aclDataType getACLType(int opencvdepth);
//! Malloc and upload raw data to devices.
std::shared_ptr<uchar> mallocAndUpload(const void* data, size_t size, AscendStream& stream,
CV_EXPORTS std::shared_ptr<uchar> mallocAndUpload(const void* data, size_t size, AscendStream& stream,
AscendMat::Allocator* allocator);
/**
* @brief Warpper of CANN streams.
@ -151,6 +162,19 @@ public:
OperatorRunner& run(AscendStream& stream);
};
template <typename KERNEL_TYPE, typename TILING_TYPE, typename... ARGS>
void kernel_launch(KERNEL_TYPE kernel, AscendStream& stream, TILING_TYPE& tiling, ARGS... args)
{
std::shared_ptr<uchar> tilingDevice =
mallocAndUpload(&tiling, sizeof(TILING_TYPE), stream, AscendMat::defaultAllocator());
aclrtStream rawStream = AscendStreamAccessor::getStream(stream);
CV_ACL_SAFE_CALL(kernel(1, rawStream, tilingDevice.get(), args...));
if (rawStream == nullptr)
{
stream.waitForCompletion();
}
}
} // namespace cann
} // namespace cv

@ -207,5 +207,25 @@ PERF_TEST_P(CPU, MAT_BITWISE_NOT_MAT, testing::Combine(TYPICAL_ASCEND_MAT_SIZES,
SANITY_CHECK_NOTHING();
}
PERF_TEST_P(NPU, THRESHOLD_ASCENDC, testing::Combine(TYPICAL_ASCEND_MAT_SIZES, Values(CV_8U, CV_16S, CV_32F)))
{
Mat mat(GET_PARAM(0), GET_PARAM(1));
AscendMat dst;
AscendMat src;
src.upload(mat);
declare.in(mat, WARMUP_RNG);
TEST_CYCLE_N(10) { cv::cann::threshold(src, dst, 100.0, 255.0, cv::THRESH_BINARY); }
SANITY_CHECK_NOTHING();
}
PERF_TEST_P(CPU, THRESHOLD, testing::Combine(TYPICAL_ASCEND_MAT_SIZES, Values(CV_8U, CV_16S, CV_32F)))
{
Mat mat(GET_PARAM(0), GET_PARAM(1));
Mat dst;
declare.in(mat, WARMUP_RNG);
TEST_CYCLE_N(10) { cv::threshold(mat, dst, 100.0, 255.0, cv::THRESH_BINARY); }
SANITY_CHECK_NOTHING();
}
} // namespace
} // namespace opencv_test

@ -23,7 +23,11 @@ std::shared_ptr<uchar> DefaultAllocator::allocate(size_t size)
bool DefaultAllocator::allocate(cv::cann::AscendMat* mat, int rows, int cols, size_t elemSize)
{
mat->data = allocate(elemSize * cols * rows);
size_t totalBytes = elemSize * cols * rows;
// align by 32B.
totalBytes = ((totalBytes + 32) & ~31);
mat->data = allocate(totalBytes);
mat->step = cols * elemSize;
return true;

@ -11,7 +11,7 @@ namespace cv
namespace cann
{
/*******************************Acl Error Checker*****************************/
static inline void checkAclError(aclError err, const char* file, const int line, const char* func)
void checkAclError(aclError err, const char* file, const int line, const char* func)
{
if (ACL_SUCCESS != err)
{
@ -20,7 +20,7 @@ static inline void checkAclError(aclError err, const char* file, const int line,
}
}
static inline void checkAclPtr(void* ptr, const char* file, const int line, const char* func)
void checkAclPtr(void* ptr, const char* file, const int line, const char* func)
{
if (nullptr == ptr)
{
@ -29,14 +29,6 @@ static inline void checkAclPtr(void* ptr, const char* file, const int line, cons
}
}
#define CV_ACL_SAFE_CALL(expr) checkAclError((expr), __FILE__, __LINE__, CV_Func)
#define CV_ACL_SAFE_CALL_PTR(expr) \
({ \
auto ptr = (expr); \
checkAclPtr(ptr, __FILE__, __LINE__, CV_Func); \
ptr; \
})
/******************************Acl Runtime Warpper****************************/
void aclrtMallocWarpper(void** data, size_t size)
{

@ -3,6 +3,7 @@
// of this distribution and at http://opencv.org/license.html.
#include "precomp.hpp"
namespace cv
{
namespace cann
@ -110,8 +111,8 @@ static void convert(const Scalar& src, Scalar& dst, AscendStream& stream)
}
template <typename T1, typename T2>
static void arithm_op(const T1& src1, const T2& src2, AscendMat& dst, const AscendMat& mask, float scale,
int dtype, const char* op, AscendStream& stream)
static void arithm_op(const T1& src1, const T2& src2, AscendMat& dst, const AscendMat& mask,
float scale, int dtype, const char* op, AscendStream& stream)
{
T1 castedSrc1;
T2 castedSrc2;
@ -170,8 +171,9 @@ static void arithm_op(const T1& src1, const T2& src2, AscendMat& dst, const Asce
}
}
static void arithm_op(const InputArray _src1, const InputArray _src2, OutputArray _dst, const InputArray _mask,
float scale, int dtype, const char* op, AscendStream& stream)
static void arithm_op(const InputArray _src1, const InputArray _src2, OutputArray _dst,
const InputArray _mask, float scale, int dtype, const char* op,
AscendStream& stream)
{
const bool isScalar1 = (_src1.kind() == _InputArray::MATX);
const bool isScalar2 = (_src2.kind() == _InputArray::MATX);
@ -213,56 +215,54 @@ static void arithm_op(const InputArray _src1, const InputArray _src2, OutputArra
}
// In order to supply more interfaces, differnet function declaration shoule be done.
void add(const InputArray src1, const InputArray src2, OutputArray dst, const InputArray mask, int dtype,
AscendStream& stream)
void add(const InputArray src1, const InputArray src2, OutputArray dst, const InputArray mask,
int dtype, AscendStream& stream)
{
arithm_op(src1, src2, dst, mask, 1, dtype, "Add", stream);
}
void add(const AscendMat& src1, const AscendMat& src2, AscendMat& dst, const AscendMat& mask, int dtype,
AscendStream& stream)
void add(const AscendMat& src1, const AscendMat& src2, AscendMat& dst, const AscendMat& mask,
int dtype, AscendStream& stream)
{
arithm_op(src1, src2, dst, mask, 1, dtype, "Add", stream);
}
void add(const AscendMat& src1, const Scalar& src2, AscendMat& dst, const AscendMat& mask, int dtype,
AscendStream& stream)
void add(const AscendMat& src1, const Scalar& src2, AscendMat& dst, const AscendMat& mask,
int dtype, AscendStream& stream)
{
arithm_op(src1, src2, dst, mask, 1, dtype, "Add", stream);
}
void add(const Scalar& src1, const AscendMat& src2, AscendMat& dst, const AscendMat& mask, int dtype,
AscendStream& stream)
void add(const Scalar& src1, const AscendMat& src2, AscendMat& dst, const AscendMat& mask,
int dtype, AscendStream& stream)
{
arithm_op(src1, src2, dst, mask, 1, dtype, "Add", stream);
}
void subtract(const InputArray src1, const InputArray src2, OutputArray dst, const InputArray mask, int dtype,
AscendStream& stream)
void subtract(const InputArray src1, const InputArray src2, OutputArray dst, const InputArray mask,
int dtype, AscendStream& stream)
{
arithm_op(src1, src2, dst, mask, 1, dtype, "Sub", stream);
}
void subtract(const AscendMat& src1, const AscendMat& src2, AscendMat& dst, const AscendMat& mask, int dtype,
AscendStream& stream)
void subtract(const AscendMat& src1, const AscendMat& src2, AscendMat& dst, const AscendMat& mask,
int dtype, AscendStream& stream)
{
arithm_op(src1, src2, dst, mask, 1, dtype, "Sub", stream);
}
void subtract(const AscendMat& src1, const Scalar& src2, AscendMat& dst, const AscendMat& mask, int dtype,
AscendStream& stream)
void subtract(const AscendMat& src1, const Scalar& src2, AscendMat& dst, const AscendMat& mask,
int dtype, AscendStream& stream)
{
arithm_op(src1, src2, dst, mask, 1, dtype, "Sub", stream);
}
void subtract(const Scalar& src1, const AscendMat& src2, AscendMat& dst, const AscendMat& mask, int dtype,
AscendStream& stream)
void subtract(const Scalar& src1, const AscendMat& src2, AscendMat& dst, const AscendMat& mask,
int dtype, AscendStream& stream)
{
arithm_op(src1, src2, dst, mask, 1, dtype, "Sub", stream);
}
void multiply(const InputArray src1, const InputArray src2, OutputArray dst, float scale, int dtype,
AscendStream& stream)
{
@ -287,7 +287,6 @@ void multiply(const Scalar& src1, const AscendMat& src2, AscendMat& dst, float s
arithm_op(src1, src2, dst, AscendMat(), scale, dtype, "Mul", stream);
}
void divide(const InputArray src1, const InputArray src2, OutputArray dst, float scale, int dtype,
AscendStream& stream)
{
@ -312,15 +311,14 @@ void divide(const Scalar& src1, const AscendMat& src2, AscendMat& dst, float sca
arithm_op(src1, src2, dst, AscendMat(), scale, dtype, "RealDiv", stream);
}
void bitwise_and(const InputArray src1, const InputArray src2, OutputArray dst, const InputArray mask,
AscendStream& stream)
void bitwise_and(const InputArray src1, const InputArray src2, OutputArray dst,
const InputArray mask, AscendStream& stream)
{
arithm_op(src1, src2, dst, mask, 1, -1, "BitwiseAnd", stream);
}
void bitwise_and(const AscendMat& src1, const AscendMat& src2, AscendMat& dst, const AscendMat& mask,
AscendStream& stream)
void bitwise_and(const AscendMat& src1, const AscendMat& src2, AscendMat& dst,
const AscendMat& mask, AscendStream& stream)
{
arithm_op(src1, src2, dst, mask, 1, -1, "BitwiseAnd", stream);
}
@ -337,9 +335,8 @@ void bitwise_and(const Scalar& src1, const AscendMat& src2, AscendMat& dst, cons
arithm_op(src1, src2, dst, mask, 1, -1, "BitwiseAnd", stream);
}
void bitwise_or(const InputArray src1, const InputArray src2, OutputArray dst, const InputArray mask,
AscendStream& stream)
void bitwise_or(const InputArray src1, const InputArray src2, OutputArray dst,
const InputArray mask, AscendStream& stream)
{
arithm_op(src1, src2, dst, mask, 1, -1, "BitwiseOr", stream);
}
@ -362,15 +359,14 @@ void bitwise_or(const Scalar& src1, const AscendMat& src2, AscendMat& dst, const
arithm_op(src1, src2, dst, mask, 1, -1, "BitwiseOr", stream);
}
void bitwise_xor(const InputArray src1, const InputArray src2, OutputArray dst, const InputArray mask,
AscendStream& stream)
void bitwise_xor(const InputArray src1, const InputArray src2, OutputArray dst,
const InputArray mask, AscendStream& stream)
{
arithm_op(src1, src2, dst, mask, 1, -1, "BitwiseXor", stream);
}
void bitwise_xor(const AscendMat& src1, const AscendMat& src2, AscendMat& dst, const AscendMat& mask,
AscendStream& stream)
void bitwise_xor(const AscendMat& src1, const AscendMat& src2, AscendMat& dst,
const AscendMat& mask, AscendStream& stream)
{
arithm_op(src1, src2, dst, mask, 1, -1, "BitwiseXor", stream);
}
@ -387,7 +383,6 @@ void bitwise_xor(const Scalar& src1, const AscendMat& src2, AscendMat& dst, cons
arithm_op(src1, src2, dst, mask, 1, -1, "BitwiseXor", stream);
}
void bitwise_not(const InputArray src, OutputArray dst, const InputArray mask, AscendStream& stream)
{
arithm_op(src, noArray(), dst, mask, 1, -1, "Invert", stream);
@ -398,9 +393,8 @@ void bitwise_not(const AscendMat& src, AscendMat& dst, const AscendMat& mask, As
arithm_op(src, AscendMat(), dst, mask, 1, -1, "Invert", stream);
}
void addWeighted(const AscendMat& src1, double alpha, const AscendMat& src2, double beta, double gamma,
AscendMat& dst, int dtype, AscendStream& stream)
void addWeighted(const AscendMat& src1, double alpha, const AscendMat& src2, double beta,
double gamma, AscendMat& dst, int dtype, AscendStream& stream)
{
if (dtype < 0)
dtype = src1.depth();
@ -421,8 +415,8 @@ void addWeighted(const AscendMat& src1, double alpha, const AscendMat& src2, dou
arithm_op(srcWeightedSumRet, (float)gamma, dst, "Adds", stream);
}
void addWeighted(const InputArray _src1, double alpha, const InputArray _src2, double beta, double gamma,
OutputArray _dst, int dtype, AscendStream& stream)
void addWeighted(const InputArray _src1, double alpha, const InputArray _src2, double beta,
double gamma, OutputArray _dst, int dtype, AscendStream& stream)
{
AscendMat src1, src2, dst;
src1.upload(_src1, stream);
@ -442,45 +436,23 @@ double threshold(const AscendMat& src, AscendMat& dst, double thresh, double max
dst.create(src.rows, src.cols, src.type());
OperatorRunner runner;
runner.setOp("Threshold")
.addInput(src, "x")
.addOutput(threshMat, "y")
.addAttr((float)thresh, "threshold")
.run(stream);
// THRESH_*_INV, THRESH_TRUNC need a inverse threshMat.
// THRESH_BINARY_INV = 1, THRESH_TRUNC = 2, THRESH_TOZERO_INV = 4,
if (type == 1 || type == 2 || type == 4)
if (src.depth() == CV_8U || src.depth() == CV_8S || src.depth() == CV_16S ||
src.depth() == CV_32S || src.depth() == CV_32F || src.depth() == CV_16F)
{
AscendMat threshInvMat(src.size(), src.type());
AscendMat ones(src.size(), src.type());
Scalar s(1, 1, 1, 1);
ones.setTo(s, stream);
arithm_op(ones, threshMat, threshInvMat, "Sub", stream);
if (type == 1)
arithm_op(threshInvMat, (float)maxval, dst, "Muls", stream);
else if (type == 2)
{
AscendMat ToZeroInvMat(src.size(), src.type());
AscendMat TruncMat(src.size(), src.type());
arithm_op(threshInvMat, src, ToZeroInvMat, "Mul", stream);
arithm_op(threshMat, (float)thresh, TruncMat, "Muls", stream);
arithm_op(ToZeroInvMat, TruncMat, dst, "Add", stream);
}
else
arithm_op(threshInvMat, src, dst, "Mul", stream);
ThresholdOpencvTilingData tiling;
tiling.maxVal = maxval;
tiling.thresh = thresh;
// AscendMat memory will be align to 32B, it's safe to set totalLengh a little bigger.
size_t totalBytes = src.rows * src.cols * src.channels();
tiling.totalLength = ALIGN_UP(totalBytes, 32);
tiling.threshType = type;
tiling.dtype = src.depth();
kernel_launch(aclrtlaunch_threshold_opencv, stream, tiling, src.data.get(), dst.data.get());
}
else
{
if (type == 0) /* THRESH_BINARY = 0 */
arithm_op(threshMat, (float)maxval, dst, "Muls", stream);
else if (type == 3) /* THRESH_TOZERO = 3 */
arithm_op(threshMat, src, dst, "Mul", stream);
else
CV_Error(Error::StsError, "Unknown/unsupported threshold type");
}
CV_Error(Error::StsUnsupportedFormat, "");
return thresh;
}

@ -10,5 +10,7 @@
#include "opencv2/cann_call.hpp"
#include "opencv2/cann_interface.hpp"
#include "opencv2/cann_private.hpp"
#include "opencv2/ascendc_kernels.hpp"
#define ALIGN_UP(num, align) (((num) + (align) - 1) & ~((align) - 1))
#endif /* __OPENCV_PRECOMP_H__ */

@ -678,7 +678,6 @@ TEST(ELEMENTWISE_OP, MAT_THRESHOLD)
for (int i = 0; i <= 4; i++)
{
cv::threshold(cpuMat, cpuOpRet, 128, 250, i);
// TODO find the reason empty AscendMat is not continuous.
cv::cann::threshold(ascendMat16F, aclOpRet, 128, 250, i);
aclOpRet.convertTo(aclOpRet16S, CV_16S);
aclOpRet16S.download(checker);
@ -693,5 +692,37 @@ TEST(ELEMENTWISE_OP, MAT_THRESHOLD)
cv::cann::resetDevice();
}
TEST(ELEMENTWISE_OP, MAT_THRESHOLD_ASCENDC)
{
cv::cann::setDevice(DEVICE_ID);
Mat cpuRet, npuRet;
AscendMat npuImg, npuTmpMat;
// opencv do not support CV_8S, CV_32S, CV_16F
// ascend do not support CV_16U, CV_64F
uint8_t dtypes[] = {CV_8U, CV_16S, CV_32F};
for (uint i = 0; i <= 4; i++)
{
for (uint j = 0; j < sizeof(dtypes) / sizeof(dtypes[0]); j++)
{
double thresh = 90.5;
double maxVal = 85.2;
Mat img = randomMat(10, 10, CV_MAKETYPE(dtypes[j], 3), 0.0f, 128.0f);
npuImg.upload(img);
npuTmpMat.create(npuImg.rows, npuImg.cols, npuImg.type());
cv::threshold(img, cpuRet, thresh, maxVal, i);
cv::cann::threshold(npuImg, npuTmpMat, thresh, maxVal, i);
npuTmpMat.download(npuRet);
EXPECT_MAT_NEAR(cpuRet, npuRet, 10.0f);
}
}
cv::cann::resetDevice();
}
} // namespace
} // namespace opencv_test

@ -0,0 +1,51 @@
#include "test_precomp.hpp"
#include "opencv2/cann_call.hpp"
namespace opencv_test
{
namespace
{
TEST(ASCENDC_KERNEL, THRESHOLD)
{
cv::cann::setDevice(DEVICE_ID);
Mat cpuRet, npuRet;
AscendMat npuImg, npuTmpMat;
// opencv do not support CV_8S, CV_32S, CV_16F
// ascend do not support CV_16U, CV_64F
uint8_t dtypes[] = {CV_8U, CV_16S, CV_32F};
for (uint i = 0; i <= 4; i++)
{
for (uint j = 0; j < sizeof(dtypes) / sizeof(dtypes[0]); j++)
{
double thresh = 90.5;
double maxVal = 85.2;
Mat img = randomMat(10, 10, CV_MAKETYPE(dtypes[j], 3), 0.0f, 128.0f);
npuImg.upload(img);
npuTmpMat.create(npuImg.rows, npuImg.cols, npuImg.type());
cv::threshold(img, cpuRet, thresh, maxVal, i);
ThresholdOpencvTilingData tiling;
tiling.maxVal = maxVal;
tiling.thresh = thresh;
size_t totalBytes = img.rows * img.cols * img.channels();
// AscendMat memory will be align to 32B, it's safe to set totalLengh a little bigger.
tiling.totalLength = ((totalBytes + 32) & ~31);
tiling.threshType = i;
tiling.dtype = dtypes[j];
kernel_launch(aclrtlaunch_threshold_opencv, AscendStream::Null(), tiling,
npuImg.data.get(), npuTmpMat.data.get());
npuTmpMat.download(npuRet);
EXPECT_MAT_NEAR(cpuRet, npuRet, 10.0f);
}
}
cv::cann::resetDevice();
}
} // namespace
} // namespace opencv_test

@ -9,6 +9,7 @@
#include "opencv2/cann.hpp"
#include "opencv2/ts/cuda_test.hpp"
#include "opencv2/cann_interface.hpp"
#include "opencv2/ascendc_kernels.hpp"
using namespace cv;
using namespace cv::cann;

Loading…
Cancel
Save