Merge pull request #21036 from fengyuentau:timvx_backend_support
dnn: TIM-VX NPU backend support * Add TimVX NPU backend for DNN module. * use official branch from tim-vx repo; fix detecting viv sdk Co-authored-by: fytao <yuantao.feng@outlook.com>pull/21803/head
parent
9390c56831
commit
7b582b71ba
37 changed files with 2982 additions and 30 deletions
@ -0,0 +1,73 @@ |
||||
set(TIMVX_COMMIT_HASH "1d9c7ab941b3d8d9c4d28d80058402725731e3d6") |
||||
set(OCV_TIMVX_DIR "${OpenCV_BINARY_DIR}/3rdparty/libtim-vx") |
||||
set(OCV_TIMVX_SOURCE_PATH "${OCV_TIMVX_DIR}/TIM-VX-${TIMVX_COMMIT_HASH}") |
||||
|
||||
# Download TIM-VX source code |
||||
if(EXISTS "${OCV_TIMVX_SOURCE_PATH}") |
||||
message(STATUS "TIM-VX: Use cache of TIM-VX source code at ${OCV_TIMVX_SOURCE_PATH}") |
||||
set(TIMVX_FOUND ON) |
||||
else() |
||||
set(OCV_TIMVX_FILENAME "${TIMVX_COMMIT_HASH}.zip") |
||||
set(OCV_TIMVX_URL "https://github.com/VeriSilicon/TIM-VX/archive/") |
||||
set(timvx_zip_md5sum 92619cc4498014ac7a09834d5e33ebd5) |
||||
|
||||
ocv_download(FILENAME ${OCV_TIMVX_FILENAME} |
||||
HASH ${timvx_zip_md5sum} |
||||
URL "${OCV_TIMVX_URL}" |
||||
DESTINATION_DIR "${OCV_TIMVX_DIR}" |
||||
ID "TIM-VX" |
||||
STATUS res |
||||
UNPACK RELATIVE_URL) |
||||
if(res) |
||||
set(TIMVX_FOUND ON) |
||||
message(STATUS "TIM-VX: Source code downloaded at ${OCV_TIMVX_SOURCE_PATH}.") |
||||
else() |
||||
set(TIMVX_FOUND OFF) |
||||
message(STATUS "TIM-VX: Failed to download source code from github. Turning off TIMVX_FOUND") |
||||
return() |
||||
endif() |
||||
endif() |
||||
|
||||
# set VIVANTE SDK especially for x86_64 which comes along with TIM-VX source code |
||||
if(CMAKE_SYSTEM_PROCESSOR STREQUAL x86_64) |
||||
set(VIVANTE_SDK_DIR "${OCV_TIMVX_SOURCE_PATH}/prebuilt-sdk/x86_64_linux") |
||||
message(STATUS "TIM-VX: Build from source using prebuilt x86_64 VIVANTE SDK.") |
||||
endif() |
||||
|
||||
# Verify if requested VIVANTE SDK libraries are all found |
||||
find_vivante_sdk_libs(missing ${VIVANTE_SDK_DIR}) |
||||
if(missing) |
||||
message(STATUS "TIM-VX: Failed to find ${missing} in ${VIVANTE_SDK_DIR}/lib. Turning off TIMVX_VIV_FOUND") |
||||
set(TIMVX_VIV_FOUND OFF) |
||||
else() |
||||
message(STATUS "TIM-VX: dependent VIVANTE SDK libraries are found at ${VIVANTE_SDK_DIR}/lib.") |
||||
set(TIMVX_VIV_FOUND ON) |
||||
endif() |
||||
|
||||
if(TIMVX_VIV_FOUND) |
||||
# vars used by TIM-VX CMake scripts |
||||
set(EXTERNAL_VIV_SDK "${VIVANTE_SDK_DIR}" CACHE INTERNAL "" FORCE) |
||||
set(VIV_SDK_DRIVER_PREFIX "lib" CACHE INTERNAL "" FORCE) |
||||
endif() |
||||
|
||||
if(TIMVX_FOUND AND TIMVX_VIV_FOUND) |
||||
set(BUILD_TIMVX ON) |
||||
else() |
||||
return() |
||||
endif() |
||||
|
||||
if(BUILD_TIMVX) |
||||
set(HAVE_TIMVX 1) |
||||
|
||||
ocv_warnings_disable(CMAKE_C_FLAGS -Wunused-parameter -Wstrict-prototypes -Wundef -Wsign-compare -Wmissing-prototypes -Wmissing-declarations -Wstrict-aliasing -Wunused-but-set-variable -Wmaybe-uninitialized -Wshadow -Wsuggest-override -Wswitch) |
||||
ocv_warnings_disable(CMAKE_CXX_FLAGS -Wunused-parameter -Wstrict-prototypes -Wundef -Wsign-compare -Wunused-but-set-variable -Wshadow -Wsuggest-override -Wmissing-declarations -Wswitch) |
||||
|
||||
set(TIMVX_INC_DIR "${OCV_TIMVX_SOURCE_PATH}/include" CACHE INTERNAL "TIM-VX include directory") |
||||
if(EXISTS "${OCV_TIMVX_SOURCE_PATH}/CMakeLists.txt") |
||||
add_subdirectory("${OCV_TIMVX_SOURCE_PATH}" "${OCV_TIMVX_DIR}/build") |
||||
else() |
||||
message(WARNING "TIM-VX: Missing 'CMakeLists.txt' in the source code: ${OCV_TIMVX_SOURCE_PATH}") |
||||
endif() |
||||
ocv_install_target(tim-vx EXPORT OpenCVModules ARCHIVE DESTINATION ${OPENCV_3P_LIB_INSTALL_PATH} COMPONENT dev) |
||||
set(TIMVX_LIB "tim-vx") |
||||
endif() |
@ -0,0 +1,69 @@ |
||||
set(TIMVX_INSTALL_DIR "" CACHE PATH "Path to libtim-vx installation") |
||||
set(VIVANTE_SDK_DIR "" CACHE PATH "Path to VIVANTE SDK needed by TIM-VX.") |
||||
set(VIVANTE_SDK_LIB_CANDIDATES "OpenVX;VSC;GAL;ArchModelSw;NNArchPerf" CACHE STRING "VIVANTE SDK library candidates") |
||||
|
||||
# Ensure VIVANTE SDK library candidates are present in given search path |
||||
function(find_vivante_sdk_libs _viv_notfound _viv_search_path) |
||||
foreach(one ${VIVANTE_SDK_LIB_CANDIDATES}) |
||||
#NO_DEFAULT_PATH is used to ensure VIVANTE SDK libs are from one only source |
||||
find_library(VIV_${one}_LIB ${one} PATHS "${_viv_search_path}/lib" NO_DEFAULT_PATH) |
||||
if(NOT VIV_${one}_LIB) |
||||
list(APPEND _viv_notfound_list ${one}) |
||||
endif() |
||||
endforeach() |
||||
set(${_viv_notfound} ${_viv_notfound_list} PARENT_SCOPE) |
||||
endfunction() |
||||
# Default value for VIVANTE_SDK_DIR: /usr |
||||
if(NOT VIVANTE_SDK_DIR) |
||||
set(VIVANTE_SDK_DIR "/usr") |
||||
endif() |
||||
# Environment variable VIVANTE_SDK_DIR overrides the one in this script |
||||
if(DEFINED ENV{VIVANTE_SDK_DIR}) |
||||
set(VIVANTE_SDK_DIR $ENV{VIVANTE_SDK_DIR}) |
||||
message(STATUS "TIM-VX: Load VIVANTE_SDK_DIR from system environment: ${VIVANTE_SDK_DIR}") |
||||
endif() |
||||
|
||||
|
||||
# Compile with pre-installed TIM-VX; Or compile together with TIM-VX from source |
||||
if(TIMVX_INSTALL_DIR AND NOT BUILD_TIMVX) |
||||
message(STATUS "TIM-VX: Use binaries at ${TIMVX_INSTALL_DIR}") |
||||
set(BUILD_TIMVX OFF) |
||||
|
||||
set(TIMVX_INC_DIR "${TIMVX_INSTALL_DIR}/include" CACHE INTERNAL "TIM-VX include directory") |
||||
find_library(TIMVX_LIB "tim-vx" PATHS "${TIMVX_INSTALL_DIR}/lib") |
||||
if(TIMVX_LIB) |
||||
set(TIMVX_FOUND ON) |
||||
else() |
||||
set(TIMVX_FOUND OFF) |
||||
endif() |
||||
|
||||
# Verify if requested VIVANTE SDK libraries are all found |
||||
find_vivante_sdk_libs(missing ${VIVANTE_SDK_DIR}) |
||||
if(missing) |
||||
message(STATUS "TIM-VX: Failed to find ${missing} in ${VIVANTE_SDK_DIR}/lib. Turning off TIMVX_VIV_FOUND") |
||||
set(TIMVX_VIV_FOUND OFF) |
||||
else() |
||||
message(STATUS "TIM-VX: dependent VIVANTE SDK libraries are found at ${VIVANTE_SDK_DIR}/lib.") |
||||
set(TIMVX_VIV_FOUND ON) |
||||
endif() |
||||
else() |
||||
message(STATUS "TIM-VX: Build from source") |
||||
include("${OpenCV_SOURCE_DIR}/3rdparty/libtim-vx/tim-vx.cmake") |
||||
endif() |
||||
|
||||
if(TIMVX_FOUND AND TIMVX_VIV_FOUND) |
||||
set(HAVE_TIMVX 1) |
||||
|
||||
message(STATUS "TIM-VX: Found TIM-VX includes: ${TIMVX_INC_DIR}") |
||||
message(STATUS "TIM-VX: Found TIM-VX library: ${TIMVX_LIB}") |
||||
set(TIMVX_LIBRARY ${TIMVX_LIB}) |
||||
set(TIMVX_INCLUDE_DIR ${TIMVX_INC_DIR}) |
||||
|
||||
message(STATUS "TIM-VX: Found VIVANTE SDK libraries: ${VIVANTE_SDK_DIR}/lib") |
||||
link_directories(${VIVANTE_SDK_DIR}/lib) |
||||
endif() |
||||
|
||||
MARK_AS_ADVANCED( |
||||
TIMVX_INC_DIR |
||||
TIMVX_LIB |
||||
) |
@ -0,0 +1,931 @@ |
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
//
|
||||
// Copyright (C) 2019-2021, Shenzhen Institute of Artificial Intelligence and
|
||||
// Robotics for Society, all rights reserved.
|
||||
// Third party copyrights are property of their respective owners.
|
||||
|
||||
#include "precomp.hpp" |
||||
#include <opencv2/dnn/shape_utils.hpp> |
||||
#include "op_timvx.hpp" |
||||
#include "net_impl.hpp" |
||||
|
||||
namespace cv |
||||
{ |
||||
namespace dnn |
||||
{ |
||||
#ifdef HAVE_TIMVX |
||||
|
||||
CV__DNN_INLINE_NS_BEGIN |
||||
|
||||
// update all comsumer
|
||||
void Net::Impl::tvUpdateConfictMap(int graphIndex, LayerData& ld, std::vector<std::vector<int> >& graphConflictMap) |
||||
{ |
||||
if (ld.consumers.empty()) |
||||
return; |
||||
for (int i = 0; i < ld.consumers.size(); i++) |
||||
{ |
||||
LayerData &consumerld = layers[ld.consumers[i].lid]; |
||||
std::vector<int>::iterator it = std::find(graphConflictMap[ld.consumers[i].lid].begin(), |
||||
graphConflictMap[ld.consumers[i].lid].end(), graphIndex); |
||||
|
||||
if (it == graphConflictMap[ld.consumers[i].lid].end()) |
||||
{ |
||||
graphConflictMap[ld.consumers[i].lid].push_back(graphIndex); |
||||
tvUpdateConfictMap(graphIndex, consumerld, graphConflictMap); |
||||
} |
||||
else |
||||
continue; |
||||
} |
||||
} |
||||
|
||||
// Convert TRANSIENT to OUTPUT
|
||||
void Net::Impl::tvConvertToOutputNode(const LayerData& ld, Ptr<TimVXBackendWrapper>& targetWrap) |
||||
{ |
||||
// find right layer.
|
||||
for (auto& inputLayerId : ld.inputLayersId) |
||||
{ |
||||
LayerData &inputld = layers[inputLayerId]; |
||||
auto itWrap = std::find(inputld.outputBlobsWrappers.begin(), |
||||
inputld.outputBlobsWrappers.end(), targetWrap); |
||||
if (itWrap != inputld.outputBlobsWrappers.end()) |
||||
{ |
||||
auto outputWrap = (*itWrap).dynamicCast<TimVXBackendWrapper>(); |
||||
if (!outputWrap->isTensor()) |
||||
continue; |
||||
|
||||
auto inputNode = inputld.backendNodes[DNN_BACKEND_TIMVX].dynamicCast<TimVXBackendNode>(); |
||||
if (!inputNode->isLast && inputNode->opIndex != -1) |
||||
{ |
||||
CV_Assert(outputWrap->getTensorAttr() == tim::vx::TRANSIENT); |
||||
// set last
|
||||
inputNode->isLast = true; |
||||
|
||||
auto shapeType = getShapeTypeFromMat(outputWrap->getMat()); |
||||
auto outQuant = outputWrap->getTensorQuantization(); |
||||
|
||||
outputWrap->setTensorShape(shapeType); |
||||
outputWrap->createTensor(inputNode->tvGraph->graph, |
||||
tim::vx::TensorAttribute::OUTPUT, outQuant); |
||||
int outIndex = inputNode->tvGraph->addWrapper(outputWrap); |
||||
inputNode->outputIndexList.clear(); |
||||
inputNode->outputIndexList.push_back(outIndex); |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
void Net::Impl::initTimVXBackend() |
||||
{ |
||||
CV_TRACE_FUNCTION(); |
||||
CV_Assert(preferableBackend == DNN_BACKEND_TIMVX); |
||||
|
||||
// Build TimVX Graph from sets of layers that support this TimVX backend.
|
||||
// Split a whole model on several TimVX Graph if some of layers are not implemented by TimVX backend.
|
||||
if (!haveTimVX()) |
||||
return; |
||||
|
||||
// Allocate graphConflictMap
|
||||
if (timVxInfo.graphConflictMap.empty()) |
||||
timVxInfo.graphConflictMap.resize(layers.size()); |
||||
|
||||
auto it = layers.begin(); |
||||
bool isLast = false; // If the node is the last node in current tvGraph.
|
||||
|
||||
for (; it != layers.end(); it++) |
||||
{ |
||||
isLast = false; |
||||
LayerData &ld = it->second; |
||||
if(ld.skip) |
||||
continue; |
||||
Ptr<Layer> layer = ld.layerInstance; |
||||
if (!layer->supportBackend(preferableBackend)) |
||||
{ |
||||
continue; |
||||
} |
||||
|
||||
// If layer consumers are more than one, set isLast true.
|
||||
// For now, TimVX backend divides multiple branchs into multiple tvGraph.
|
||||
if (ld.consumers.size() == 0) |
||||
{ |
||||
isLast = true; |
||||
} |
||||
else if(ld.consumers.size() == 1) |
||||
{ |
||||
LayerData* consumerld = &layers[ld.consumers[0].lid]; |
||||
|
||||
while (consumerld) |
||||
{ |
||||
if (consumerld->skip) |
||||
{ |
||||
if (consumerld->consumers.size() == 1) |
||||
{ |
||||
int nextLayerId = consumerld->consumers[0].lid; |
||||
consumerld = &layers[nextLayerId]; |
||||
} |
||||
else |
||||
{ |
||||
isLast = true; |
||||
break; |
||||
} |
||||
} |
||||
else |
||||
{ |
||||
break; |
||||
} |
||||
} |
||||
Ptr<Layer>& consumerLayer = consumerld->layerInstance; |
||||
|
||||
if (!isLast && !consumerLayer->supportBackend(preferableBackend)) |
||||
{ |
||||
isLast = true; |
||||
} |
||||
} |
||||
else |
||||
{ |
||||
// If there are is multiple input, and only one of them is supported.
|
||||
int tvSupportNum = 0; |
||||
for (int i = 0; i<ld.consumers.size(); i++) |
||||
{ |
||||
LayerData* consumerld = &layers[ld.consumers[0].lid]; |
||||
|
||||
while (consumerld) |
||||
{ |
||||
if (consumerld->skip) |
||||
{ |
||||
if (consumerld->consumers.size() == 1) |
||||
{ |
||||
int nextLayerId = consumerld->consumers[0].lid; |
||||
consumerld = &layers[nextLayerId]; |
||||
} |
||||
else |
||||
{ |
||||
isLast = true; |
||||
break; |
||||
} |
||||
} |
||||
else |
||||
{ |
||||
break; |
||||
} |
||||
} |
||||
Ptr<Layer>& consumerLayer = consumerld->layerInstance; |
||||
|
||||
if (consumerLayer->supportBackend(preferableBackend)) |
||||
{ |
||||
tvSupportNum++; |
||||
} |
||||
} |
||||
|
||||
if (tvSupportNum != 1) |
||||
isLast = true; |
||||
} |
||||
|
||||
int graphIndex = -1; |
||||
bool needRecorrect = !timVxInfo.findGraphIndex(ld.inputBlobsWrappers, graphIndex); |
||||
|
||||
if (graphIndex != -1 && !needRecorrect) |
||||
{ |
||||
needRecorrect = timVxInfo.isConflict(ld.id, graphIndex); |
||||
} |
||||
|
||||
// Recorrect the input layer.
|
||||
if (needRecorrect) |
||||
{ |
||||
// set all inputLayers' as last layer, and convert TRANSIENT to output.
|
||||
for (int i = 0; i < ld.inputBlobsWrappers.size(); i++) |
||||
{ |
||||
auto inputWrap = ld.inputBlobsWrappers[i]; |
||||
auto tvInputWrap = inputWrap.dynamicCast<TimVXBackendWrapper>(); |
||||
if (!tvInputWrap->isTensor()) |
||||
continue; |
||||
|
||||
auto attr = tvInputWrap->getTensorAttr(); |
||||
if (attr == tim::vx::TensorAttribute::OUTPUT) |
||||
{ |
||||
continue; |
||||
} |
||||
else if (attr == tim::vx::TensorAttribute::INPUT) |
||||
{ |
||||
Mat matTmp = tvInputWrap->getMat(); |
||||
tvInputWrap = Ptr<TimVXBackendWrapper>(new TimVXBackendWrapper(matTmp)); |
||||
|
||||
} |
||||
else if (attr == tim::vx::TensorAttribute::TRANSIENT) |
||||
{ |
||||
tvConvertToOutputNode(ld, tvInputWrap); |
||||
// updateConflictMap
|
||||
tvUpdateConfictMap(graphIndex, ld, timVxInfo.graphConflictMap); |
||||
} |
||||
} |
||||
graphIndex = -1; |
||||
} |
||||
|
||||
if (graphIndex == -1) |
||||
{ |
||||
graphIndex = timVxInfo.createGraph(); |
||||
} |
||||
timVxInfo.setTmpGraphIndex(graphIndex); |
||||
|
||||
ld.backendNodes[DNN_BACKEND_TIMVX] = |
||||
layer->initTimVX(&timVxInfo, ld.inputBlobsWrappers, ld.outputBlobsWrappers, isLast); |
||||
|
||||
// post process, create last node correctly.
|
||||
if (isLast && ld.backendNodes[DNN_BACKEND_TIMVX]) |
||||
{ |
||||
auto tmpNode = ld.backendNodes[DNN_BACKEND_TIMVX].dynamicCast<TimVXBackendNode>(); |
||||
tmpNode->isLast = true; |
||||
// update graphConflictMap
|
||||
tvUpdateConfictMap(graphIndex, ld, timVxInfo.graphConflictMap); |
||||
} |
||||
|
||||
// post process for failing to create timvx Node.
|
||||
if (!ld.backendNodes[DNN_BACKEND_TIMVX]) |
||||
{ |
||||
for (int i = 0; i < ld.inputBlobsWrappers.size(); i++) |
||||
{ |
||||
auto inputWrap = ld.inputBlobsWrappers[i]; |
||||
auto tvInputWrap = inputWrap.dynamicCast<TimVXBackendWrapper>(); |
||||
if (!tvInputWrap->isTensor()) |
||||
continue; |
||||
|
||||
auto attr = tvInputWrap->getTensorAttr(); |
||||
if (attr == tim::vx::TensorAttribute::TRANSIENT) |
||||
{ |
||||
tvConvertToOutputNode(ld, tvInputWrap); |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
// Op Binding
|
||||
it = layers.begin(); |
||||
Ptr<TimVXBackendNode> node; |
||||
std::vector<Ptr<TimVXGraph> > tmpGrapList; |
||||
for (; it != layers.end(); it++) |
||||
{ |
||||
LayerData &ld = it->second; |
||||
|
||||
if (ld.backendNodes[DNN_BACKEND_TIMVX]) |
||||
node = ld.backendNodes[DNN_BACKEND_TIMVX].dynamicCast<TimVXBackendNode>(); |
||||
else |
||||
continue; |
||||
|
||||
// Binding tvTensor and tvOp
|
||||
if (node->opIndex >= 0) |
||||
node->opBinding(); |
||||
} |
||||
} |
||||
|
||||
CV__DNN_INLINE_NS_END |
||||
|
||||
// from CPU to NPU
|
||||
bool copyToTensor(std::shared_ptr<tim::vx::Tensor> &dst, const Mat &src) |
||||
{ |
||||
CV_Assert(src.isContinuous() && (src.type() == CV_8S || src.type() == CV_32F)); |
||||
if (dst->CopyDataToTensor(src.data, src.total())) |
||||
{ |
||||
return true; |
||||
} |
||||
else |
||||
return false; |
||||
} |
||||
|
||||
// from NPU to CPU
|
||||
bool copyToMat(const Mat &dst, std::shared_ptr<tim::vx::Tensor> &src) |
||||
{ |
||||
CV_Assert(dst.isContinuous() && (dst.type() == CV_8S || dst.type() == CV_32F)); |
||||
if (src->CopyDataFromTensor(dst.data)) |
||||
{ |
||||
return true; |
||||
} |
||||
else |
||||
return false; |
||||
} |
||||
|
||||
tvActivationType getTimVXActType(String & actString) |
||||
{ |
||||
if (actString == "ReLUInt8") return tvActReLU; |
||||
if (actString == "ReLU6Int8") return tvActReLU6; |
||||
if (actString == "TanHInt8") return tvActTanH; |
||||
if (actString == "SwishInt8") return tvActSwish; |
||||
if (actString == "MishInt8") return tvActMish; |
||||
if (actString == "SigmoidInt8") return tvActSigmoid; |
||||
if (actString == "ELUInt8") return tvActELU; |
||||
|
||||
return tvActNotSupported; |
||||
} |
||||
|
||||
tim::vx::ShapeType getShapeTypeFromMat(const Mat& mat, bool ifConst) |
||||
{ |
||||
/* Convert Mat shape to TimVX Tensor shape.
|
||||
DataLayout in TimVX is WHCN, while NCHW in OpenCV. |
||||
So we do vector reverse. |
||||
*/ |
||||
CV_Assert(!mat.empty()); |
||||
tim::vx::ShapeType tvInputShape; |
||||
auto matShape = shape(mat); |
||||
tvInputShape.assign(matShape.begin(), matShape.end()); |
||||
|
||||
if ( matShape.size() > 1 ) // TODO: check when we need reverse the shape vector.
|
||||
{ |
||||
if (ifConst && tvInputShape.size() == 2 && tvInputShape[1] == 1) |
||||
{ // if bias vector, shape [n, 1] to [n].
|
||||
tvInputShape.resize(1); |
||||
} |
||||
else |
||||
std::reverse(tvInputShape.begin(), tvInputShape.end()); |
||||
} |
||||
return tvInputShape; |
||||
} |
||||
|
||||
bool getQuantType(const std::vector<float>& scales, int numOutput) |
||||
{ |
||||
CV_Assert(!scales.empty()); |
||||
if (numOutput == -1) |
||||
{ |
||||
numOutput = scales.size(); |
||||
} |
||||
bool tvSymmetric = false; |
||||
|
||||
for (int i =1; i < numOutput; i++) |
||||
{ |
||||
if (std::abs(scales[0] - scales[i]) > std::numeric_limits<float>::epsilon()) |
||||
{ |
||||
tvSymmetric = true; |
||||
break; |
||||
} |
||||
} |
||||
|
||||
return tvSymmetric; |
||||
} |
||||
|
||||
// convert mat Depth to tensorDataType
|
||||
tim::vx::DataType dataTypeConvert(int matDepth) |
||||
{ |
||||
tim::vx::DataType tensorDataType; |
||||
switch(matDepth) |
||||
{ |
||||
case CV_8U: |
||||
{ |
||||
tensorDataType = tim::vx::DataType::UINT8; |
||||
break; |
||||
} |
||||
case CV_8S: |
||||
{ |
||||
tensorDataType = tim::vx::DataType::INT8; |
||||
break; |
||||
} |
||||
case CV_16U: |
||||
{ |
||||
tensorDataType = tim::vx::DataType::UINT16; |
||||
break; |
||||
} |
||||
case CV_16S: |
||||
{ |
||||
tensorDataType = tim::vx::DataType::INT16; |
||||
break; |
||||
} |
||||
case CV_32S: |
||||
{ |
||||
tensorDataType = tim::vx::DataType::INT32; |
||||
break; |
||||
} |
||||
case CV_32F: |
||||
{ |
||||
tensorDataType = tim::vx::DataType::FLOAT32; |
||||
break; |
||||
} |
||||
case CV_16F: |
||||
{ |
||||
tensorDataType = tim::vx::DataType::FLOAT16; |
||||
break; |
||||
} |
||||
default: |
||||
{ |
||||
tensorDataType = tim::vx::DataType::UNKNOWN; |
||||
break; |
||||
} |
||||
} |
||||
return tensorDataType; |
||||
} |
||||
|
||||
std::vector<Ptr<TimVXBackendWrapper> > getWrappers(const std::vector<int> wrappersIndex, |
||||
Ptr<TimVXGraph> tvGraph) |
||||
{ |
||||
std::vector<Ptr<TimVXBackendWrapper> > wrappers; |
||||
for (int i = 0; i<wrappersIndex.size(); i++) |
||||
{ |
||||
auto wrapper = tvGraph->getWrapper(wrappersIndex[i]); |
||||
if (wrapper) |
||||
wrappers.push_back(wrapper); |
||||
} |
||||
|
||||
return wrappers; |
||||
} |
||||
|
||||
// *********************** TimVXGraph ********************
|
||||
TimVXGraph::TimVXGraph() |
||||
{ |
||||
// new TimVX Graph
|
||||
context = tim::vx::Context::Create(); |
||||
graph = context->CreateGraph(); |
||||
isCompiled = false; |
||||
} |
||||
|
||||
TimVXGraph::~TimVXGraph() |
||||
{ |
||||
|
||||
// release opList
|
||||
for (auto& tensor: tensorList) |
||||
tensor.reset(); |
||||
|
||||
// release tensorList
|
||||
for (auto& op: opList) |
||||
op.reset(); |
||||
|
||||
// release graph
|
||||
graph.reset(); |
||||
|
||||
// release context
|
||||
context.reset(); |
||||
} |
||||
|
||||
std::shared_ptr<tim::vx::Operation> TimVXGraph::getOp(const int opIndex) |
||||
{ |
||||
CV_Assert(0 <= opIndex && !opList.empty() && opIndex < opList.size()); |
||||
return opList[opIndex]; |
||||
} |
||||
|
||||
int TimVXGraph::addWrapper(Ptr<TimVXBackendWrapper>& tensorWrapper) |
||||
{ |
||||
CV_Assert(tensorWrapper->isTensor()); |
||||
tim::vx::TensorAttribute tensorAttr = tensorWrapper->getTensorAttr(); |
||||
|
||||
wrapperList.push_back(tensorWrapper); |
||||
tensorList.push_back(tensorWrapper->getTensor()); |
||||
int wrapperIndex = wrapperList.size() -1; |
||||
|
||||
if (tensorAttr == tim::vx::TensorAttribute::INPUT) |
||||
{ |
||||
inputWrappersIndex.push_back(wrapperIndex); |
||||
} |
||||
|
||||
if (tensorAttr == tim::vx::TensorAttribute::OUTPUT) |
||||
{ |
||||
outputWrappersIndex.push_back(wrapperIndex); |
||||
} |
||||
|
||||
return wrapperIndex; |
||||
} |
||||
|
||||
Ptr<TimVXBackendWrapper> TimVXGraph::getWrapper(int wrapperIndex) |
||||
{ |
||||
CV_Assert(wrapperIndex>=0 && wrapperIndex < wrapperList.size()); |
||||
return wrapperList[wrapperIndex]; |
||||
} |
||||
|
||||
int TimVXGraph::addOp(const std::shared_ptr<tim::vx::Operation>& op) |
||||
{ |
||||
CV_Assert(op); |
||||
opList.emplace_back(op); |
||||
return opList.size()-1; |
||||
} |
||||
|
||||
int TimVXGraph::getTensorIndex(const std::shared_ptr<tim::vx::Tensor>& tensor) |
||||
{ |
||||
auto it = find(tensorList.begin(), tensorList.end(), tensor); |
||||
if (it != tensorList.end()) |
||||
return it - tensorList.begin(); |
||||
else |
||||
return -1; |
||||
} |
||||
|
||||
void TimVXGraph::forward() |
||||
{ |
||||
CV_Assert(!inputWrappersIndex.empty() && !outputWrappersIndex.empty()); |
||||
|
||||
// Every TimVXGraph Instance only compiles once.
|
||||
if (!this->isCompiled) |
||||
{ |
||||
if (!graph->Compile()) |
||||
CV_Error(cv::Error::StsBadArg, " Fail to compile TimVX graph!"); |
||||
this->isCompiled = true; |
||||
} |
||||
|
||||
if (!graph->Run()) |
||||
CV_Error(cv::Error::StsBadArg, " Fail to run TimVX graph!"); |
||||
} |
||||
|
||||
// *********************** TimVXBackendNode ********************
|
||||
TimVXBackendNode::TimVXBackendNode(const Ptr<TimVXGraph>& tvGraph_): BackendNode(DNN_BACKEND_TIMVX) |
||||
{ |
||||
opIndex = -1; |
||||
tvGraph = tvGraph_; |
||||
isLast = false; |
||||
} |
||||
|
||||
TimVXBackendNode::TimVXBackendNode(const Ptr<TimVXGraph>& tvGraph_, |
||||
const std::shared_ptr<tim::vx::Operation>& op_): BackendNode(DNN_BACKEND_TIMVX) |
||||
{ |
||||
tvGraph = tvGraph_; |
||||
opIndex = tvGraph->addOp(op_); |
||||
isLast = false; |
||||
} |
||||
|
||||
TimVXBackendNode::TimVXBackendNode(const Ptr<TimVXGraph>& tvGraph_, std::shared_ptr<tim::vx::Operation>& op_, |
||||
std::vector<int>& inputsIndex, std::vector<int>& outpusIndex) |
||||
:BackendNode(DNN_BACKEND_TIMVX) |
||||
{ |
||||
tvGraph = tvGraph_; |
||||
opIndex = tvGraph->addOp(op_); |
||||
isLast = false; |
||||
|
||||
if (!inputsIndex.empty()) |
||||
inputIndexList.assign(inputsIndex.begin(), inputsIndex.end()); |
||||
|
||||
if (!outpusIndex.empty()) |
||||
outputIndexList.assign(outpusIndex.begin(), outpusIndex.end()); |
||||
} |
||||
|
||||
bool TimVXBackendNode::opBinding() |
||||
{ |
||||
if (!tvGraph || tvGraph->isCompiled || opIndex == -1) |
||||
return false; |
||||
|
||||
std::shared_ptr<tim::vx::Operation> op = tvGraph->getOp(opIndex); |
||||
|
||||
if (!inputIndexList.empty()) |
||||
{ |
||||
std::vector<Ptr<TimVXBackendWrapper> > inputsWrapper = getWrappers(inputIndexList, tvGraph); |
||||
// Binding input Tensor.
|
||||
for (auto& warpper: inputsWrapper) |
||||
{ |
||||
op->BindInput(warpper->getTensor()); |
||||
} |
||||
} |
||||
|
||||
if (!outputIndexList.empty()) |
||||
{ |
||||
std::vector<Ptr<TimVXBackendWrapper> > outputsWrapper = getWrappers(outputIndexList, tvGraph); |
||||
for (auto& warpper: outputsWrapper) |
||||
{ |
||||
op->BindOutput(warpper->getTensor()); |
||||
} |
||||
} |
||||
return true; |
||||
} |
||||
|
||||
void TimVXBackendNode::setInputTensor() |
||||
{ |
||||
if (!tvGraph || opIndex == -1) |
||||
return; |
||||
|
||||
if (!inputIndexList.empty()) |
||||
{ |
||||
std::vector<Ptr<TimVXBackendWrapper> > inputsWrapper = getWrappers(inputIndexList, tvGraph); |
||||
|
||||
// Binding input Tensor.
|
||||
for (auto& warpper: inputsWrapper) |
||||
{ |
||||
if (warpper->getTensorAttr() == tim::vx::TensorAttribute::INPUT) |
||||
{ |
||||
warpper->setHostDirty(); |
||||
warpper->copyToDevice(); |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
// *********************** TimVXBackendWrapper ********************
|
||||
// Default Constructor
|
||||
TimVXBackendWrapper::TimVXBackendWrapper() : BackendWrapper(DNN_BACKEND_TIMVX, DNN_TARGET_NPU) |
||||
{ |
||||
isTensor_ = false; |
||||
deviceDirty = false; |
||||
hostDirty = false; |
||||
tensorType = tim::vx::DataType::UNKNOWN; |
||||
tensorShape = {}; |
||||
tensorIndex = -1; |
||||
tensorAttr = tim::vx::TensorAttribute::CONSTANT; |
||||
} |
||||
|
||||
TimVXBackendWrapper::TimVXBackendWrapper(Mat& m) : BackendWrapper(DNN_BACKEND_TIMVX, |
||||
DNN_TARGET_NPU) |
||||
{ |
||||
host = m; |
||||
isTensor_ = false; |
||||
deviceDirty = false; |
||||
hostDirty = true; |
||||
tensorType = dataTypeConvert(m.type()); |
||||
tensorShape = {}; |
||||
tensorIndex = -1; |
||||
tensorAttr = tim::vx::TensorAttribute::CONSTANT; |
||||
|
||||
// TODO: unsupported data by TimVX should run convert function first.
|
||||
CV_Assert(tensorType != tim::vx::DataType::UNKNOWN); |
||||
} |
||||
|
||||
TimVXBackendWrapper::TimVXBackendWrapper(const Ptr<BackendWrapper>& baseBuffer, Mat& m) |
||||
:BackendWrapper(DNN_BACKEND_TIMVX, DNN_TARGET_NPU) |
||||
{ |
||||
Ptr<TimVXBackendWrapper> base = baseBuffer.dynamicCast<TimVXBackendWrapper>(); |
||||
CV_Assert(!base.empty()); |
||||
tensor = base->tensor; |
||||
isTensor_ = base->isTensor_; |
||||
tensorIndex = base->tensorIndex; |
||||
tensorType = base->tensorType; |
||||
tensorAttr = base->tensorAttr; |
||||
tensorShape = base->tensorShape; |
||||
deviceDirty = base->deviceDirty; |
||||
hostDirty = base->hostDirty; |
||||
host = m; |
||||
} |
||||
|
||||
TimVXBackendWrapper::TimVXBackendWrapper(std::shared_ptr<tim::vx::Tensor>& tensor_) |
||||
:BackendWrapper(DNN_BACKEND_TIMVX, DNN_TARGET_NPU) |
||||
{ |
||||
tensor = tensor_; |
||||
isTensor_ = true; |
||||
deviceDirty = true; |
||||
hostDirty = false; |
||||
tensorType = tensor_->GetDataType(); // getTensor DataType.
|
||||
tensorAttr = tensor_->GetSpec().attr_; // getTensor Attribution.
|
||||
tensorShape = tensor_->GetShape(); |
||||
tensorIndex = -1; |
||||
} |
||||
|
||||
void TimVXBackendWrapper::setTensorShape(const tim::vx::ShapeType & matShape) |
||||
{ |
||||
CV_Assert(!matShape.empty()); |
||||
tensorShape.assign(matShape.begin(), matShape.end()); |
||||
} |
||||
|
||||
int TimVXBackendWrapper::getTensorIndex() |
||||
{ |
||||
CV_Assert(isTensor_); |
||||
return tensorIndex; |
||||
} |
||||
|
||||
tim::vx::TensorAttribute TimVXBackendWrapper::getTensorAttr() |
||||
{ |
||||
CV_Assert(isTensor_); |
||||
return tensorAttr; |
||||
} |
||||
|
||||
// Create tensor
|
||||
void TimVXBackendWrapper::createTensor(std::shared_ptr<tim::vx::Graph>& graph, |
||||
tim::vx::TensorAttribute tensorAttribute) |
||||
{ |
||||
Ptr<tim::vx::Quantization> epmtyQuant = nullptr; |
||||
return this->createTensor(graph, tensorAttribute, epmtyQuant); |
||||
} |
||||
|
||||
// Create tensor
|
||||
void TimVXBackendWrapper::createTensor(std::shared_ptr<tim::vx::Graph>& graph, |
||||
tim::vx::TensorAttribute tensorAttribute, Ptr<tim::vx::Quantization>& tvQuant) |
||||
{ |
||||
CV_Assert(graph); |
||||
tim::vx::TensorSpec tensorSpec; |
||||
|
||||
if (tensorAttribute == tim::vx::INPUT) |
||||
{ |
||||
CV_Assert(!host.empty()); |
||||
tensorShape = getShapeTypeFromMat(host); |
||||
} |
||||
else if (tensorAttribute == tim::vx::OUTPUT) |
||||
{ |
||||
CV_Assert(!tensorShape.empty() && !host.empty()); |
||||
tensorShape = getShapeTypeFromMat(host); |
||||
} |
||||
else if (tensorAttribute == tim::vx::CONSTANT) |
||||
{ |
||||
if (!host.empty()) |
||||
tensorShape = getShapeTypeFromMat(host, true); |
||||
} |
||||
else |
||||
{ |
||||
if (!host.empty()) |
||||
tensorShape = getShapeTypeFromMat(host); |
||||
} |
||||
|
||||
// Tensor shape
|
||||
if (tvQuant) |
||||
{ |
||||
tensorSpec = tim::vx::TensorSpec(tensorType, tensorShape, tensorAttribute, *tvQuant); |
||||
} |
||||
else |
||||
{ |
||||
tensorSpec = tim::vx::TensorSpec(tensorType, tensorShape, tensorAttribute); |
||||
} |
||||
|
||||
if (!host.empty() && tensorAttribute != tim::vx::INPUT && tensorAttribute != tim::vx::OUTPUT && tensorAttribute != tim::vx::TRANSIENT) |
||||
{ |
||||
tensor = graph->CreateTensor(tensorSpec, (void *)(host.data)); |
||||
} |
||||
else |
||||
{ |
||||
tensor = graph->CreateTensor(tensorSpec); |
||||
} |
||||
isTensor_ = true; |
||||
|
||||
// set Attribution
|
||||
tensorAttr = tensorAttribute; |
||||
} |
||||
|
||||
Ptr<tim::vx::Quantization> TimVXBackendWrapper::getTensorQuantization() |
||||
{ |
||||
CV_Assert(isTensor_ && tensor); |
||||
auto quantize = tensor->GetQuantization(); |
||||
return makePtr<tim::vx::Quantization>(quantize); |
||||
} |
||||
|
||||
std::shared_ptr<tim::vx::Tensor> TimVXBackendWrapper::getTensor() |
||||
{ |
||||
CV_Assert(isTensor_); |
||||
return tensor; |
||||
} |
||||
|
||||
Mat TimVXBackendWrapper::getMat() |
||||
{ |
||||
if (host.empty()) |
||||
return {}; |
||||
return host; |
||||
} |
||||
|
||||
|
||||
bool TimVXBackendWrapper::isTensor() |
||||
{ |
||||
return isTensor_; |
||||
} |
||||
|
||||
void TimVXBackendWrapper::copyToHost() |
||||
{ |
||||
if (deviceDirty && !host.empty()) |
||||
{ |
||||
copyToMat(host, tensor); |
||||
deviceDirty = false; |
||||
} |
||||
} |
||||
|
||||
void TimVXBackendWrapper::setHostDirty() |
||||
{ |
||||
hostDirty = true; |
||||
} |
||||
|
||||
void TimVXBackendWrapper::setDeviceDirty() |
||||
{ |
||||
deviceDirty = true; |
||||
} |
||||
|
||||
void TimVXBackendWrapper::copyToDevice() |
||||
{ |
||||
if (isTensor_ && hostDirty && !host.empty()) |
||||
{ |
||||
copyToTensor(tensor, host); |
||||
hostDirty = false; |
||||
} |
||||
} |
||||
|
||||
// *********************** TimVXInfo ********************
|
||||
TimVXInfo::TimVXInfo() |
||||
{ |
||||
graphIndex = -1; |
||||
} |
||||
|
||||
TimVXInfo::~TimVXInfo() |
||||
{} |
||||
|
||||
int TimVXInfo::createGraph() |
||||
{ |
||||
Ptr<TimVXGraph> tmpGraph = Ptr<TimVXGraph>(new TimVXGraph()); |
||||
this->tvGraphList.push_back(tmpGraph); |
||||
return this->tvGraphList.size() - 1; |
||||
} |
||||
|
||||
bool TimVXInfo::findGraphIndex(const std::vector<Ptr<BackendWrapper> > &inputsWrapper, int& graphIndex) |
||||
{ |
||||
graphIndex = -1; |
||||
int wrapperSize = inputsWrapper.size(); |
||||
int graphSize = tvGraphList.size(); |
||||
|
||||
if (wrapperSize != 0 && graphSize == 0) |
||||
{ |
||||
return true; |
||||
} |
||||
|
||||
int tensorIndex = -1; |
||||
Ptr<TimVXBackendWrapper> wrapper; |
||||
Ptr<TimVXGraph> tvGraph; |
||||
|
||||
for (int i = 0; i < graphSize; i++) |
||||
{ |
||||
tvGraph = tvGraphList[i]; |
||||
for (int j = 0; j < wrapperSize; j++ ) |
||||
{ |
||||
wrapper = inputsWrapper[j].dynamicCast<TimVXBackendWrapper>(); |
||||
|
||||
if (!wrapper->isTensor()) // Skip wrapper without Tensor.
|
||||
continue; |
||||
|
||||
tensorIndex = tvGraph->getTensorIndex(wrapper->getTensor()); |
||||
if (tensorIndex != -1 && wrapper->getTensorAttr() == tim::vx::TensorAttribute::TRANSIENT) |
||||
{ |
||||
if (graphIndex == -1) |
||||
graphIndex = i; |
||||
else if (graphIndex != i) // if inputs of the same inputWrapper are from differen tvGraph.
|
||||
{ |
||||
graphIndex = -1; |
||||
return false; |
||||
} |
||||
} |
||||
} |
||||
} |
||||
return true; |
||||
} |
||||
|
||||
void TimVXInfo::setTmpGraphIndex(int graphIndex) |
||||
{ |
||||
this->graphIndex = graphIndex; |
||||
} |
||||
|
||||
int TimVXInfo::getTmpGraphIndex() |
||||
{ |
||||
int res = -1; |
||||
if (graphIndex != -1) |
||||
{ |
||||
res = graphIndex; |
||||
graphIndex = -1; |
||||
} |
||||
return res; |
||||
} |
||||
|
||||
bool TimVXInfo::isConflict(int layerId, int graphIndex) |
||||
{ |
||||
if (graphConflictMap[layerId].empty()) |
||||
return false; |
||||
|
||||
std::vector<int>::iterator it = std::find(graphConflictMap[layerId].begin(), |
||||
graphConflictMap[layerId].end(), graphIndex); |
||||
if (it != graphConflictMap[layerId].end()) |
||||
return true; |
||||
else |
||||
return false; |
||||
} |
||||
|
||||
Ptr<TimVXGraph> TimVXInfo::getGraph() |
||||
{ |
||||
int index = getTmpGraphIndex(); |
||||
if (0 <= index && index < tvGraphList.size()) |
||||
return tvGraphList[index]; |
||||
else |
||||
return {}; |
||||
} |
||||
|
||||
#endif |
||||
|
||||
void forwardTimVX(std::vector<Ptr<BackendWrapper> >& outputs, const Ptr<BackendNode>& node_) |
||||
{ |
||||
#ifdef HAVE_TIMVX |
||||
CV_Assert(!node_.empty()); |
||||
Ptr<TimVXBackendNode> node = node_.dynamicCast<TimVXBackendNode>(); |
||||
|
||||
if (node) |
||||
{ |
||||
// set input
|
||||
node->setInputTensor(); |
||||
|
||||
// graph Forward
|
||||
if (node->isLast) |
||||
{ |
||||
node->tvGraph->forward(); |
||||
} |
||||
} |
||||
else |
||||
return; |
||||
|
||||
// set ouput
|
||||
Ptr<TimVXBackendWrapper> outWarpper; |
||||
for (int i = 0; i < outputs.size(); i++) |
||||
{ |
||||
outWarpper = outputs[i].dynamicCast<TimVXBackendWrapper>(); |
||||
if (outWarpper->isTensor() && outWarpper->getTensorAttr() == tim::vx::TensorAttribute::OUTPUT) |
||||
{ |
||||
outWarpper->setDeviceDirty(); |
||||
outWarpper->copyToHost(); |
||||
} |
||||
} |
||||
#endif |
||||
} |
||||
|
||||
bool haveTimVX() |
||||
{ |
||||
#ifdef HAVE_TIMVX |
||||
return true; |
||||
#else |
||||
return false; |
||||
#endif |
||||
} |
||||
} // namespace dnn
|
||||
} // namespace cv
|
@ -0,0 +1,187 @@ |
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
//
|
||||
// Copyright (C) 2019-2021, Shenzhen Institute of Artificial Intelligence and
|
||||
// Robotics for Society, all rights reserved.
|
||||
// Third party copyrights are property of their respective owners.
|
||||
|
||||
#ifndef OPENCV_DNN_OP_TIMVX_HPP |
||||
#define OPENCV_DNN_OP_TIMVX_HPP |
||||
|
||||
#include <opencv2/dnn/shape_utils.hpp> |
||||
|
||||
// TimVX head file.
|
||||
#ifdef HAVE_TIMVX |
||||
#include "tim/vx/context.h" |
||||
#include "tim/vx/graph.h" |
||||
#include "tim/vx/operation.h" |
||||
#include "tim/vx/ops.h" |
||||
#include "tim/vx/tensor.h" |
||||
#endif // HAVE_TIMVX
|
||||
|
||||
namespace cv |
||||
{ |
||||
namespace dnn |
||||
{ |
||||
#ifdef HAVE_TIMVX |
||||
|
||||
enum tvActivationType{ |
||||
tvActNotSupported = -1, |
||||
tvActReLU, |
||||
tvActReLU6, |
||||
tvActTanH, |
||||
tvActSwish, |
||||
tvActMish, |
||||
tvActSigmoid, |
||||
tvActELU |
||||
}; |
||||
|
||||
// Data copied from/to Mat to/from Tensor. Change the shape of dst if
|
||||
// needed to make it the same shape as src.
|
||||
bool copyToTensor(Ptr<tim::vx::Tensor> &dst, const Mat &src); |
||||
bool copyToMat(const Mat &dst, Ptr<tim::vx::Tensor> &src); |
||||
tvActivationType getTimVXActType(String & actString); |
||||
|
||||
// Convert Mat shape to TimVX TensorShape
|
||||
tim::vx::ShapeType getShapeTypeFromMat(const Mat& mat, bool ifConst = false); |
||||
|
||||
// if all value in weight
|
||||
bool getQuantType(const std::vector<float>& scales, int numOutput = -1); |
||||
|
||||
class TimVXInfo; |
||||
class TimVXGraph; |
||||
class TimVXBackendNode; |
||||
class TimVXBackendWrapper; |
||||
|
||||
// Maintain the tvGraph and tvTensor List. For now, every tvGraph only have one output node, and each node
|
||||
// in tvGraph has only one output too. It could be optimized in future.
|
||||
// TODO: tvGraph supports multiple output node.
|
||||
class TimVXGraph |
||||
{ |
||||
public: |
||||
TimVXGraph(); |
||||
~TimVXGraph(); |
||||
std::shared_ptr<tim::vx::Operation> getOp(const int opIndex); |
||||
|
||||
// It will add tensorWrapper to wrapperList, and return index.
|
||||
// And add tensor Ptr to tensorList.
|
||||
int addWrapper(Ptr<TimVXBackendWrapper>& tensorWrapper); |
||||
|
||||
void forward(); |
||||
|
||||
// Add new op to opList, and return the index.
|
||||
int addOp(const std::shared_ptr<tim::vx::Operation>& op); |
||||
|
||||
// If tensor existed in tensorList, return the tensorIndex, otherwise return -1.
|
||||
int getTensorIndex(const std::shared_ptr<tim::vx::Tensor>& tensor); |
||||
|
||||
Ptr<TimVXBackendWrapper> getWrapper(int wrapperIndex); |
||||
|
||||
std::shared_ptr<tim::vx::Graph> graph; |
||||
bool isCompiled; // Every tvGraph can only be compiled once.
|
||||
|
||||
private: |
||||
std::shared_ptr<tim::vx::Context> context; |
||||
std::vector<int> inputWrappersIndex; |
||||
std::vector<int> outputWrappersIndex; |
||||
std::vector<Ptr<TimVXBackendWrapper> > wrapperList; |
||||
std::vector<std::shared_ptr<tim::vx::Tensor> > tensorList; |
||||
std::vector<std::shared_ptr<tim::vx::Operation> > opList; |
||||
}; |
||||
|
||||
class TimVXBackendNode : public BackendNode |
||||
{ |
||||
public: |
||||
TimVXBackendNode(const Ptr<TimVXGraph>& tvGraph); |
||||
TimVXBackendNode(const Ptr<TimVXGraph>& tvGraph, const std::shared_ptr<tim::vx::Operation>& op); |
||||
TimVXBackendNode(const Ptr<TimVXGraph>& tvGraph, std::shared_ptr<tim::vx::Operation>& op, |
||||
std::vector<int>& inputsIndex, std::vector<int>& outpusIndex); |
||||
|
||||
void setInputTensor(); |
||||
bool opBinding(); |
||||
|
||||
// flag for marking OutputNode of tvGraph this node is the last node in this TimVX Graph.
|
||||
bool isLast; |
||||
int opIndex; |
||||
|
||||
// index of tensor and wrapper.
|
||||
std::vector<int> inputIndexList; |
||||
std::vector<int> outputIndexList; |
||||
Ptr<TimVXGraph> tvGraph; |
||||
}; |
||||
|
||||
class TimVXBackendWrapper : public BackendWrapper |
||||
{ |
||||
public: |
||||
TimVXBackendWrapper(); |
||||
TimVXBackendWrapper(Mat& m); |
||||
TimVXBackendWrapper(const Ptr<BackendWrapper>& baseBuffer, Mat& m); |
||||
TimVXBackendWrapper(std::shared_ptr<tim::vx::Tensor>& tensor); |
||||
|
||||
// Create Output Tensor
|
||||
void createTensor(std::shared_ptr<tim::vx::Graph>& graph, tim::vx::TensorAttribute tensorAttribute); |
||||
void createTensor(std::shared_ptr<tim::vx::Graph>& graph, tim::vx::TensorAttribute tensorAttribute, |
||||
Ptr<tim::vx::Quantization>& tvQuant); |
||||
std::shared_ptr<tim::vx::Tensor> getTensor(); |
||||
Mat getMat(); |
||||
|
||||
// The Output tensor in TimVX doesn't have HostMat, The shape can only be given.
|
||||
void setTensorShape(const tim::vx::ShapeType & matShape); |
||||
int getTensorIndex(); |
||||
Ptr<tim::vx::Quantization> getTensorQuantization(); |
||||
tim::vx::TensorAttribute getTensorAttr(); |
||||
bool isTensor(); |
||||
|
||||
// Data Copy, CPU <==> NPU
|
||||
virtual void copyToHost() CV_OVERRIDE; |
||||
virtual void setHostDirty() CV_OVERRIDE; |
||||
void setDeviceDirty(); |
||||
void copyToDevice(); |
||||
|
||||
private: |
||||
tim::vx::DataType tensorType; |
||||
bool deviceDirty; |
||||
bool hostDirty; |
||||
int tensorIndex; // index of tensorList in specific TimVXGraph.
|
||||
bool isTensor_; |
||||
Mat host; |
||||
|
||||
tim::vx::ShapeType tensorShape; |
||||
std::shared_ptr<tim::vx::Tensor> tensor; |
||||
tim::vx::TensorAttribute tensorAttr; |
||||
}; |
||||
|
||||
// Contain all created tvGraphList, used in every
|
||||
class TimVXInfo{ |
||||
public: |
||||
TimVXInfo(); |
||||
~TimVXInfo(); |
||||
|
||||
// Find the right graph Index set as graphIndex, if cannot find, return empty ptr.
|
||||
Ptr<TimVXGraph> getGraph(); |
||||
bool findGraphIndex(const std::vector<Ptr<BackendWrapper> > &inputsWrapper, int& graphIndex); |
||||
void setTmpGraphIndex(int graphIndex); |
||||
bool isConflict(int layerId, int graphIndex); |
||||
|
||||
// create a TimVXGraph, add it to tvGraphList, and return the index in tvGraphList.
|
||||
int createGraph(); |
||||
|
||||
// graphConflictIndex[layerIndex] saves conflict graph index, which should be excluded
|
||||
std::vector<std::vector<int> > graphConflictMap; |
||||
|
||||
private: |
||||
int getTmpGraphIndex(); |
||||
std::vector<Ptr<TimVXGraph> > tvGraphList; |
||||
int graphIndex; |
||||
|
||||
}; |
||||
|
||||
#endif |
||||
|
||||
void forwardTimVX(std::vector<Ptr<BackendWrapper> > &outputs, const Ptr<BackendNode>& node); |
||||
bool haveTimVX(); |
||||
} // namespace dnn
|
||||
} // namespace cv
|
||||
|
||||
#endif // OPENCV_DNN_OP_TIMVX_HPP
|
Loading…
Reference in new issue