Merge pull request #21036 from fengyuentau:timvx_backend_support

dnn: TIM-VX NPU backend support

* Add TimVX NPU backend for DNN module.

* use official branch from tim-vx repo; fix detecting viv sdk

Co-authored-by: fytao <yuantao.feng@outlook.com>
This commit is contained in:
Zihao Mu 2022-04-01 05:42:11 +08:00 committed by GitHub
parent 9390c56831
commit 7b582b71ba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
37 changed files with 2982 additions and 30 deletions

73
3rdparty/libtim-vx/tim-vx.cmake vendored Normal file
View File

@ -0,0 +1,73 @@
set(TIMVX_COMMIT_HASH "1d9c7ab941b3d8d9c4d28d80058402725731e3d6")
set(OCV_TIMVX_DIR "${OpenCV_BINARY_DIR}/3rdparty/libtim-vx")
set(OCV_TIMVX_SOURCE_PATH "${OCV_TIMVX_DIR}/TIM-VX-${TIMVX_COMMIT_HASH}")
# Download TIM-VX source code
if(EXISTS "${OCV_TIMVX_SOURCE_PATH}")
message(STATUS "TIM-VX: Use cache of TIM-VX source code at ${OCV_TIMVX_SOURCE_PATH}")
set(TIMVX_FOUND ON)
else()
set(OCV_TIMVX_FILENAME "${TIMVX_COMMIT_HASH}.zip")
set(OCV_TIMVX_URL "https://github.com/VeriSilicon/TIM-VX/archive/")
set(timvx_zip_md5sum 92619cc4498014ac7a09834d5e33ebd5)
ocv_download(FILENAME ${OCV_TIMVX_FILENAME}
HASH ${timvx_zip_md5sum}
URL "${OCV_TIMVX_URL}"
DESTINATION_DIR "${OCV_TIMVX_DIR}"
ID "TIM-VX"
STATUS res
UNPACK RELATIVE_URL)
if(res)
set(TIMVX_FOUND ON)
message(STATUS "TIM-VX: Source code downloaded at ${OCV_TIMVX_SOURCE_PATH}.")
else()
set(TIMVX_FOUND OFF)
message(STATUS "TIM-VX: Failed to download source code from github. Turning off TIMVX_FOUND")
return()
endif()
endif()
# set VIVANTE SDK especially for x86_64 which comes along with TIM-VX source code
if(CMAKE_SYSTEM_PROCESSOR STREQUAL x86_64)
set(VIVANTE_SDK_DIR "${OCV_TIMVX_SOURCE_PATH}/prebuilt-sdk/x86_64_linux")
message(STATUS "TIM-VX: Build from source using prebuilt x86_64 VIVANTE SDK.")
endif()
# Verify if requested VIVANTE SDK libraries are all found
find_vivante_sdk_libs(missing ${VIVANTE_SDK_DIR})
if(missing)
message(STATUS "TIM-VX: Failed to find ${missing} in ${VIVANTE_SDK_DIR}/lib. Turning off TIMVX_VIV_FOUND")
set(TIMVX_VIV_FOUND OFF)
else()
message(STATUS "TIM-VX: dependent VIVANTE SDK libraries are found at ${VIVANTE_SDK_DIR}/lib.")
set(TIMVX_VIV_FOUND ON)
endif()
if(TIMVX_VIV_FOUND)
# vars used by TIM-VX CMake scripts
set(EXTERNAL_VIV_SDK "${VIVANTE_SDK_DIR}" CACHE INTERNAL "" FORCE)
set(VIV_SDK_DRIVER_PREFIX "lib" CACHE INTERNAL "" FORCE)
endif()
if(TIMVX_FOUND AND TIMVX_VIV_FOUND)
set(BUILD_TIMVX ON)
else()
return()
endif()
if(BUILD_TIMVX)
set(HAVE_TIMVX 1)
ocv_warnings_disable(CMAKE_C_FLAGS -Wunused-parameter -Wstrict-prototypes -Wundef -Wsign-compare -Wmissing-prototypes -Wmissing-declarations -Wstrict-aliasing -Wunused-but-set-variable -Wmaybe-uninitialized -Wshadow -Wsuggest-override -Wswitch)
ocv_warnings_disable(CMAKE_CXX_FLAGS -Wunused-parameter -Wstrict-prototypes -Wundef -Wsign-compare -Wunused-but-set-variable -Wshadow -Wsuggest-override -Wmissing-declarations -Wswitch)
set(TIMVX_INC_DIR "${OCV_TIMVX_SOURCE_PATH}/include" CACHE INTERNAL "TIM-VX include directory")
if(EXISTS "${OCV_TIMVX_SOURCE_PATH}/CMakeLists.txt")
add_subdirectory("${OCV_TIMVX_SOURCE_PATH}" "${OCV_TIMVX_DIR}/build")
else()
message(WARNING "TIM-VX: Missing 'CMakeLists.txt' in the source code: ${OCV_TIMVX_SOURCE_PATH}")
endif()
ocv_install_target(tim-vx EXPORT OpenCVModules ARCHIVE DESTINATION ${OPENCV_3P_LIB_INSTALL_PATH} COMPONENT dev)
set(TIMVX_LIB "tim-vx")
endif()

View File

@ -453,6 +453,9 @@ OCV_OPTION(WITH_TENGINE "Include Arm Inference Tengine support" OFF
OCV_OPTION(WITH_ONNX "Include Microsoft ONNX Runtime support" OFF
VISIBLE_IF TRUE
VERIFY HAVE_ONNX)
OCV_OPTION(WITH_TIMVX "Include Tim-VX support" OFF
VISIBLE_IF TRUE
VERIFY HAVE_TIMVX)
# OpenCV build components
# ===================================================
@ -733,6 +736,9 @@ include(cmake/OpenCVFindProtobuf.cmake)
if(WITH_TENGINE)
include(cmake/OpenCVFindTengine.cmake)
endif()
if(WITH_TIMVX)
include(cmake/OpenCVFindTIMVX.cmake)
endif()
# ----------------------------------------------------------------------------
# Detect other 3rd-party libraries/tools
@ -1645,6 +1651,16 @@ if(WITH_WEBNN OR HAVE_WEBNN)
endif()
endif()
if(WITH_TIMVX)
status("")
status(" Tim-VX:" HAVE_TIMVX THEN "YES" ELSE "NO")
if(HAVE_TIMVX)
status(" Include path" TIMVX_INCLUDE_DIR THEN "${TIMVX_INCLUDE_DIR}" ELSE "NO")
status(" Link libraries:" TIMVX_LIBRARY THEN "${TIMVX_LIBRARY}" ELSE "NO")
status(" VIVANTE SDK path" VIVANTE_SDK_DIR THEN "${VIVANTE_SDK_DIR}" ELSE "NO")
endif()
endif()
if(WITH_OPENCL OR HAVE_OPENCL)
ocv_build_features_string(opencl_features
IF HAVE_OPENCL_SVM THEN "SVM"

View File

@ -0,0 +1,69 @@
set(TIMVX_INSTALL_DIR "" CACHE PATH "Path to libtim-vx installation")
set(VIVANTE_SDK_DIR "" CACHE PATH "Path to VIVANTE SDK needed by TIM-VX.")
set(VIVANTE_SDK_LIB_CANDIDATES "OpenVX;VSC;GAL;ArchModelSw;NNArchPerf" CACHE STRING "VIVANTE SDK library candidates")
# Ensure VIVANTE SDK library candidates are present in given search path
function(find_vivante_sdk_libs _viv_notfound _viv_search_path)
foreach(one ${VIVANTE_SDK_LIB_CANDIDATES})
#NO_DEFAULT_PATH is used to ensure VIVANTE SDK libs are from one only source
find_library(VIV_${one}_LIB ${one} PATHS "${_viv_search_path}/lib" NO_DEFAULT_PATH)
if(NOT VIV_${one}_LIB)
list(APPEND _viv_notfound_list ${one})
endif()
endforeach()
set(${_viv_notfound} ${_viv_notfound_list} PARENT_SCOPE)
endfunction()
# Default value for VIVANTE_SDK_DIR: /usr
if(NOT VIVANTE_SDK_DIR)
set(VIVANTE_SDK_DIR "/usr")
endif()
# Environment variable VIVANTE_SDK_DIR overrides the one in this script
if(DEFINED ENV{VIVANTE_SDK_DIR})
set(VIVANTE_SDK_DIR $ENV{VIVANTE_SDK_DIR})
message(STATUS "TIM-VX: Load VIVANTE_SDK_DIR from system environment: ${VIVANTE_SDK_DIR}")
endif()
# Compile with pre-installed TIM-VX; Or compile together with TIM-VX from source
if(TIMVX_INSTALL_DIR AND NOT BUILD_TIMVX)
message(STATUS "TIM-VX: Use binaries at ${TIMVX_INSTALL_DIR}")
set(BUILD_TIMVX OFF)
set(TIMVX_INC_DIR "${TIMVX_INSTALL_DIR}/include" CACHE INTERNAL "TIM-VX include directory")
find_library(TIMVX_LIB "tim-vx" PATHS "${TIMVX_INSTALL_DIR}/lib")
if(TIMVX_LIB)
set(TIMVX_FOUND ON)
else()
set(TIMVX_FOUND OFF)
endif()
# Verify if requested VIVANTE SDK libraries are all found
find_vivante_sdk_libs(missing ${VIVANTE_SDK_DIR})
if(missing)
message(STATUS "TIM-VX: Failed to find ${missing} in ${VIVANTE_SDK_DIR}/lib. Turning off TIMVX_VIV_FOUND")
set(TIMVX_VIV_FOUND OFF)
else()
message(STATUS "TIM-VX: dependent VIVANTE SDK libraries are found at ${VIVANTE_SDK_DIR}/lib.")
set(TIMVX_VIV_FOUND ON)
endif()
else()
message(STATUS "TIM-VX: Build from source")
include("${OpenCV_SOURCE_DIR}/3rdparty/libtim-vx/tim-vx.cmake")
endif()
if(TIMVX_FOUND AND TIMVX_VIV_FOUND)
set(HAVE_TIMVX 1)
message(STATUS "TIM-VX: Found TIM-VX includes: ${TIMVX_INC_DIR}")
message(STATUS "TIM-VX: Found TIM-VX library: ${TIMVX_LIB}")
set(TIMVX_LIBRARY ${TIMVX_LIB})
set(TIMVX_INCLUDE_DIR ${TIMVX_INC_DIR})
message(STATUS "TIM-VX: Found VIVANTE SDK libraries: ${VIVANTE_SDK_DIR}/lib")
link_directories(${VIVANTE_SDK_DIR}/lib)
endif()
MARK_AS_ADVANCED(
TIMVX_INC_DIR
TIMVX_LIB
)

View File

@ -23,6 +23,10 @@ if(WITH_WEBNN AND HAVE_WEBNN)
add_definitions(-DHAVE_WEBNN=1)
endif()
if(HAVE_TIMVX)
add_definitions(-DHAVE_TIMVX=1)
endif()
ocv_option(OPENCV_DNN_CUDA "Build with CUDA support"
HAVE_CUDA
AND HAVE_CUBLAS
@ -146,6 +150,11 @@ if(HAVE_TENGINE)
list(APPEND libs -Wl,--whole-archive ${TENGINE_LIBRARIES} -Wl,--no-whole-archive)
endif()
if(HAVE_TIMVX)
list(APPEND include_dirs ${TIMVX_INCLUDE_DIR})
list(APPEND libs -Wl,--whole-archive ${TIMVX_LIBRARY} -Wl,--no-whole-archive)
endif()
set(webnn_srcs "")
if(NOT EMSCRIPTEN)
if(HAVE_WEBNN)

View File

@ -262,7 +262,7 @@ CV__DNN_INLINE_NS_BEGIN
{
public:
int input_zp, output_zp;
float output_sc;
float input_sc, output_sc;
static Ptr<BaseConvolutionLayer> create(const LayerParams& params);
};
@ -322,6 +322,7 @@ CV__DNN_INLINE_NS_BEGIN
{
public:
int input_zp, output_zp;
float input_sc, output_sc;
static Ptr<PoolingLayerInt8> create(const LayerParams& params);
};
@ -365,7 +366,8 @@ CV__DNN_INLINE_NS_BEGIN
class CV_EXPORTS InnerProductLayerInt8 : public InnerProductLayer
{
public:
int output_zp;
int input_zp, output_zp;
float input_sc, output_sc;
static Ptr<InnerProductLayerInt8> create(const LayerParams& params);
};

View File

@ -75,6 +75,7 @@ CV__DNN_INLINE_NS_BEGIN
DNN_BACKEND_VKCOM,
DNN_BACKEND_CUDA,
DNN_BACKEND_WEBNN,
DNN_BACKEND_TIMVX,
#ifdef __OPENCV_BUILD
DNN_BACKEND_INFERENCE_ENGINE_NGRAPH = 1000000, // internal - use DNN_BACKEND_INFERENCE_ENGINE + setInferenceEngineBackendType()
DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019, // internal - use DNN_BACKEND_INFERENCE_ENGINE + setInferenceEngineBackendType()
@ -95,7 +96,8 @@ CV__DNN_INLINE_NS_BEGIN
DNN_TARGET_FPGA, //!< FPGA device with CPU fallbacks using Inference Engine's Heterogeneous plugin.
DNN_TARGET_CUDA,
DNN_TARGET_CUDA_FP16,
DNN_TARGET_HDDL
DNN_TARGET_HDDL,
DNN_TARGET_NPU,
};
CV_EXPORTS std::vector< std::pair<Backend, Target> > getAvailableBackends();
@ -321,6 +323,19 @@ CV__DNN_INLINE_NS_BEGIN
const std::vector<Ptr<BackendWrapper>>& outputs
);
/**
* @brief Returns a TimVX backend node
*
* @param timVxInfo void pointer to CSLContext object
* @param inputsWrapper layer inputs
* @param outputsWrapper layer outputs
* @param isLast if the node is the last one of the TimVX Graph.
*/
virtual Ptr<BackendNode> initTimVX(void* timVxInfo,
const std::vector<Ptr<BackendWrapper> > &inputsWrapper,
const std::vector<Ptr<BackendWrapper> > &outputsWrapper,
bool isLast);
/**
* @brief Automatic Halide scheduling based on layer hyper-parameters.
* @param[in] node Backend node with Halide functions.

View File

@ -4,6 +4,8 @@
#include "../precomp.hpp"
#include "layers_common.hpp"
#include "../op_timvx.hpp"
#include <opencv2/dnn/shape_utils.hpp>
namespace cv
@ -103,6 +105,11 @@ public:
virtual bool supportBackend(int backendId) CV_OVERRIDE
{
if (backendId == DNN_BACKEND_TIMVX && haveTimVX())
{
return true;
}
return backendId == DNN_BACKEND_OPENCV;
}
@ -116,6 +123,121 @@ public:
return false;
}
virtual Ptr<BackendNode> initTimVX(void* timVXInfo_,
const std::vector<Ptr<BackendWrapper> > &inputsWrapper,
const std::vector<Ptr<BackendWrapper> > &outputsWrapper,
bool isLast) CV_OVERRIDE
{
#ifdef HAVE_TIMVX
// tvGraph Initialization.
auto timVxInfo = reinterpret_cast<TimVXInfo *>(timVXInfo_);
CV_Assert(timVxInfo);
Ptr<TimVXGraph> tvGraph = timVxInfo->getGraph();
CV_Assert(tvGraph);
Ptr<tim::vx::Graph> graph = tvGraph->graph;
const int numChannels = (int)origin_bias.total();
Mat tvGamma = origin_weights.reshape(1, numChannels);
Mat tvBeta = origin_bias.reshape(1, numChannels);
std::vector<int> inputsIndex;
std::vector<int> outputsIndex;
Mat tvMean = Mat::zeros(1, numChannels, CV_32F);
tvMean = tvMean.reshape(1, numChannels);
Mat tvVar = Mat::ones(1, numChannels, CV_32F);
tvVar = tvVar.reshape(1, numChannels);
CV_Assert(inputsWrapper.size() == 1);
if (outputsWrapper.size() > 1)
return Ptr<BackendNode>();
Ptr<tim::vx::Quantization> tvInputQuant = Ptr<tim::vx::Quantization>(
new tim::vx::Quantization(tim::vx::QuantType::ASYMMETRIC, input_sc, input_zp));
// input Tensor
auto inputWrapper = inputsWrapper[0].dynamicCast<TimVXBackendWrapper>();
Mat tmpInput = inputWrapper->getMat();
if (tmpInput.dims != 4) // Only support 4 dim input.
return Ptr<BackendNode>();
int input_index = -1, mean_index = -1, var_index = -1, gamma_index = -1, beta_index = -1, output_index = -1;
if (inputWrapper->isTensor())
{
input_index = tvGraph->getTensorIndex(inputWrapper->getTensor());
if (input_index == -1)
{
// Copy To New inputWrapper
Mat tmp = inputWrapper->getMat();
inputWrapper = Ptr<TimVXBackendWrapper>(new TimVXBackendWrapper(tmp));
}
}
if (!inputWrapper->isTensor())
{
inputWrapper->createTensor(graph,tim::vx::TensorAttribute::INPUT, tvInputQuant);
input_index = tvGraph->addWrapper(inputWrapper);
}
inputsIndex.push_back(input_index);
// Mean tensor
Ptr<TimVXBackendWrapper> meanWrapper = Ptr<TimVXBackendWrapper>(new TimVXBackendWrapper(tvMean));
Ptr<tim::vx::Quantization> meanQuant;
meanWrapper->createTensor(graph, tim::vx::TensorAttribute::CONSTANT);
mean_index = tvGraph->addWrapper(meanWrapper);
inputsIndex.push_back(mean_index);
// Var tensor
Ptr<TimVXBackendWrapper> varWrapper = Ptr<TimVXBackendWrapper>(new TimVXBackendWrapper(tvVar));
varWrapper->createTensor(graph,tim::vx::TensorAttribute::CONSTANT);
var_index = tvGraph->addWrapper(varWrapper);
inputsIndex.push_back(var_index);
// Gamma tensor
Ptr<TimVXBackendWrapper> gammaWrapper = Ptr<TimVXBackendWrapper>(new TimVXBackendWrapper(tvGamma));
gammaWrapper->createTensor(graph,tim::vx::TensorAttribute::CONSTANT);
gamma_index = tvGraph->addWrapper(gammaWrapper);
inputsIndex.push_back(gamma_index);
// Beta tensor
Ptr<TimVXBackendWrapper> betaWrapper = Ptr<TimVXBackendWrapper>(new TimVXBackendWrapper(tvBeta));
betaWrapper->createTensor(graph,tim::vx::TensorAttribute::CONSTANT);
beta_index = tvGraph->addWrapper(betaWrapper);
inputsIndex.push_back(beta_index);
// Output tensor
CV_Assert(outputsWrapper.size() == 1);
Ptr<TimVXBackendWrapper> outputWrapper = outputsWrapper[0].dynamicCast<TimVXBackendWrapper>();
Ptr<tim::vx::Quantization> outputQuant = Ptr<tim::vx::Quantization>(
new tim::vx::Quantization(tim::vx::QuantType::ASYMMETRIC, output_sc, output_zp));
if (isLast)
{
auto shapeType = getShapeTypeFromMat(outputWrapper->getMat());
// For Graph Output tensor, we need to set tensor shape before createTensor().
outputWrapper->setTensorShape(shapeType);
outputWrapper->createTensor(graph, tim::vx::TensorAttribute::OUTPUT, outputQuant);
}
else
{
outputWrapper->createTensor(graph, tim::vx::TensorAttribute::TRANSIENT, outputQuant);
}
output_index = tvGraph->addWrapper(outputWrapper);
outputsIndex.push_back(output_index);
std::shared_ptr<tim::vx::Operation> tvBatchNorm = graph->CreateOperation<tim::vx::ops::BatchNorm>(0.f);
Ptr<TimVXBackendNode> tvBackendNode = new TimVXBackendNode(tvGraph, tvBatchNorm, inputsIndex, outputsIndex);
return tvBackendNode;
#endif // HAVE_TIMVX
return Ptr<BackendNode>();
}
void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE
{
CV_TRACE_FUNCTION();

View File

@ -9,6 +9,7 @@
#include "opencv2/core/hal/hal.hpp"
#include "opencv2/core/hal/intrin.hpp"
#include "../op_timvx.hpp"
#include <iostream>
#include <numeric>
@ -46,6 +47,7 @@ public:
int ngroups = params.get<int>("group", 1);
CV_Assert(numOutput % ngroups == 0);
input_sc = params.get<float>("input_scale");
input_zp = params.get<int>("input_zeropoint");
output_zp = params.get<int>("zeropoints");
output_sc = params.get<float>("scales");
@ -181,6 +183,16 @@ public:
virtual bool supportBackend(int backendId) CV_OVERRIDE
{
size_t ksize = kernel_size.size();
#ifdef HAVE_TIMVX
if (backendId == DNN_BACKEND_TIMVX)
{
/* only Conv1d and Conv2d supported. */
if (ksize == 2 || ksize == 1)
return true;
return false;
}
#endif
// Only default backend and Conv1D/Conv2D/Conv3D are supported
return backendId == DNN_BACKEND_OPENCV && ksize >= 1 && ksize <= 3;
}
@ -261,6 +273,11 @@ public:
bool setActivation(const Ptr<ActivationLayer>& layer) CV_OVERRIDE
{
// TODO! add activation in convolution.
#ifdef HAVE_TIMVX
if (preferableTarget == DNN_TARGET_NPU)
return false;
#endif
Ptr<ActivationLayerInt8> activ_int8 = layer.dynamicCast<ActivationLayerInt8>();
if (!activ_int8.empty())
{
@ -300,6 +317,249 @@ public:
outputMultiplier[outCn] = outputMultiplier[outCn+1] = outputMultiplier[outCn-1];
}
virtual Ptr<BackendNode> initTimVX(void* timVXInfo_,
const std::vector<Ptr<BackendWrapper> > &inputsWrapper,
const std::vector<Ptr<BackendWrapper> > &outputsWrapper,
bool isLast) CV_OVERRIDE
{
#ifdef HAVE_TIMVX
/* TODO :support GroupConv;
Ref:
https://github.com/VeriSilicon/TIM-VX/blob/main/docs/Operators.md#conv2d
Link Reference: https://github.com/VeriSilicon/TIM-VX/blob/main/src/tim/vx/ops/conv1d_test.cc
*/
// tvGraph Initialization.
auto timVxInfo = reinterpret_cast<TimVXInfo *>(timVXInfo_);
CV_Assert(timVxInfo);
Ptr<TimVXGraph> tvGraph = timVxInfo->getGraph();
CV_Assert(tvGraph);
Ptr<tim::vx::Graph> graph = tvGraph->graph;
Mat tvWeightMat = blobs[0];
std::vector<int> tvBiasVec;
tvBiasVec.assign(biasvec.begin(), biasvec.end() - 2);
Mat tvBiasMat(tvBiasVec);
for (int i = 0; i < numOutput; i++)
{
tvBiasVec[i] += input_zp * (cv::sum(blobs[0].row(i))[0]);
}
// Padding Type
tim::vx::PadType tvPadType;
if (padMode.empty())
{
tvPadType = tim::vx::PadType::AUTO; // TODO! check the padding type.
}
else if(padMode == "VALID")
{
tvPadType = tim::vx::PadType::VALID;
}
else if (padMode == "SAME")
{
tvPadType = tim::vx::PadType::SAME;
}
else
{
CV_Error(Error::StsError, "Unsupported padding mode in TimVXBackend!");
}
size_t ksize = kernel_size.size();
std::vector<int> inputsIndex;
std::vector<int> outputsIndex;
CV_Assert(inputsWrapper.size() == 1);
CV_Assert(ksize == 2 || ksize == 1);
std::vector<float> weight_scs, bias_scs;
std::vector<int32_t> weight_zps, bias_zps;
weight_scs.resize(numOutput);
bias_scs.resize(numOutput);
for (int i = 0; i < numOutput; i++)
{
bias_scs[i] = outputMultiplier[i] * output_sc;
weight_scs[i] = bias_scs[i] / input_sc;
}
weight_zps.assign(numOutput, 0);
bias_zps.assign(numOutput, 0);
bool tvSymmetric;
tvSymmetric = getQuantType(weight_scs, numOutput);
// input Tensor
auto inputWrapper = inputsWrapper[0].dynamicCast<TimVXBackendWrapper>();
int input_index = -1, weight_index = -1, bias_index = -1, output_index = -1;
if (inputWrapper->isTensor())
{
input_index = tvGraph->getTensorIndex(inputWrapper->getTensor());
if (input_index == -1)
{
// Copy To New inputWrapper
Mat tmp = inputWrapper->getMat();
inputWrapper = Ptr<TimVXBackendWrapper>(new TimVXBackendWrapper(tmp));
}
}
if (!inputWrapper->isTensor())
{
Ptr<tim::vx::Quantization> tvInputQuant = Ptr<tim::vx::Quantization>(
new tim::vx::Quantization(tim::vx::QuantType::ASYMMETRIC, input_sc, input_zp));
inputWrapper->createTensor(graph, tim::vx::TensorAttribute::INPUT, tvInputQuant);
input_index = tvGraph->addWrapper(inputWrapper);
}
inputsIndex.push_back(input_index);
// weight Tensor
auto tvConvWeightShape = shape(tvWeightMat);
Mat tvInputMat = inputWrapper->getMat();
// calculate group value.
int group = tvInputMat.size[1] / tvWeightMat.size[1];
// TODO! It will be supported in future.
if (tvSymmetric && tvWeightMat.total() == tvConvWeightShape[0])
return Ptr<TimVXBackendNode>();
// Reverse weight shape From OpenCV NCHW to TimVX WHCN.
std::reverse(tvConvWeightShape.begin(), tvConvWeightShape.end());
Ptr<TimVXBackendWrapper> weightWrapper = Ptr<TimVXBackendWrapper>(new TimVXBackendWrapper(tvWeightMat));
Ptr<tim::vx::Quantization> weightQuant;
if (tvSymmetric)
{
int wtChanneldim = tvWeightMat.dims - 1;
weightQuant = Ptr<tim::vx::Quantization>(
new tim::vx::Quantization(tim::vx::QuantType::SYMMETRIC_PER_CHANNEL, wtChanneldim,
weight_scs, weight_zps));
}
else
{
weightQuant = Ptr<tim::vx::Quantization>(
new tim::vx::Quantization(tim::vx::QuantType::ASYMMETRIC, weight_scs[0], 0));
}
weightWrapper->createTensor(graph,tim::vx::TensorAttribute::CONSTANT, weightQuant);
weight_index = tvGraph->addWrapper(weightWrapper);
inputsIndex.push_back(weight_index);
// Bias Tensor
Ptr<TimVXBackendWrapper> biasWrapper = Ptr<TimVXBackendWrapper>(new TimVXBackendWrapper(tvBiasMat));
Ptr<tim::vx::Quantization> biasQuant;
if (tvSymmetric)
{
biasQuant = Ptr<tim::vx::Quantization>(
new tim::vx::Quantization(tim::vx::QuantType::SYMMETRIC_PER_CHANNEL, 0,
bias_scs, bias_zps));
}
else
{
biasQuant = Ptr<tim::vx::Quantization>(
new tim::vx::Quantization(tim::vx::QuantType::ASYMMETRIC, weight_scs[0] * input_sc, 0));
}
biasWrapper->createTensor(graph, tim::vx::TensorAttribute::CONSTANT, biasQuant);
bias_index = tvGraph->addWrapper(biasWrapper);
inputsIndex.push_back(bias_index);
// Output tensor
CV_Assert(outputsWrapper.size() == 1);
auto outputWrapper = outputsWrapper[0].dynamicCast<TimVXBackendWrapper>();
Ptr<tim::vx::Quantization> outputQuant = Ptr<tim::vx::Quantization>(
new tim::vx::Quantization(tim::vx::QuantType::ASYMMETRIC, output_sc, output_zp));
if (isLast)
{
// From OpenCV NCHW, to TimVX WHCN
auto shapeType = getShapeTypeFromMat(outputWrapper->getMat());
// For Graph Output tensor, we need to set tensor shape before createTensor().
outputWrapper->setTensorShape(shapeType);
outputWrapper->createTensor(graph, tim::vx::TensorAttribute::OUTPUT, outputQuant);
}
else
{
outputWrapper->createTensor(graph, tim::vx::TensorAttribute::TRANSIENT, outputQuant);
}
output_index = tvGraph->addWrapper(outputWrapper);
outputsIndex.push_back(output_index);
std::shared_ptr<tim::vx::Operation> tvConv;
if (ksize == 2) // for conv2d
{
int multiplier = 0;
if(group == tvConvWeightShape[3] && group != 1)
multiplier = 1;
if (group == 1 || (group == tvConvWeightShape[3] && group != 1)) // Conv2D || DeConv2D
{
if (tvPadType == tim::vx::PadType::AUTO) {
tvConv = graph->CreateOperation<tim::vx::ops::Conv2d>(
tvConvWeightShape[3], tvPadType,
std::array<uint32_t, 2>({(uint32_t) kernel_size[1], (uint32_t) kernel_size[0]}),
std::array<uint32_t, 2>({(uint32_t) strides[1], (uint32_t) strides[0]}),
std::array<uint32_t, 2>({(uint32_t) dilations[1], (uint32_t) dilations[0]}),
std::array<uint32_t, 4>({(uint32_t) pads_begin[1], (uint32_t) pads_end[1],
(uint32_t) pads_begin[0], (uint32_t) pads_end[0]}),
multiplier);
}
else
{
tvConv = graph->CreateOperation<tim::vx::ops::Conv2d>(
tvPadType,
std::array<uint32_t, 2>({(uint32_t) strides[1], (uint32_t) strides[0]}),
std::array<uint32_t, 2>({(uint32_t) dilations[1], (uint32_t) dilations[0]}),
multiplier);
}
}
else
{
// GroupedConv2d
if (tvPadType == tim::vx::PadType::AUTO)
{
tvConv = graph->CreateOperation<tim::vx::ops::GroupedConv2d>(
std::array<uint32_t, 4>({(uint32_t) pads_begin[1], (uint32_t) pads_end[1],
(uint32_t) pads_begin[0], (uint32_t) pads_end[0]}),
std::array<uint32_t, 2>({(uint32_t)strides[1], (uint32_t)strides[0]}),
std::array<uint32_t, 2>({(uint32_t)dilations[1], (uint32_t)dilations[0]}),
group);
}
else
{
tvConv = graph->CreateOperation<tim::vx::ops::GroupedConv2d>(
tvPadType,
std::array<uint32_t, 2>({(uint32_t)strides[1], (uint32_t)strides[0]}),
std::array<uint32_t, 2>({(uint32_t)dilations[1], (uint32_t)dilations[0]}),
group);
}
}
}
else
{
// for Conv1d
if (group != 1)
CV_Error( CV_StsNotImplemented, " Grouped Conv1d or Depth-Wise Conv1d are not supported by "
"TimVX Backend. Please try OpenCV Backend.");
tvConv = graph->CreateOperation<tim::vx::ops::Conv1d>(
tvConvWeightShape[2], tvPadType, (uint32_t)kernel_size[0],
(uint32_t)strides[0],(uint32_t)dilations[0],
std::array<uint32_t, 2>({(uint32_t)pads_begin[0], (uint32_t)pads_end[0]}));
}
// Create TimVXBackendNode
Ptr<TimVXBackendNode> tvBackendNode = new TimVXBackendNode(tvGraph, tvConv, inputsIndex, outputsIndex);
return tvBackendNode;
#endif // HAVE_TIMVX
return Ptr<BackendNode>();
}
class ParallelConv : public cv::ParallelLoopBody
{
public:

View File

@ -4,6 +4,7 @@
#include "../precomp.hpp"
#include "layers_common.hpp"
#include "../op_timvx.hpp"
#include <opencv2/dnn/shape_utils.hpp>
#include <iostream>
@ -16,14 +17,45 @@ namespace dnn
class ActivationLayerInt8Impl CV_FINAL : public ActivationLayerInt8
{
public:
int input_zp, output_zp;
float input_sc, output_sc;
float slope = 0.0f;
#ifdef HAVE_TIMVX
tvActivationType tvActType;
#endif
ActivationLayerInt8Impl(const LayerParams &params)
{
setParamsFrom(params);
activationLUT = !blobs.empty() ? blobs[0] : Mat();
input_zp = params.get<int>("input_zeropoint");
input_sc = params.get<float>("input_scale");
output_zp = params.get<int>("zeropoints");
output_sc = params.get<float>("scales");
if (params.has("slope"))
{
slope = params.get<float>("slope");
}
#ifdef HAVE_TIMVX
tvActType = getTimVXActType(type);
#endif
}
virtual bool supportBackend(int backendId) CV_OVERRIDE
{
#ifdef HAVE_TIMVX
if (backendId == DNN_BACKEND_TIMVX)
{
// TODO!: Leaky ReLU will be supported in future.
if (tvActType == tvActReLU && slope != 0.f)
return false;
return tvActType != tvActNotSupported;
}
#endif
return backendId == DNN_BACKEND_OPENCV;
}
@ -106,6 +138,112 @@ public:
}
};
virtual Ptr<BackendNode> initTimVX(void* timVXInfo_,
const std::vector<Ptr<BackendWrapper> > &inputsWrapper,
const std::vector<Ptr<BackendWrapper> > &outputsWrapper,
bool isLast) CV_OVERRIDE
{
#ifdef HAVE_TIMVX
// tvGraph Initialization.
auto timVxInfo = reinterpret_cast<TimVXInfo *>(timVXInfo_);
CV_Assert(timVxInfo);
Ptr<TimVXGraph> tvGraph = timVxInfo->getGraph();
CV_Assert(tvGraph);
Ptr<tim::vx::Graph> graph = tvGraph->graph;
std::vector<int> inputsIndex, outputsIndex;
int input_index, output_index;
CV_Assert(inputsWrapper.size() == 1);
// input Tensor
Ptr<TimVXBackendWrapper> inputWrapper = inputsWrapper[0].dynamicCast<TimVXBackendWrapper>();
if (inputWrapper->isTensor())
{
input_index = tvGraph->getTensorIndex(inputWrapper->getTensor());
if(input_index == -1)
{
// Copy To New inputWrapper
Mat tmp = inputWrapper->getMat();
inputWrapper = Ptr<TimVXBackendWrapper>(new TimVXBackendWrapper(tmp));
}
}
if (!inputWrapper->isTensor())
{
Ptr<tim::vx::Quantization> tvInputQuant = Ptr<tim::vx::Quantization>(
new tim::vx::Quantization(tim::vx::QuantType::ASYMMETRIC, input_sc, input_zp));
inputWrapper->createTensor(graph, tim::vx::TensorAttribute::INPUT, tvInputQuant);
input_index = tvGraph->addWrapper(inputWrapper);
}
inputsIndex.push_back(input_index);
// output tensor
CV_Assert(outputsWrapper.size() == 1);
Ptr<TimVXBackendWrapper> outputWrapper = outputsWrapper[0].dynamicCast<TimVXBackendWrapper>();
Ptr<tim::vx::Quantization> outputQuant = Ptr<tim::vx::Quantization>(
new tim::vx::Quantization(tim::vx::QuantType::ASYMMETRIC, output_sc, output_zp));
Ptr<tim::vx::Tensor> outputTensor;
if (isLast)
{
auto shapeType = getShapeTypeFromMat(outputWrapper->getMat());
// For Graph Output tensor, we need to set tensor shape before createTensor().
outputWrapper->setTensorShape(shapeType);
outputWrapper->createTensor(graph, tim::vx::TensorAttribute::OUTPUT, outputQuant);
}
else
{
outputWrapper->createTensor(graph, tim::vx::TensorAttribute::TRANSIENT, outputQuant);
}
output_index = tvGraph->addWrapper(outputWrapper);
outputsIndex.push_back(output_index);
std::shared_ptr<tim::vx::Operation> tvAct;
switch(tvActType) {
case tvActReLU:
{
if (slope != 0.f)
tvAct = graph->CreateOperation<tim::vx::ops::LeakyRelu>(slope);
else
tvAct = graph->CreateOperation<tim::vx::ops::Relu>();
break;
}
case tvActReLU6:
tvAct = graph->CreateOperation<tim::vx::ops::Relu6>();
break;
case tvActTanH:
tvAct = graph->CreateOperation<tim::vx::ops::Tanh>();
break;
case tvActSwish:
tvAct = graph->CreateOperation<tim::vx::ops::Swish>();
break;
case tvActMish:
tvAct = graph->CreateOperation<tim::vx::ops::Mish>();
break;
case tvActSigmoid:
tvAct = graph->CreateOperation<tim::vx::ops::Sigmoid>();
break;
case tvActELU:
tvAct = graph->CreateOperation<tim::vx::ops::Elu>();
break;
default:
// TODO! check the default function.
tvAct = graph->CreateOperation<tim::vx::ops::Relu>();
break;
}
Ptr<TimVXBackendNode> tvBackendNode = new TimVXBackendNode(tvGraph, tvAct, inputsIndex, outputsIndex);
return tvBackendNode;
#endif // HAVE_TIMVX
return Ptr<BackendNode>();
}
void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE
{
CV_TRACE_FUNCTION();

View File

@ -4,6 +4,7 @@
#include "../precomp.hpp"
#include "layers_common.hpp"
#include "../op_timvx.hpp"
#include <opencv2/dnn/shape_utils.hpp>
namespace cv
@ -22,6 +23,10 @@ public:
} op;
std::vector<float> coeffs;
std::vector<int> zeropoints;
std::vector<float> scales;
int output_zp;
float output_sc;
enum OutputChannelsMode
{
@ -84,6 +89,20 @@ public:
}
}
if (params.has("input_scales"))
{
DictValue sc = params.get("input_scales");
int i, n = sc.size();
scales.resize(n);
for (i = 0; i < n; i++)
{
scales[i] = sc.get<float>(i);
}
}
output_zp = params.get<int>("zeropoints");
output_sc = params.get<float>("scales");
channelsModeInput = ELTWISE_CHANNNELS_SAME;
if (params.has("output_channels_mode"))
{
@ -116,6 +135,9 @@ public:
virtual bool supportBackend(int backendId) CV_OVERRIDE
{
// For TimVX Backend, only ELTWISE_CHANNNELS_SAME was supported.
if (backendId == DNN_BACKEND_TIMVX && haveTimVX())
return channelsModeInput == ELTWISE_CHANNNELS_SAME;
return backendId == DNN_BACKEND_OPENCV;
}
@ -219,6 +241,134 @@ public:
}
}
virtual Ptr<BackendNode> initTimVX(void* timVXInfo_,
const std::vector<Ptr<BackendWrapper> > &inputsWrapper,
const std::vector<Ptr<BackendWrapper> > &outputsWrapper,
bool isLast) CV_OVERRIDE
{
#ifdef HAVE_TIMVX
// tvGraph Initialization.
if (inputsWrapper.size() != 2)
return Ptr<BackendNode>();
auto timVxInfo = reinterpret_cast<TimVXInfo *>(timVXInfo_);
CV_Assert(timVxInfo);
Ptr<TimVXGraph> tvGraph = timVxInfo->getGraph();
CV_Assert(tvGraph);
Ptr<tim::vx::Graph> graph = tvGraph->graph;
bool isSub = false;
// TODO: support variable coeffs.
if (op == SUM)
{
CV_Assert(coeffs.size() == scales.size());
std::vector<float> originalCoeffs;
for (int i = 0; i < coeffs.size(); i++)
{
originalCoeffs.push_back(coeffs[i] * output_sc / scales[i]);
}
float eps = std::numeric_limits<float>::epsilon();
if (std::fabs(originalCoeffs[0] - 1.0f) <= eps * std::fabs(originalCoeffs[0] + 1.0f) &&
std::fabs(originalCoeffs[1] + 1.0f) <= eps * std::fabs(originalCoeffs[1] - 1.0f))
{
// Sub, if coeffs = {1., -1.}, isSub = true.
isSub = true;
}
else if (std::fabs(originalCoeffs[0] - 1.0f) <= eps * std::fabs(originalCoeffs[0] + 1.0f) &&
std::abs(originalCoeffs[1] - 1.0f) <= eps * std::abs(originalCoeffs[1] + 1.0f))
{
// Sum, if coeff = {1., 1.}, isSub = false.
isSub = false;
}
else
{
return Ptr<BackendNode>();
}
}
std::vector<int> inputsIndex, outputsIndex;
int input_index = -1, output_index = -1;
CV_Assert(channelsModeInput == ELTWISE_CHANNNELS_SAME);
// Input
Ptr<TimVXBackendWrapper> inputWrapper;
CV_Assert(!scales.empty() && !zeropoints.empty());
for (int i = 0; i<inputsWrapper.size(); i++){
inputWrapper = inputsWrapper[i].dynamicCast<TimVXBackendWrapper>();
if (inputWrapper->isTensor())
{
input_index = tvGraph->getTensorIndex(inputWrapper->getTensor());
if (input_index == -1)
{
// Copy To New inputWrapper
Mat tmp = inputWrapper->getMat();
inputWrapper = Ptr<TimVXBackendWrapper>(new TimVXBackendWrapper(tmp));
}
}
if (!inputWrapper->isTensor())
{
Ptr<tim::vx::Quantization> tvInputQuant = Ptr<tim::vx::Quantization>(
new tim::vx::Quantization(tim::vx::QuantType::ASYMMETRIC, scales[i], zeropoints[i]));
inputWrapper->createTensor(graph,tim::vx::TensorAttribute::INPUT, tvInputQuant);
input_index = tvGraph->addWrapper(inputWrapper);
}
inputsIndex.push_back(input_index);
}
// Output
CV_Assert(outputsWrapper.size() == 1);
Ptr<TimVXBackendWrapper> outputWrapper = outputsWrapper[0].dynamicCast<TimVXBackendWrapper>();
Ptr<tim::vx::Quantization> outputQuant = Ptr<tim::vx::Quantization>(
new tim::vx::Quantization(tim::vx::QuantType::ASYMMETRIC, output_sc, output_zp));
if (isLast)
{
auto shapeType = getShapeTypeFromMat(outputWrapper->getMat());
// For Graph Output tensor, we need to set tensor shape before createTensor().
outputWrapper->setTensorShape(shapeType);
outputWrapper->createTensor(graph, tim::vx::TensorAttribute::OUTPUT, outputQuant);
}
else
{
outputWrapper->createTensor(graph, tim::vx::TensorAttribute::TRANSIENT, outputQuant);
}
output_index = tvGraph->addWrapper(outputWrapper);
outputsIndex.push_back(output_index);
std::shared_ptr<tim::vx::Operation> tvEltwise;
switch (op) {
case SUM:
if (isSub)
tvEltwise = graph->CreateOperation<tim::vx::ops::Sub>();
else
tvEltwise = graph->CreateOperation<tim::vx::ops::Add>();
break;
case PROD:
tvEltwise = graph->CreateOperation<tim::vx::ops::Multiply>();
break;
case MAX:
tvEltwise = graph->CreateOperation<tim::vx::ops::Maximum>();
break;
default:
CV_Error(Error::StsNotImplemented, "Unsupported eltwise operation");
}
Ptr<TimVXBackendNode> tvBackendNode = new TimVXBackendNode(tvGraph, tvEltwise, inputsIndex, outputsIndex);
return tvBackendNode;
#endif // HAVE_TIMVX
return Ptr<BackendNode>();
}
class EltwiseInvoker : public ParallelLoopBody
{
EltwiseLayerInt8Impl& self;

View File

@ -4,6 +4,7 @@
#include "../precomp.hpp"
#include "layers_common.hpp"
#include "../op_timvx.hpp"
#include <opencv2/dnn/shape_utils.hpp>
@ -19,7 +20,11 @@ public:
FullyConnectedLayerInt8Impl(const LayerParams& params)
{
setParamsFrom(params);
input_sc = params.get<float>("input_scale");
input_zp = params.get<int>("input_zeropoint");
output_zp = params.get<int>("zeropoints");
output_sc = params.get<float>("scales");
axis = params.get<int>("axis", 1);
if (blobs.size() == 3)
{
@ -71,11 +76,25 @@ public:
virtual bool supportBackend(int backendId) CV_OVERRIDE
{
if (backendId == DNN_BACKEND_TIMVX && haveTimVX())
{
if (biasMat.empty())
return true;
else
return false;
}
return backendId == DNN_BACKEND_OPENCV;
}
virtual bool setActivation(const Ptr<ActivationLayer>& layer) CV_OVERRIDE
{
// TODO! add activation in Fully connection.
#ifdef HAVE_TIMVX
if(preferableTarget == DNN_TARGET_NPU)
return false;
#endif
Ptr<ActivationLayerInt8> activ_int8 = layer.dynamicCast<ActivationLayerInt8>();
if (!activ_int8.empty())
{
@ -87,6 +106,120 @@ public:
return false;
}
virtual Ptr<BackendNode> initTimVX(void* timVXInfo_,
const std::vector<Ptr<BackendWrapper> > &inputsWrapper,
const std::vector<Ptr<BackendWrapper> > &outputsWrapper,
bool isLast) CV_OVERRIDE
{
#ifdef HAVE_TIMVX
// tvGraph Initialization.
auto timVxInfo = reinterpret_cast<TimVXInfo *>(timVXInfo_);
CV_Assert(timVxInfo);
Ptr<TimVXGraph> tvGraph = timVxInfo->getGraph();
CV_Assert(tvGraph);
Ptr<tim::vx::Graph> graph = tvGraph->graph;
int numOutput = blobs[0].size[0];
Mat weightMat = blobs[0];
std::vector<int> inputsIndex;
std::vector<int> outputsIndex;
std::vector<float> weight_scs, bias_scs;
std::vector<int32_t> weight_zps;
bias_scs.resize(numOutput);
weight_scs.resize(numOutput);
for (int i = 0; i < numOutput; i++)
{
bias_scs[i] = outputMultiplier.at<float>(i) * output_sc;
weight_scs[i] = bias_scs[i] / input_sc;
}
weight_zps.assign(numOutput, 0);
// input Tensor
auto inputWrapper = inputsWrapper[0].dynamicCast<TimVXBackendWrapper>();
int input_index = -1, weight_index = -1, output_index = -1;
if (inputWrapper->isTensor())
{
input_index = tvGraph->getTensorIndex(inputWrapper->getTensor());
if (input_index == -1)
{
// Copy To New inputWrapper
Mat tmp = inputWrapper->getMat();
inputWrapper = Ptr<TimVXBackendWrapper>(new TimVXBackendWrapper(tmp));
}
}
if (!inputWrapper->isTensor() || input_index == -1)
{
Ptr<tim::vx::Quantization> tvInputQuant = Ptr<tim::vx::Quantization>(
new tim::vx::Quantization(tim::vx::QuantType::ASYMMETRIC, input_sc, input_zp));
inputWrapper->createTensor(graph,tim::vx::TensorAttribute::INPUT, tvInputQuant);
input_index = tvGraph->addWrapper(inputWrapper);
}
inputsIndex.push_back(input_index);
// weight tensor
Ptr<TimVXBackendWrapper> weightWrapper = Ptr<TimVXBackendWrapper>(new TimVXBackendWrapper(weightMat));
Ptr<tim::vx::Quantization> weightQuant;
bool tvSymmetric;
tvSymmetric = getQuantType(weight_scs, numOutput);
if (tvSymmetric)
{
// TODO! fix the following issue.
// TimVX does not support the SYMMETRIC PER CHANNEL MatMul.
return Ptr<BackendNode>();
}
else
{
weightQuant = Ptr<tim::vx::Quantization>(
new tim::vx::Quantization(tim::vx::QuantType::ASYMMETRIC, weight_scs[0], 0));
}
weightWrapper->createTensor(graph,tim::vx::TensorAttribute::CONSTANT, weightQuant);
weight_index = tvGraph->addWrapper(weightWrapper);
inputsIndex.push_back(weight_index);
// Output tensor
CV_Assert(outputsWrapper.size() == 1);
Ptr<TimVXBackendWrapper> outputWrapper = outputsWrapper[0].dynamicCast<TimVXBackendWrapper>();
Ptr<tim::vx::Quantization> outputQuant = Ptr<tim::vx::Quantization>(
new tim::vx::Quantization(tim::vx::QuantType::ASYMMETRIC, output_sc, output_zp));
if (isLast)
{
auto shapeType = getShapeTypeFromMat(outputWrapper->getMat());
// For Graph Output tensor, we need to set tensor shape before createTensor().
outputWrapper->setTensorShape(shapeType);
outputWrapper->createTensor(graph, tim::vx::TensorAttribute::OUTPUT, outputQuant);
}
else
{
outputWrapper->createTensor(graph, tim::vx::TensorAttribute::TRANSIENT, outputQuant);
}
output_index = tvGraph->addWrapper(outputWrapper);
outputsIndex.push_back(output_index);
std::shared_ptr<tim::vx::Operation> tvMatmul;
tvMatmul = graph->CreateOperation<tim::vx::ops::Matmul>(false, true);
Ptr<TimVXBackendNode> tvBackendNode = new TimVXBackendNode(tvGraph, tvMatmul, inputsIndex, outputsIndex);
return tvBackendNode;
#endif // HAVE_TIMVX
return Ptr<BackendNode>();
}
class FullyConnected : public ParallelLoopBody
{
public:

View File

@ -4,6 +4,7 @@
#include "../precomp.hpp"
#include "layers_common.hpp"
#include "../op_timvx.hpp"
#include "opencv2/core/hal/intrin.hpp"
#include <float.h>
@ -26,9 +27,12 @@ public:
globalPooling = false;
isGlobalPooling = std::vector<bool>(3, false);
output_zp = params.get<int>("zeropoints");
input_zp = params.get<int>("input_zeropoint", 0);
input_zp = params.get<int>("input_zeropoint", output_zp);
multiplier = params.get<float>("multiplier", 1.f);
output_sc = params.get<float>("scales");
input_sc = multiplier * output_sc;
hasDynamicShapes = params.get<bool>("has_dynamic_shapes", false);
shapesInitialized = !hasDynamicShapes;
@ -103,6 +107,24 @@ public:
else
return false;
}
else if (backendId == DNN_BACKEND_TIMVX && haveTimVX())
{
// Only pool 2d and pool 1d were supported.
if (kernel_size.size() == 3)
{
// fallback to CPU implementation.
preferableTarget = DNN_TARGET_CPU;
return false;
}
if (!avePoolPaddedArea) // TimVX does not support exclude padding.
return false;
if (globalPooling) // TODO support globalPooling in TimVX backend.
return false;
if (kernel_size.size() == 2)
return type == MAX || type == AVE;
return false;
}
return false;
}
@ -116,6 +138,139 @@ public:
return false;
}
virtual Ptr<BackendNode> initTimVX(void* timVXInfo_,
const std::vector<Ptr<BackendWrapper> > &inputsWrapper,
const std::vector<Ptr<BackendWrapper> > &outputsWrapper,
bool isLast) CV_OVERRIDE
{
#ifdef HAVE_TIMVX
// tvGraph Initialization.
auto timVxInfo = reinterpret_cast<TimVXInfo *>(timVXInfo_);
CV_Assert(timVxInfo);
Ptr<TimVXGraph> tvGraph = timVxInfo->getGraph();
CV_Assert(tvGraph);
Ptr<tim::vx::Graph> graph = tvGraph->graph;
tim::vx::PoolType tvPoolType;
tim::vx::RoundType tvRoundType;
size_t ksize = kernel_size.size();
if (ksize != 2)
return Ptr<BackendNode>();
// type Change from OpenCV to TimVX only MAX and AVG are supported.
switch (type) {
case MAX: {
tvPoolType = tim::vx::PoolType::MAX;
break;
}
case AVE:{
tvPoolType = tim::vx::PoolType::AVG;
break;
}
default:
CV_Error(Error::StsNotImplemented, "Not implemented Pooling type in TimVX Backend.");
}
// Padding Type
tim::vx::PadType tvPadType;
if (padMode.empty())
{
tvPadType = tim::vx::PadType::AUTO; // TODO! check the padding type.
}
else if(padMode == "VALID")
{
tvPadType = tim::vx::PadType::VALID;
}
else if (padMode == "SAME")
{
tvPadType = tim::vx::PadType::SAME;
}
else
{
CV_Error(Error::StsError, "Unsupported padding mode in TimVXBackend!");
}
if (ceilMode)
tvRoundType = tim::vx::RoundType::CEILING;
else
tvRoundType = tim::vx::RoundType::FLOOR;
auto input = inputsWrapper[0];
std::vector<int> inputsIndex;
std::vector<int> outputsIndex;
// input Tensor
auto inputWrapper = inputsWrapper[0].dynamicCast<TimVXBackendWrapper>();
int input_index, output_index;
if (inputWrapper->isTensor())
{
input_index = tvGraph->getTensorIndex(inputWrapper->getTensor());
if (input_index == -1)
{
// Copy To New inputWrapper
Mat tmp = inputWrapper->getMat();
inputWrapper = Ptr<TimVXBackendWrapper>(new TimVXBackendWrapper(tmp));
}
}
if (!inputWrapper->isTensor())
{
Ptr<tim::vx::Quantization> tvInputQuant = Ptr<tim::vx::Quantization>(
new tim::vx::Quantization(tim::vx::QuantType::ASYMMETRIC, input_sc, input_zp));
inputWrapper->createTensor(graph,tim::vx::TensorAttribute::INPUT, tvInputQuant);
input_index = tvGraph->addWrapper(inputWrapper);
}
inputsIndex.push_back(input_index);
// Output tensor
CV_Assert(outputsWrapper.size() == 1);
auto outputWrapper = outputsWrapper[0].dynamicCast<TimVXBackendWrapper>();
Ptr<tim::vx::Quantization> outputQuant = Ptr<tim::vx::Quantization>(
new tim::vx::Quantization(tim::vx::QuantType::ASYMMETRIC, output_sc, output_zp));
if (isLast)
{
auto shapeType = getShapeTypeFromMat(outputWrapper->getMat());
// For Graph Output tensor, we need to set tensor shape before createTensor().
outputWrapper->setTensorShape(shapeType);
outputWrapper->createTensor(graph, tim::vx::TensorAttribute::OUTPUT, outputQuant);
}
else
{
outputWrapper->createTensor(graph, tim::vx::TensorAttribute::TRANSIENT, outputQuant);
}
output_index = tvGraph->addWrapper(outputWrapper);
outputsIndex.push_back(output_index);
std::shared_ptr<tim::vx::Operation> tvPool;
if (tvPadType == tim::vx::PadType::AUTO)
{
tvPool = graph->CreateOperation<tim::vx::ops::Pool2d>( tvPoolType,
std::array<uint32_t, 4>({(uint32_t) pads_begin[1], (uint32_t) pads_end[1],
(uint32_t) pads_begin[0], (uint32_t) pads_end[0]}),
std::array<uint32_t, 2>({(uint32_t)kernel_size[1], (uint32_t)kernel_size[0]}),
std::array<uint32_t, 2>({(uint32_t)strides[1], (uint32_t)strides[0]}),
tvRoundType);
}
else
{
tvPool = graph->CreateOperation<tim::vx::ops::Pool2d>(
tvPoolType, tvPadType,
std::array<uint32_t, 2>({(uint32_t)kernel_size[1], (uint32_t)kernel_size[0]}),
std::array<uint32_t, 2>({(uint32_t)strides[1], (uint32_t)strides[0]}),
tvRoundType);
}
Ptr<TimVXBackendNode> tvBackendNode = new TimVXBackendNode(tvGraph, tvPool, inputsIndex, outputsIndex);
return tvBackendNode;
#endif // HAVE_TIMVX
return Ptr<BackendNode>();
}
void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE
{
CV_TRACE_FUNCTION();

View File

@ -4,6 +4,7 @@
#include "../precomp.hpp"
#include "layers_common.hpp"
#include "../op_timvx.hpp"
namespace cv
{
@ -149,15 +150,21 @@ public:
class RequantizeLayerImpl CV_FINAL : public RequantizeLayer
{
public:
bool isEltwise;
RequantizeLayerImpl(const LayerParams& params)
{
scale = params.get<float>("scale", 1.f);
shift = params.get<float>("shift", 0.f);
isEltwise = params.get<bool>("isEltwise", false);
setParamsFrom(params);
}
virtual bool supportBackend(int backendId) CV_OVERRIDE
{
if (backendId == DNN_BACKEND_TIMVX && haveTimVX() && !isEltwise)
{
return true;
}
return backendId == DNN_BACKEND_OPENCV;
}
@ -178,6 +185,82 @@ public:
outputs_arr.getMatVector(outputs);
}
virtual Ptr<BackendNode> initTimVX(void* timVXInfo_,
const std::vector<Ptr<BackendWrapper> > &inputsWrapper,
const std::vector<Ptr<BackendWrapper> > &outputsWrapper,
bool isLast) CV_OVERRIDE
{
#ifdef HAVE_TIMVX
// preprocessing
// Check if data is 8-bit.
CV_Assert(inputsWrapper.size() == 1 && outputsWrapper.size() == 1);
Ptr<TimVXBackendWrapper> inputWrapper = inputsWrapper[0].dynamicCast<TimVXBackendWrapper>();
if (!inputWrapper->isTensor())
{
return Ptr<BackendNode>();
}
auto timVxInfo = reinterpret_cast<TimVXInfo *>(timVXInfo_);
CV_Assert(timVxInfo);
Ptr<TimVXGraph> tvGraph = timVxInfo->getGraph();
CV_Assert(tvGraph);
Ptr<tim::vx::Graph> graph = tvGraph->graph;
std::vector<int> inputsIndex, outputsIndex;
int input_index = -1, output_index = -1;
// Input
std::shared_ptr<tim::vx::Tensor> inputTensor = inputWrapper->getTensor();
input_index = tvGraph->getTensorIndex(inputTensor);
if (input_index == -1)
return Ptr<BackendNode>();
inputsIndex.push_back(input_index);
Ptr<tim::vx::Quantization> inputQuant = inputWrapper->getTensorQuantization();
tim::vx::QuantType quanType = inputQuant->Type();
CV_Assert(quanType == tim::vx::QuantType::ASYMMETRIC);
std::vector<float> scales = inputQuant->Scales();
std::vector<int32_t> zeropoints = inputQuant->ZeroPoints();
CV_Assert(!scales.empty() && !zeropoints.empty());
int input_zp = int(zeropoints[0]);
float input_scale = scales[0];
float tmpOut_sc = input_scale/scale;
int tmpOut_zp = int(shift + scale * input_zp);
// Output
Ptr<TimVXBackendWrapper> outputWrapper = outputsWrapper[0].dynamicCast<TimVXBackendWrapper>();
Ptr<tim::vx::Quantization> outputQuant = Ptr<tim::vx::Quantization>(
new tim::vx::Quantization(tim::vx::QuantType::ASYMMETRIC, tmpOut_sc, tmpOut_zp));
if (isLast)
{
auto shapeType = getShapeTypeFromMat(outputWrapper->getMat());
// For Graph Output tensor, we need to set tensor shape before createTensor().
outputWrapper->setTensorShape(shapeType);
outputWrapper->createTensor(graph, tim::vx::TensorAttribute::OUTPUT, outputQuant);
}
else
{
outputWrapper->createTensor(graph, tim::vx::TensorAttribute::TRANSIENT, outputQuant);
}
output_index = tvGraph->addWrapper(outputWrapper);
outputsIndex.push_back(output_index);
std::shared_ptr<tim::vx::Operation> tvRequantize = graph->CreateOperation<tim::vx::ops::DataConvert>();
Ptr<TimVXBackendNode> tvBackendNode = new TimVXBackendNode(tvGraph, tvRequantize, inputsIndex, outputsIndex);
return tvBackendNode;
#endif // HAVE_TIMVX
return Ptr<BackendNode>();
}
void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE
{
CV_TRACE_FUNCTION();

View File

@ -4,6 +4,7 @@
#include "../precomp.hpp"
#include "layers_common.hpp"
#include "../op_timvx.hpp"
#include <algorithm>
#include <stdlib.h>
@ -16,11 +17,17 @@ namespace dnn
class SoftMaxLayerInt8Impl CV_FINAL : public SoftmaxLayerInt8
{
public:
float input_sc;
int input_zp;
SoftMaxLayerInt8Impl(const LayerParams& params)
{
axisRaw = params.get<int>("axis", 1);
logSoftMax = params.get<bool>("log_softmax", false);
input_sc = params.get<float>("input_scale");
input_zp = params.get<int>("input_zeropoint");
output_sc = params.get<float>("scales");
output_zp = params.get<int>("zeropoints");
setParamsFrom(params);
@ -41,7 +48,8 @@ public:
virtual bool supportBackend(int backendId) CV_OVERRIDE
{
return backendId == DNN_BACKEND_OPENCV;
return backendId == DNN_BACKEND_OPENCV ||
(backendId == DNN_BACKEND_TIMVX && haveTimVX());
}
virtual bool tryFuse(Ptr<Layer>& top) CV_OVERRIDE
@ -50,6 +58,102 @@ public:
return !dequantize_layer.empty() && preferableTarget != DNN_TARGET_OPENCL_FP16;
}
virtual Ptr<BackendNode> initTimVX(void* timVXInfo_,
const std::vector<Ptr<BackendWrapper> > &inputsWrapper,
const std::vector<Ptr<BackendWrapper> > &outputsWrapper,
bool isLast) CV_OVERRIDE
{
#ifdef HAVE_TIMVX
// tvGraph Initialization.
auto timVxInfo = reinterpret_cast<TimVXInfo *>(timVXInfo_);
CV_Assert(timVxInfo);
Ptr<TimVXGraph> tvGraph = timVxInfo->getGraph();
CV_Assert(tvGraph);
Ptr<tim::vx::Graph> graph = tvGraph->graph;
std::vector<int> inputsIndex, outputsIndex;
int input_index, output_index;
// input Tensor
CV_Assert(inputsWrapper.size() == 1);
Ptr<TimVXBackendWrapper> inputWrapper = inputsWrapper[0].dynamicCast<TimVXBackendWrapper>();
const Mat &src = inputWrapper->getMat();
// convert axis from OpenCV NCHW toTimVX WHCN.
int axis = normalize_axis(axisRaw, src.dims);
int tvAxis = src.dims - 1 - axis;
if(tvAxis < 0)
tvAxis = 0; // default value is 0.
if (inputWrapper->isTensor())
{
input_index = tvGraph->getTensorIndex(inputWrapper->getTensor());
if (input_index == -1)
{
// Copy To New inputWrapper
Mat tmp = inputWrapper->getMat();
inputWrapper = Ptr<TimVXBackendWrapper>(new TimVXBackendWrapper(tmp));
}
}
if (!inputWrapper->isTensor())
{
Ptr<tim::vx::Quantization> tvInputQuant = Ptr<tim::vx::Quantization>(
new tim::vx::Quantization(tim::vx::QuantType::ASYMMETRIC, input_sc, input_zp));
inputWrapper->createTensor(graph,tim::vx::TensorAttribute::INPUT, tvInputQuant);
input_index = tvGraph->addWrapper(inputWrapper);
}
inputsIndex.push_back(input_index);
// output tensor
CV_Assert(outputsWrapper.size() == 1);
Ptr<TimVXBackendWrapper> outputWrapper = outputsWrapper[0].dynamicCast<TimVXBackendWrapper>();
Mat dstMat = outputWrapper->getMat();
Ptr<tim::vx::Quantization> outputQuant = Ptr<tim::vx::Quantization>(
new tim::vx::Quantization(tim::vx::QuantType::ASYMMETRIC, output_sc, output_zp));
Ptr<tim::vx::Tensor> outputTensor;
if (isLast)
{
auto shapeType = getShapeTypeFromMat(outputWrapper->getMat());
// For Graph Output tensor, we need to set tensor shape before createTensor().
outputWrapper->setTensorShape(shapeType);
if (dstMat.type() == CV_32F)
outputWrapper->createTensor(graph, tim::vx::TensorAttribute::OUTPUT);
else
outputWrapper->createTensor(graph, tim::vx::TensorAttribute::OUTPUT, outputQuant);
}
else
{
if (dstMat.type() == CV_32F)
outputWrapper->createTensor(graph, tim::vx::TensorAttribute::TRANSIENT);
else
outputWrapper->createTensor(graph, tim::vx::TensorAttribute::TRANSIENT, outputQuant);
}
output_index = tvGraph->addWrapper(outputWrapper);
outputsIndex.push_back(output_index);
std::shared_ptr<tim::vx::Operation> tvSoftmax;
if (logSoftMax)
{
tvSoftmax = graph->CreateOperation<tim::vx::ops::LogSoftmax>(tvAxis);
}
else
{
tvSoftmax = graph->CreateOperation<tim::vx::ops::Softmax>(1.0f, tvAxis);
}
Ptr<TimVXBackendNode> tvBackendNode = new TimVXBackendNode(tvGraph, tvSoftmax, inputsIndex, outputsIndex);
return tvBackendNode;
#endif // HAVE_TIMVX
return Ptr<BackendNode>();
}
void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE
{
CV_TRACE_FUNCTION();

View File

@ -74,6 +74,16 @@ Ptr<BackendNode> Layer::initWebnn(const std::vector<Ptr<BackendWrapper>>& inputs
return Ptr<BackendNode>();
}
Ptr<BackendNode> Layer::initTimVX(void* timVxInfo,
const std::vector<Ptr<BackendWrapper> > & inputsWrapper,
const std::vector<Ptr<BackendWrapper> > & outputsWrapper,
bool isLast)
{
CV_Error(Error::StsNotImplemented, "TimVX pipeline of " + type +
" layers is not defined.");
return Ptr<BackendNode>();
}
Ptr<BackendNode> Layer::tryAttach(const Ptr<BackendNode>& node)
{
return Ptr<BackendNode>();

View File

@ -409,6 +409,7 @@ public:
{
params.set("input_scale", scales[0][0]);
params.set("input_zeropoint", zeropoints[0][0]);
params.set("eps", epsilon);
params.blobs.clear();
params.blobs.push_back(origin_weights);

View File

@ -48,6 +48,7 @@
#include "../ie_ngraph.hpp"
#include "../op_vkcom.hpp"
#include "../op_webnn.hpp"
#include "../op_timvx.hpp"
#ifdef HAVE_OPENCL
#include "opencl_kernels_dnn.hpp"
@ -72,6 +73,9 @@ public:
axis = params.get<int>("axis", 1);
padding = params.get<bool>("padding", false);
paddingValue = params.get<int>("padding_value", 0);
zeropoint = params.get<int>("zeropoints", 0);
scale = params.get<float>("scales", 1.0f);
}
virtual bool getMemoryShapes(const std::vector<MatShape> &inputs,
@ -113,6 +117,21 @@ public:
virtual bool supportBackend(int backendId) CV_OVERRIDE
{
#ifdef HAVE_TIMVX
if (backendId == DNN_BACKEND_TIMVX && haveTimVX() && !padding)
{
if (axis == -1)
return false;
int len = this->type.length();
if (len <= 4)
return false;
if (this->type.substr(len - 4) == "Int8")
return true;
else
return false;
}
#endif
#ifdef HAVE_INF_ENGINE
if (backendId == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH)
return true;
@ -393,6 +412,86 @@ public:
}
#endif // HAVE_DNN_NGRAPH
#ifdef HAVE_TIMVX
virtual Ptr<BackendNode> initTimVX(void* timVXInfo_,
const std::vector<Ptr<BackendWrapper> > &inputsWrapper,
const std::vector<Ptr<BackendWrapper> > &outputsWrapper,
bool isLast) CV_OVERRIDE
{
// tvGraph Initialization.
auto timVxInfo = reinterpret_cast<TimVXInfo *>(timVXInfo_);
CV_Assert(timVxInfo);
Ptr<TimVXGraph> tvGraph = timVxInfo->getGraph();
CV_Assert(tvGraph);
Ptr<tim::vx::Graph> graph = tvGraph->graph;
Ptr<TimVXBackendWrapper> inputWrapper = inputsWrapper[0].dynamicCast<TimVXBackendWrapper>();
// convert axis from OpenCV NCHW toTimVX WHCN.
Mat blob0 = inputWrapper->getMat();
// TODO! support TimVX 5 dim in future.
if(blob0.dims >4)
return Ptr<TimVXBackendNode>();
int cAxis = normalize_axis(axis, blob0.dims);
int tvAxis = blob0.dims - 1 - cAxis;
CV_Assert(tvAxis>= 0);
std::vector<int> inputsIndex, outputsIndex;
int input_index = -1, output_index = -1;
// Input
Ptr<tim::vx::Quantization> tvQuant = Ptr<tim::vx::Quantization>(
new tim::vx::Quantization(tim::vx::QuantType::ASYMMETRIC, scale, zeropoint));
for (int i = 0; i<inputsWrapper.size(); i++)
{
inputWrapper = inputsWrapper[i].dynamicCast<TimVXBackendWrapper>();
if (inputWrapper->isTensor())
{
input_index = tvGraph->getTensorIndex(inputWrapper->getTensor());
if (input_index == -1)
{
// Copy To New inputWrapper
Mat tmp = inputWrapper->getMat();
inputWrapper = Ptr<TimVXBackendWrapper>(new TimVXBackendWrapper(tmp));
}
}
if (!inputWrapper->isTensor())
{
inputWrapper->createTensor(graph,tim::vx::TensorAttribute::INPUT, tvQuant);
input_index = tvGraph->addWrapper(inputWrapper);
}
inputsIndex.push_back(input_index);
}
//Output
CV_Assert(outputsWrapper.size() == 1);
Ptr<TimVXBackendWrapper> outputWrapper = outputsWrapper[0].dynamicCast<TimVXBackendWrapper>();
if (isLast)
{
auto shapeType = getShapeTypeFromMat(outputWrapper->getMat());
// For Graph Output tensor, we need to set tensor shape before createTensor().
outputWrapper->setTensorShape(shapeType);
outputWrapper->createTensor(graph, tim::vx::TensorAttribute::OUTPUT, tvQuant);
}
else
{
outputWrapper->createTensor(graph, tim::vx::TensorAttribute::TRANSIENT, tvQuant);
}
output_index = tvGraph->addWrapper(outputWrapper);
outputsIndex.push_back(output_index);
std::shared_ptr<tim::vx::Operation> tvConcate = graph->CreateOperation<tim::vx::ops::Concat>(tvAxis, inputsWrapper.size());
Ptr<TimVXBackendNode> tvBackendNode = new TimVXBackendNode(tvGraph, tvConcate, inputsIndex, outputsIndex);
return tvBackendNode;
}
#endif // HAVE_TIMVX
virtual bool tryQuantize(const std::vector<std::vector<float> > &scales,
const std::vector<std::vector<int> > &zeropoints, LayerParams& params) CV_OVERRIDE
{
@ -416,6 +515,8 @@ public:
}
#endif
int zeropoint;
float scale;
};
Ptr<ConcatLayer> ConcatLayer::create(const LayerParams& params)

View File

@ -2168,6 +2168,7 @@ public:
float inputScale = scales[0][0], outputScale = scales[1][0];
int inputZp = zeropoints[0][0];
params.set("input_zeropoint", inputZp);
params.set("input_scale", inputScale);
Mat weightsQuantized(weightsMat.rows, weightsMat.cols, CV_8S);
Mat biasQuantized(1, numOutput, CV_32S);

View File

@ -496,6 +496,9 @@ struct ReLUFunctor : public BaseFunctor
params.blobs.clear();
params.blobs.push_back(lookUpTable);
}
params.set("input_scale", scales[0][0]);
params.set("input_zeropoint", zeropoints[0][0]);
params.set("slope", slope);
return true;
}
@ -635,6 +638,8 @@ struct ReLU6Functor : public BaseFunctor
bool tryQuantize(const std::vector<std::vector<float> > &scales,
const std::vector<std::vector<int> > &zeropoints, LayerParams& params)
{
params.set("input_scale", scales[0][0]);
params.set("input_zeropoint", zeropoints[0][0]);
return true;
}
@ -704,6 +709,8 @@ struct BaseDefaultFunctor : public BaseFunctor
}
params.blobs.clear();
params.blobs.push_back(lookUpTable);
params.set("input_scale", scales[0][0]);
params.set("input_zeropoint", zeropoints[0][0]);
return true;
}

View File

@ -875,6 +875,8 @@ public:
virtual bool tryQuantize(const std::vector<std::vector<float> > &scales,
const std::vector<std::vector<int> > &zeropoints, LayerParams& params) CV_OVERRIDE
{
params.set("input_scales", DictValue::arrayReal(scales[0].data(), scales[0].size()));
params.set("input_zeropoints", DictValue::arrayInt(zeropoints[0].data(), zeropoints[0].size()));
if (op == SUM)
{
std::vector<float> newCoeffs;
@ -897,7 +899,6 @@ public:
newCoeffs[0] /= scales[1][0];
params.set("coeff", DictValue::arrayReal(newCoeffs.data(), newCoeffs.size()));
params.set("offset", zeropoints[1][0]);
params.set("input_zeropoints", DictValue::arrayInt(zeropoints[0].data(), zeropoints[0].size()));
return true;
}
return op == MAX;

View File

@ -642,6 +642,8 @@ public:
params.blobs.push_back(weightsQuantized.reshape(1, shape(blobs[0])));
params.blobs.push_back(biasQuantized);
params.blobs.push_back(outputMultiplier);
params.set("input_scale", inputScale);
params.set("input_zeropoint", inputZp);
return true;
}

View File

@ -47,6 +47,7 @@
#include "../ie_ngraph.hpp"
#include "../op_vkcom.hpp"
#include "../op_webnn.hpp"
#include "../op_timvx.hpp"
#include <float.h>
#include <algorithm>
@ -108,6 +109,9 @@ public:
_order.push_back(currentOrder);
}
zeropoint = params.get<int>("zeropoints", 0);
scale = params.get<float>("scales", 1.0f);
setParamsFrom(params);
checkNeedForPermutation();
}
@ -122,6 +126,20 @@ public:
return true;
}
#endif
#ifdef HAVE_TIMVX
if (backendId == DNN_BACKEND_TIMVX && haveTimVX())
{
int len = this->type.length();
if (len <= 4)
return false;
if (this->type.substr(len - 4) == "Int8")
return true;
else
return false;
}
#endif
return backendId == DNN_BACKEND_OPENCV ||
backendId == DNN_BACKEND_CUDA ||
backendId == DNN_BACKEND_WEBNN ||
@ -471,12 +489,120 @@ public:
}
#endif // HAVE_VULKAN
#ifdef HAVE_TIMVX
virtual Ptr<BackendNode> initTimVX(void* timVXInfo_,
const std::vector<Ptr<BackendWrapper> > &inputsWrapper,
const std::vector<Ptr<BackendWrapper> > &outputsWrapper,
bool isLast) CV_OVERRIDE
{
// tvGraph Initialization.
auto timVxInfo = reinterpret_cast<TimVXInfo *>(timVXInfo_);
CV_Assert(timVxInfo);
Ptr<TimVXGraph> tvGraph = timVxInfo->getGraph();
CV_Assert(tvGraph);
Ptr<tim::vx::Graph> graph = tvGraph->graph;
std::vector<int> inputsIndex, outputsIndex;
int input_index = -1, output_index = -1;
if (outputsWrapper.size() != 1) // only work for single outputBlob
return Ptr<BackendNode>();
// Input
Ptr<TimVXBackendWrapper> inputWrapper = inputsWrapper[0].dynamicCast<TimVXBackendWrapper>();
if (inputWrapper->isTensor())
{
input_index = tvGraph->getTensorIndex(inputWrapper->getTensor());
if (input_index == -1)
{
// Copy To New inputWrapper
Mat tmp = inputWrapper->getMat();
inputWrapper = Ptr<TimVXBackendWrapper>(new TimVXBackendWrapper(tmp));
}
}
if (!inputWrapper->isTensor())
{
Ptr<tim::vx::Quantization> tvInputQuant = Ptr<tim::vx::Quantization>(
new tim::vx::Quantization(tim::vx::QuantType::ASYMMETRIC, scale, zeropoint));
inputWrapper->createTensor(graph,tim::vx::TensorAttribute::INPUT, tvInputQuant);
input_index = tvGraph->addWrapper(inputWrapper);
}
inputsIndex.push_back(input_index);
//Output
Ptr<TimVXBackendWrapper> outputWrapper = outputsWrapper[0].dynamicCast<TimVXBackendWrapper>();
// output has the same quantized attrib.
Ptr<tim::vx::Quantization> outputQuant = inputWrapper->getTensorQuantization();
if (isLast)
{
auto shapeType = getShapeTypeFromMat(outputWrapper->getMat());
// For Graph Output tensor, we need to set tensor shape before createTensor().
outputWrapper->setTensorShape(shapeType);
outputWrapper->createTensor(graph, tim::vx::TensorAttribute::OUTPUT, outputQuant);
}
else
{
outputWrapper->createTensor(graph, tim::vx::TensorAttribute::TRANSIENT, outputQuant);
}
output_index = tvGraph->addWrapper(outputWrapper);
outputsIndex.push_back(output_index);
std::vector<uint32_t> tvOrder;
if (getOrderWHCN(tvOrder))
{
std::shared_ptr<tim::vx::Operation> tvPermute = graph->CreateOperation<tim::vx::ops::Transpose>(tvOrder);
Ptr<TimVXBackendNode> tvBackendNode = new TimVXBackendNode(tvGraph, tvPermute, inputsIndex, outputsIndex);
return tvBackendNode;
}
else
{
return Ptr<BackendNode>();
}
}
#endif // HAVE_TIMVX
virtual bool tryQuantize(const std::vector<std::vector<float> > &scales,
const std::vector<std::vector<int> > &zeropoints, LayerParams& params) CV_OVERRIDE
{
return true;
}
// convert OpenCV NCHW order to WHCN order.
bool getOrderWHCN(std::vector<uint32_t>& orderWHCN)
{
std::map<int, int> lookup;
int orderLen = _order.size();
if (orderLen <2)
return false;
orderWHCN.assign(_order.begin(), _order.end());
if (orderLen == 2)
{
return true;
}
else if (orderLen >= 3)
{
for (int i = 0; i < orderLen; i++)
{
lookup[i] = orderLen - i - 1;
}
for (int i = 0; i < orderLen; i++)
{
orderWHCN[i] = lookup[_order[i]];
}
std::reverse(orderWHCN.begin(), orderWHCN.end());
return true;
}
else
return false;
}
size_t _count;
std::vector<size_t> _order;
@ -492,6 +618,8 @@ public:
#endif
size_t _numAxes;
int zeropoint;
float scale;
};
Ptr<PermuteLayer> PermuteLayer::create(const LayerParams &params)

View File

@ -272,6 +272,17 @@ public:
return true;
}
}
else if (backendId == DNN_BACKEND_TIMVX)
{
#ifdef HAVE_TIMVX
if (kernel_size.size() == 3)
{
// fallback to CPU implementation.
preferableTarget = DNN_TARGET_CPU;
}
#endif
return false;
}
return false;
}

View File

@ -46,6 +46,7 @@
#include "../op_inf_engine.hpp"
#include "../ie_ngraph.hpp"
#include "../op_webnn.hpp"
#include "../op_timvx.hpp"
#include <opencv2/dnn/shape_utils.hpp>
@ -167,6 +168,9 @@ public:
hasDynamicShapes = params.get<bool>("has_dynamic_shapes", false);
shapesInitialized = !hasDynamicShapes;
zeropoint = params.get<int>("zeropoints", 0);
scale = params.get<float>("scales", 1.0f);
CV_Assert(numAxes >= -1);
newShapeRange = (numAxes == -1) ? Range(axis, INT_MAX) : Range(axis, axis + numAxes);
@ -202,6 +206,18 @@ public:
virtual bool supportBackend(int backendId) CV_OVERRIDE
{
if (backendId == DNN_BACKEND_TIMVX && haveTimVX())
{
int len = this->type.length();
if (len <= 4)
return false;
if (this->type.substr(len - 4) == "Int8")
return true;
else
return false;
}
#ifdef HAVE_INF_ENGINE
if (backendId == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH)
return true;
@ -348,6 +364,99 @@ public:
}
#endif
virtual Ptr<BackendNode> initTimVX(void* timVXInfo_,
const std::vector<Ptr<BackendWrapper> > &inputsWrapper,
const std::vector<Ptr<BackendWrapper> > &outputsWrapper,
bool isLast) CV_OVERRIDE
{
#ifdef HAVE_TIMVX
// tvGraph Initialization.
auto timVxInfo = reinterpret_cast<TimVXInfo *>(timVXInfo_);
CV_Assert(timVxInfo);
Ptr<TimVXGraph> tvGraph = timVxInfo->getGraph();
CV_Assert(tvGraph);
Ptr<tim::vx::Graph> graph = tvGraph->graph;
std::vector<int> inputsIndex, outputsIndex;
int input_index = -1, output_index = -1;
int reshapeNum = 0;
Ptr<TimVXBackendWrapper> tmpWrapper, inputWrapper, outputWrapper;
for (size_t i = 0; i < outputsWrapper.size(); i++)
{
tmpWrapper = inputsWrapper[i].dynamicCast<TimVXBackendWrapper>();
Mat srcBlob = tmpWrapper->getMat();
tmpWrapper = outputsWrapper[i].dynamicCast<TimVXBackendWrapper>();
Mat dstBlob = tmpWrapper->getMat();
if (dstBlob.data != srcBlob.data)
{
reshapeNum++;
inputWrapper = inputsWrapper[i].dynamicCast<TimVXBackendWrapper>();
outputWrapper = outputsWrapper[i].dynamicCast<TimVXBackendWrapper>();
}
}
// Only work for single reshape Mat
if (reshapeNum != 1)
{
return Ptr<BackendNode>();
}
// Input
if (inputWrapper->isTensor())
{
input_index = tvGraph->getTensorIndex(inputWrapper->getTensor());
if (input_index == -1)
{
// Copy To New inputWrapper
Mat tmp = inputWrapper->getMat();
inputWrapper = Ptr<TimVXBackendWrapper>(new TimVXBackendWrapper(tmp));
}
}
if (!inputWrapper->isTensor() || input_index == -1)
{
Ptr<tim::vx::Quantization> tvInputQuant = Ptr<tim::vx::Quantization>(
new tim::vx::Quantization(tim::vx::QuantType::ASYMMETRIC, scale, zeropoint));
inputWrapper->createTensor(graph,tim::vx::TensorAttribute::INPUT,tvInputQuant);
input_index = tvGraph->addWrapper(inputWrapper);
}
inputsIndex.push_back(input_index);
//Output
// Output Tensor has the same quantized attrib as Input Tesor.
Ptr<tim::vx::Quantization> outputQuant = inputWrapper->getTensorQuantization();
if (isLast)
{
auto shapeType = getShapeTypeFromMat(outputWrapper->getMat());
// For Graph Output tensor, we need to set tensor shape before createTensor().
outputWrapper->setTensorShape(shapeType);
outputWrapper->createTensor(graph, tim::vx::TensorAttribute::OUTPUT, outputQuant);
}
else
{
outputWrapper->createTensor(graph, tim::vx::TensorAttribute::TRANSIENT, outputQuant);
}
output_index = tvGraph->addWrapper(outputWrapper);
outputsIndex.push_back(output_index);
// generate output shape.
MatShape outputShape = shape(outputWrapper->getMat());
// reverse shape, from NCHW to WHCN
std::reverse(outputShape.begin(), outputShape.end());
std::vector<uint32_t> tvShape(outputShape.begin(), outputShape.end());
std::shared_ptr<tim::vx::Operation> tvReshape = graph->CreateOperation<tim::vx::ops::Reshape>(tvShape);
Ptr<TimVXBackendNode> tvBackendNode = new TimVXBackendNode(tvGraph, tvReshape, inputsIndex, outputsIndex);
return tvBackendNode;
#endif // HAVE_TIMVX
return Ptr<BackendNode>();
}
virtual bool tryQuantize(const std::vector<std::vector<float> > &scales,
const std::vector<std::vector<int> > &zeropoints, LayerParams& params) CV_OVERRIDE
{
@ -360,6 +469,8 @@ private:
std::vector<int> inputIndices; // Which axes from input are needed to compute correct output shape
bool hasDynamicShapes;
bool shapesInitialized;
float scale;
int zeropoint;
};
Ptr<ReshapeLayer> ReshapeLayer::create(const LayerParams& params)

View File

@ -390,6 +390,8 @@ public:
}
params.blobs.clear();
params.blobs.push_back(lookUpTable);
params.set("input_scale", inpScale);
params.set("input_zeropoint", zeropoints[0][0]);
return true;
}

View File

@ -12,6 +12,7 @@
#include "op_vkcom.hpp"
#include "op_cuda.hpp"
#include "op_webnn.hpp"
#include "op_timvx.hpp"
namespace cv {
namespace dnn {
@ -110,6 +111,13 @@ Ptr<BackendWrapper> wrapMat(int backendId, int targetId, cv::Mat& m)
CV_Assert(IS_DNN_CUDA_TARGET(targetId));
}
#endif
}
else if (backendId == DNN_BACKEND_TIMVX)
{
CV_Assert(haveTimVX());
#ifdef HAVE_TIMVX
return Ptr<BackendWrapper>(new TimVXBackendWrapper(m));
#endif // HAVE_TIMVX
}
else
CV_Error(Error::StsNotImplemented, "Unknown backend identifier");

View File

@ -133,6 +133,9 @@ void Net::Impl::setUpNet(const std::vector<LayerPin>& blobsToKeep_)
preferableTarget == DNN_TARGET_VULKAN);
CV_Assert(preferableBackend != DNN_BACKEND_CUDA ||
IS_DNN_CUDA_TARGET(preferableTarget));
CV_Assert(preferableBackend != DNN_BACKEND_TIMVX ||
preferableTarget == DNN_TARGET_NPU);
if (!netWasAllocated || this->blobsToKeep != blobsToKeep_)
{
if (preferableBackend == DNN_BACKEND_OPENCV && IS_DNN_OPENCL_TARGET(preferableTarget))
@ -179,6 +182,12 @@ void Net::Impl::setUpNet(const std::vector<LayerPin>& blobsToKeep_)
preferableTarget = DNN_TARGET_CPU;
}
if (preferableBackend == DNN_BACKEND_TIMVX && !haveTimVX())
{
preferableBackend = DNN_BACKEND_OPENCV;
preferableTarget = DNN_TARGET_CPU;
}
clear();
if (hasDynamicShapes)
@ -515,7 +524,7 @@ void Net::Impl::allocateLayer(int lid, const LayersShapesMap& layersShapes)
ld.outputBlobsWrappers[i] = wrap(ld.outputBlobs[i]);
/* CUDA backend has its own system for internal blobs; we don't need these */
ld.internalBlobsWrappers.resize((preferableBackend == DNN_BACKEND_CUDA) ? 0 : ld.internals.size());
ld.internalBlobsWrappers.resize((preferableBackend == DNN_BACKEND_CUDA || preferableBackend == DNN_BACKEND_TIMVX) ? 0 : ld.internals.size());
for (int i = 0; i < ld.internalBlobsWrappers.size(); ++i)
ld.internalBlobsWrappers[i] = wrap(ld.internals[i]);
@ -814,6 +823,10 @@ void Net::Impl::forwardLayer(LayerData& ld)
{
forwardWebnn(ld.outputBlobsWrappers, node, isAsync);
}
else if (preferableBackend == DNN_BACKEND_TIMVX)
{
forwardTimVX(ld.outputBlobsWrappers, node);
}
#ifdef HAVE_VULKAN
else if (preferableBackend == DNN_BACKEND_VKCOM)
{
@ -1568,7 +1581,7 @@ string Net::Impl::dump(bool forceAllocation) const
prevNode = itBackend->second;
}
}
std::vector<string> colors = { "#ffffb3", "#fccde5", "#8dd3c7", "#bebada", "#80b1d3", "#fdb462", "#ff4848", "#b35151", "#b266ff" };
std::vector<string> colors = { "#ffffb3", "#fccde5", "#8dd3c7", "#bebada", "#80b1d3", "#fdb462", "#ff4848", "#b35151", "#b266ff", "#b266ff", "#3cb371"};
string backend;
switch (prefBackend)
{
@ -1580,9 +1593,8 @@ string Net::Impl::dump(bool forceAllocation) const
case DNN_BACKEND_OPENCV: backend = "OCV/"; break;
case DNN_BACKEND_VKCOM: backend = "VULKAN/"; break;
case DNN_BACKEND_CUDA: backend = "CUDA/"; break;
case DNN_BACKEND_WEBNN:
backend = "WEBNN/";
break;
case DNN_BACKEND_WEBNN: backend = "WEBNN/"; break;
case DNN_BACKEND_TIMVX: backend = "TIMVX/"; break;
// don't use default:
}
out << "digraph G {\n";
@ -1767,6 +1779,10 @@ string Net::Impl::dump(bool forceAllocation) const
out << "CUDA_FP16";
colorId = 6;
break;
case DNN_TARGET_NPU:
out << "NPU";
colorId = 9;
break;
// don't use default:
}
CV_Assert(colorId < colors.size());

View File

@ -11,6 +11,7 @@
#include "op_vkcom.hpp"
#include "op_cuda.hpp"
#include "op_webnn.hpp"
#include "op_timvx.hpp"
#include <opencv2/dnn/shape_utils.hpp>
#include <opencv2/imgproc.hpp>
@ -152,6 +153,14 @@ struct Net::Impl : public detail::NetImplBase
void initVkComBackend();
#endif
#ifdef HAVE_TIMVX
// Create timVxInfo for reserve tvGraphList.
TimVXInfo timVxInfo = TimVXInfo();
void tvUpdateConfictMap(int graphIndex, LayerData& ld, std::vector<std::vector<int> >& graphConflictMap);
void tvConvertToOutputNode(const LayerData& ld, Ptr<TimVXBackendWrapper>& targetWrap);
void initTimVXBackend();
#endif
#ifdef HAVE_CUDA
struct CudaInfo_t
{

View File

@ -74,6 +74,12 @@ Ptr<BackendWrapper> Net::Impl::wrap(Mat& host)
default:
CV_Assert(IS_DNN_CUDA_TARGET(preferableTarget));
}
#endif
}
else if (preferableBackend == DNN_BACKEND_TIMVX)
{
#ifdef HAVE_TIMVX
return Ptr<BackendWrapper>(new TimVXBackendWrapper(baseBuffer, host));
#endif
}
else
@ -131,6 +137,14 @@ void Net::Impl::initBackend(const std::vector<LayerPin>& blobsToKeep_)
initCUDABackend(blobsToKeep_);
#else
CV_Error(Error::StsNotImplemented, "This OpenCV version is built without support of CUDA/CUDNN");
#endif
}
else if (preferableBackend == DNN_BACKEND_TIMVX)
{
#ifdef HAVE_TIMVX
initTimVXBackend();
#else
CV_Error(Error::StsNotImplemented, "This OpenCV version is built without support of TimVX");
#endif
}
else
@ -145,9 +159,9 @@ void Net::Impl::setPreferableBackend(int backendId)
if (backendId == DNN_BACKEND_DEFAULT)
backendId = (Backend)getParam_DNN_BACKEND_DEFAULT();
if (netWasQuantized && backendId != DNN_BACKEND_OPENCV)
if (netWasQuantized && backendId != DNN_BACKEND_OPENCV && backendId != DNN_BACKEND_TIMVX)
{
CV_LOG_WARNING(NULL, "DNN: Only default backend supports quantized networks");
CV_LOG_WARNING(NULL, "DNN: Only default and TIMVX backends support quantized networks");
backendId = DNN_BACKEND_OPENCV;
}
@ -166,9 +180,9 @@ void Net::Impl::setPreferableBackend(int backendId)
void Net::Impl::setPreferableTarget(int targetId)
{
if (netWasQuantized && targetId != DNN_TARGET_CPU &&
targetId != DNN_TARGET_OPENCL && targetId != DNN_TARGET_OPENCL_FP16)
targetId != DNN_TARGET_OPENCL && targetId != DNN_TARGET_OPENCL_FP16 && targetId != DNN_TARGET_NPU)
{
CV_LOG_WARNING(NULL, "DNN: Only CPU and OpenCL/OpenCL FP16 target is supported by quantized networks");
CV_LOG_WARNING(NULL, "DNN: Only CPU, OpenCL/OpenCL FP16 and NPU targets are supported by quantized networks");
targetId = DNN_TARGET_CPU;
}

View File

@ -38,7 +38,8 @@ void Net::Impl::fuseLayers(const std::vector<LayerPin>& blobsToKeep_)
if(!fusion || (preferableBackend != DNN_BACKEND_OPENCV &&
preferableBackend != DNN_BACKEND_CUDA &&
preferableBackend != DNN_BACKEND_INFERENCE_ENGINE_NGRAPH))
preferableBackend != DNN_BACKEND_INFERENCE_ENGINE_NGRAPH &&
preferableBackend != DNN_BACKEND_TIMVX))
return;
#if 0 // FIXIT mode without fusion is broken due to unsupported layers and handling of "custom" nodes

View File

@ -3271,6 +3271,7 @@ void ONNXImporter::parseQConv(LayerParams& layerParams, const opencv_onnx::NodeP
layerParams.type = "ConvolutionInt8";
layerParams.set("num_output", outCn);
layerParams.set("input_zeropoint", inp_zp.at<int8_t>(0));
layerParams.set("input_scale",inp_sc.at<float>(0));
layerParams.blobs.push_back(weights);
layerParams.blobs.push_back(biasFused);
layerParams.blobs.push_back(outputMultiplier);
@ -3310,6 +3311,9 @@ void ONNXImporter::parseQMatMul(LayerParams& layerParams, const opencv_onnx::Nod
layerParams.type = "InnerProductInt8";
layerParams.set("num_output", outCn);
layerParams.set("axis", firstInpDims - secondInpDims + 1);
layerParams.set("input_scale", inp_sc.at<float>(0));
layerParams.set("input_zeropoint", inp_zp.at<int8_t>(0));
layerParams.blobs.push_back(weights);
layerParams.blobs.push_back(bias);
layerParams.blobs.push_back(outputMultiplier);
@ -3380,6 +3384,7 @@ void ONNXImporter::parseQEltwise(LayerParams& layerParams, const opencv_onnx::No
rescaleParams.set("depth", CV_8S);
rescaleParams.set("scale", scale);
rescaleParams.set("shift", shift);
rescaleParams.set("isEltwise", true);
addLayer(rescaleParams, node_proto);
return;
}
@ -3428,7 +3433,6 @@ void ONNXImporter::parseQEltwise(LayerParams& layerParams, const opencv_onnx::No
Mat blob_dequantized;
blob.convertTo(blob_dequantized, CV_32F, inp_scales[1], -(inp_scales[1] * inp_zps[1]));
layerParams.blobs.push_back(blob_dequantized);
layerParams.set("input_scales", DictValue::arrayReal(inp_scales.data(), inp_scales.size()));
}
}
}
@ -3443,9 +3447,9 @@ void ONNXImporter::parseQEltwise(LayerParams& layerParams, const opencv_onnx::No
{
layerParams.type = "ScaleInt8";
layerParams.set("bias_term", op == "sum");
layerParams.set("input_scales", DictValue::arrayReal(inp_scales.data(), inp_scales.size()));
}
layerParams.set("input_scales", DictValue::arrayReal(inp_scales.data(), inp_scales.size()));
layerParams.set("input_zeropoints", DictValue::arrayInt(inp_zps.data(), inp_zps.size()));
addLayer(layerParams, node_proto);
}
@ -3471,6 +3475,9 @@ void ONNXImporter::parseQLeakyRelu(LayerParams& layerParams, const opencv_onnx::
}
layerParams.type = "ReLUInt8";
layerParams.set("input_scale", inp_sc);
layerParams.set("input_zeropoint", inp_zp);
layerParams.set("slope", slope);
layerParams.blobs.push_back(lookUpTable);
addLayer(layerParams, node_proto);
}
@ -3495,6 +3502,8 @@ void ONNXImporter::parseQSigmoid(LayerParams& layerParams, const opencv_onnx::No
}
layerParams.type = "SigmoidInt8";
layerParams.set("input_scale", inp_sc);
layerParams.set("input_zeropoint", inp_zp);
layerParams.blobs.push_back(lookUpTable);
addLayer(layerParams, node_proto);
}
@ -3548,6 +3557,7 @@ void ONNXImporter::parseQConcat(LayerParams& layerParams, const opencv_onnx::Nod
rescaleParams.set("depth", CV_8S);
rescaleParams.set("scale", scale);
rescaleParams.set("shift", shift);
rescaleParams.set("isEltwise", false);
opencv_onnx::NodeProto proto;
proto.add_input(node_proto.input(i));

View File

@ -0,0 +1,931 @@
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
//
// Copyright (C) 2019-2021, Shenzhen Institute of Artificial Intelligence and
// Robotics for Society, all rights reserved.
// Third party copyrights are property of their respective owners.
#include "precomp.hpp"
#include <opencv2/dnn/shape_utils.hpp>
#include "op_timvx.hpp"
#include "net_impl.hpp"
namespace cv
{
namespace dnn
{
#ifdef HAVE_TIMVX
CV__DNN_INLINE_NS_BEGIN
// update all comsumer
void Net::Impl::tvUpdateConfictMap(int graphIndex, LayerData& ld, std::vector<std::vector<int> >& graphConflictMap)
{
if (ld.consumers.empty())
return;
for (int i = 0; i < ld.consumers.size(); i++)
{
LayerData &consumerld = layers[ld.consumers[i].lid];
std::vector<int>::iterator it = std::find(graphConflictMap[ld.consumers[i].lid].begin(),
graphConflictMap[ld.consumers[i].lid].end(), graphIndex);
if (it == graphConflictMap[ld.consumers[i].lid].end())
{
graphConflictMap[ld.consumers[i].lid].push_back(graphIndex);
tvUpdateConfictMap(graphIndex, consumerld, graphConflictMap);
}
else
continue;
}
}
// Convert TRANSIENT to OUTPUT
void Net::Impl::tvConvertToOutputNode(const LayerData& ld, Ptr<TimVXBackendWrapper>& targetWrap)
{
// find right layer.
for (auto& inputLayerId : ld.inputLayersId)
{
LayerData &inputld = layers[inputLayerId];
auto itWrap = std::find(inputld.outputBlobsWrappers.begin(),
inputld.outputBlobsWrappers.end(), targetWrap);
if (itWrap != inputld.outputBlobsWrappers.end())
{
auto outputWrap = (*itWrap).dynamicCast<TimVXBackendWrapper>();
if (!outputWrap->isTensor())
continue;
auto inputNode = inputld.backendNodes[DNN_BACKEND_TIMVX].dynamicCast<TimVXBackendNode>();
if (!inputNode->isLast && inputNode->opIndex != -1)
{
CV_Assert(outputWrap->getTensorAttr() == tim::vx::TRANSIENT);
// set last
inputNode->isLast = true;
auto shapeType = getShapeTypeFromMat(outputWrap->getMat());
auto outQuant = outputWrap->getTensorQuantization();
outputWrap->setTensorShape(shapeType);
outputWrap->createTensor(inputNode->tvGraph->graph,
tim::vx::TensorAttribute::OUTPUT, outQuant);
int outIndex = inputNode->tvGraph->addWrapper(outputWrap);
inputNode->outputIndexList.clear();
inputNode->outputIndexList.push_back(outIndex);
}
}
}
}
void Net::Impl::initTimVXBackend()
{
CV_TRACE_FUNCTION();
CV_Assert(preferableBackend == DNN_BACKEND_TIMVX);
// Build TimVX Graph from sets of layers that support this TimVX backend.
// Split a whole model on several TimVX Graph if some of layers are not implemented by TimVX backend.
if (!haveTimVX())
return;
// Allocate graphConflictMap
if (timVxInfo.graphConflictMap.empty())
timVxInfo.graphConflictMap.resize(layers.size());
auto it = layers.begin();
bool isLast = false; // If the node is the last node in current tvGraph.
for (; it != layers.end(); it++)
{
isLast = false;
LayerData &ld = it->second;
if(ld.skip)
continue;
Ptr<Layer> layer = ld.layerInstance;
if (!layer->supportBackend(preferableBackend))
{
continue;
}
// If layer consumers are more than one, set isLast true.
// For now, TimVX backend divides multiple branchs into multiple tvGraph.
if (ld.consumers.size() == 0)
{
isLast = true;
}
else if(ld.consumers.size() == 1)
{
LayerData* consumerld = &layers[ld.consumers[0].lid];
while (consumerld)
{
if (consumerld->skip)
{
if (consumerld->consumers.size() == 1)
{
int nextLayerId = consumerld->consumers[0].lid;
consumerld = &layers[nextLayerId];
}
else
{
isLast = true;
break;
}
}
else
{
break;
}
}
Ptr<Layer>& consumerLayer = consumerld->layerInstance;
if (!isLast && !consumerLayer->supportBackend(preferableBackend))
{
isLast = true;
}
}
else
{
// If there are is multiple input, and only one of them is supported.
int tvSupportNum = 0;
for (int i = 0; i<ld.consumers.size(); i++)
{
LayerData* consumerld = &layers[ld.consumers[0].lid];
while (consumerld)
{
if (consumerld->skip)
{
if (consumerld->consumers.size() == 1)
{
int nextLayerId = consumerld->consumers[0].lid;
consumerld = &layers[nextLayerId];
}
else
{
isLast = true;
break;
}
}
else
{
break;
}
}
Ptr<Layer>& consumerLayer = consumerld->layerInstance;
if (consumerLayer->supportBackend(preferableBackend))
{
tvSupportNum++;
}
}
if (tvSupportNum != 1)
isLast = true;
}
int graphIndex = -1;
bool needRecorrect = !timVxInfo.findGraphIndex(ld.inputBlobsWrappers, graphIndex);
if (graphIndex != -1 && !needRecorrect)
{
needRecorrect = timVxInfo.isConflict(ld.id, graphIndex);
}
// Recorrect the input layer.
if (needRecorrect)
{
// set all inputLayers' as last layer, and convert TRANSIENT to output.
for (int i = 0; i < ld.inputBlobsWrappers.size(); i++)
{
auto inputWrap = ld.inputBlobsWrappers[i];
auto tvInputWrap = inputWrap.dynamicCast<TimVXBackendWrapper>();
if (!tvInputWrap->isTensor())
continue;
auto attr = tvInputWrap->getTensorAttr();
if (attr == tim::vx::TensorAttribute::OUTPUT)
{
continue;
}
else if (attr == tim::vx::TensorAttribute::INPUT)
{
Mat matTmp = tvInputWrap->getMat();
tvInputWrap = Ptr<TimVXBackendWrapper>(new TimVXBackendWrapper(matTmp));
}
else if (attr == tim::vx::TensorAttribute::TRANSIENT)
{
tvConvertToOutputNode(ld, tvInputWrap);
// updateConflictMap
tvUpdateConfictMap(graphIndex, ld, timVxInfo.graphConflictMap);
}
}
graphIndex = -1;
}
if (graphIndex == -1)
{
graphIndex = timVxInfo.createGraph();
}
timVxInfo.setTmpGraphIndex(graphIndex);
ld.backendNodes[DNN_BACKEND_TIMVX] =
layer->initTimVX(&timVxInfo, ld.inputBlobsWrappers, ld.outputBlobsWrappers, isLast);
// post process, create last node correctly.
if (isLast && ld.backendNodes[DNN_BACKEND_TIMVX])
{
auto tmpNode = ld.backendNodes[DNN_BACKEND_TIMVX].dynamicCast<TimVXBackendNode>();
tmpNode->isLast = true;
// update graphConflictMap
tvUpdateConfictMap(graphIndex, ld, timVxInfo.graphConflictMap);
}
// post process for failing to create timvx Node.
if (!ld.backendNodes[DNN_BACKEND_TIMVX])
{
for (int i = 0; i < ld.inputBlobsWrappers.size(); i++)
{
auto inputWrap = ld.inputBlobsWrappers[i];
auto tvInputWrap = inputWrap.dynamicCast<TimVXBackendWrapper>();
if (!tvInputWrap->isTensor())
continue;
auto attr = tvInputWrap->getTensorAttr();
if (attr == tim::vx::TensorAttribute::TRANSIENT)
{
tvConvertToOutputNode(ld, tvInputWrap);
}
}
}
}
// Op Binding
it = layers.begin();
Ptr<TimVXBackendNode> node;
std::vector<Ptr<TimVXGraph> > tmpGrapList;
for (; it != layers.end(); it++)
{
LayerData &ld = it->second;
if (ld.backendNodes[DNN_BACKEND_TIMVX])
node = ld.backendNodes[DNN_BACKEND_TIMVX].dynamicCast<TimVXBackendNode>();
else
continue;
// Binding tvTensor and tvOp
if (node->opIndex >= 0)
node->opBinding();
}
}
CV__DNN_INLINE_NS_END
// from CPU to NPU
bool copyToTensor(std::shared_ptr<tim::vx::Tensor> &dst, const Mat &src)
{
CV_Assert(src.isContinuous() && (src.type() == CV_8S || src.type() == CV_32F));
if (dst->CopyDataToTensor(src.data, src.total()))
{
return true;
}
else
return false;
}
// from NPU to CPU
bool copyToMat(const Mat &dst, std::shared_ptr<tim::vx::Tensor> &src)
{
CV_Assert(dst.isContinuous() && (dst.type() == CV_8S || dst.type() == CV_32F));
if (src->CopyDataFromTensor(dst.data))
{
return true;
}
else
return false;
}
tvActivationType getTimVXActType(String & actString)
{
if (actString == "ReLUInt8") return tvActReLU;
if (actString == "ReLU6Int8") return tvActReLU6;
if (actString == "TanHInt8") return tvActTanH;
if (actString == "SwishInt8") return tvActSwish;
if (actString == "MishInt8") return tvActMish;
if (actString == "SigmoidInt8") return tvActSigmoid;
if (actString == "ELUInt8") return tvActELU;
return tvActNotSupported;
}
tim::vx::ShapeType getShapeTypeFromMat(const Mat& mat, bool ifConst)
{
/* Convert Mat shape to TimVX Tensor shape.
DataLayout in TimVX is WHCN, while NCHW in OpenCV.
So we do vector reverse.
*/
CV_Assert(!mat.empty());
tim::vx::ShapeType tvInputShape;
auto matShape = shape(mat);
tvInputShape.assign(matShape.begin(), matShape.end());
if ( matShape.size() > 1 ) // TODO: check when we need reverse the shape vector.
{
if (ifConst && tvInputShape.size() == 2 && tvInputShape[1] == 1)
{ // if bias vector, shape [n, 1] to [n].
tvInputShape.resize(1);
}
else
std::reverse(tvInputShape.begin(), tvInputShape.end());
}
return tvInputShape;
}
bool getQuantType(const std::vector<float>& scales, int numOutput)
{
CV_Assert(!scales.empty());
if (numOutput == -1)
{
numOutput = scales.size();
}
bool tvSymmetric = false;
for (int i =1; i < numOutput; i++)
{
if (std::abs(scales[0] - scales[i]) > std::numeric_limits<float>::epsilon())
{
tvSymmetric = true;
break;
}
}
return tvSymmetric;
}
// convert mat Depth to tensorDataType
tim::vx::DataType dataTypeConvert(int matDepth)
{
tim::vx::DataType tensorDataType;
switch(matDepth)
{
case CV_8U:
{
tensorDataType = tim::vx::DataType::UINT8;
break;
}
case CV_8S:
{
tensorDataType = tim::vx::DataType::INT8;
break;
}
case CV_16U:
{
tensorDataType = tim::vx::DataType::UINT16;
break;
}
case CV_16S:
{
tensorDataType = tim::vx::DataType::INT16;
break;
}
case CV_32S:
{
tensorDataType = tim::vx::DataType::INT32;
break;
}
case CV_32F:
{
tensorDataType = tim::vx::DataType::FLOAT32;
break;
}
case CV_16F:
{
tensorDataType = tim::vx::DataType::FLOAT16;
break;
}
default:
{
tensorDataType = tim::vx::DataType::UNKNOWN;
break;
}
}
return tensorDataType;
}
std::vector<Ptr<TimVXBackendWrapper> > getWrappers(const std::vector<int> wrappersIndex,
Ptr<TimVXGraph> tvGraph)
{
std::vector<Ptr<TimVXBackendWrapper> > wrappers;
for (int i = 0; i<wrappersIndex.size(); i++)
{
auto wrapper = tvGraph->getWrapper(wrappersIndex[i]);
if (wrapper)
wrappers.push_back(wrapper);
}
return wrappers;
}
// *********************** TimVXGraph ********************
TimVXGraph::TimVXGraph()
{
// new TimVX Graph
context = tim::vx::Context::Create();
graph = context->CreateGraph();
isCompiled = false;
}
TimVXGraph::~TimVXGraph()
{
// release opList
for (auto& tensor: tensorList)
tensor.reset();
// release tensorList
for (auto& op: opList)
op.reset();
// release graph
graph.reset();
// release context
context.reset();
}
std::shared_ptr<tim::vx::Operation> TimVXGraph::getOp(const int opIndex)
{
CV_Assert(0 <= opIndex && !opList.empty() && opIndex < opList.size());
return opList[opIndex];
}
int TimVXGraph::addWrapper(Ptr<TimVXBackendWrapper>& tensorWrapper)
{
CV_Assert(tensorWrapper->isTensor());
tim::vx::TensorAttribute tensorAttr = tensorWrapper->getTensorAttr();
wrapperList.push_back(tensorWrapper);
tensorList.push_back(tensorWrapper->getTensor());
int wrapperIndex = wrapperList.size() -1;
if (tensorAttr == tim::vx::TensorAttribute::INPUT)
{
inputWrappersIndex.push_back(wrapperIndex);
}
if (tensorAttr == tim::vx::TensorAttribute::OUTPUT)
{
outputWrappersIndex.push_back(wrapperIndex);
}
return wrapperIndex;
}
Ptr<TimVXBackendWrapper> TimVXGraph::getWrapper(int wrapperIndex)
{
CV_Assert(wrapperIndex>=0 && wrapperIndex < wrapperList.size());
return wrapperList[wrapperIndex];
}
int TimVXGraph::addOp(const std::shared_ptr<tim::vx::Operation>& op)
{
CV_Assert(op);
opList.emplace_back(op);
return opList.size()-1;
}
int TimVXGraph::getTensorIndex(const std::shared_ptr<tim::vx::Tensor>& tensor)
{
auto it = find(tensorList.begin(), tensorList.end(), tensor);
if (it != tensorList.end())
return it - tensorList.begin();
else
return -1;
}
void TimVXGraph::forward()
{
CV_Assert(!inputWrappersIndex.empty() && !outputWrappersIndex.empty());
// Every TimVXGraph Instance only compiles once.
if (!this->isCompiled)
{
if (!graph->Compile())
CV_Error(cv::Error::StsBadArg, " Fail to compile TimVX graph!");
this->isCompiled = true;
}
if (!graph->Run())
CV_Error(cv::Error::StsBadArg, " Fail to run TimVX graph!");
}
// *********************** TimVXBackendNode ********************
TimVXBackendNode::TimVXBackendNode(const Ptr<TimVXGraph>& tvGraph_): BackendNode(DNN_BACKEND_TIMVX)
{
opIndex = -1;
tvGraph = tvGraph_;
isLast = false;
}
TimVXBackendNode::TimVXBackendNode(const Ptr<TimVXGraph>& tvGraph_,
const std::shared_ptr<tim::vx::Operation>& op_): BackendNode(DNN_BACKEND_TIMVX)
{
tvGraph = tvGraph_;
opIndex = tvGraph->addOp(op_);
isLast = false;
}
TimVXBackendNode::TimVXBackendNode(const Ptr<TimVXGraph>& tvGraph_, std::shared_ptr<tim::vx::Operation>& op_,
std::vector<int>& inputsIndex, std::vector<int>& outpusIndex)
:BackendNode(DNN_BACKEND_TIMVX)
{
tvGraph = tvGraph_;
opIndex = tvGraph->addOp(op_);
isLast = false;
if (!inputsIndex.empty())
inputIndexList.assign(inputsIndex.begin(), inputsIndex.end());
if (!outpusIndex.empty())
outputIndexList.assign(outpusIndex.begin(), outpusIndex.end());
}
bool TimVXBackendNode::opBinding()
{
if (!tvGraph || tvGraph->isCompiled || opIndex == -1)
return false;
std::shared_ptr<tim::vx::Operation> op = tvGraph->getOp(opIndex);
if (!inputIndexList.empty())
{
std::vector<Ptr<TimVXBackendWrapper> > inputsWrapper = getWrappers(inputIndexList, tvGraph);
// Binding input Tensor.
for (auto& warpper: inputsWrapper)
{
op->BindInput(warpper->getTensor());
}
}
if (!outputIndexList.empty())
{
std::vector<Ptr<TimVXBackendWrapper> > outputsWrapper = getWrappers(outputIndexList, tvGraph);
for (auto& warpper: outputsWrapper)
{
op->BindOutput(warpper->getTensor());
}
}
return true;
}
void TimVXBackendNode::setInputTensor()
{
if (!tvGraph || opIndex == -1)
return;
if (!inputIndexList.empty())
{
std::vector<Ptr<TimVXBackendWrapper> > inputsWrapper = getWrappers(inputIndexList, tvGraph);
// Binding input Tensor.
for (auto& warpper: inputsWrapper)
{
if (warpper->getTensorAttr() == tim::vx::TensorAttribute::INPUT)
{
warpper->setHostDirty();
warpper->copyToDevice();
}
}
}
}
// *********************** TimVXBackendWrapper ********************
// Default Constructor
TimVXBackendWrapper::TimVXBackendWrapper() : BackendWrapper(DNN_BACKEND_TIMVX, DNN_TARGET_NPU)
{
isTensor_ = false;
deviceDirty = false;
hostDirty = false;
tensorType = tim::vx::DataType::UNKNOWN;
tensorShape = {};
tensorIndex = -1;
tensorAttr = tim::vx::TensorAttribute::CONSTANT;
}
TimVXBackendWrapper::TimVXBackendWrapper(Mat& m) : BackendWrapper(DNN_BACKEND_TIMVX,
DNN_TARGET_NPU)
{
host = m;
isTensor_ = false;
deviceDirty = false;
hostDirty = true;
tensorType = dataTypeConvert(m.type());
tensorShape = {};
tensorIndex = -1;
tensorAttr = tim::vx::TensorAttribute::CONSTANT;
// TODO: unsupported data by TimVX should run convert function first.
CV_Assert(tensorType != tim::vx::DataType::UNKNOWN);
}
TimVXBackendWrapper::TimVXBackendWrapper(const Ptr<BackendWrapper>& baseBuffer, Mat& m)
:BackendWrapper(DNN_BACKEND_TIMVX, DNN_TARGET_NPU)
{
Ptr<TimVXBackendWrapper> base = baseBuffer.dynamicCast<TimVXBackendWrapper>();
CV_Assert(!base.empty());
tensor = base->tensor;
isTensor_ = base->isTensor_;
tensorIndex = base->tensorIndex;
tensorType = base->tensorType;
tensorAttr = base->tensorAttr;
tensorShape = base->tensorShape;
deviceDirty = base->deviceDirty;
hostDirty = base->hostDirty;
host = m;
}
TimVXBackendWrapper::TimVXBackendWrapper(std::shared_ptr<tim::vx::Tensor>& tensor_)
:BackendWrapper(DNN_BACKEND_TIMVX, DNN_TARGET_NPU)
{
tensor = tensor_;
isTensor_ = true;
deviceDirty = true;
hostDirty = false;
tensorType = tensor_->GetDataType(); // getTensor DataType.
tensorAttr = tensor_->GetSpec().attr_; // getTensor Attribution.
tensorShape = tensor_->GetShape();
tensorIndex = -1;
}
void TimVXBackendWrapper::setTensorShape(const tim::vx::ShapeType & matShape)
{
CV_Assert(!matShape.empty());
tensorShape.assign(matShape.begin(), matShape.end());
}
int TimVXBackendWrapper::getTensorIndex()
{
CV_Assert(isTensor_);
return tensorIndex;
}
tim::vx::TensorAttribute TimVXBackendWrapper::getTensorAttr()
{
CV_Assert(isTensor_);
return tensorAttr;
}
// Create tensor
void TimVXBackendWrapper::createTensor(std::shared_ptr<tim::vx::Graph>& graph,
tim::vx::TensorAttribute tensorAttribute)
{
Ptr<tim::vx::Quantization> epmtyQuant = nullptr;
return this->createTensor(graph, tensorAttribute, epmtyQuant);
}
// Create tensor
void TimVXBackendWrapper::createTensor(std::shared_ptr<tim::vx::Graph>& graph,
tim::vx::TensorAttribute tensorAttribute, Ptr<tim::vx::Quantization>& tvQuant)
{
CV_Assert(graph);
tim::vx::TensorSpec tensorSpec;
if (tensorAttribute == tim::vx::INPUT)
{
CV_Assert(!host.empty());
tensorShape = getShapeTypeFromMat(host);
}
else if (tensorAttribute == tim::vx::OUTPUT)
{
CV_Assert(!tensorShape.empty() && !host.empty());
tensorShape = getShapeTypeFromMat(host);
}
else if (tensorAttribute == tim::vx::CONSTANT)
{
if (!host.empty())
tensorShape = getShapeTypeFromMat(host, true);
}
else
{
if (!host.empty())
tensorShape = getShapeTypeFromMat(host);
}
// Tensor shape
if (tvQuant)
{
tensorSpec = tim::vx::TensorSpec(tensorType, tensorShape, tensorAttribute, *tvQuant);
}
else
{
tensorSpec = tim::vx::TensorSpec(tensorType, tensorShape, tensorAttribute);
}
if (!host.empty() && tensorAttribute != tim::vx::INPUT && tensorAttribute != tim::vx::OUTPUT && tensorAttribute != tim::vx::TRANSIENT)
{
tensor = graph->CreateTensor(tensorSpec, (void *)(host.data));
}
else
{
tensor = graph->CreateTensor(tensorSpec);
}
isTensor_ = true;
// set Attribution
tensorAttr = tensorAttribute;
}
Ptr<tim::vx::Quantization> TimVXBackendWrapper::getTensorQuantization()
{
CV_Assert(isTensor_ && tensor);
auto quantize = tensor->GetQuantization();
return makePtr<tim::vx::Quantization>(quantize);
}
std::shared_ptr<tim::vx::Tensor> TimVXBackendWrapper::getTensor()
{
CV_Assert(isTensor_);
return tensor;
}
Mat TimVXBackendWrapper::getMat()
{
if (host.empty())
return {};
return host;
}
bool TimVXBackendWrapper::isTensor()
{
return isTensor_;
}
void TimVXBackendWrapper::copyToHost()
{
if (deviceDirty && !host.empty())
{
copyToMat(host, tensor);
deviceDirty = false;
}
}
void TimVXBackendWrapper::setHostDirty()
{
hostDirty = true;
}
void TimVXBackendWrapper::setDeviceDirty()
{
deviceDirty = true;
}
void TimVXBackendWrapper::copyToDevice()
{
if (isTensor_ && hostDirty && !host.empty())
{
copyToTensor(tensor, host);
hostDirty = false;
}
}
// *********************** TimVXInfo ********************
TimVXInfo::TimVXInfo()
{
graphIndex = -1;
}
TimVXInfo::~TimVXInfo()
{}
int TimVXInfo::createGraph()
{
Ptr<TimVXGraph> tmpGraph = Ptr<TimVXGraph>(new TimVXGraph());
this->tvGraphList.push_back(tmpGraph);
return this->tvGraphList.size() - 1;
}
bool TimVXInfo::findGraphIndex(const std::vector<Ptr<BackendWrapper> > &inputsWrapper, int& graphIndex)
{
graphIndex = -1;
int wrapperSize = inputsWrapper.size();
int graphSize = tvGraphList.size();
if (wrapperSize != 0 && graphSize == 0)
{
return true;
}
int tensorIndex = -1;
Ptr<TimVXBackendWrapper> wrapper;
Ptr<TimVXGraph> tvGraph;
for (int i = 0; i < graphSize; i++)
{
tvGraph = tvGraphList[i];
for (int j = 0; j < wrapperSize; j++ )
{
wrapper = inputsWrapper[j].dynamicCast<TimVXBackendWrapper>();
if (!wrapper->isTensor()) // Skip wrapper without Tensor.
continue;
tensorIndex = tvGraph->getTensorIndex(wrapper->getTensor());
if (tensorIndex != -1 && wrapper->getTensorAttr() == tim::vx::TensorAttribute::TRANSIENT)
{
if (graphIndex == -1)
graphIndex = i;
else if (graphIndex != i) // if inputs of the same inputWrapper are from differen tvGraph.
{
graphIndex = -1;
return false;
}
}
}
}
return true;
}
void TimVXInfo::setTmpGraphIndex(int graphIndex)
{
this->graphIndex = graphIndex;
}
int TimVXInfo::getTmpGraphIndex()
{
int res = -1;
if (graphIndex != -1)
{
res = graphIndex;
graphIndex = -1;
}
return res;
}
bool TimVXInfo::isConflict(int layerId, int graphIndex)
{
if (graphConflictMap[layerId].empty())
return false;
std::vector<int>::iterator it = std::find(graphConflictMap[layerId].begin(),
graphConflictMap[layerId].end(), graphIndex);
if (it != graphConflictMap[layerId].end())
return true;
else
return false;
}
Ptr<TimVXGraph> TimVXInfo::getGraph()
{
int index = getTmpGraphIndex();
if (0 <= index && index < tvGraphList.size())
return tvGraphList[index];
else
return {};
}
#endif
void forwardTimVX(std::vector<Ptr<BackendWrapper> >& outputs, const Ptr<BackendNode>& node_)
{
#ifdef HAVE_TIMVX
CV_Assert(!node_.empty());
Ptr<TimVXBackendNode> node = node_.dynamicCast<TimVXBackendNode>();
if (node)
{
// set input
node->setInputTensor();
// graph Forward
if (node->isLast)
{
node->tvGraph->forward();
}
}
else
return;
// set ouput
Ptr<TimVXBackendWrapper> outWarpper;
for (int i = 0; i < outputs.size(); i++)
{
outWarpper = outputs[i].dynamicCast<TimVXBackendWrapper>();
if (outWarpper->isTensor() && outWarpper->getTensorAttr() == tim::vx::TensorAttribute::OUTPUT)
{
outWarpper->setDeviceDirty();
outWarpper->copyToHost();
}
}
#endif
}
bool haveTimVX()
{
#ifdef HAVE_TIMVX
return true;
#else
return false;
#endif
}
} // namespace dnn
} // namespace cv

View File

@ -0,0 +1,187 @@
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
//
// Copyright (C) 2019-2021, Shenzhen Institute of Artificial Intelligence and
// Robotics for Society, all rights reserved.
// Third party copyrights are property of their respective owners.
#ifndef OPENCV_DNN_OP_TIMVX_HPP
#define OPENCV_DNN_OP_TIMVX_HPP
#include <opencv2/dnn/shape_utils.hpp>
// TimVX head file.
#ifdef HAVE_TIMVX
#include "tim/vx/context.h"
#include "tim/vx/graph.h"
#include "tim/vx/operation.h"
#include "tim/vx/ops.h"
#include "tim/vx/tensor.h"
#endif // HAVE_TIMVX
namespace cv
{
namespace dnn
{
#ifdef HAVE_TIMVX
enum tvActivationType{
tvActNotSupported = -1,
tvActReLU,
tvActReLU6,
tvActTanH,
tvActSwish,
tvActMish,
tvActSigmoid,
tvActELU
};
// Data copied from/to Mat to/from Tensor. Change the shape of dst if
// needed to make it the same shape as src.
bool copyToTensor(Ptr<tim::vx::Tensor> &dst, const Mat &src);
bool copyToMat(const Mat &dst, Ptr<tim::vx::Tensor> &src);
tvActivationType getTimVXActType(String & actString);
// Convert Mat shape to TimVX TensorShape
tim::vx::ShapeType getShapeTypeFromMat(const Mat& mat, bool ifConst = false);
// if all value in weight
bool getQuantType(const std::vector<float>& scales, int numOutput = -1);
class TimVXInfo;
class TimVXGraph;
class TimVXBackendNode;
class TimVXBackendWrapper;
// Maintain the tvGraph and tvTensor List. For now, every tvGraph only have one output node, and each node
// in tvGraph has only one output too. It could be optimized in future.
// TODO: tvGraph supports multiple output node.
class TimVXGraph
{
public:
TimVXGraph();
~TimVXGraph();
std::shared_ptr<tim::vx::Operation> getOp(const int opIndex);
// It will add tensorWrapper to wrapperList, and return index.
// And add tensor Ptr to tensorList.
int addWrapper(Ptr<TimVXBackendWrapper>& tensorWrapper);
void forward();
// Add new op to opList, and return the index.
int addOp(const std::shared_ptr<tim::vx::Operation>& op);
// If tensor existed in tensorList, return the tensorIndex, otherwise return -1.
int getTensorIndex(const std::shared_ptr<tim::vx::Tensor>& tensor);
Ptr<TimVXBackendWrapper> getWrapper(int wrapperIndex);
std::shared_ptr<tim::vx::Graph> graph;
bool isCompiled; // Every tvGraph can only be compiled once.
private:
std::shared_ptr<tim::vx::Context> context;
std::vector<int> inputWrappersIndex;
std::vector<int> outputWrappersIndex;
std::vector<Ptr<TimVXBackendWrapper> > wrapperList;
std::vector<std::shared_ptr<tim::vx::Tensor> > tensorList;
std::vector<std::shared_ptr<tim::vx::Operation> > opList;
};
class TimVXBackendNode : public BackendNode
{
public:
TimVXBackendNode(const Ptr<TimVXGraph>& tvGraph);
TimVXBackendNode(const Ptr<TimVXGraph>& tvGraph, const std::shared_ptr<tim::vx::Operation>& op);
TimVXBackendNode(const Ptr<TimVXGraph>& tvGraph, std::shared_ptr<tim::vx::Operation>& op,
std::vector<int>& inputsIndex, std::vector<int>& outpusIndex);
void setInputTensor();
bool opBinding();
// flag for marking OutputNode of tvGraph this node is the last node in this TimVX Graph.
bool isLast;
int opIndex;
// index of tensor and wrapper.
std::vector<int> inputIndexList;
std::vector<int> outputIndexList;
Ptr<TimVXGraph> tvGraph;
};
class TimVXBackendWrapper : public BackendWrapper
{
public:
TimVXBackendWrapper();
TimVXBackendWrapper(Mat& m);
TimVXBackendWrapper(const Ptr<BackendWrapper>& baseBuffer, Mat& m);
TimVXBackendWrapper(std::shared_ptr<tim::vx::Tensor>& tensor);
// Create Output Tensor
void createTensor(std::shared_ptr<tim::vx::Graph>& graph, tim::vx::TensorAttribute tensorAttribute);
void createTensor(std::shared_ptr<tim::vx::Graph>& graph, tim::vx::TensorAttribute tensorAttribute,
Ptr<tim::vx::Quantization>& tvQuant);
std::shared_ptr<tim::vx::Tensor> getTensor();
Mat getMat();
// The Output tensor in TimVX doesn't have HostMat, The shape can only be given.
void setTensorShape(const tim::vx::ShapeType & matShape);
int getTensorIndex();
Ptr<tim::vx::Quantization> getTensorQuantization();
tim::vx::TensorAttribute getTensorAttr();
bool isTensor();
// Data Copy, CPU <==> NPU
virtual void copyToHost() CV_OVERRIDE;
virtual void setHostDirty() CV_OVERRIDE;
void setDeviceDirty();
void copyToDevice();
private:
tim::vx::DataType tensorType;
bool deviceDirty;
bool hostDirty;
int tensorIndex; // index of tensorList in specific TimVXGraph.
bool isTensor_;
Mat host;
tim::vx::ShapeType tensorShape;
std::shared_ptr<tim::vx::Tensor> tensor;
tim::vx::TensorAttribute tensorAttr;
};
// Contain all created tvGraphList, used in every
class TimVXInfo{
public:
TimVXInfo();
~TimVXInfo();
// Find the right graph Index set as graphIndex, if cannot find, return empty ptr.
Ptr<TimVXGraph> getGraph();
bool findGraphIndex(const std::vector<Ptr<BackendWrapper> > &inputsWrapper, int& graphIndex);
void setTmpGraphIndex(int graphIndex);
bool isConflict(int layerId, int graphIndex);
// create a TimVXGraph, add it to tvGraphList, and return the index in tvGraphList.
int createGraph();
// graphConflictIndex[layerIndex] saves conflict graph index, which should be excluded
std::vector<std::vector<int> > graphConflictMap;
private:
int getTmpGraphIndex();
std::vector<Ptr<TimVXGraph> > tvGraphList;
int graphIndex;
};
#endif
void forwardTimVX(std::vector<Ptr<BackendWrapper> > &outputs, const Ptr<BackendNode>& node);
bool haveTimVX();
} // namespace dnn
} // namespace cv
#endif // OPENCV_DNN_OP_TIMVX_HPP

View File

@ -10,6 +10,7 @@
#include "op_vkcom.hpp"
#include "op_cuda.hpp"
#include "op_webnn.hpp"
#include "op_timvx.hpp"
#include "halide_scheduler.hpp"
@ -109,6 +110,13 @@ private:
backends.push_back(std::make_pair(DNN_BACKEND_CUDA, DNN_TARGET_CUDA_FP16));
}
#endif
#ifdef HAVE_TIMVX
if (haveTimVX())
{
backends.push_back(std::make_pair(DNN_BACKEND_TIMVX, DNN_TARGET_NPU));
}
#endif
}
BackendsList backends;

View File

@ -48,6 +48,7 @@
#define CV_TEST_TAG_DNN_SKIP_ONNX_CONFORMANCE "dnn_skip_onnx_conformance"
#define CV_TEST_TAG_DNN_SKIP_PARSER "dnn_skip_parser"
#define CV_TEST_TAG_DNN_SKIP_TIMVX "dnn_skip_timvx"
#ifdef HAVE_INF_ENGINE
#if INF_ENGINE_VER_MAJOR_EQ(2018050000)

View File

@ -30,6 +30,7 @@ void PrintTo(const cv::dnn::Backend& v, std::ostream* os)
case DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019: *os << "DLIE"; return;
case DNN_BACKEND_INFERENCE_ENGINE_NGRAPH: *os << "NGRAPH"; return;
case DNN_BACKEND_WEBNN: *os << "WEBNN"; return;
case DNN_BACKEND_TIMVX: *os << "TIMVX"; return;
} // don't use "default:" to emit compiler warnings
*os << "DNN_BACKEND_UNKNOWN(" << (int)v << ")";
}
@ -46,6 +47,7 @@ void PrintTo(const cv::dnn::Target& v, std::ostream* os)
case DNN_TARGET_FPGA: *os << "FPGA"; return;
case DNN_TARGET_CUDA: *os << "CUDA"; return;
case DNN_TARGET_CUDA_FP16: *os << "CUDA_FP16"; return;
case DNN_TARGET_NPU: *os << "NPU"; return;
} // don't use "default:" to emit compiler warnings
*os << "DNN_TARGET_UNKNOWN(" << (int)v << ")";
}
@ -478,6 +480,11 @@ void initDNNTests()
registerGlobalSkipTag(
CV_TEST_TAG_DNN_SKIP_CUDA, CV_TEST_TAG_DNN_SKIP_CUDA_FP32, CV_TEST_TAG_DNN_SKIP_CUDA_FP16
);
#endif
#ifdef HAVE_TIMVX
registerGlobalSkipTag(
CV_TEST_TAG_DNN_SKIP_TIMVX
);
#endif
registerGlobalSkipTag(
CV_TEST_TAG_DNN_SKIP_ONNX_CONFORMANCE,

View File

@ -12,6 +12,9 @@ testing::internal::ParamGenerator< tuple<Backend, Target> > dnnBackendsAndTarget
{
std::vector< tuple<Backend, Target> > targets;
targets.push_back(make_tuple(DNN_BACKEND_OPENCV, DNN_TARGET_CPU));
#ifdef HAVE_TIMVX
targets.push_back(make_tuple(DNN_BACKEND_TIMVX, DNN_TARGET_NPU));
#endif
return testing::ValuesIn(targets);
}
@ -104,14 +107,29 @@ TEST_P(Test_Int8_layers, Convolution1D)
TEST_P(Test_Int8_layers, Convolution2D)
{
testLayer("layer_convolution", "Caffe", 0.0174, 0.0758, 1, 1, true);
testLayer("single_conv", "TensorFlow", 0.00413, 0.02201);
testLayer("depthwise_conv2d", "TensorFlow", 0.0388, 0.169);
if(backend == DNN_BACKEND_TIMVX)
testLayer("single_conv", "TensorFlow", 0.00424, 0.02201);
else
testLayer("single_conv", "TensorFlow", 0.00413, 0.02201);
testLayer("atrous_conv2d_valid", "TensorFlow", 0.0193, 0.0633);
testLayer("atrous_conv2d_same", "TensorFlow", 0.0185, 0.1322);
testLayer("keras_atrous_conv2d_same", "TensorFlow", 0.0056, 0.0244);
testLayer("convolution", "ONNX", 0.0052, 0.01516);
testLayer("two_convolution", "ONNX", 0.00295, 0.00840);
if(backend == DNN_BACKEND_TIMVX)
testLayer("convolution", "ONNX", 0.00534, 0.01516);
else
testLayer("convolution", "ONNX", 0.0052, 0.01516);
if(backend == DNN_BACKEND_TIMVX)
testLayer("two_convolution", "ONNX", 0.0033, 0.01);
else
testLayer("two_convolution", "ONNX", 0.00295, 0.00840);
if(backend == DNN_BACKEND_TIMVX)
applyTestTag(CV_TEST_TAG_DNN_SKIP_TIMVX);
testLayer("layer_convolution", "Caffe", 0.0174, 0.0758, 1, 1, true);
testLayer("depthwise_conv2d", "TensorFlow", 0.0388, 0.169);
}
TEST_P(Test_Int8_layers, Convolution3D)
@ -130,9 +148,21 @@ TEST_P(Test_Int8_layers, Flatten)
TEST_P(Test_Int8_layers, Padding)
{
testLayer("padding_valid", "TensorFlow", 0.0026, 0.0064);
testLayer("padding_same", "TensorFlow", 0.0081, 0.032);
testLayer("spatial_padding", "TensorFlow", 0.0078, 0.028);
if (backend == DNN_BACKEND_TIMVX)
testLayer("padding_valid", "TensorFlow", 0.0292, 0.0105);
else
testLayer("padding_valid", "TensorFlow", 0.0026, 0.0064);
if (backend == DNN_BACKEND_TIMVX)
testLayer("padding_same", "TensorFlow", 0.0085, 0.032);
else
testLayer("padding_same", "TensorFlow", 0.0081, 0.032);
if (backend == DNN_BACKEND_TIMVX)
testLayer("spatial_padding", "TensorFlow", 0.0079, 0.028);
else
testLayer("spatial_padding", "TensorFlow", 0.0078, 0.028);
testLayer("mirror_pad", "TensorFlow", 0.0064, 0.013);
testLayer("pad_and_concat", "TensorFlow", 0.0021, 0.0098);
testLayer("padding", "ONNX", 0.0005, 0.0069);
@ -283,20 +313,35 @@ TEST_P(Test_Int8_layers, InnerProduct)
{
testLayer("layer_inner_product", "Caffe", 0.005, 0.02, 1, 1, true);
testLayer("matmul", "TensorFlow", 0.0061, 0.019);
testLayer("nhwc_transpose_reshape_matmul", "TensorFlow", 0.0009, 0.0091);
if (backend == DNN_BACKEND_TIMVX)
testLayer("nhwc_transpose_reshape_matmul", "TensorFlow", 0.0018, 0.0175);
else
testLayer("nhwc_transpose_reshape_matmul", "TensorFlow", 0.0009, 0.0091);
testLayer("nhwc_reshape_matmul", "TensorFlow", 0.03, 0.071);
testLayer("matmul_layout", "TensorFlow", 0.035, 0.06);
testLayer("tf2_dense", "TensorFlow", 0, 0);
testLayer("matmul_add", "ONNX", 0.041, 0.082);
testLayer("linear", "ONNX", 0.0018, 0.0029);
testLayer("constant", "ONNX", 0.00021, 0.0006);
if (backend == DNN_BACKEND_TIMVX)
testLayer("constant", "ONNX", 0.00048, 0.0013);
else
testLayer("constant", "ONNX", 0.00021, 0.0006);
testLayer("lin_with_constant", "ONNX", 0.0011, 0.0016);
}
TEST_P(Test_Int8_layers, Reshape)
{
testLayer("reshape_layer", "TensorFlow", 0.0032, 0.0082);
testLayer("reshape_nchw", "TensorFlow", 0.0089, 0.029);
if (backend == DNN_BACKEND_TIMVX)
testLayer("reshape_nchw", "TensorFlow", 0.0092, 0.0495);
else
testLayer("reshape_nchw", "TensorFlow", 0.0089, 0.029);
testLayer("reshape_conv", "TensorFlow", 0.035, 0.054);
testLayer("reshape_reduce", "TensorFlow", 0.0042, 0.0078);
testLayer("reshape_as_shape", "TensorFlow", 0.0014, 0.0028);
@ -307,7 +352,12 @@ TEST_P(Test_Int8_layers, Reshape)
testLayer("flatten_by_prod", "ONNX", 0.0048, 0.0081);
testLayer("squeeze", "ONNX", 0.0048, 0.0081);
testLayer("unsqueeze", "ONNX", 0.0033, 0.0053);
testLayer("squeeze_and_conv_dynamic_axes", "ONNX", 0.0054, 0.0154);
if (backend == DNN_BACKEND_TIMVX)
testLayer("squeeze_and_conv_dynamic_axes", "ONNX", 0.006, 0.0212);
else
testLayer("squeeze_and_conv_dynamic_axes", "ONNX", 0.0054, 0.0154);
testLayer("unsqueeze_and_conv_dynamic_axes", "ONNX", 0.0037, 0.0151);
}
@ -378,6 +428,10 @@ TEST_P(Test_Int8_layers, Dropout)
TEST_P(Test_Int8_layers, Eltwise)
{
testLayer("layer_eltwise", "Caffe", 0.062, 0.15);
if (backend == DNN_BACKEND_TIMVX)
applyTestTag(CV_TEST_TAG_DNN_SKIP_TIMVX);
testLayer("conv_2_inps", "Caffe", 0.0086, 0.0232, 2, 1, true, false);
testLayer("eltwise_sub", "TensorFlow", 0.015, 0.047);
testLayer("eltwise_add_vec", "TensorFlow", 0.037, 0.21); // tflite 0.0095, 0.0365
@ -862,6 +916,8 @@ TEST_P(Test_Int8_nets, EfficientDet)
applyTestTag(CV_TEST_TAG_DNN_SKIP_OPENCL_FP16);
if (target == DNN_TARGET_OPENCL && !ocl::Device::getDefault().isIntel())
applyTestTag(CV_TEST_TAG_DNN_SKIP_OPENCL);
if (backend == DNN_BACKEND_TIMVX)
applyTestTag(CV_TEST_TAG_DNN_SKIP_TIMVX);
if (target != DNN_TARGET_CPU)
{