Merge pull request #24411 from alexlyulkov:al/dnn-type-inference

Added int32, int64 support and type inference to dnn #24411

**Added a type inference to dnn similar to the shape inference, added int32 and int64 support.**

- Added getTypes method for layers that calculates layer outputs types and internals types from inputs types (Similar to getMemoryShapes). By default outputs and internals types = input[0] type
- Added type inference pipeline similar to shape inference pipeline. LayersShapes struct (that is used in shape inference pipeline) now contains both shapes and types
- All layers output blobs are now allocated using the calculated types from the type inference.
- Inputs and constants with int32 and int64 types are not automatically converted into float32 now.
- Added int32 and int64 support for all the layers with indexing and for all the layers required in tests.

Added  int32 and int64 support for CUDA:
- Added host<->device data moving for int32 and int64
- Added int32 and int64 support for several layers (just slightly modified CUDA C++ templates)

Passed all the accuracy tests on CPU, OCL, OCL_FP16, CUDA, CUDA_FP16. (except RAFT model)

**CURRENT PROBLEMS**:
-  ONNX parser always converts int64 constants and layers attributes to int32, so some models with int64 constants doesn't work (e.g. RAFT). The solution is to disable int64->int32 conversion and fix attributes reading in a lot of ONNX layers parsers (https://github.com/opencv/opencv/issues/25102)
- I didn't add type inference and int support to VULCAN, so it doesn't work at all now.
- Some layers don't support int yet, so some unknown models may not work.

**CURRENT WORKAROUNDS**:
- CPU arg_layer indides are implemented in int32 followed by a int32->int64 conversion (the master branch has the same workaround with int32->float conversion)
- CPU and OCL pooling_layer indices are implemented in float followed by a float->int64 conversion
- CPU gather_layer indices are implemented in int32, so int64 indices are converted to int32 (the master branch has the same workaround with float->int32 conversion)

**DISABLED TESTS**:
- RAFT model

**REMOVED TESTS**:
- Greater_input_dtype_int64 (because it doesn't fit ONNX rules, the whole test is just comparing float tensor with int constant)

**TODO IN NEXT PULL REQUESTS**:
- Add int64 support for ONNX parser
- Add int support for more layers
- Add int support for OCL (currently int layers just run on CPU)
- Add int tests
- Add int support for other backends
This commit is contained in:
alexlyulkov 2024-03-01 17:07:38 +03:00 committed by GitHub
parent 81956ad83e
commit 1d1faaabef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
53 changed files with 1113 additions and 286 deletions

View File

@ -25,14 +25,16 @@ jobs:
Windows10-x64:
uses: opencv/ci-gha-workflow/.github/workflows/OCV-PR-5.x-W10.yaml@main
Windows10-x64-Vulkan:
uses: opencv/ci-gha-workflow/.github/workflows/OCV-PR-5.x-W10-Vulkan.yaml@main
# Vulkan configuration disabled as Vulkan backend for DNN does not support int/int64 for now
# Details: https://github.com/opencv/opencv/issues/25110
# Windows10-x64-Vulkan:
# uses: opencv/ci-gha-workflow/.github/workflows/OCV-PR-5.x-W10-Vulkan.yaml@main
macOS-ARM64:
uses: opencv/ci-gha-workflow/.github/workflows/OCV-PR-5.x-macOS-ARM64.yaml@main
macOS-ARM64-Vulkan:
uses: opencv/ci-gha-workflow/.github/workflows/OCV-PR-5.x-macOS-ARM64-Vulkan.yaml@main
# macOS-ARM64-Vulkan:
# uses: opencv/ci-gha-workflow/.github/workflows/OCV-PR-5.x-macOS-ARM64-Vulkan.yaml@main
macOS-x64:
uses: opencv/ci-gha-workflow/.github/workflows/OCV-PR-5.x-macOS-x86_64.yaml@main

View File

@ -62,6 +62,7 @@ CV__DNN_INLINE_NS_BEGIN
//! @{
typedef std::vector<int> MatShape;
typedef int MatType;
/**
* @brief Enum of computation backends supported by layers.
@ -205,8 +206,16 @@ CV__DNN_INLINE_NS_BEGIN
*/
virtual void setHostDirty() = 0;
int getHostMatDepth() {
CV_Assert(hostMatDepth != -1);
return hostMatDepth;
}
int backendId; //!< Backend identifier.
int targetId; //!< Target identifier.
protected:
int hostMatDepth = -1;
};
class CV_EXPORTS ActivationLayer;
@ -397,6 +406,12 @@ CV__DNN_INLINE_NS_BEGIN
std::vector<MatShape> &outputs,
std::vector<MatShape> &internals) const;
virtual void getTypes(const std::vector<MatType>& inputs,
const int requiredOutputs,
const int requiredInternals,
std::vector<MatType>&outputs,
std::vector<MatType>&internals) const;
virtual int64 getFLOPS(const std::vector<MatShape> &inputs,
const std::vector<MatShape> &outputs) const {CV_UNUSED(inputs); CV_UNUSED(outputs); return 0;}
@ -675,6 +690,7 @@ CV__DNN_INLINE_NS_BEGIN
/** @brief Returns input and output shapes for all layers in loaded model;
* preliminary inferencing isn't necessary.
* @param netInputShapes shapes for all input blobs in net input layer.
* @param netInputTypes types for all input blobs in net input layer.
* @param layersIds output parameter for layer IDs.
* @param inLayersShapes output parameter for input layers shapes;
* order is the same as in layersIds
@ -682,12 +698,14 @@ CV__DNN_INLINE_NS_BEGIN
* order is the same as in layersIds
*/
CV_WRAP void getLayersShapes(const std::vector<MatShape>& netInputShapes,
const std::vector<int>& netInputTypes,
CV_OUT std::vector<int>& layersIds,
CV_OUT std::vector<std::vector<MatShape> >& inLayersShapes,
CV_OUT std::vector<std::vector<MatShape> >& outLayersShapes) const;
/** @overload */
CV_WRAP void getLayersShapes(const MatShape& netInputShape,
const int& netInputType,
CV_OUT std::vector<int>& layersIds,
CV_OUT std::vector<std::vector<MatShape> >& inLayersShapes,
CV_OUT std::vector<std::vector<MatShape> >& outLayersShapes) const;
@ -695,6 +713,7 @@ CV__DNN_INLINE_NS_BEGIN
/** @brief Returns input and output shapes for layer with specified
* id in loaded model; preliminary inferencing isn't necessary.
* @param netInputShape shape input blob in net input layer.
* @param netInputType input type in net input layer.
* @param layerId id for layer.
* @param inLayerShapes output parameter for input layers shapes;
* order is the same as in layersIds
@ -702,29 +721,36 @@ CV__DNN_INLINE_NS_BEGIN
* order is the same as in layersIds
*/
void getLayerShapes(const MatShape& netInputShape,
const int& netInputType,
const int layerId,
CV_OUT std::vector<MatShape>& inLayerShapes,
CV_OUT std::vector<MatShape>& outLayerShapes) const; // FIXIT: CV_WRAP
/** @overload */
void getLayerShapes(const std::vector<MatShape>& netInputShapes,
const std::vector<int>& netInputTypes,
const int layerId,
CV_OUT std::vector<MatShape>& inLayerShapes,
CV_OUT std::vector<MatShape>& outLayerShapes) const; // FIXIT: CV_WRAP
/** @brief Computes FLOP for whole loaded model with specified input shapes.
* @param netInputShapes vector of shapes for all net inputs.
* @param netInputTypes vector of types for all net inputs.
* @returns computed FLOP.
*/
CV_WRAP int64 getFLOPS(const std::vector<MatShape>& netInputShapes) const;
CV_WRAP int64 getFLOPS(const std::vector<MatShape>& netInputShapes,
const std::vector<int>& netInputTypes) const;
/** @overload */
CV_WRAP int64 getFLOPS(const MatShape& netInputShape) const;
CV_WRAP int64 getFLOPS(const MatShape& netInputShape,
const int& netInputType) const;
/** @overload */
CV_WRAP int64 getFLOPS(const int layerId,
const std::vector<MatShape>& netInputShapes) const;
const std::vector<MatShape>& netInputShapes,
const std::vector<int>& netInputTypes) const;
/** @overload */
CV_WRAP int64 getFLOPS(const int layerId,
const MatShape& netInputShape) const;
const MatShape& netInputShape,
const int& netInputType) const;
/** @brief Returns list of types for layer used in model.
* @param layersTypes output parameter for returning types.
@ -740,36 +766,44 @@ CV__DNN_INLINE_NS_BEGIN
/** @brief Computes bytes number which are required to store
* all weights and intermediate blobs for model.
* @param netInputShapes vector of shapes for all net inputs.
* @param netInputTypes vector of types for all net inputs.
* @param weights output parameter to store resulting bytes for weights.
* @param blobs output parameter to store resulting bytes for intermediate blobs.
*/
void getMemoryConsumption(const std::vector<MatShape>& netInputShapes,
const std::vector<int>& netInputTypes,
CV_OUT size_t& weights, CV_OUT size_t& blobs) const; // FIXIT: CV_WRAP
/** @overload */
CV_WRAP void getMemoryConsumption(const MatShape& netInputShape,
const int& netInputType,
CV_OUT size_t& weights, CV_OUT size_t& blobs) const;
/** @overload */
CV_WRAP void getMemoryConsumption(const int layerId,
const std::vector<MatShape>& netInputShapes,
const std::vector<int>& netInputTypes,
CV_OUT size_t& weights, CV_OUT size_t& blobs) const;
/** @overload */
CV_WRAP void getMemoryConsumption(const int layerId,
const MatShape& netInputShape,
const int& netInputType,
CV_OUT size_t& weights, CV_OUT size_t& blobs) const;
/** @brief Computes bytes number which are required to store
* all weights and intermediate blobs for each layer.
* @param netInputShapes vector of shapes for all net inputs.
* @param netInputTypes vector of types for all net inputs.
* @param layerIds output vector to save layer IDs.
* @param weights output parameter to store resulting bytes for weights.
* @param blobs output parameter to store resulting bytes for intermediate blobs.
*/
void getMemoryConsumption(const std::vector<MatShape>& netInputShapes,
const std::vector<int>& netInputTypes,
CV_OUT std::vector<int>& layerIds,
CV_OUT std::vector<size_t>& weights,
CV_OUT std::vector<size_t>& blobs) const; // FIXIT: CV_WRAP
/** @overload */
void getMemoryConsumption(const MatShape& netInputShape,
const int& netInputType,
CV_OUT std::vector<int>& layerIds,
CV_OUT std::vector<size_t>& weights,
CV_OUT std::vector<size_t>& blobs) const; // FIXIT: CV_WRAP

View File

@ -97,10 +97,11 @@ public class DnnListRegressionTest extends OpenCVTestCase {
int layerId = 1;
List<MatOfInt> netInputShapes = new ArrayList();
netInputShapes.add(new MatOfInt(1, 3, 224, 224));
MatOfInt netInputTypes = new MatOfInt(5);
long[] weights=null;
long[] blobs=null;
try {
net.getMemoryConsumption(layerId, netInputShapes, weights, blobs);
net.getMemoryConsumption(layerId, netInputShapes, netInputTypes, weights, blobs);
} catch(Exception e) {
fail("Net getMemoryConsumption failed: " + e.getMessage());
}
@ -110,8 +111,9 @@ public class DnnListRegressionTest extends OpenCVTestCase {
int layerId = 1;
List<MatOfInt> netInputShapes = new ArrayList();
netInputShapes.add(new MatOfInt(1, 3, 224, 224));
MatOfInt netInputTypes = new MatOfInt(5);
try {
net.getFLOPS(layerId, netInputShapes);
net.getFLOPS(layerId, netInputShapes, netInputTypes);
} catch(Exception e) {
fail("Net getFLOPS failed: " + e.getMessage());
}

View File

@ -886,9 +886,17 @@ Net build_net(
Mat output = net.forward();
MatShape netInputShape = shape(input);
cv::dnn::MatType netInputType = input.depth();
bool fp16 = false;
#ifdef HAVE_OPENCL
fp16 = ocl::Device::getDefault().isExtensionSupported("cl_khr_fp16");
#endif
if (netInputType == CV_32F && fp16 && targetId == DNN_TARGET_OPENCL_FP16)
netInputType = CV_16F;
size_t weightsMemory = 0, blobsMemory = 0;
net.getMemoryConsumption(netInputShape, weightsMemory, blobsMemory);
int64 flops = net.getFLOPS(netInputShape);
net.getMemoryConsumption(netInputShape, netInputType, weightsMemory, blobsMemory);
int64 flops = net.getFLOPS(netInputShape, netInputType);
CV_Assert(flops > 0);
std::cout

View File

@ -136,9 +136,17 @@ PERF_TEST_P_(Conv1D, conv1d)
Mat output = net.forward();
MatShape netInputShape = shape(input);
cv::dnn::MatType netInputType = input.depth();
bool fp16 = false;
#ifdef HAVE_OPENCL
fp16 = ocl::Device::getDefault().isExtensionSupported("cl_khr_fp16");
#endif
if (netInputType == CV_32F && fp16 && targetId == DNN_TARGET_OPENCL_FP16)
netInputType = CV_16F;
size_t weightsMemory = 0, blobsMemory = 0;
net.getMemoryConsumption(netInputShape, weightsMemory, blobsMemory);
int64 flops = net.getFLOPS(netInputShape);
net.getMemoryConsumption(netInputShape, netInputType, weightsMemory, blobsMemory);
int64 flops = net.getFLOPS(netInputShape, netInputType);
CV_Assert(flops > 0);
std::cout

View File

@ -155,9 +155,17 @@ PERF_TEST_P_(Conv3D, conv3d)
Mat output = net.forward();
MatShape netInputShape = shape(input);
cv::dnn::MatType netInputType = input.depth();
bool fp16 = false;
#ifdef HAVE_OPENCL
fp16 = ocl::Device::getDefault().isExtensionSupported("cl_khr_fp16");
#endif
if (netInputType == CV_32F && fp16 && targetId == DNN_TARGET_OPENCL_FP16)
netInputType = CV_16F;
size_t weightsMemory = 0, blobsMemory = 0;
net.getMemoryConsumption(netInputShape, weightsMemory, blobsMemory);
int64 flops = net.getFLOPS(netInputShape);
net.getMemoryConsumption(netInputShape, netInputType, weightsMemory, blobsMemory);
int64 flops = net.getFLOPS(netInputShape, netInputType);
CV_Assert(flops > 0);
std::cout

View File

@ -267,15 +267,13 @@ PERF_TEST_P_(Layer_Scatter, scatter) {
int target_id = get<1>(get<3>(GetParam()));
Mat data(shape, CV_32FC1);
Mat indices(shape, CV_32FC1);
Mat indices(shape, CV_64SC1);
Mat updates(shape, CV_32FC1);
randn(data, 0.f, 1.f);
randu(indices, 0, shape[axis]);
randn(updates, 0.f, 1.f);
indices.convertTo(indices, CV_32SC1, 1, -1);
Net net;
LayerParams lp;
lp.type = "Scatter";
@ -334,7 +332,7 @@ PERF_TEST_P_(Layer_ScatterND, scatterND) {
std::vector<int> indices_shape(shape);
indices_shape.push_back(int(shape.size()));
Mat data(shape, CV_32FC1);
Mat indices(indices_shape, CV_32FC1);
Mat indices(indices_shape, CV_32SC1);
Mat updates(shape, CV_32FC1);
randn(data, 0.f, 1.f);
@ -346,11 +344,11 @@ PERF_TEST_P_(Layer_ScatterND, scatterND) {
std::vector<int> indices_step;
for (int i = 0; i < indices.dims; i++)
{
int step = indices.step.p[i] / sizeof(float);
int step = indices.step.p[i] / sizeof(int32_t);
indices_step.push_back(step);
}
int t, j, idx, offset_at_idx, offset;
auto *indices_ptr = indices.ptr<float>();
auto *indices_ptr = indices.ptr<int32_t>();
for (int i = 0; i < total; i++)
{
t = i;
@ -629,7 +627,7 @@ struct Layer_GatherElements : public TestBaseWithParam<tuple<Backend, Target> >
int targetId = get<1>(GetParam());
Mat data(data_shape, CV_32FC1);
Mat indices(indices_shape, CV_32FC1);
Mat indices(indices_shape, CV_64SC1);
randu(data, 0.f, 1.f);
randu(indices, 0, data_shape[axis]);

View File

@ -47,13 +47,25 @@ public:
for(auto &inp: inputs){
netMatShapes.push_back(shape(std::get<0>(inp)));
}
size_t weightsMemory = 0, blobsMemory = 0;
net.getMemoryConsumption(netMatShapes, weightsMemory, blobsMemory);
int64 flops = net.getFLOPS(netMatShapes);
CV_Assert(flops > 0);
bool fp16 = false;
#ifdef HAVE_OPENCL
fp16 = ocl::Device::getDefault().isExtensionSupported("cl_khr_fp16");
#endif
std::vector<cv::dnn::MatType> netMatTypes;
for (auto& inp : inputs) {
cv::dnn::MatType t = std::get<0>(inp).depth();
if (t == CV_32F && fp16 && target == DNN_TARGET_OPENCL_FP16)
t = CV_16F;
netMatTypes.push_back(t);
}
net.forward(outputLayer); // warmup
size_t weightsMemory = 0, blobsMemory = 0;
net.getMemoryConsumption(netMatShapes, netMatTypes, weightsMemory, blobsMemory);
int64 flops = net.getFLOPS(netMatShapes, netMatTypes);
CV_Assert(flops > 0);
std::cout << "Memory consumption:" << std::endl;
std::cout << " Weights(parameters): " << divUp(weightsMemory, 1u<<20) << " Mb" << std::endl;
std::cout << " Blobs: " << divUp(blobsMemory, 1u<<20) << " Mb" << std::endl;

View File

@ -152,6 +152,8 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
template void concat<__half>(const Stream&, TensorSpan<__half>, std::size_t, TensorView<__half>, std::size_t);
#endif
template void concat<float>(const Stream&, TensorSpan<float>, std::size_t, TensorView<float>, std::size_t);
template void concat<int32_t>(const Stream&, TensorSpan<int32_t>, std::size_t, TensorView<int32_t>, std::size_t);
template void concat<int64_t>(const Stream&, TensorSpan<int64_t>, std::size_t, TensorView<int64_t>, std::size_t);
template <class T, std::size_t Rank> static
void launch_concat_with_offsets(
@ -271,7 +273,11 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
concat_with_offsets_dispatcher<T, 1, CSL_MAX_TENSOR_RANK>(rank, stream, output, outStride, offsets, input, inStride);
}
#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
template void concat_with_offsets(const Stream&, TensorSpan<__half>, TensorView<__half>, std::vector<std::size_t>);
#endif
template void concat_with_offsets(const Stream&, TensorSpan<float>, TensorView<float>, std::vector<std::size_t>);
template void concat_with_offsets(const Stream&, TensorSpan<int32_t>, TensorView<int32_t>, std::vector<std::size_t>);
template void concat_with_offsets(const Stream&, TensorSpan<int64_t>, TensorView<int64_t>, std::vector<std::size_t>);
}}}} /* namespace cv::dnn::cuda4dnn::kernels */

View File

@ -371,4 +371,25 @@ void eltwise_fmod_2(const Stream& stream, TensorSpan<T> output, TensorView<T> x,
template void eltwise_max_2(const Stream& stream, TensorSpan<float> output, TensorView<float> x, TensorView<float> y);
template void eltwise_min_2(const Stream& stream, TensorSpan<float> output, TensorView<float> x, TensorView<float> y);
template void eltwise_mod_2(const Stream& stream, TensorSpan<int32_t> output, TensorView<int32_t> x, TensorView<int32_t> y);
template void eltwise_fmod_2(const Stream& stream, TensorSpan<int32_t> output, TensorView<int32_t> x, TensorView<int32_t> y);
template void eltwise_sub_2(const Stream& stream, TensorSpan<int32_t> output, TensorView<int32_t> x, TensorView<int32_t> y);
template void eltwise_div_2(const Stream& stream, TensorSpan<int32_t> output, TensorView<int32_t> x, TensorView<int32_t> y);
template void eltwise_prod_2(const Stream& stream, TensorSpan<int32_t> output, TensorView<int32_t> x, TensorView<int32_t> y);
template void eltwise_sum_coeff_2(const Stream&, TensorSpan<int32_t>, int32_t, TensorView<int32_t>, int32_t, TensorView<int32_t>);
template void eltwise_sum_2(const Stream& stream, TensorSpan<int32_t> output, TensorView<int32_t> x, TensorView<int32_t> y);
template void eltwise_max_2(const Stream& stream, TensorSpan<int32_t> output, TensorView<int32_t> x, TensorView<int32_t> y);
template void eltwise_min_2(const Stream& stream, TensorSpan<int32_t> output, TensorView<int32_t> x, TensorView<int32_t> y);
template void eltwise_mod_2(const Stream& stream, TensorSpan<int64_t> output, TensorView<int64_t> x, TensorView<int64_t> y);
template void eltwise_fmod_2(const Stream& stream, TensorSpan<int64_t> output, TensorView<int64_t> x, TensorView<int64_t> y);
template void eltwise_sub_2(const Stream& stream, TensorSpan<int64_t> output, TensorView<int64_t> x, TensorView<int64_t> y);
template void eltwise_div_2(const Stream& stream, TensorSpan<int64_t> output, TensorView<int64_t> x, TensorView<int64_t> y);
template void eltwise_prod_2(const Stream& stream, TensorSpan<int64_t> output, TensorView<int64_t> x, TensorView<int64_t> y);
template void eltwise_sum_coeff_2(const Stream&, TensorSpan<int64_t>, int64_t, TensorView<int64_t>, int64_t, TensorView<int64_t>);
template void eltwise_sum_2(const Stream& stream, TensorSpan<int64_t> output, TensorView<int64_t> x, TensorView<int64_t> y);
template void eltwise_max_2(const Stream& stream, TensorSpan<int64_t> output, TensorView<int64_t> x, TensorView<int64_t> y);
template void eltwise_min_2(const Stream& stream, TensorSpan<int64_t> output, TensorView<int64_t> x, TensorView<int64_t> y);
}}}} /* namespace cv::dnn::cuda4dnn::kernels */

View File

@ -68,6 +68,7 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
#endif
template void fill(const Stream&, Span<float>, float);
template void fill(const Stream&, Span<int>, int);
template void fill(const Stream&, Span<int64_t>, int64_t);
template <class T, std::size_t N> static
void launch_vectorized_copy(const Stream& stream, Span<T> output, View<T> input) {
@ -94,5 +95,7 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
template void copy(const Stream&, Span<__half>, View<__half>);
#endif
template void copy(const Stream&, Span<float>, View<float>);
template void copy(const Stream&, Span<int32_t>, View<int32_t>);
template void copy(const Stream&, Span<int64_t>, View<int64_t>);
}}}} /* namespace cv::dnn::cuda4dnn::kernels */

View File

@ -31,6 +31,20 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace de
__device__ static float lowest() { return -FLT_MAX; }
};
template <>
struct numeric_limits<int32_t> {
__device__ static int32_t min() { return 1; }
__device__ static int32_t max() { return INT_MAX; }
__device__ static int32_t lowest() { return INT_MIN; }
};
template <>
struct numeric_limits<int64_t> {
__device__ static int64_t min() { return 1; }
__device__ static int64_t max() { return LLONG_MAX; }
__device__ static int64_t lowest() { return LLONG_MIN; }
};
}}}}} /* namespace cv::dnn::cuda4dnn::csl::device */
#endif /* OPENCV_DNN_SRC_CUDA_LIMITS_HPP */

View File

@ -30,10 +30,10 @@ using namespace cv::dnn::cuda4dnn::csl::device;
namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
namespace raw {
template <class T, std::size_t Order,
template <class T, class T_INDEX, std::size_t Order,
typename std::enable_if<Order == 1 || Order == 2 || Order == 3, bool>::type = true> /* Order has been hardcoded; see code */
__global__ void max_pooling_with_indices(
Span<T> output, Span<T> indices, View<T> input, size_type channels,
Span<T> output, Span<T_INDEX> indices, View<T> input, size_type channels,
array<size_type, Order> out_spatial_dims, array<size_type, Order> in_spatial_dims,
array<size_type, Order> window_size, array<size_type, Order> strides, array<size_type, Order> padding_left)
{
@ -130,9 +130,9 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
}
}
template <class T, std::size_t Order>
template <class T, class T_INDEX, std::size_t Order>
__global__ void max_unpooling(
Span<T> output, View<T> input, View<T> indices, size_type channels,
Span<T> output, View<T> input, View<T_INDEX> indices, size_type channels,
array<size_type, Order> out_spatial_dims, array<size_type, Order> in_spatial_dims,
array<size_type, Order> window_size, array<size_type, Order> strides, array<size_type, Order> padding_left)
{
@ -164,15 +164,15 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
out_spatial_size *= out_spatial_dims[i];
index_type outer_offset = (n * channels + c) * out_spatial_size;
output[outer_offset + static_cast<index_type>(indices[idx])] = input[idx];
output[outer_offset + indices[idx]] = input[idx];
}
}
}
template <class T, std::size_t Order> static
template <class T, class T_INDEX, std::size_t Order> static
void launch_max_pooling_kernel(
const Stream& stream,
Span<T> output, Span<T> indices, View<T> input, std::size_t channels,
Span<T> output, Span<T_INDEX> indices, View<T> input, std::size_t channels,
const std::vector<std::size_t>& out_spatial_dims, const std::vector<std::size_t>& in_spatial_dims,
const std::vector<std::size_t>& window_size,
const std::vector<std::size_t>& strides, const std::vector<std::size_t>& padding_left)
@ -193,16 +193,16 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
strides_k.assign(std::begin(strides), std::end(strides));
padding_left_k.assign(std::begin(padding_left), std::end(padding_left));
auto kernel = raw::max_pooling_with_indices<T, Order>;
auto kernel = raw::max_pooling_with_indices<T, T_INDEX, Order>;
auto policy = make_policy(kernel, output.size(), 0, stream);
launch_kernel(kernel, policy, output, indices, input, channels,
out_spatial_dims_k, in_spatial_dims_k, window_size_k, strides_k, padding_left_k);
}
template <class T>
template <class T, class T_INDEX>
void max_pooling_with_indices(
const Stream& stream,
TensorSpan<T> output, TensorSpan<T> indices, TensorView<T> input,
TensorSpan<T> output, TensorSpan<T_INDEX> indices, TensorView<T> input,
const std::vector<std::size_t>& window_size, const std::vector<std::size_t>& strides,
const std::vector<std::size_t>& padding_left)
{
@ -224,33 +224,63 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
CV_Assert(1 <= order && order <= 3);
std::size_t channels = input.get_axis_size(1);
if (order == 3) {
launch_max_pooling_kernel<T, 3>(stream, output, indices, input, channels,
launch_max_pooling_kernel<T, T_INDEX, 3>(stream, output, indices, input, channels,
out_spatial_dims, in_spatial_dims, window_size, strides, padding_left);
} else if (order == 2) {
launch_max_pooling_kernel<T, 2>(stream, output, indices, input, channels,
launch_max_pooling_kernel<T, T_INDEX, 2>(stream, output, indices, input, channels,
out_spatial_dims, in_spatial_dims, window_size, strides, padding_left);
} else if (order == 1) {
launch_max_pooling_kernel<T, 1>(stream, output, indices, input, channels,
launch_max_pooling_kernel<T, T_INDEX, 1>(stream, output, indices, input, channels,
out_spatial_dims, in_spatial_dims, window_size, strides, padding_left);
}
}
#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
template void max_pooling_with_indices(const Stream&,
TensorSpan<__half>, TensorSpan<__half>, TensorView<__half>,
TensorSpan<__half>, TensorSpan<int32_t>, TensorView<__half>,
const std::vector<std::size_t>&, const std::vector<std::size_t>&,
const std::vector<std::size_t>&);
template void max_pooling_with_indices(const Stream&,
TensorSpan<__half>, TensorSpan<int64_t>, TensorView<__half>,
const std::vector<std::size_t>&, const std::vector<std::size_t>&,
const std::vector<std::size_t>&);
#endif
template void max_pooling_with_indices(const Stream&,
TensorSpan<float>, TensorSpan<float>, TensorView<float>,
TensorSpan<float>, TensorSpan<int32_t>, TensorView<float>,
const std::vector<std::size_t>&, const std::vector<std::size_t>&,
const std::vector<std::size_t>&);
template <class T, std::size_t Order> static
template void max_pooling_with_indices(const Stream&,
TensorSpan<float>, TensorSpan<int64_t>, TensorView<float>,
const std::vector<std::size_t>&, const std::vector<std::size_t>&,
const std::vector<std::size_t>&);
template void max_pooling_with_indices(const Stream&,
TensorSpan<int32_t>, TensorSpan<int32_t>, TensorView<int32_t>,
const std::vector<std::size_t>&, const std::vector<std::size_t>&,
const std::vector<std::size_t>&);
template void max_pooling_with_indices(const Stream&,
TensorSpan<int32_t>, TensorSpan<int64_t>, TensorView<int32_t>,
const std::vector<std::size_t>&, const std::vector<std::size_t>&,
const std::vector<std::size_t>&);
template void max_pooling_with_indices(const Stream&,
TensorSpan<int64_t>, TensorSpan<int32_t>, TensorView<int64_t>,
const std::vector<std::size_t>&, const std::vector<std::size_t>&,
const std::vector<std::size_t>&);
template void max_pooling_with_indices(const Stream&,
TensorSpan<int64_t>, TensorSpan<int64_t>, TensorView<int64_t>,
const std::vector<std::size_t>&, const std::vector<std::size_t>&,
const std::vector<std::size_t>&);
template <class T, class T_INDEX, std::size_t Order> static
void launch_max_unpooling_kernel(
const Stream& stream,
Span<T> output, View<T> input, View<T> indices, std::size_t channels,
Span<T> output, View<T> input, View<T_INDEX> indices, std::size_t channels,
const std::vector<std::size_t>& out_spatial_dims, const std::vector<std::size_t>& in_spatial_dims,
const std::vector<std::size_t>& window_size,
const std::vector<std::size_t>& strides, const std::vector<std::size_t>& padding_left)
@ -271,16 +301,16 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
strides_k.assign(std::begin(strides), std::end(strides));
padding_left_k.assign(std::begin(padding_left), std::end(padding_left));
auto kernel = raw::max_unpooling<T, Order>;
auto kernel = raw::max_unpooling<T, T_INDEX, Order>;
auto policy = make_policy(kernel, input.size(), 0, stream);
launch_kernel(kernel, policy, output, input, indices, channels,
out_spatial_dims_k, in_spatial_dims_k, window_size_k, strides_k, padding_left_k);
}
template <class T>
template <class T, class T_INDEX>
void max_unpooling(
const Stream& stream,
TensorSpan<T> output, TensorView<T> input, TensorView<T> indices,
TensorSpan<T> output, TensorView<T> input, TensorView<T_INDEX> indices,
const std::vector<std::size_t>& window_size, const std::vector<std::size_t>& strides,
const std::vector<std::size_t>& padding_left)
{
@ -305,23 +335,53 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
CV_Assert(2 <= order && order <= 3);
std::size_t channels = input.get_axis_size(1);
if (order == 3) {
launch_max_unpooling_kernel<T, 3>(stream, output, input, indices, channels,
launch_max_unpooling_kernel<T, T_INDEX, 3>(stream, output, input, indices, channels,
out_spatial_dims, in_spatial_dims, window_size, strides, padding_left);
} else if (order == 2) {
launch_max_unpooling_kernel<T, 2>(stream, output, input, indices, channels,
launch_max_unpooling_kernel<T, T_INDEX, 2>(stream, output, input, indices, channels,
out_spatial_dims, in_spatial_dims, window_size, strides, padding_left);
}
}
#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
template void max_unpooling(const Stream&,
TensorSpan<__half>, TensorView<__half>, TensorView<__half>,
TensorSpan<__half>, TensorView<__half>, TensorView<int32_t>,
const std::vector<std::size_t>&, const std::vector<std::size_t>&,
const std::vector<std::size_t>&);
template void max_unpooling(const Stream&,
TensorSpan<__half>, TensorView<__half>, TensorView<int64_t>,
const std::vector<std::size_t>&, const std::vector<std::size_t>&,
const std::vector<std::size_t>&);
#endif
template void max_unpooling(const Stream&,
TensorSpan<float>, TensorView<float>, TensorView<float>,
TensorSpan<float>, TensorView<float>, TensorView<int32_t>,
const std::vector<std::size_t>&, const std::vector<std::size_t>&,
const std::vector<std::size_t>&);
template void max_unpooling(const Stream&,
TensorSpan<float>, TensorView<float>, TensorView<int64_t>,
const std::vector<std::size_t>&, const std::vector<std::size_t>&,
const std::vector<std::size_t>&);
template void max_unpooling(const Stream&,
TensorSpan<int32_t>, TensorView<int32_t>, TensorView<int32_t>,
const std::vector<std::size_t>&, const std::vector<std::size_t>&,
const std::vector<std::size_t>&);
template void max_unpooling(const Stream&,
TensorSpan<int32_t>, TensorView<int32_t>, TensorView<int64_t>,
const std::vector<std::size_t>&, const std::vector<std::size_t>&,
const std::vector<std::size_t>&);
template void max_unpooling(const Stream&,
TensorSpan<int64_t>, TensorView<int64_t>, TensorView<int32_t>,
const std::vector<std::size_t>&, const std::vector<std::size_t>&,
const std::vector<std::size_t>&);
template void max_unpooling(const Stream&,
TensorSpan<int64_t>, TensorView<int64_t>, TensorView<int64_t>,
const std::vector<std::size_t>&, const std::vector<std::size_t>&,
const std::vector<std::size_t>&);

View File

@ -13,17 +13,17 @@
namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
template <class T>
template <class T, class T_INDEX>
void max_pooling_with_indices(
const csl::Stream& stream,
csl::TensorSpan<T> output, csl::TensorSpan<T> indices, csl::TensorView<T> input,
csl::TensorSpan<T> output, csl::TensorSpan<T_INDEX> indices, csl::TensorView<T> input,
const std::vector<std::size_t>& kernel_size, const std::vector<std::size_t>& strides,
const std::vector<std::size_t>& padding_left);
template <class T>
template <class T, class T_INDEX>
void max_unpooling(
const csl::Stream& stream,
csl::TensorSpan<T> output, csl::TensorView<T> input, csl::TensorView<T> indices,
csl::TensorSpan<T> output, csl::TensorView<T> input, csl::TensorView<T_INDEX> indices,
const std::vector<std::size_t>& window_size, const std::vector<std::size_t>& strides,
const std::vector<std::size_t>& padding_left);

View File

@ -39,7 +39,7 @@ namespace cv { namespace dnn { namespace cuda4dnn {
std::vector<std::size_t> input_shape;
};
template <class T>
template <class T, class T_INDEX>
class MaxPoolingOp final : public CUDABackendNode {
public:
using wrapper_type = GetCUDABackendWrapperType<T>;
@ -103,10 +103,10 @@ namespace cv { namespace dnn { namespace cuda4dnn {
auto output_wrapper = outputs[0].dynamicCast<wrapper_type>();
auto output_data = output_wrapper->getSpan();
auto indices_wrapper = outputs[1].dynamicCast<wrapper_type>();
auto indices_wrapper = outputs[1].dynamicCast<GetCUDABackendWrapperType<T_INDEX>>();
auto output_indices = indices_wrapper->getSpan();
kernels::max_pooling_with_indices<T>(
kernels::max_pooling_with_indices<T, T_INDEX>(
stream, output_data, output_indices, input_data, window_size, strides, padding_left
);
}
@ -124,7 +124,7 @@ namespace cv { namespace dnn { namespace cuda4dnn {
std::vector<std::size_t> pads_begin;
};
template <class T>
template <class T, class T_INDEX>
class MaxUnpoolingOp final : public CUDABackendNode {
public:
using wrapper_type = GetCUDABackendWrapperType<T>;
@ -160,13 +160,13 @@ namespace cv { namespace dnn { namespace cuda4dnn {
auto input_wrapper = inputs[0].dynamicCast<wrapper_type>();
auto input_data = input_wrapper->getView();
auto indices_wrapper = inputs[1].dynamicCast<wrapper_type>();
auto indices_wrapper = inputs[1].dynamicCast<GetCUDABackendWrapperType<T_INDEX>>();
auto input_indices = indices_wrapper->getView();
auto output_wrapper = outputs[i].dynamicCast<wrapper_type>();
auto output_data = output_wrapper->getSpan();
kernels::max_unpooling<T>(stream, output_data, input_data, input_indices, window_size, strides, padding_left);
kernels::max_unpooling<T, T_INDEX>(stream, output_data, input_data, input_indices, window_size, strides, padding_left);
}
}

View File

@ -46,10 +46,12 @@ bool getParam_DNN_CHECK_NAN_INF_RAISE_ERROR();
inline namespace detail {
typedef std::vector<MatShape> ShapesVec;
typedef std::vector<MatType> TypesVec;
struct LayerShapes
{
ShapesVec in, out, internal;
TypesVec inTypes, outTypes, internalTypes;
// No guarantees that layer which support in-place computations
// will be computed in-place (input.data_ptr == output.data_ptr).
// If layer said that it could work in-place and layers after it

View File

@ -113,6 +113,16 @@ public:
return false;
}
void getTypes(const std::vector<MatType>& inputs,
const int requiredOutputs,
const int requiredInternals,
std::vector<MatType>& outputs,
std::vector<MatType>& internals) const CV_OVERRIDE
{
outputs.assign(requiredOutputs, CV_8S);
}
virtual void finalize(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr) CV_OVERRIDE
{
std::vector<Mat> inputs, outputs;
@ -239,6 +249,19 @@ public:
return false;
}
void getTypes(const std::vector<MatType>& inputs,
const int requiredOutputs,
const int requiredInternals,
std::vector<MatType>& outputs,
std::vector<MatType>& internals) const CV_OVERRIDE
{
if (preferableTarget == DNN_TARGET_OPENCL_FP16)
outputs.assign(requiredOutputs, CV_16F);
else
outputs.assign(requiredOutputs, CV_32F);
}
virtual void finalize(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr) CV_OVERRIDE
{
std::vector<Mat> inputs, outputs;

View File

@ -181,20 +181,32 @@ void Layer::forward_fallback(InputArrayOfArrays inputs_arr, OutputArrayOfArrays
inputs.resize(orig_inputs.size());
for (size_t i = 0; i < orig_inputs.size(); i++)
orig_inputs[i].convertTo(inputs[i], CV_32F);
if (orig_inputs[i].depth() == CV_16F)
orig_inputs[i].convertTo(inputs[i], CV_32F);
else
inputs[i] = orig_inputs[i];
outputs.resize(orig_outputs.size());
for (size_t i = 0; i < orig_outputs.size(); i++)
outputs[i].create(shape(orig_outputs[i]), CV_32F);
if (orig_outputs[i].depth() == CV_16F)
outputs[i].create(shape(orig_outputs[i]), CV_32F);
else
outputs[i] = orig_outputs[i];
internals.resize(orig_internals.size());
for (size_t i = 0; i < orig_internals.size(); i++)
internals[i].create(shape(orig_internals[i]), CV_32F);
if (orig_internals[i].depth() == CV_16F)
internals[i].create(shape(orig_internals[i]), CV_32F);
else
internals[i] = orig_internals[i];
forward(inputs, outputs, internals);
for (size_t i = 0; i < outputs.size(); i++)
outputs[i].convertTo(orig_outputs[i], CV_16F);
if (orig_outputs[i].depth() == CV_16F)
outputs[i].convertTo(orig_outputs[i], CV_16F);
else
outputs[i] = orig_outputs[i];
// sync results back
outputs_arr.assign(orig_outputs);
@ -240,6 +252,25 @@ bool Layer::getMemoryShapes(const std::vector<MatShape>& inputs,
return false;
}
void Layer::getTypes(const std::vector<MatType>&inputs,
const int requiredOutputs,
const int requiredInternals,
std::vector<MatType>&outputs,
std::vector<MatType>&internals) const
{
CV_Assert(inputs.size());
for (auto input : inputs)
if (preferableTarget == DNN_TARGET_CUDA_FP16 || preferableTarget == DNN_TARGET_CUDA)
CV_CheckTypeEQ(input, CV_32F, "");
else if (preferableTarget == DNN_TARGET_OPENCL_FP16)
CV_CheckType(input, input == CV_16F || input == CV_8S, "");
else
CV_CheckType(input, input == CV_32F || input == CV_8S, "");
outputs.assign(requiredOutputs, inputs[0]);
internals.assign(requiredInternals, inputs[0]);
}
bool Layer::updateMemoryShapes(const std::vector<MatShape>& inputs)
{
return true;

View File

@ -146,14 +146,19 @@ struct DataLayer : public Layer
CV_OCL_RUN(IS_DNN_OPENCL_TARGET(preferableTarget),
forward_ocl(inputs_arr, outputs_arr, internals_arr))
bool isFP16 = outputs_arr.depth() == CV_16F;
std::vector<Mat> outputs, internals;
outputs_arr.getMatVector(outputs);
internals_arr.getMatVector(internals);
for (int i = 0; i < inputsData.size(); ++i)
{
bool isFP16 = outputs[i].depth() == CV_16F;
if (inputsData[i].type() == CV_32S || inputsData[i].type() == CV_64S) {
CV_CheckTypeEQ(outputs[i].type(), inputsData[i].type(), "");
CV_Assert(means[i] == Scalar() && scaleFactors[i] == 1.0);
inputsData[i].copyTo(outputs[i]);
continue;
}
double scale = scaleFactors[i];
Scalar& mean = means[i];
@ -209,13 +214,18 @@ struct DataLayer : public Layer
#ifdef HAVE_OPENCL
bool forward_ocl(InputArrayOfArrays, OutputArrayOfArrays outputs_, OutputArrayOfArrays internals_)
{
bool isFP16 = outputs_.depth() == CV_16F;
std::vector<UMat> outputs;
outputs_.getUMatVector(outputs);
for (int i = 0; i < inputsData.size(); ++i)
{
bool isFP16 = outputs[i].depth() == CV_16F;
if (inputsData[i].type() == CV_32S || inputsData[i].type() == CV_64S) {
CV_CheckTypeEQ(outputs[i].type(), inputsData[i].type(), "");
CV_Assert(means[i] == Scalar() && scaleFactors[i] == 1.0);
inputsData[i].copyTo(outputs[i]);
continue;
}
Mat inputData = inputsData[i];
double scale = scaleFactors[i];
@ -228,9 +238,12 @@ struct DataLayer : public Layer
CV_CheckTypeEQ(outputs[i].type(), CV_32FC1, "");
bool singleMean = true;
for (int j = 1; j < std::min(4, inputData.size[1]) && singleMean; ++j)
if (mean != Scalar())
{
singleMean = mean[j] == mean[j - 1];
for (int j = 1; j < std::min(4, inputData.size[1]) && singleMean; ++j)
{
singleMean = mean[j] == mean[j - 1];
}
}
if (singleMean)
@ -311,6 +324,16 @@ struct DataLayer : public Layer
return false;
}
void getTypes(const std::vector<MatType>& inputs,
const int requiredOutputs,
const int requiredInternals,
std::vector<MatType>& outputs,
std::vector<MatType>& internals) const CV_OVERRIDE
{
CV_Assert(inputs.size());
outputs = inputs;
}
virtual void finalize(InputArrayOfArrays, OutputArrayOfArrays outputs_arr) CV_OVERRIDE
{
std::vector<Mat> outputs;

View File

@ -72,6 +72,15 @@ public:
return false;
}
virtual void getTypes(const std::vector<MatType>& inputs,
const int requiredOutputs,
const int requiredInternals,
std::vector<MatType>& outputs,
std::vector<MatType>& internals) const CV_OVERRIDE
{
outputs.assign(1, CV_64S);
}
void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE
{
CV_TRACE_FUNCTION();
@ -98,7 +107,7 @@ public:
}
output = output.reshape(1, outShape);
output.convertTo(outputs[0], CV_32FC1);
output.convertTo(outputs[0], CV_64SC1);
}
private:

View File

@ -82,6 +82,17 @@ public:
return true;
}
void getTypes(const std::vector<MatType>& inputs,
const int requiredOutputs,
const int requiredInternals,
std::vector<MatType>& outputs,
std::vector<MatType>& internals) const CV_OVERRIDE
{
CV_Assert(inputs.size());
outputs = inputs;
}
#ifdef HAVE_OPENCL
bool forward_ocl(InputArrayOfArrays inputs_, OutputArrayOfArrays outputs_, OutputArrayOfArrays internals_)
{
@ -165,7 +176,7 @@ public:
) override
{
auto context = reinterpret_cast<csl::CSLContext*>(context_);
return make_cuda_node<cuda4dnn::ReshapeOp>(preferableTarget, std::move(context->stream));
return make_cuda_node_with_type<cuda4dnn::ReshapeOp>(preferableTarget, inputs[0]->getHostMatDepth(), std::move(context->stream));
}
#endif
};

View File

@ -115,6 +115,19 @@ public:
return false;
}
virtual void getTypes(const std::vector<MatType>& inputs,
const int requiredOutputs,
const int requiredInternals,
std::vector<MatType>& outputs,
std::vector<MatType>& internals) const CV_OVERRIDE
{
CV_Assert(inputs.size());
for (int i = 1; i < inputs.size(); i++)
CV_CheckTypeEQ(inputs[i], inputs[0], "All input types should be equal");
outputs.assign(1, inputs[0]);
}
virtual bool supportBackend(int backendId) CV_OVERRIDE
{
#ifdef HAVE_TIMVX
@ -273,7 +286,7 @@ public:
CV_TRACE_ARG_VALUE(name, "name", name.c_str());
CV_OCL_RUN(IS_DNN_OPENCL_TARGET(preferableTarget) &&
inputs_arr.depth() != CV_8S,
(inputs_arr.depth() == CV_32F || inputs_arr.depth() == CV_16F),
forward_ocl(inputs_arr, outputs_arr, internals_arr))
std::vector<Mat> inputs, outputs;
@ -286,7 +299,7 @@ public:
if (padding)
outMat.setTo(paddingValue);
if( cAxis == 1 && outMat.dims == 4 && !padding)
if(cAxis == 1 && outMat.dims == 4 && !padding && (inputs[0].depth() == CV_32F || inputs[0].depth() == CV_8S))
{
int nstripes = getNumThreads();
if (outMat.type() == CV_8S)
@ -325,7 +338,7 @@ public:
auto input_wrapper = inputs[0].dynamicCast<CUDABackendWrapper>();
auto concat_axis = normalize_axis(axis, input_wrapper->getRank());
return make_cuda_node<cuda4dnn::ConcatOp>(preferableTarget, std::move(context->stream), concat_axis, padding);
return make_cuda_node_with_type<cuda4dnn::ConcatOp>(preferableTarget, inputs[0]->getHostMatDepth(), std::move(context->stream), concat_axis, padding);
}
#endif

View File

@ -57,6 +57,20 @@ public:
return false;
}
void getTypes(const std::vector<MatType>& inputs,
const int requiredOutputs,
const int requiredInternals,
std::vector<MatType>& outputs,
std::vector<MatType>& internals) const CV_OVERRIDE
{
if (preferableTarget == DNN_TARGET_OPENCL_FP16
&& blobs[0].type() == CV_32F)
outputs.assign(1, CV_16F);
else
outputs.assign(1, blobs[0].depth());
}
#ifdef HAVE_OPENCL
bool forward_ocl(InputArrayOfArrays inps, OutputArrayOfArrays outs, OutputArrayOfArrays internals)
{
@ -171,10 +185,7 @@ public:
CV_Assert(blobs.size() == 1);
Mat blob = blobs[0];
if (blob.type() != CV_32F) {
blob.convertTo(blob, CV_32F);
}
return make_cuda_node<cuda4dnn::ConstOp>(preferableTarget, std::move(context->stream), blob);
return make_cuda_node_with_type<cuda4dnn::ConstOp>(preferableTarget, blob.type(), std::move(context->stream), blob);
}
#endif
};

View File

@ -57,6 +57,18 @@ public:
return false;
}
virtual void getTypes(const std::vector<MatType>& inputs,
const int requiredOutputs,
const int requiredInternals,
std::vector<MatType>& outputs,
std::vector<MatType>& internals) const CV_OVERRIDE
{
CV_CheckEQ(inputs.size(), (size_t)2, "");
CV_CheckType(inputs[0], inputs[0] == CV_32F || inputs[0] == CV_32S || inputs[0] == CV_16F || inputs[0] == CV_8U, "");
CV_CheckType(inputs[1], inputs[1] == CV_64S || inputs[1] == CV_32S, "");
outputs.assign(1, inputs[0]);
}
virtual void finalize(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr) CV_OVERRIDE {
std::vector<Mat> inputs;
inputs_arr.getMatVector(inputs);
@ -70,12 +82,6 @@ public:
CV_TRACE_FUNCTION();
CV_TRACE_ARG_VALUE(name, "name", name.c_str());
if (inputs_arr.depth() == CV_16F)
{
forward_fallback(inputs_arr, outputs_arr, internals_arr);
return;
}
std::vector<Mat> inputs, outputs;
inputs_arr.getMatVector(inputs);
outputs_arr.getMatVector(outputs);
@ -84,14 +90,14 @@ public:
const Mat& indices = inputs[1];
Mat& out = outputs[0];
typeDispatch(outputs[0].type(), data, indices, out);
indexTypeDispatch(out.type(), indices.type(), data, indices, out);
}
template <typename T>
template <typename T, typename T_INDEX>
void forward_impl(const Mat& data_, const Mat& indices_, Mat& out_)
{
const auto *ptr_data = data_.ptr<const T>();
const auto *ptr_indices = indices_.ptr<const T>();
const auto *ptr_indices = indices_.ptr<const T_INDEX>();
auto *ptr_out = out_.ptr<T>();
const auto shape_data = shape(data_);
@ -112,12 +118,12 @@ public:
if (innermost_axis) {
for (int j = 0; j < inner_most_dim; j++) {
int index = static_cast<int>((indices[j] + axis_dim)) % axis_dim; // TODO: Check out-of-range index
int index = (indices[j] + axis_dim) % axis_dim; // TODO: Check out-of-range index
out[j] = data[index];
}
} else {
for (int j = 0; j < inner_most_dim; j++) {
int index = static_cast<int>(indices[j] + axis_dim) % axis_dim; // TODO: Check out-of-range index
int index = (indices[j] + axis_dim) % axis_dim; // TODO: Check out-of-range index
out[j] = data[index * axis_step + j];
}
}
@ -130,18 +136,37 @@ public:
}
template<typename... Args>
inline void indexTypeDispatch(const int type, const int index_type, Args&&... args)
{
switch (index_type)
{
case CV_32S:
typeDispatch<int32_t>(type, std::forward<Args>(args)...);
break;
case CV_64S:
typeDispatch<int64_t>(type, std::forward<Args>(args)...);
break;
default:
CV_Error(cv::Error::BadDepth, "Unsupported type.");
};
}
template<typename T_INDEX, typename... Args>
inline void typeDispatch(const int type, Args&&... args)
{
switch (type)
{
case CV_8U:
forward_impl<uint8_t>(std::forward<Args>(args)...);
forward_impl<uint8_t, T_INDEX>(std::forward<Args>(args)...);
break;
case CV_16F:
forward_impl<int16_t, T_INDEX>(std::forward<Args>(args)...);
break;
case CV_32S:
forward_impl<int32_t>(std::forward<Args>(args)...);
forward_impl<int32_t, T_INDEX>(std::forward<Args>(args)...);
break;
case CV_32F:
forward_impl<float>(std::forward<Args>(args)...);
forward_impl<float, T_INDEX>(std::forward<Args>(args)...);
break;
default:
CV_Error(cv::Error::BadDepth, "DNN/GatherElements: Unsupported type.");

View File

@ -40,6 +40,19 @@ public:
return false;
}
virtual void getTypes(const std::vector<MatType>& inputs,
const int requiredOutputs,
const int requiredInternals,
std::vector<MatType>& outputs,
std::vector<MatType>& internals) const CV_OVERRIDE
{
CV_CheckEQ(inputs.size(), (size_t)2, "");
CV_CheckType(inputs[0], inputs[0] == CV_32F || inputs[0] == CV_32S || inputs[0] == CV_16F || inputs[0] == CV_8U, "");
CV_CheckType(inputs[1], inputs[1] == CV_64S || inputs[1] == CV_32S, "");
outputs.assign(1, inputs[0]);
}
void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE
{
CV_TRACE_FUNCTION();
@ -57,17 +70,15 @@ public:
const Mat& inp = inputs[0];
int indicesType = inputs[1].type();
CV_CheckType(indicesType, indicesType == CV_32FC1 || indicesType == CV_16FC1, "");
CV_CheckType(indicesType, indicesType == CV_32SC1 || indicesType == CV_64SC1, "");
Mat indices32S;
if (indicesType == CV_16F/*FP16*/)
if (indicesType == CV_64SC1)
{
Mat indicesF32;
inputs[1].convertTo(indicesF32, CV_32F);
indicesF32.convertTo(indices32S, CV_32S);
inputs[1].convertTo(indices32S, CV_32S);
}
else
{
inputs[1].convertTo(indices32S, CV_32S);
indices32S = inputs[1];
}
const size_t indices_total = indices32S.total();
indices32S = indices32S.reshape(1, indices_total);

View File

@ -68,17 +68,24 @@ public:
return false;
}
virtual void getTypes(const std::vector<MatType>& inputs,
const int requiredOutputs,
const int requiredInternals,
std::vector<MatType>& outputs,
std::vector<MatType>& internals) const CV_OVERRIDE
{
CV_CheckGE(inputs.size(), (size_t)2, "");
CV_CheckType(inputs[0], inputs[0] == CV_32F || inputs[0] == CV_16F, "");
CV_CheckType(inputs[1], inputs[1] == CV_64S || inputs[1] == CV_32S, "");
outputs.assign(1, inputs[0]);
}
void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE
{
CV_TRACE_FUNCTION();
CV_TRACE_ARG_VALUE(name, "name", name.c_str());
if (inputs_arr.depth() == CV_16F)
{
forward_fallback(inputs_arr, outputs_arr, internals_arr);
return;
}
std::vector<Mat> inputs, outputs;
inputs_arr.getMatVector(inputs);
outputs_arr.getMatVector(outputs);
@ -87,6 +94,19 @@ public:
Mat& input = inputs[0];
Mat& indices = inputs[1];
if (input.type() == CV_32F && indices.type() == CV_32S)
run<float, int32_t>(input, indices, outputs);
else if (input.type() == CV_32F && indices.type() == CV_64S)
run<float, int64_t>(input, indices, outputs);
else if (input.type() == CV_16F && indices.type() == CV_32S)
run<int16_t, int32_t>(input, indices, outputs);
else if (input.type() == CV_16F && indices.type() == CV_64S)
run<int16_t, int64_t>(input, indices, outputs);
}
template<typename T, typename INDEX_TYPE>
void run(cv::Mat& input, cv::Mat& indices, std::vector<cv::Mat>& outputs)
{
CV_Assert(input.total() == indices.total());
CV_Assert(input.size[0] == 1);
CV_Assert(input.isContinuous());
@ -102,9 +122,9 @@ public:
{
Mat outPlane = getPlane(outBlob, 0, i_c);
int wh_area = input.size[2]*input.size[3];
const float* inptr = input.ptr<float>(0, i_c);
const float* idxptr = indices.ptr<float>(0, i_c);
float* outptr = outPlane.ptr<float>();
const T* inptr = input.ptr<T>(0, i_c);
const INDEX_TYPE* idxptr = indices.ptr<INDEX_TYPE>(0, i_c);
T* outptr = outPlane.ptr<T>();
for(int i_wh = 0; i_wh < wh_area; i_wh++)
{
@ -112,8 +132,8 @@ public:
if (!(0 <= index && index < outPlaneTotal))
{
CV_LOG_ERROR(NULL, cv::format(
"i_n=%d\ni_c=%d\ni_wh=%d\nindex=%d\nmaxval=%lf\noutPlaneTotal=%d\n",
i_n, i_c, i_wh, index, inptr[i_wh], outPlaneTotal));
"i_n=%d\ni_c=%d\ni_wh=%d\nindex=%d\noutPlaneTotal=%d\n",
i_n, i_c, i_wh, index, outPlaneTotal));
CV_LOG_ERROR(NULL, "input.size=" << input.size);
CV_LOG_ERROR(NULL, "indices.size=" << indices.size);
CV_LOG_ERROR(NULL, "outBlob=" << outBlob.size);
@ -125,6 +145,7 @@ public:
}
}
#ifdef HAVE_CUDA
Ptr<BackendNode> initCUDA(
void *context_,
@ -150,7 +171,16 @@ public:
pads_begin[0] = poolPad.height;
pads_begin[1] = poolPad.width;
return make_cuda_node<cuda4dnn::MaxUnpoolingOp>(preferableTarget, std::move(context->stream), config);
int indicesType = inputs[1]->getHostMatDepth();
CV_CheckType(indicesType, indicesType == CV_32S || indicesType == CV_64S, "Unsupported indices type");
if (indicesType == CV_32S)
return make_cuda_node_with_indices<cuda4dnn::MaxUnpoolingOp, int32_t>(preferableTarget, inputs[0]->getHostMatDepth(), std::move(context->stream), config);
else if (indicesType == CV_64S)
return make_cuda_node_with_indices<cuda4dnn::MaxUnpoolingOp, int64_t>(preferableTarget, inputs[0]->getHostMatDepth(), std::move(context->stream), config);
CV_Error(Error::BadDepth, "Unsupported indices type");
return Ptr<BackendNode>();
}
#endif

View File

@ -349,6 +349,28 @@ public:
return false;
}
void getTypes(const std::vector<MatType>& inputs,
const int requiredOutputs,
const int requiredInternals,
std::vector<MatType>& outputs,
std::vector<MatType>& internals) const CV_OVERRIDE
{
CV_Assert(inputs.size());
for (auto input : inputs)
{
CV_CheckTypeEQ(inputs[0], input, "All inputs should have equal types");
if (preferableTarget == DNN_TARGET_CUDA_FP16 || preferableTarget == DNN_TARGET_CUDA)
CV_CheckType(input, input == CV_32F || input == CV_32S || input == CV_64S, "Unsupported type");
else if (preferableTarget == DNN_TARGET_OPENCL_FP16)
CV_CheckType(input, input == CV_16F || input == CV_8S || input == CV_8U || input == CV_32S || input == CV_64S, "");
else
CV_CheckType(input, input == CV_32F || input == CV_8S || input == CV_8U || input == CV_32S || input == CV_64S, "");
}
outputs.assign(requiredOutputs, inputs[0]);
}
template <typename T, typename Functor>
void binary_forward_impl(
int ndims, const std::vector<int>& shape,
@ -773,11 +795,17 @@ public:
helper.reInit(sizeof(uint8_t));
opDispatch<uint8_t>(std::forward<Args>(args)...);
break;
case CV_8S:
opDispatch<int8_t>(std::forward<Args>(args)...);
break;
case CV_32S:
// TODO: integrate with type inference
helper.reInit(sizeof(int32_t));
opDispatch<int32_t>(std::forward<Args>(args)...);
break;
case CV_64S:
opDispatch<int64_t>(std::forward<Args>(args)...);
break;
case CV_32F:
CV_Assert(op != OPERATION::BITSHIFT && op != OPERATION::AND &&
op != OPERATION::OR && op != OPERATION::XOR);
@ -829,7 +857,7 @@ public:
default: return Ptr<BackendNode>(); // return empty cuda_node if the EltwiseOpType is unsupported type.
};
return make_cuda_node<cuda4dnn::EltwiseOp>(preferableTarget, std::move(context->stream), op_, std::vector<float>());
return make_cuda_node_with_type<cuda4dnn::EltwiseOp>(preferableTarget, inputs[0]->getHostMatDepth(), std::move(context->stream), op_, std::vector<float>());
}
#endif

View File

@ -178,6 +178,24 @@ public:
return false;
}
void getTypes(const std::vector<MatType>& inputs,
const int requiredOutputs,
const int requiredInternals,
std::vector<MatType>& outputs,
std::vector<MatType>& internals) const CV_OVERRIDE
{
CV_Assert(inputs.size());
for (auto input : inputs)
if (preferableTarget == DNN_TARGET_CUDA_FP16 || preferableTarget == DNN_TARGET_CUDA)
CV_CheckTypeEQ(input, CV_32F, "Unsupported type");
else if (preferableTarget == DNN_TARGET_OPENCL_FP16)
CV_CheckType(input, input == CV_16F || input == CV_8S || input == CV_32S, "");
else
CV_CheckType(input, input == CV_32F || input == CV_8S || input == CV_32S, "");
outputs.assign(requiredOutputs, inputs[0]);
}
void computeStrides(const MatShape &shapeBefore, const MatShape &shapeAfter)
{
_oldStride.resize(_numAxes);
@ -347,7 +365,7 @@ public:
CV_TRACE_ARG_VALUE(name, "name", name.c_str());
CV_OCL_RUN(IS_DNN_OPENCL_TARGET(preferableTarget) &&
inputs_arr.depth() != CV_8S,
inputs_arr.depth() != CV_8S && inputs_arr.depth() != CV_32S,
forward_ocl(inputs_arr, outputs_arr, internals_arr))
if (inputs_arr.depth() == CV_16F)

View File

@ -320,11 +320,25 @@ public:
CV_Assert_N(inputs.size() == 1, !outputs.empty(), !computeMaxIdx || outputs.size() == 2);
UMat& inpMat = inputs[0];
UMat& outMat = outputs[0];
UMat maskMat = computeMaxIdx ? outputs[1] : UMat();
UMat maskMat;
if (computeMaxIdx)
maskMat.create(shape(outputs[1]), use_half ? CV_16F : CV_32F);
CV_Assert(inpMat.offset == 0 && outMat.offset == 0);
return poolOp->Forward(inpMat, outMat, maskMat);
bool result = poolOp->Forward(inpMat, outMat, maskMat);
if (computeMaxIdx) {
if (use_half) {
UMat maskMat32F;
maskMat.convertTo(maskMat32F, CV_32F);
maskMat32F.convertTo(outputs[1], CV_64S);
}
else
maskMat.convertTo(outputs[1], CV_64S);
}
return result;
}
#endif
@ -353,8 +367,12 @@ public:
case MAX:
{
CV_Assert_N(inputs.size() == 1, !computeMaxIdx || outputs.size() == 2);
Mat mask = computeMaxIdx ? outputs[1] : Mat();
Mat mask;
if (computeMaxIdx)
mask.create(shape(outputs[1]), CV_32F);
maxPooling(inputs[0], outputs[0], mask);
if (computeMaxIdx)
mask.convertTo(outputs[1], CV_64S);
break;
}
case AVE: case SUM:
@ -413,7 +431,16 @@ public:
config.input_shape.assign(std::begin(input_shape), std::end(input_shape));
return make_cuda_node<cuda4dnn::MaxPoolingOp>(preferableTarget, std::move(context->stream), config);
int indicesType = outputs[1]->getHostMatDepth();
CV_CheckType(indicesType, indicesType == CV_32S || indicesType == CV_64S, "Unsupported indices type");
if (indicesType == CV_32S)
return make_cuda_node_with_indices<cuda4dnn::MaxPoolingOp, int32_t>(preferableTarget, inputs[0]->getHostMatDepth(), std::move(context->stream), config);
else if (indicesType == CV_64S)
return make_cuda_node_with_indices<cuda4dnn::MaxPoolingOp, int64_t>(preferableTarget, inputs[0]->getHostMatDepth(), std::move(context->stream), config);
CV_Error(Error::BadDepth, "Unsupported indices type");
return Ptr<BackendNode>();
}
if (input_shape.size() == 3)
@ -1251,6 +1278,26 @@ public:
return false;
}
virtual void getTypes(const std::vector<MatType>& inputs,
const int requiredOutputs,
const int requiredInternals,
std::vector<MatType>& outputs,
std::vector<MatType>& internals) const CV_OVERRIDE
{
CV_Assert(inputs.size());
if (preferableTarget == DNN_TARGET_CUDA_FP16 || preferableTarget == DNN_TARGET_CUDA)
CV_CheckTypeEQ(inputs[0], CV_32F, "Unsupported type");
else if (preferableTarget == DNN_TARGET_OPENCL_FP16)
CV_CheckType(inputs[0], inputs[0] == CV_16F || inputs[0] == CV_8S, "");
else
CV_CheckType(inputs[0], inputs[0] == CV_32F || inputs[0] == CV_8S, "");
outputs.push_back(inputs[0]);
if (type == MAX && requiredOutputs == 2) {
outputs.push_back(CV_64S);
}
}
bool updateMemoryShapes(const std::vector<MatShape> &inputs) CV_OVERRIDE
{
int dims = inputs[0].size();

View File

@ -101,6 +101,24 @@ public:
return false;
}
void getTypes(const std::vector<MatType>& inputs,
const int requiredOutputs,
const int requiredInternals,
std::vector<MatType>& outputs,
std::vector<MatType>& internals) const CV_OVERRIDE
{
CV_Assert(inputs.size());
for (auto input : inputs)
if (preferableTarget == DNN_TARGET_CUDA_FP16 || preferableTarget == DNN_TARGET_CUDA)
CV_CheckTypeEQ(input, CV_32F, "Unsupported type for CUDA");
else if (preferableTarget == DNN_TARGET_OPENCL_FP16)
CV_CheckType(input, input == CV_16F || input == CV_8S || input == CV_32S || input == CV_64S, "");
else
CV_CheckType(input, input == CV_32F || input == CV_8S || input == CV_32S || input == CV_64S, "");
outputs.assign(requiredOutputs, inputs[0]);
}
virtual void finalize(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr) CV_OVERRIDE
{
std::vector<Mat> inputs, outputs;
@ -181,7 +199,7 @@ public:
CV_TRACE_FUNCTION();
CV_TRACE_ARG_VALUE(name, "name", name.c_str());
CV_OCL_RUN(IS_DNN_OPENCL_TARGET(preferableTarget),
CV_OCL_RUN(IS_DNN_OPENCL_TARGET(preferableTarget) && inputs_arr.depth() != CV_32S && inputs_arr.depth() != CV_64S,
forward_ocl(inputs_arr, outputs_arr, internals_arr))
if (inputs_arr.depth() == CV_16F)

View File

@ -259,6 +259,25 @@ public:
return true;
}
void getTypes(const std::vector<MatType>& inputs,
const int requiredOutputs,
const int requiredInternals,
std::vector<MatType>& outputs,
std::vector<MatType>& internals) const CV_OVERRIDE
{
CV_Assert(inputs.size());
for (auto input : inputs)
if (preferableTarget == DNN_TARGET_CUDA_FP16 || preferableTarget == DNN_TARGET_CUDA)
CV_CheckTypeEQ(input, CV_32F, "Unsupported type");
else if (preferableTarget == DNN_TARGET_OPENCL_FP16)
CV_CheckType(input, input == CV_16F || input == CV_8S || input == CV_32S || input == CV_64S, "");
else
CV_CheckType(input, input == CV_32F || input == CV_8S || input == CV_32S || input == CV_64S, "");
outputs.assign(requiredOutputs, inputs[0]);
}
bool updateMemoryShapes(const std::vector<MatShape> &inputs) CV_OVERRIDE
{
if (hasDynamicShapes)
@ -312,7 +331,7 @@ public:
CV_TRACE_FUNCTION();
CV_TRACE_ARG_VALUE(name, "name", name.c_str());
CV_OCL_RUN(IS_DNN_OPENCL_TARGET(preferableTarget),
CV_OCL_RUN(IS_DNN_OPENCL_TARGET(preferableTarget) && inputs_arr.depth() != CV_32S && inputs_arr.depth() != CV_64S,
forward_ocl(inputs_arr, outputs_arr, internals_arr))
std::vector<Mat> inputs, outputs;

View File

@ -69,6 +69,19 @@ public:
return false;
}
virtual void getTypes(const std::vector<MatType>& inputs,
const int requiredOutputs,
const int requiredInternals,
std::vector<MatType>& outputs,
std::vector<MatType>& internals) const CV_OVERRIDE
{
CV_CheckEQ(inputs.size(), (size_t)3, "");
CV_CheckType(inputs[0], inputs[0] == CV_32F || inputs[0] == CV_32S || inputs[0] == CV_16F || inputs[0] == CV_8U, "");
CV_CheckType(inputs[1], inputs[1] == CV_64S || inputs[1] == CV_32S, "");
CV_CheckTypeEQ(inputs[2], inputs[0], "");
outputs.assign(1, inputs[0]);
}
void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE
{
CV_TRACE_FUNCTION();
@ -88,12 +101,12 @@ public:
const Mat& updates = inputs[2];
Mat& out = outputs[0];
typeDispatch(outputs[0].type(), data, indices, updates, out);
indexTypeDispatch(outputs[0].type(), indices.type(), data, indices, updates, out);
}
// NOTE: This impl does not check whether indices have duplicate entries.
// The last duplicate entry will overwrite the previous.
template<typename T, typename Functor>
template<typename T, typename T_INDEX, typename Functor>
void forward_impl(const Functor &reduce_operation, const Mat &input_mat, const Mat &indices_mat, const Mat &updates_mat, Mat& output_mat) {
input_mat.copyTo(output_mat);
@ -120,14 +133,14 @@ public:
indices_offset = r.start * indices_last_dim,
updates_offset = r.start * updates_size;
for (int i = r.start; i < r.end; i++) {
const T* indices = indices_mat.ptr<const T>();
const T_INDEX* indices = indices_mat.ptr<const T_INDEX>();
const T* updates = updates_mat.ptr<const T>();
T* output = output_mat.ptr<T>();
input_offset = 0;
indices += indices_offset;
for (int j = 0; j < indices_last_dim; j++) {
int index = static_cast<int>(*(indices + j));
int index = *(indices + j);
index = (index + input_mat_shape[j]) % input_mat_shape[j];
CV_Assert(index < input_mat_shape[j] && index >= 0);
input_offset += index * input_mat_step[j];
@ -150,25 +163,42 @@ public:
}
template<typename... Args>
inline void indexTypeDispatch(const int type, const int index_type, Args&&... args)
{
switch (index_type)
{
case CV_32S:
typeDispatch<int32_t>(type, std::forward<Args>(args)...);
break;
case CV_64S:
typeDispatch<int64_t>(type, std::forward<Args>(args)...);
break;
default:
CV_Error(cv::Error::BadDepth, "Unsupported type.");
};
}
template<typename T_INDEX, typename... Args>
inline void typeDispatch(const int type, Args&&... args)
{
switch (type)
{
case CV_8U:
reductionDispatch<uint8_t>(std::forward<Args>(args)...);
reductionDispatch<uint8_t, T_INDEX>(std::forward<Args>(args)...);
break;
case CV_32S:
reductionDispatch<int32_t>(std::forward<Args>(args)...);
reductionDispatch<int32_t, T_INDEX>(std::forward<Args>(args)...);
break;
case CV_32F:
reductionDispatch<float>(std::forward<Args>(args)...);
reductionDispatch<float, T_INDEX>(std::forward<Args>(args)...);
break;
default:
CV_Error(cv::Error::BadDepth, "Unsupported type.");
};
}
template<typename T, typename... Args>
template<typename T, typename T_INDEX, typename... Args>
inline void reductionDispatch(Args&&... args)
{
switch (reduction)
@ -176,31 +206,31 @@ public:
case REDUCTION::NONE:
{
auto rd = [](const T& a, const T& b) { return b; }; // a from input data, b from updates
forward_impl<T>(rd, std::forward<Args>(args)...);
forward_impl<T, T_INDEX>(rd, std::forward<Args>(args)...);
break;
}
case REDUCTION::ADD:
{
auto rd = [](const T& a, const T& b) { return a + b; };
forward_impl<T>(rd, std::forward<Args>(args)...);
forward_impl<T, T_INDEX>(rd, std::forward<Args>(args)...);
break;
}
case REDUCTION::MUL:
{
auto rd = [](const T& a, const T& b) { return a * b; };
forward_impl<T>(rd, std::forward<Args>(args)...);
forward_impl<T, T_INDEX>(rd, std::forward<Args>(args)...);
break;
}
case REDUCTION::MAX:
{
auto rd = [](const T& a, const T& b) { return std::max(a, b); };
forward_impl<T>(rd, std::forward<Args>(args)...);
forward_impl<T, T_INDEX>(rd, std::forward<Args>(args)...);
break;
}
case REDUCTION::MIN:
{
auto rd = [](const T& a, const T& b) { return std::min(a, b); };
forward_impl<T>(rd, std::forward<Args>(args)...);
forward_impl<T, T_INDEX>(rd, std::forward<Args>(args)...);
break;
}
default:

View File

@ -63,6 +63,19 @@ public:
return false;
}
virtual void getTypes(const std::vector<MatType>& inputs,
const int requiredOutputs,
const int requiredInternals,
std::vector<MatType>& outputs,
std::vector<MatType>& internals) const CV_OVERRIDE
{
CV_CheckEQ(inputs.size(), (size_t)3, "");
CV_CheckType(inputs[0], inputs[0] == CV_32F || inputs[0] == CV_32S || inputs[0] == CV_16F || inputs[0] == CV_8U, "");
CV_CheckType(inputs[1], inputs[1] == CV_64S || inputs[1] == CV_32S, "");
CV_CheckTypeEQ(inputs[2], inputs[0], "");
outputs.assign(1, inputs[0]);
}
void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE
{
CV_TRACE_FUNCTION();
@ -82,10 +95,10 @@ public:
const Mat& updates = inputs[2];
Mat& out = outputs[0];
typeDispatch(outputs[0].type(), data, indices, updates, out);
indexTypeDispatch(outputs[0].type(), indices.type(), data, indices, updates, out);
}
template<typename T, typename Functor>
template<typename T, typename T_INDEX, typename Functor>
void forward_impl(const Functor &reduce_operation, const Mat &input_mat, const Mat &indices_mat, const Mat &updates_mat, Mat &output_mat) {
input_mat.copyTo(output_mat);
@ -99,7 +112,7 @@ public:
for (int i = 0; i < ndims; i++) {
input_mat_step[i] = static_cast<size_t>(input_mat.step.p[i] / sizeof(T));
indices_mat_step[i] = static_cast<size_t>(indices_mat.step.p[i] / sizeof(T));
indices_mat_step[i] = static_cast<size_t>(indices_mat.step.p[i] / sizeof(T_INDEX));
}
auto fn = [&](const Range &r) {
@ -108,7 +121,7 @@ public:
int indices_index, index;
size_t axis_offset, tmp_index, j_index;
for (int i = r.start; i < r.end; i++) {
const T* indices = indices_mat.ptr<const T>();
const T_INDEX* indices = indices_mat.ptr<const T_INDEX>();
const T* updates = updates_mat.ptr<const T>();
T* output = output_mat.ptr<T>();
@ -128,7 +141,7 @@ public:
}
// get index and overwrite current indices
index = static_cast<int>(*(indices + indices_offset));
index = *(indices + indices_offset);
index = (index + input_mat_shape[axis]) % input_mat_shape[axis];
CV_Assert(index < input_mat_shape[axis] && index >= 0);
input_offset = input_offset - axis_offset + index * input_mat_step[axis];
@ -145,25 +158,42 @@ public:
}
template<typename... Args>
inline void indexTypeDispatch(const int type, const int index_type, Args&&... args)
{
switch (index_type)
{
case CV_32S:
typeDispatch<int32_t>(type, std::forward<Args>(args)...);
break;
case CV_64S:
typeDispatch<int64_t>(type, std::forward<Args>(args)...);
break;
default:
CV_Error(cv::Error::BadDepth, "Unsupported type.");
};
}
template<typename T_INDEX, typename... Args>
inline void typeDispatch(const int type, Args&&... args)
{
switch (type)
{
case CV_8U:
reductionDispatch<uint8_t>(std::forward<Args>(args)...);
reductionDispatch<uint8_t, T_INDEX>(std::forward<Args>(args)...);
break;
case CV_32S:
reductionDispatch<int32_t>(std::forward<Args>(args)...);
reductionDispatch<int32_t, T_INDEX>(std::forward<Args>(args)...);
break;
case CV_32F:
reductionDispatch<float>(std::forward<Args>(args)...);
reductionDispatch<float, T_INDEX>(std::forward<Args>(args)...);
break;
default:
CV_Error(cv::Error::BadDepth, "Unsupported type.");
};
}
template<typename T, typename... Args>
template<typename T, typename T_INDEX, typename... Args>
inline void reductionDispatch(Args&&... args)
{
switch (reduction)
@ -171,31 +201,31 @@ public:
case REDUCTION::NONE:
{
auto rd = [](const T& a, const T& b) { return b; }; // a from input data, b from updates
forward_impl<T>(rd, std::forward<Args>(args)...);
forward_impl<T, T_INDEX>(rd, std::forward<Args>(args)...);
break;
}
case REDUCTION::ADD:
{
auto rd = [](const T& a, const T& b) { return a + b; };
forward_impl<T>(rd, std::forward<Args>(args)...);
forward_impl<T, T_INDEX>(rd, std::forward<Args>(args)...);
break;
}
case REDUCTION::MUL:
{
auto rd = [](const T& a, const T& b) { return a * b; };
forward_impl<T>(rd, std::forward<Args>(args)...);
forward_impl<T, T_INDEX>(rd, std::forward<Args>(args)...);
break;
}
case REDUCTION::MAX:
{
auto rd = [](const T& a, const T& b) { return std::max(a, b); };
forward_impl<T>(rd, std::forward<Args>(args)...);
forward_impl<T, T_INDEX>(rd, std::forward<Args>(args)...);
break;
}
case REDUCTION::MIN:
{
auto rd = [](const T& a, const T& b) { return std::min(a, b); };
forward_impl<T>(rd, std::forward<Args>(args)...);
forward_impl<T, T_INDEX>(rd, std::forward<Args>(args)...);
break;
}
default:

View File

@ -278,6 +278,25 @@ public:
return false;
}
void getTypes(const std::vector<MatType>& inputs,
const int requiredOutputs,
const int requiredInternals,
std::vector<MatType>& outputs,
std::vector<MatType>& internals) const CV_OVERRIDE
{
CV_CheckEQ(inputs.size(), (size_t)1, "");
for (auto input : inputs)
if (preferableTarget == DNN_TARGET_CUDA_FP16 || preferableTarget == DNN_TARGET_CUDA)
CV_CheckEQ(input, CV_32F, "Unsupported type");
else if (preferableTarget == DNN_TARGET_OPENCL_FP16)
CV_CheckType(input, input == CV_16F || input == CV_8S || input == CV_32S || input == CV_64S, "");
else
CV_CheckType(input, input == CV_32F || input == CV_8S || input == CV_32S || input == CV_64S, "");
outputs.assign(requiredOutputs, inputs[0]);
}
bool updateMemoryShapes(const std::vector<MatShape> &inputs) CV_OVERRIDE
{
shapesInitialized = true;
@ -596,13 +615,14 @@ public:
CV_TRACE_FUNCTION();
CV_TRACE_ARG_VALUE(name, "name", name.c_str());
CV_OCL_RUN(IS_DNN_OPENCL_TARGET(preferableTarget),
forward_ocl(inputs_arr, outputs_arr, internals_arr))
std::vector<Mat> inputs, outputs;
inputs_arr.getMatVector(inputs);
outputs_arr.getMatVector(outputs);
CV_OCL_RUN((IS_DNN_OPENCL_TARGET(preferableTarget) &&
(outputs[0].type() != CV_32S && outputs[0].type() != CV_64S)),
forward_ocl(inputs_arr, outputs_arr, internals_arr))
const Mat& inpMat = inputs[0];
CV_Assert(outputs.size() == finalSliceRanges.size());
@ -621,7 +641,11 @@ public:
{
std::vector<int> inpIdx(dimsNum, 0);
std::vector<int> outIdx(dimsNum, 0);
if (inpMat.type() == CV_16F)
if (inpMat.type() == CV_32S)
getSliceRecursive<int32_t>(inpMat, inpIdx, finalSliceRanges[i], sliceSteps[i], 0, dimsNum, outputs[i], outIdx);
else if (inpMat.type() == CV_64S)
getSliceRecursive<int64_t>(inpMat, inpIdx, finalSliceRanges[i], sliceSteps[i], 0, dimsNum, outputs[i], outIdx);
else if (inpMat.type() == CV_16F)
getSliceRecursive<int16_t>(inpMat, inpIdx, finalSliceRanges[i], sliceSteps[i], 0, dimsNum, outputs[i], outIdx);
else if (inpMat.type() == CV_8S)
getSliceRecursive<int8_t>(inpMat, inpIdx, finalSliceRanges[i], sliceSteps[i], 0, dimsNum, outputs[i], outIdx);
@ -876,6 +900,25 @@ public:
return false;
}
void getTypes(const std::vector<MatType>& inputs,
const int requiredOutputs,
const int requiredInternals,
std::vector<MatType>& outputs,
std::vector<MatType>& internals) const CV_OVERRIDE
{
CV_CheckEQ(inputs.size(), (size_t)2, "");
for (auto input : inputs)
if (preferableTarget == DNN_TARGET_CUDA_FP16 || preferableTarget == DNN_TARGET_CUDA)
CV_CheckTypeEQ(input, CV_32F, "Unsupported type");
else if (preferableTarget == DNN_TARGET_OPENCL_FP16)
CV_CheckType(input, input == CV_16F || input == CV_8S || input == CV_32S || input == CV_64S, "");
else
CV_CheckType(input, input == CV_32F || input == CV_8S || input == CV_32S || input == CV_64S, "");
outputs.assign(requiredOutputs, inputs[0]);
}
void finalize(InputArrayOfArrays inputs_arr, OutputArrayOfArrays) CV_OVERRIDE
{
std::vector<Mat> inputs;

View File

@ -90,14 +90,21 @@ Ptr<BackendWrapper> wrapMat(int backendId, int targetId, cv::Mat& m)
CV_Assert(haveCUDA());
#ifdef HAVE_CUDA
switch (targetId)
CV_CheckType(m.depth(), m.depth() == CV_32F || m.depth() == CV_32S || m.depth() == CV_64S, "Unsupported type for CUDA");
CV_Assert(IS_DNN_CUDA_TARGET(targetId));
switch (m.depth())
{
case DNN_TARGET_CUDA:
return CUDABackendWrapperFP32::create(m);
case DNN_TARGET_CUDA_FP16:
return CUDABackendWrapperFP16::create(m);
case CV_32F:
if (targetId == DNN_TARGET_CUDA_FP16)
return CUDABackendWrapperFP16::create(m);
else
return CUDABackendWrapperFP32::create(m);
case CV_32S:
return CUDABackendWrapperINT32::create(m);
case CV_64S:
return CUDABackendWrapperINT64::create(m);
default:
CV_Assert(IS_DNN_CUDA_TARGET(targetId));
CV_Error(Error::BadDepth, "Unsupported mat type for CUDA");
}
#endif
}

View File

@ -237,6 +237,10 @@ public:
const ShapesVec &outShapes = layerShapes.out,
internalShapes = layerShapes.internal;
const TypesVec &outTypes = layerShapes.outTypes,
&internalTypes = layerShapes.internalTypes;
CV_CheckEQ(outShapes.size(), outTypes.size(), "Numbers shapes and types shoud be equal");
CV_CheckEQ(internalShapes.size(), internalTypes.size(), "Numbers shapes and types shoud be equal");
outputBlobs.resize(std::max((size_t)1, outShapes.size())); // layer produce at least one output blob
internalBlobs.resize(internalShapes.size());
@ -257,7 +261,9 @@ public:
}
ShapesVec shapes(outShapes);
TypesVec types(outTypes);
shapes.insert(shapes.end(), internalShapes.begin(), internalShapes.end());
types.insert(types.end(), internalTypes.begin(), internalTypes.end());
std::vector<Mat*> blobs;
for (int i = 0; i < outputBlobs.size(); i++)
{
@ -292,12 +298,13 @@ public:
LayerPin blobPin(ld.id, index);
if (index < outShapes.size() && inPlace)
{
CV_Assert(ld.inputBlobs[0]->total() == total(shapes[index]));
CV_CheckEQ((int)ld.inputBlobs[0]->total(), total(shapes[index]), "");
CV_CheckTypeEQ(ld.inputBlobs[0]->type(), types[index], "blob can't be reused if it has different type");
ld.outputBlobs[index] = ld.inputBlobs[0]->reshape(1, shapes[index]);
reuse(ld.inputBlobsId[0], blobPin);
}
else
reuseOrCreate(shapes[index], blobPin, *blobs[index], ld.dtype);
reuseOrCreate(shapes[index], blobPin, *blobs[index], types[index]);
}
}
}

View File

@ -48,7 +48,7 @@ public:
outNames = net.getUnconnectedOutLayersNames();
std::vector<MatShape> inLayerShapes;
std::vector<MatShape> outLayerShapes;
net.getLayerShapes(MatShape(), 0, inLayerShapes, outLayerShapes);
net.getLayerShapes(MatShape(), CV_32F, 0, inLayerShapes, outLayerShapes);
if (!inLayerShapes.empty() && inLayerShapes[0].size() == 4)
size = Size(inLayerShapes[0][3], inLayerShapes[0][2]);
else

View File

@ -234,68 +234,77 @@ std::vector<String> Net::getUnconnectedOutLayersNames() const
}
void Net::getLayersShapes(const ShapesVec& netInputShapes,
const TypesVec& netInputTypes,
std::vector<int>& layersIds,
std::vector<ShapesVec>& inLayersShapes,
std::vector<ShapesVec>& outLayersShapes) const
{
CV_Assert(impl);
return impl->getLayersShapes(netInputShapes, layersIds, inLayersShapes, outLayersShapes);
return impl->getLayersShapes(netInputShapes, netInputTypes, layersIds, inLayersShapes, outLayersShapes);
}
void Net::getLayersShapes(const MatShape& netInputShape,
const MatType& netInputType,
std::vector<int>& layerIds,
std::vector<ShapesVec>& inLayersShapes,
std::vector<ShapesVec>& outLayersShapes) const
{
getLayersShapes(ShapesVec(1, netInputShape),
TypesVec(1, netInputType),
layerIds, inLayersShapes, outLayersShapes);
}
void Net::getLayerShapes(const MatShape& netInputShape,
const MatType& netInputType,
const int layerId,
ShapesVec& inLayerShapes,
ShapesVec& outLayerShapes) const
{
getLayerShapes(ShapesVec(1, netInputShape),
getLayerShapes(ShapesVec(1, netInputShape), TypesVec(1, netInputType),
layerId, inLayerShapes, outLayerShapes);
}
void Net::getLayerShapes(const ShapesVec& netInputShapes,
const TypesVec& netInputTypes,
const int layerId,
ShapesVec& inLayerShapes,
ShapesVec& outLayerShapes) const
{
CV_Assert(impl);
LayerShapes shapes;
impl->getLayerShapes(netInputShapes, layerId, shapes);
impl->getLayerShapes(netInputShapes, netInputTypes, layerId, shapes);
inLayerShapes = shapes.in;
outLayerShapes = shapes.out;
}
int64 Net::getFLOPS(const std::vector<MatShape>& netInputShapes) const
int64 Net::getFLOPS(const std::vector<MatShape>& netInputShapes, const std::vector<MatType>& netInputTypes) const
{
CV_TRACE_FUNCTION();
CV_Assert(impl);
return impl->getFLOPS(netInputShapes);
return impl->getFLOPS(netInputShapes, netInputTypes);
}
int64 Net::getFLOPS(const MatShape& netInputShape) const
int64 Net::getFLOPS(const MatShape& netInputShape, const MatType& netInputType) const
{
return getFLOPS(std::vector<MatShape>(1, netInputShape));
return getFLOPS(std::vector<MatShape>(1, netInputShape),
std::vector<MatType>(1, netInputType));
}
int64 Net::getFLOPS(const int layerId,
const std::vector<MatShape>& netInputShapes) const
const std::vector<MatShape>& netInputShapes,
const std::vector<MatType>& netInputTypes) const
{
CV_TRACE_FUNCTION();
CV_Assert(impl);
return impl->getFLOPS(layerId, netInputShapes);
return impl->getFLOPS(layerId, netInputShapes, netInputTypes);
}
int64 Net::getFLOPS(const int layerId,
const MatShape& netInputShape) const
const MatShape& netInputShape,
const MatType& netInputType) const
{
return getFLOPS(layerId, std::vector<MatShape>(1, netInputShape));
return getFLOPS(layerId, std::vector<MatShape>(1, netInputShape),
std::vector<MatType>(1, netInputType));
}
void Net::getLayerTypes(std::vector<String>& layersTypes) const
@ -314,50 +323,59 @@ int Net::getLayersCount(const String& layerType) const
void Net::getMemoryConsumption(const int layerId,
const std::vector<MatShape>& netInputShapes,
const std::vector<MatType>& netInputTypes,
size_t& weights, size_t& blobs) const
{
CV_TRACE_FUNCTION();
CV_Assert(impl);
return impl->getMemoryConsumption(layerId, netInputShapes, weights, blobs);
return impl->getMemoryConsumption(layerId, netInputShapes, netInputTypes, weights, blobs);
}
void Net::getMemoryConsumption(const std::vector<MatShape>& netInputShapes,
const std::vector<MatType>& netInputTypes,
size_t& weights, size_t& blobs) const
{
CV_TRACE_FUNCTION();
CV_Assert(impl);
return impl->getMemoryConsumption(netInputShapes, weights, blobs);
return impl->getMemoryConsumption(netInputShapes, netInputTypes, weights, blobs);
}
void Net::getMemoryConsumption(const int layerId,
const MatShape& netInputShape,
const MatType& netInputType,
size_t& weights, size_t& blobs) const
{
getMemoryConsumption(layerId, std::vector<MatShape>(1, netInputShape),
weights, blobs);
std::vector<MatType>(1, netInputType),
weights, blobs);
}
void Net::getMemoryConsumption(const MatShape& netInputShape,
const MatType& netInputType,
size_t& weights, size_t& blobs) const
{
getMemoryConsumption(std::vector<MatShape>(1, netInputShape),
std::vector<MatType>(1, netInputType),
weights, blobs);
}
void Net::getMemoryConsumption(const std::vector<MatShape>& netInputShapes,
const std::vector<MatType>& netInputTypes,
std::vector<int>& layerIds, std::vector<size_t>& weights,
std::vector<size_t>& blobs) const
{
CV_TRACE_FUNCTION();
CV_Assert(impl);
return impl->getMemoryConsumption(netInputShapes, layerIds, weights, blobs);
return impl->getMemoryConsumption(netInputShapes, netInputTypes, layerIds, weights, blobs);
}
void Net::getMemoryConsumption(const MatShape& netInputShape, std::vector<int>& layerIds,
void Net::getMemoryConsumption(const MatShape& netInputShape, const MatType& netInputType,
std::vector<int>& layerIds,
std::vector<size_t>& weights, std::vector<size_t>& blobs) const
{
getMemoryConsumption(std::vector<MatShape>(1, netInputShape), layerIds,
weights, blobs);
getMemoryConsumption(std::vector<MatShape>(1, netInputShape),
std::vector<MatType>(1, netInputType),
layerIds, weights, blobs);
}
// FIXIT return old value or add get method

View File

@ -186,11 +186,6 @@ void Net::Impl::setUpNet(const std::vector<LayerPin>& blobsToKeep_)
clear();
if (hasDynamicShapes)
{
updateLayersShapes();
}
this->blobsToKeep = blobsToKeep_;
allocateLayers(blobsToKeep_);
@ -475,7 +470,7 @@ void Net::Impl::allocateLayer(int lid, const LayersShapesMap& layersShapes)
allocateLayer(*i, layersShapes);
// bind inputs
if (ld.id == 0) // DataLayer
if (ld.id == 0 && netInputLayer->supportBackend(preferableBackend)) // DataLayer
{
ninputs = netInputLayer->inputsData.size();
ld.inputBlobsWrappers.resize(ninputs);
@ -500,7 +495,8 @@ void Net::Impl::allocateLayer(int lid, const LayersShapesMap& layersShapes)
CV_Assert(layerShapesIt != layersShapes.end());
if (preferableBackend == DNN_BACKEND_OPENCV && preferableTarget == DNN_TARGET_OPENCL_FP16 && ld.dtype == CV_32F)
if (preferableBackend == DNN_BACKEND_OPENCV && ld.dtype == CV_32F
&& preferableTarget == DNN_TARGET_OPENCL_FP16)
ld.dtype = CV_16F;
std::vector<LayerPin> pinsForInternalBlobs;
@ -522,7 +518,6 @@ void Net::Impl::allocateLayer(int lid, const LayersShapesMap& layersShapes)
inps[i] = *ld.inputBlobs[i];
}
layerPtr->finalize(inps, ld.outputBlobs);
layerPtr->preferableTarget = preferableTarget;
#if 0
std::cout << "\toutputs:";
size_t noutputs = ld.outputBlobs.size();
@ -551,20 +546,39 @@ void Net::Impl::allocateLayers(const std::vector<LayerPin>& blobsToKeep_)
CV_Assert(!layers[0].outputBlobs.empty());
ShapesVec inputShapes;
TypesVec inputTypes;
for (int i = 0; i < layers[0].outputBlobs.size(); i++)
{
Mat& inp = layers[0].outputBlobs[i];
CV_Assert(inp.total());
if (preferableBackend == DNN_BACKEND_OPENCV &&
preferableTarget == DNN_TARGET_OPENCL_FP16 &&
layers[0].dtype == CV_32F)
int type = inp.type();
if (type != CV_32S && type != CV_64S)
{
layers[0].outputBlobs[i].create(inp.dims, inp.size, CV_16F);
type = CV_32F;
if (preferableBackend == DNN_BACKEND_OPENCV &&
preferableTarget == DNN_TARGET_OPENCL_FP16)
{
type = CV_16F;
if (layers[0].dtype == CV_32F)
layers[0].outputBlobs[i].create(inp.dims, inp.size, CV_16F);
}
if (netWasQuantized && inp.type() == CV_8S) {
type = CV_8S;
}
}
inputShapes.push_back(shape(inp));
inputTypes.push_back(type);
}
for (auto& layer : layers)
{
auto& ld = layer.second;
Ptr<Layer> layerPtr = getLayerInstance(ld);
layerPtr->preferableTarget = preferableTarget;
}
LayersShapesMap layersShapes;
getLayersShapes(inputShapes, layersShapes);
getLayersShapes(inputShapes, inputTypes, layersShapes);
blobManager.reset();
backendWrappers.clear();
@ -969,7 +983,12 @@ void Net::Impl::forward(OutputArrayOfArrays outputBlobs, const String& outputNam
std::vector<Mat>& outputvec = *(std::vector<Mat>*)outputBlobs.getObj();
outputvec.resize(ld.outputBlobs.size());
for (int i = 0; i < outputvec.size(); i++)
ld.outputBlobs[i].convertTo(outputvec[i], CV_32F);
{
if (ld.outputBlobs[i].depth() == CV_32S || ld.outputBlobs[i].depth() == CV_64S)
outputvec[i] = ld.outputBlobs[i];
else
ld.outputBlobs[i].convertTo(outputvec[i], CV_32F);
}
}
else
{
@ -1079,13 +1098,16 @@ void Net::Impl::getLayerShapesRecursively(int id, LayersShapesMap& inOutShapes)
if (!layerData.outputBlobs.empty())
{
ShapesVec shapes;
TypesVec types;
for (int i = 0; i < layerData.outputBlobs.size(); i++)
{
Mat& inp = layerData.outputBlobs[i];
CV_Assert(!inp.empty());
shapes.push_back(shape(inp));
types.push_back(inp.type());
}
layerShapes.in = shapes;
layerShapes.inTypes = types;
}
else
{
@ -1102,11 +1124,13 @@ void Net::Impl::getLayerShapesRecursively(int id, LayersShapesMap& inOutShapes)
if (none)
{
layerShapes.out.clear();
layerShapes.outTypes.clear();
return;
}
else
{
layerShapes.in = inputShapes;
layerShapes.inTypes.assign(inputShapes.size(), layerData.dtype);
}
}
}
@ -1126,7 +1150,9 @@ void Net::Impl::getLayerShapesRecursively(int id, LayersShapesMap& inOutShapes)
const int out_port = inputLayerIds[i].oid;
CV_CheckLT(out_port, (int)it->second.out.size(), "");
const MatShape& shape = it->second.out[out_port];
const MatType& type = it->second.outTypes[out_port];
layerShapes.in.push_back(shape);
layerShapes.inTypes.push_back(type);
}
}
const ShapesVec& is = layerShapes.in;
@ -1138,7 +1164,11 @@ void Net::Impl::getLayerShapesRecursively(int id, LayersShapesMap& inOutShapes)
bool layerSupportInPlace = false;
try
{
l->updateMemoryShapes(layerShapes.in);
layerSupportInPlace = l->getMemoryShapes(is, requiredOutputs, os, ints);
l->getTypes(layerShapes.inTypes, os.size(), ints.size(), layerShapes.outTypes, layerShapes.internalTypes);
CV_CheckEQ(layerShapes.out.size(), layerShapes.outTypes.size(), "Number of shapes and types should be equal");
CV_CheckEQ(layerShapes.internal.size(), layerShapes.internalTypes.size(), "Number of shapes and types should be equal");
}
catch (const cv::Exception& e)
{
@ -1197,6 +1227,7 @@ void Net::Impl::getLayerShapesRecursively(int id, LayersShapesMap& inOutShapes)
void Net::Impl::getLayersShapes(
const ShapesVec& netInputShapes,
const TypesVec& netInputTypes,
std::vector<int>& layersIds,
std::vector<ShapesVec>& inLayersShapes,
std::vector<ShapesVec>& outLayersShapes) /*const*/
@ -1206,7 +1237,7 @@ void Net::Impl::getLayersShapes(
outLayersShapes.clear();
Impl::LayersShapesMap inOutShapes;
getLayersShapes(netInputShapes, inOutShapes);
getLayersShapes(netInputShapes, netInputTypes, inOutShapes);
for (Impl::LayersShapesMap::const_iterator it = inOutShapes.begin();
it != inOutShapes.end(); it++)
@ -1219,11 +1250,13 @@ void Net::Impl::getLayersShapes(
void Net::Impl::getLayersShapes(const ShapesVec& netInputShapes,
const TypesVec& netInputTypes,
LayersShapesMap& inOutShapes)
{
inOutShapes.clear();
inOutShapes[0].in = netInputShapes; // insert shape for first input layer
inOutShapes[0].inTypes = netInputTypes;
for (MapIdToLayerData::const_iterator it = layers.begin();
it != layers.end(); it++)
{
@ -1232,11 +1265,13 @@ void Net::Impl::getLayersShapes(const ShapesVec& netInputShapes,
}
void Net::Impl::getLayerShapes(const ShapesVec& netInputShapes,
const TypesVec& netInputTypes,
const int layerId,
LayerShapes& shapes)
{
LayersShapesMap inOutShapes;
inOutShapes[0].in = netInputShapes; // insert shape for first input layer
inOutShapes[0].inTypes = netInputTypes;
getLayerShapesRecursively(layerId, inOutShapes);
shapes = inOutShapes[layerId];
}
@ -1250,6 +1285,7 @@ void Net::Impl::updateLayersShapes()
CV_Assert(inputLayerData.layerInstance.get() == &inputLayer);
CV_Assert(!inputLayerData.outputBlobs.empty());
ShapesVec inputShapes;
TypesVec inputTypes;
for (int i = 0; i < inputLayerData.outputBlobs.size(); i++)
{
Mat& inp = inputLayerData.outputBlobs[i];
@ -1261,10 +1297,12 @@ void Net::Impl::updateLayersShapes()
inp.create(inp.dims, inp.size, CV_16F);
}
inputShapes.push_back(shape(inp));
inputTypes.push_back(inp.type());
}
CV_LOG_DEBUG(NULL, toString(inputShapes, "Network input shapes"));
LayersShapesMap layersShapes;
layersShapes[0].in = inputShapes;
layersShapes[0].inTypes = inputTypes;
for (MapIdToLayerData::iterator it = layers.begin(); it != layers.end(); it++)
{
int layerId = it->first;
@ -1285,7 +1323,9 @@ void Net::Impl::updateLayersShapes()
getLayerShapesRecursively(inputLayerId, layersShapes);
}
const MatShape& shape = layersShapes[inputLayerId].out[inputPin.oid];
const MatType& type = layersShapes[inputLayerId].outTypes[inputPin.oid];
layerShapes.in.push_back(shape);
layerShapes.inTypes.push_back(type);
}
getLayerInstance(layerData)->updateMemoryShapes(layerShapes.in);
}
@ -1910,12 +1950,13 @@ std::vector<String> Net::Impl::getUnconnectedOutLayersNames() /*const*/
}
int64 Net::Impl::getFLOPS(const std::vector<MatShape>& netInputShapes) /*const*/
int64 Net::Impl::getFLOPS(const std::vector<MatShape>& netInputShapes,
const std::vector<MatType>& netInputTypes) /*const*/
{
int64 flops = 0;
std::vector<int> ids;
std::vector<std::vector<MatShape>> inShapes, outShapes;
getLayersShapes(netInputShapes, ids, inShapes, outShapes);
getLayersShapes(netInputShapes, netInputTypes, ids, inShapes, outShapes);
CV_Assert(inShapes.size() == outShapes.size());
CV_Assert(inShapes.size() == ids.size());
@ -1930,13 +1971,14 @@ int64 Net::Impl::getFLOPS(const std::vector<MatShape>& netInputShapes) /*const*/
int64 Net::Impl::getFLOPS(
const int layerId,
const std::vector<MatShape>& netInputShapes) /*const*/
const std::vector<MatShape>& netInputShapes,
const std::vector<MatType>& netInputTypes) /*const*/
{
Impl::MapIdToLayerData::const_iterator layer = layers.find(layerId);
CV_Assert(layer != layers.end());
LayerShapes shapes;
getLayerShapes(netInputShapes, layerId, shapes);
getLayerShapes(netInputShapes, netInputTypes, layerId, shapes);
return getLayerInstance(const_cast<LayerData&>(layer->second))->getFLOPS(shapes.in, shapes.out);
}
@ -1945,6 +1987,7 @@ int64 Net::Impl::getFLOPS(
void Net::Impl::getMemoryConsumption(
const int layerId,
const std::vector<MatShape>& netInputShapes,
const std::vector<MatType>& netInputTypes,
size_t& weights, size_t& blobs) /*const*/
{
Impl::MapIdToLayerData::const_iterator layer = layers.find(layerId);
@ -1959,25 +2002,22 @@ void Net::Impl::getMemoryConsumption(
}
LayerShapes shapes;
getLayerShapes(netInputShapes, layerId, shapes);
getLayerShapes(netInputShapes, netInputTypes, layerId, shapes);
const ShapesVec& outLayerShapes = shapes.out;
// FIXIT netWasQuantized check is not enough - per layer check should be done
size_t elemSize = netWasQuantized ? sizeof(char) : sizeof(float);
for (int i = 0; i < outLayerShapes.size(); i++)
{
blobs += total(outLayerShapes[i]) * elemSize;
}
blobs += total(outLayerShapes[i]) * CV_ELEM_SIZE(shapes.outTypes[i]);
}
void Net::Impl::getMemoryConsumption(
const std::vector<MatShape>& netInputShapes,
const std::vector<MatType>& netInputTypes,
size_t& weights, size_t& blobs) /*const*/
{
std::vector<int> layerIds;
std::vector<size_t> w, b;
getMemoryConsumption(netInputShapes, layerIds, w, b);
getMemoryConsumption(netInputShapes, netInputTypes, layerIds, w, b);
weights = blobs = 0;
for (int i = 0; i < layerIds.size(); i++)
@ -1997,6 +2037,7 @@ int64 Net::Impl::getPerfProfile(std::vector<double>& timings) const
void Net::Impl::getMemoryConsumption(
const std::vector<MatShape>& netInputShapes,
const std::vector<MatType>& netInputTypes,
std::vector<int>& layerIds, std::vector<size_t>& weights,
std::vector<size_t>& blobs) /*const*/
{
@ -2006,7 +2047,7 @@ void Net::Impl::getMemoryConsumption(
std::vector<std::vector<MatShape>> inLayerShapes, outLayerShapes;
getLayersShapes(netInputShapes, layerIds, inLayerShapes, outLayerShapes);
getLayersShapes(netInputShapes, netInputTypes, layerIds, inLayerShapes, outLayerShapes);
// FIXIT netWasQuantized check is not enough - per layer check should be done
size_t elemSize = netWasQuantized ? sizeof(char) : sizeof(float);
for (int i = 0; i < layerIds.size(); i++)

View File

@ -227,33 +227,41 @@ struct Net::Impl : public detail::NetImplBase
void getLayersShapes(
const ShapesVec& netInputShapes,
const TypesVec& netInputTypes,
std::vector<int>& layersIds,
std::vector<ShapesVec>& inLayersShapes,
std::vector<ShapesVec>& outLayersShapes) /*const*/;
void getLayersShapes(const ShapesVec& netInputShapes,
const TypesVec& netInputTypes,
LayersShapesMap& inOutShapes);
void getLayerShapes(const ShapesVec& netInputShapes,
const TypesVec& netInputTypes,
const int layerId,
LayerShapes& shapes);
void updateLayersShapes();
int64 getFLOPS(const std::vector<MatShape>& netInputShapes) /*const*/;
int64 getFLOPS(const std::vector<MatShape>& netInputShapes,
const std::vector<MatType>& netInputTypes) /*const*/;
int64 getFLOPS(
const int layerId,
const std::vector<MatShape>& netInputShapes) /*const*/;
const std::vector<MatShape>& netInputShapes,
const std::vector<MatType>& netInputTypes) /*const*/;
void getMemoryConsumption(
const int layerId,
const std::vector<MatShape>& netInputShapes,
const std::vector<MatType>& netInputTypes,
size_t& weights, size_t& blobs) /*const*/;
void getMemoryConsumption(
const std::vector<MatShape>& netInputShapes,
const std::vector<MatType>& netInputTypes,
size_t& weights, size_t& blobs) /*const*/;
void getMemoryConsumption(
const std::vector<MatShape>& netInputShapes,
const std::vector<MatType>& netInputTypes,
std::vector<int>& layerIds, std::vector<size_t>& weights,
std::vector<size_t>& blobs) /*const*/;
int64 getPerfProfile(std::vector<double>& timings) const;

View File

@ -62,14 +62,21 @@ Ptr<BackendWrapper> Net::Impl::wrap(Mat& host)
{
CV_Assert(haveCUDA());
#ifdef HAVE_CUDA
switch (preferableTarget)
CV_CheckType(host.depth(), host.depth() == CV_32F || host.depth() == CV_32S || host.depth() == CV_64S, "Unsupported type for CUDA");
CV_Assert(IS_DNN_CUDA_TARGET(preferableTarget));
switch (host.depth())
{
case DNN_TARGET_CUDA:
return CUDABackendWrapperFP32::create(baseBuffer, shape);
case DNN_TARGET_CUDA_FP16:
return CUDABackendWrapperFP16::create(baseBuffer, shape);
case CV_32F:
if (preferableTarget == DNN_TARGET_CUDA_FP16)
return CUDABackendWrapperFP16::create(baseBuffer, shape);
else
return CUDABackendWrapperFP32::create(baseBuffer, shape);
case CV_32S:
return CUDABackendWrapperINT32::create(baseBuffer, shape);
case CV_64S:
return CUDABackendWrapperINT64::create(baseBuffer, shape);
default:
CV_Assert(IS_DNN_CUDA_TARGET(preferableTarget));
CV_Error(Error::BadDepth, "Unsupported mat type for CUDA");
}
#endif
}

View File

@ -381,27 +381,24 @@ void runLayer(LayerParams& params, const std::vector<Mat>& inputs,
CV_Assert((bool)layer);
std::vector<MatShape> inpShapes(inputs.size());
int ddepth = params.get<int>("depth", CV_32F);
std::vector<MatType> inpTypes(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i)
{
inpShapes[i] = shape(inputs[i]);
if (i > 0 && ddepth != inputs[i].depth())
CV_Error(Error::StsNotImplemented, cv::format("Mixed input data types. Required type: %d, actual type: %d", ddepth, inputs[i].depth()));
// Quantize and Dequantize layer have different output type than input.
if (params.type != "Quantize" && params.type != "Dequantize")
ddepth = inputs[i].depth();
inpTypes[i] = inputs[i].type();
}
std::vector<MatShape> outShapes, internalShapes;
std::vector<MatType> outTypes, internalTypes;
layer->getMemoryShapes(inpShapes, 0, outShapes, internalShapes);
layer->getTypes(inpTypes, outShapes.size(), internalShapes.size(), outTypes, internalTypes);
std::vector<Mat> internals(internalShapes.size());
outputs.resize(outShapes.size());
for (size_t i = 0; i < outShapes.size(); ++i)
outputs[i].create(outShapes[i], ddepth);
outputs[i].create(outShapes[i], outTypes[i]);
for (size_t i = 0; i < internalShapes.size(); ++i)
internals[i].create(internalShapes[i], ddepth);
internals[i].create(internalShapes[i], internalTypes[i]);
layer->finalize(inputs, outputs);
layer->forward(inputs, outputs, internals);
@ -2506,7 +2503,6 @@ void ONNXImporter::parseGather(LayerParams& layerParams, const opencv_onnx::Node
inputs.push_back(input);
Mat indices = getBlob(node_proto, 1);
indices.convertTo(indices, CV_32FC1);
inputs.push_back(indices);
runLayer(layerParams, inputs, output);
@ -2525,10 +2521,6 @@ void ONNXImporter::parseGather(LayerParams& layerParams, const opencv_onnx::Node
constParams.name = node_proto.input(i);
constParams.type = "Const";
Mat blob = getBlob(node_proto, i);
if (i == 1)
{
blob.convertTo(blob, CV_32FC1);
}
constParams.blobs.push_back(blob);
opencv_onnx::NodeProto proto;
@ -3037,8 +3029,6 @@ void ONNXImporter::parseScatter(LayerParams& layerParams, const opencv_onnx::Nod
for (size_t i = 0; i < node_proto.input_size(); i++)
{
Mat blob = getBlob(node_proto, i);
if (i == 1) // indices
blob.convertTo(blob, CV_32F);
inputs.push_back(blob);
}
runLayer(layerParams, inputs, output);

View File

@ -51,7 +51,8 @@ void Net::Impl::initCUDABackend(const std::vector<LayerPin>& blobsToKeep_)
for (auto& layer : layers)
{
auto& ld = layer.second;
if (ld.id == 0)
if (ld.id == 0 && netInputLayer->supportBackend(preferableBackend))
{
for (auto& wrapper : ld.inputBlobsWrappers)
{

View File

@ -58,6 +58,16 @@ namespace cv { namespace dnn {
return Tensor<T>(std::begin(sizes), std::end(sizes));
}
template <class T> inline
void copyMatToTensorImpl(const Mat& srcMat, const TensorSpan<T> destTensor, const Stream& stream) {
CV_Assert(srcMat.total() >= destTensor.size());
Mat temp = srcMat.isContinuous() ? srcMat : srcMat.clone();
CV_Assert(temp.isContinuous());
memcpy<T>(destTensor.get(), reinterpret_cast<T*>(temp.data), destTensor.size(), stream);
}
/** @brief copies data from a cv::Mat to TensorType
*
* \tparam T the type of the elements contained in TensorType object
@ -81,8 +91,7 @@ namespace cv { namespace dnn {
template <> inline
void copyMatToTensor(const Mat& srcMat, const TensorSpan<half> destTensor, const Stream& stream) {
/* should perhaps convert cv::Mat of different type to the required type and copy */
CV_Assert(srcMat.type() == CV_32F);
CV_CheckTypeEQ(srcMat.type(), CV_32F, "");
CV_Assert(srcMat.total() >= destTensor.size());
Mat temp;
@ -94,14 +103,20 @@ namespace cv { namespace dnn {
template <> inline
void copyMatToTensor(const Mat& srcMat, const TensorSpan<float> destTensor, const Stream& stream) {
/* should perhaps convert cv::Mat of different type to the required type and copy */
CV_Assert(srcMat.type() == CV_32F);
CV_Assert(srcMat.total() >= destTensor.size());
CV_CheckTypeEQ(srcMat.type(), CV_32F, "");
copyMatToTensorImpl(srcMat, destTensor, stream);
}
Mat temp = srcMat.isContinuous() ? srcMat : srcMat.clone();
CV_Assert(temp.isContinuous());
template <> inline
void copyMatToTensor(const Mat& srcMat, const TensorSpan<int32_t> destTensor, const Stream& stream) {
CV_CheckTypeEQ(srcMat.type(), CV_32S, "");
copyMatToTensorImpl(srcMat, destTensor, stream);
}
memcpy<float>(destTensor.get(), reinterpret_cast<float*>(temp.data), destTensor.size(), stream);
template <> inline
void copyMatToTensor(const Mat& srcMat, const TensorSpan<int64_t> destTensor, const Stream& stream) {
CV_CheckTypeEQ(srcMat.type(), CV_64S, "");
copyMatToTensorImpl(srcMat, destTensor, stream);
}
/** @brief copies data from a TensorType to a cv::Mat
@ -126,7 +141,7 @@ namespace cv { namespace dnn {
template <> inline
void copyTensorToMat(TensorView<half> srcTensor, Mat& destMat, const Stream& stream) {
CV_Assert(destMat.type() == CV_32F);
CV_CheckTypeEQ(destMat.type(), CV_32F, "Unsupported type");
CV_Assert(destMat.total() >= srcTensor.size());
Mat temp(shape(destMat), CV_16F);
@ -139,7 +154,7 @@ namespace cv { namespace dnn {
template <> inline
void copyTensorToMat(TensorView<float> srcTensor, Mat& destMat, const Stream& stream) {
CV_Assert(destMat.type() == CV_32F);
CV_CheckTypeEQ(destMat.type(), CV_32F, "Unsupported type");
CV_Assert(destMat.total() >= srcTensor.size());
Mat temp = destMat.isContinuous() ? destMat : destMat.clone();
@ -200,6 +215,44 @@ namespace cv { namespace dnn {
return Ptr<BackendNode>();
}
template <template <class> class NodeType, class ...Args>
cv::Ptr<BackendNode> make_cuda_node_with_type(int targetId, int hostMatType, Args&& ...args) {
CV_CheckType(hostMatType, hostMatType == CV_32F || hostMatType == CV_32S || hostMatType == CV_64S, "");
if (hostMatType == CV_32S)
return Ptr<BackendNode>(new NodeType<int32_t>(std::forward<Args>(args)...));
else if (hostMatType == CV_64S)
return Ptr<BackendNode>(new NodeType<int64_t>(std::forward<Args>(args)...));
else if (hostMatType == CV_32F)
{
if (targetId == DNN_TARGET_CUDA_FP16)
return Ptr<BackendNode>(new NodeType<half>(std::forward<Args>(args)...));
else if (targetId == DNN_TARGET_CUDA)
return Ptr<BackendNode>(new NodeType<float>(std::forward<Args>(args)...));
}
CV_Error(Error::BadDepth, "Unsupported mat type");
return Ptr<BackendNode>();
}
template <template <class, class> class NodeType, class T_INDEX, class ...Args>
cv::Ptr<BackendNode> make_cuda_node_with_indices(int targetId, int hostMatType, Args&& ...args) {
CV_CheckType(hostMatType, hostMatType == CV_32F || hostMatType == CV_32S || hostMatType == CV_64S, "");
if (hostMatType == CV_32S)
return Ptr<BackendNode>(new NodeType<int32_t, T_INDEX>(std::forward<Args>(args)...));
else if (hostMatType == CV_64S)
return Ptr<BackendNode>(new NodeType<int64_t, T_INDEX>(std::forward<Args>(args)...));
else if (hostMatType == CV_32F)
{
if (targetId == DNN_TARGET_CUDA_FP16)
return Ptr<BackendNode>(new NodeType<half, T_INDEX>(std::forward<Args>(args)...));
else if (targetId == DNN_TARGET_CUDA)
return Ptr<BackendNode>(new NodeType<float, T_INDEX>(std::forward<Args>(args)...));
}
CV_Error(Error::BadDepth, "Unsupported mat type");
return Ptr<BackendNode>();
}
/* base class for all CUDA backend/target wrappers */
class CUDABackendWrapper : public BackendWrapper {
public:
@ -224,11 +277,11 @@ namespace cv { namespace dnn {
namespace cuda4dnn { namespace detail {
template <class U>
void convert_D2H(const cv::Mat& mat, cuda4dnn::csl::View<U> view, cuda4dnn::csl::ManagedPtr<float>& device_temp, const cuda4dnn::csl::Stream& stream);
template <class DEVICE_T, class HOST_T>
void convert_D2H(const cv::Mat& mat, cuda4dnn::csl::View<DEVICE_T> view, cuda4dnn::csl::ManagedPtr<HOST_T>& device_temp, const cuda4dnn::csl::Stream& stream);
template <> inline
void convert_D2H<half>(const cv::Mat& mat, cuda4dnn::csl::View<half> view, cuda4dnn::csl::ManagedPtr<float>& device_temp, const cuda4dnn::csl::Stream& stream) {
void convert_D2H<half, float>(const cv::Mat& mat, cuda4dnn::csl::View<half> view, cuda4dnn::csl::ManagedPtr<float>& device_temp, const cuda4dnn::csl::Stream& stream) {
if (device_temp.size() < view.size())
device_temp.reset(view.size());
auto temp_span = cuda4dnn::csl::Span<float>(device_temp.get(), view.size());
@ -238,15 +291,25 @@ namespace cv { namespace dnn {
}
template <> inline
void convert_D2H<float>(const cv::Mat& mat, cuda4dnn::csl::View<float> view, cuda4dnn::csl::ManagedPtr<float>& device_temp, const cuda4dnn::csl::Stream& stream) {
void convert_D2H<float, float>(const cv::Mat& mat, cuda4dnn::csl::View<float> view, cuda4dnn::csl::ManagedPtr<float>& device_temp, const cuda4dnn::csl::Stream& stream) {
cuda4dnn::csl::memcpy<float>(reinterpret_cast<float*>(mat.data), view.data(), view.size(), stream);
}
template <class U>
void convert_D2H_background(const cv::Mat& mat, cuda4dnn::csl::View<U> view, cuda4dnn::csl::ManagedPtr<float>& device_temp, const cuda4dnn::csl::Stream& stream, const cuda4dnn::csl::Stream& d2h_stream, cuda4dnn::csl::Event& d2h_event);
template <> inline
void convert_D2H<int32_t, int32_t>(const cv::Mat& mat, cuda4dnn::csl::View<int32_t> view, cuda4dnn::csl::ManagedPtr<int32_t>& device_temp, const cuda4dnn::csl::Stream& stream) {
cuda4dnn::csl::memcpy<int32_t>(reinterpret_cast<int32_t*>(mat.data), view.data(), view.size(), stream);
}
template <> inline
void convert_D2H_background<half>(const cv::Mat& mat, cuda4dnn::csl::View<half> view, cuda4dnn::csl::ManagedPtr<float>& device_temp, const cuda4dnn::csl::Stream& stream, const cuda4dnn::csl::Stream& d2h_stream, cuda4dnn::csl::Event& d2h_event) {
void convert_D2H<int64_t, int64_t>(const cv::Mat& mat, cuda4dnn::csl::View<int64_t> view, cuda4dnn::csl::ManagedPtr<int64_t>& device_temp, const cuda4dnn::csl::Stream& stream) {
cuda4dnn::csl::memcpy<int64_t>(reinterpret_cast<int64_t*>(mat.data), view.data(), view.size(), stream);
}
template <class DEVICE_T, class HOST_T>
void convert_D2H_background(const cv::Mat& mat, cuda4dnn::csl::View<DEVICE_T> view, cuda4dnn::csl::ManagedPtr<HOST_T>& device_temp, const cuda4dnn::csl::Stream& stream, const cuda4dnn::csl::Stream& d2h_stream, cuda4dnn::csl::Event& d2h_event);
template <> inline
void convert_D2H_background<half, float>(const cv::Mat& mat, cuda4dnn::csl::View<half> view, cuda4dnn::csl::ManagedPtr<float>& device_temp, const cuda4dnn::csl::Stream& stream, const cuda4dnn::csl::Stream& d2h_stream, cuda4dnn::csl::Event& d2h_event) {
if (device_temp.size() < view.size())
device_temp.reset(view.size());
auto temp_span = cuda4dnn::csl::Span<float>(device_temp.get(), view.size());
@ -266,17 +329,31 @@ namespace cv { namespace dnn {
}
template <> inline
void convert_D2H_background<float>(const cv::Mat& mat, cuda4dnn::csl::View<float> view, cuda4dnn::csl::ManagedPtr<float>& device_temp, const cuda4dnn::csl::Stream& stream, const cuda4dnn::csl::Stream& d2h_stream, cuda4dnn::csl::Event& d2h_event) {
void convert_D2H_background<float, float>(const cv::Mat& mat, cuda4dnn::csl::View<float> view, cuda4dnn::csl::ManagedPtr<float>& device_temp, const cuda4dnn::csl::Stream& stream, const cuda4dnn::csl::Stream& d2h_stream, cuda4dnn::csl::Event& d2h_event) {
d2h_event.record(stream);
cuda4dnn::csl::StreamWaitOnEvent(d2h_stream, d2h_event);
cuda4dnn::csl::memcpy<float>(reinterpret_cast<float*>(mat.data), view.data(), view.size(), d2h_stream);
}
template <class U>
void convert_H2D(cuda4dnn::csl::Span<U> span, const cv::Mat& mat, cuda4dnn::csl::ManagedPtr<float>& device_temp, const cuda4dnn::csl::Stream& stream);
template <> inline
void convert_D2H_background<int32_t, int32_t>(const cv::Mat& mat, cuda4dnn::csl::View<int32_t> view, cuda4dnn::csl::ManagedPtr<int32_t>& device_temp, const cuda4dnn::csl::Stream& stream, const cuda4dnn::csl::Stream& d2h_stream, cuda4dnn::csl::Event& d2h_event) {
d2h_event.record(stream);
cuda4dnn::csl::StreamWaitOnEvent(d2h_stream, d2h_event);
cuda4dnn::csl::memcpy<int32_t>(reinterpret_cast<int32_t*>(mat.data), view.data(), view.size(), d2h_stream);
}
template <> inline
void convert_H2D<half>(cuda4dnn::csl::Span<half> span, const cv::Mat& mat, cuda4dnn::csl::ManagedPtr<float>& device_temp, const cuda4dnn::csl::Stream& stream) {
void convert_D2H_background<int64_t, int64_t>(const cv::Mat& mat, cuda4dnn::csl::View<int64_t> view, cuda4dnn::csl::ManagedPtr<int64_t>& device_temp, const cuda4dnn::csl::Stream& stream, const cuda4dnn::csl::Stream& d2h_stream, cuda4dnn::csl::Event& d2h_event) {
d2h_event.record(stream);
cuda4dnn::csl::StreamWaitOnEvent(d2h_stream, d2h_event);
cuda4dnn::csl::memcpy<int64_t>(reinterpret_cast<int64_t*>(mat.data), view.data(), view.size(), d2h_stream);
}
template <class DEVICE_T, class HOST_T>
void convert_H2D(cuda4dnn::csl::Span<DEVICE_T> span, const cv::Mat& mat, cuda4dnn::csl::ManagedPtr<HOST_T>& device_temp, const cuda4dnn::csl::Stream& stream);
template <> inline
void convert_H2D<half, float>(cuda4dnn::csl::Span<half> span, const cv::Mat& mat, cuda4dnn::csl::ManagedPtr<float>& device_temp, const cuda4dnn::csl::Stream& stream) {
if (device_temp.size() < span.size())
device_temp.reset(span.size());
auto temp_span = cuda4dnn::csl::Span<float>(device_temp.get(), span.size());
@ -286,15 +363,25 @@ namespace cv { namespace dnn {
}
template <> inline
void convert_H2D<float>(cuda4dnn::csl::Span<float> span, const cv::Mat& mat, cuda4dnn::csl::ManagedPtr<float>& device_temp, const cuda4dnn::csl::Stream& stream) {
void convert_H2D<float, float>(cuda4dnn::csl::Span<float> span, const cv::Mat& mat, cuda4dnn::csl::ManagedPtr<float>& device_temp, const cuda4dnn::csl::Stream& stream) {
cuda4dnn::csl::memcpy<float>(span.data(), reinterpret_cast<float*>(mat.data), span.size(), stream);
}
template <> inline
void convert_H2D<int32_t, int32_t>(cuda4dnn::csl::Span<int32_t> span, const cv::Mat& mat, cuda4dnn::csl::ManagedPtr<int32_t>& device_temp, const cuda4dnn::csl::Stream& stream) {
cuda4dnn::csl::memcpy<int32_t>(span.data(), reinterpret_cast<int32_t*>(mat.data), span.size(), stream);
}
template <> inline
void convert_H2D<int64_t, int64_t>(cuda4dnn::csl::Span<int64_t> span, const cv::Mat& mat, cuda4dnn::csl::ManagedPtr<int64_t>& device_temp, const cuda4dnn::csl::Stream& stream) {
cuda4dnn::csl::memcpy<int64_t>(span.data(), reinterpret_cast<int64_t*>(mat.data), span.size(), stream);
}
}} /* namespace cuda4dnn::detail */
template <class T, int TargetID>
template <class DEVICE_T, class HOST_T, int TargetID>
class GenericCUDABackendWrapper final : public CUDABackendWrapper {
public:
using value_type = T;
using value_type = DEVICE_T;
using tensor_span_type = cuda4dnn::csl::TensorSpan<value_type>;
using tensor_view_type = cuda4dnn::csl::TensorView<value_type>;
@ -309,6 +396,7 @@ namespace cv { namespace dnn {
: CUDABackendWrapper(TargetID)
{
shape = cv::dnn::shape(m);
hostMatDepth = m.depth();
offset = 0;
shared_block = std::make_shared<shared_block_type>();
@ -324,7 +412,7 @@ namespace cv { namespace dnn {
/* we ignore the failure as this is just an optimization and not a requirement */
}
shared_block->device = cuda4dnn::csl::ManagedPtr<T>(m.total());
shared_block->device = cuda4dnn::csl::ManagedPtr<DEVICE_T>(m.total());
}
GenericCUDABackendWrapper(const Ptr<BackendWrapper>& base_, const MatShape& shape_)
@ -334,6 +422,7 @@ namespace cv { namespace dnn {
CV_Assert(base);
shape = shape_;
hostMatDepth = base_->getHostMatDepth();
offset = 0;
shared_block = base->shared_block;
@ -377,9 +466,8 @@ namespace cv { namespace dnn {
auto& mat = shared_block->host;
CV_Assert(mat.isContinuous());
CV_Assert(mat.type() == CV_32F);
cuda4dnn::detail::convert_D2H<T>(mat, view, shared_block->device_temp, shared_block->stream);
cuda4dnn::detail::convert_D2H<DEVICE_T, HOST_T>(mat, view, shared_block->device_temp, shared_block->stream);
shared_block->stream.synchronize();
} else if(shared_block->d2h_event && shared_block->d2h_event.busy()) {
/* wait for the background copy to finish */
@ -401,7 +489,7 @@ namespace cv { namespace dnn {
if (!shared_block->d2h_event)
shared_block->d2h_event = cuda4dnn::csl::Event(true);
cuda4dnn::detail::convert_D2H_background<T>(mat, view, shared_block->device_temp, shared_block->stream, shared_block->d2h_stream, shared_block->d2h_event);
cuda4dnn::detail::convert_D2H_background<DEVICE_T, HOST_T>(mat, view, shared_block->device_temp, shared_block->stream, shared_block->d2h_stream, shared_block->d2h_event);
shared_block->d2h_event.record(shared_block->d2h_stream); // record position so that we can check status later
}
}
@ -422,9 +510,8 @@ namespace cv { namespace dnn {
auto& mat = shared_block->host;
CV_Assert(mat.isContinuous());
CV_Assert(mat.type() == CV_32F);
cuda4dnn::detail::convert_H2D<T>(span, mat, shared_block->device_temp, shared_block->stream);
cuda4dnn::detail::convert_H2D<DEVICE_T, HOST_T>(span, mat, shared_block->device_temp, shared_block->stream);
}
}
@ -504,8 +591,8 @@ namespace cv { namespace dnn {
cv::Mat host;
cuda4dnn::csl::MemoryLockGuard memGuard; /* keeps host memory page-locked if possible */
cuda4dnn::csl::ManagedPtr<T> device;
cuda4dnn::csl::ManagedPtr<float> device_temp; /* use for conversions */
cuda4dnn::csl::ManagedPtr<DEVICE_T> device;
cuda4dnn::csl::ManagedPtr<HOST_T> device_temp; /* use for conversions */
cuda4dnn::csl::Stream stream;
cuda4dnn::csl::Event d2h_event;
@ -515,12 +602,16 @@ namespace cv { namespace dnn {
std::shared_ptr<shared_block_type> shared_block;
};
using CUDABackendWrapperFP16 = GenericCUDABackendWrapper<half, DNN_TARGET_CUDA_FP16>;
using CUDABackendWrapperFP32 = GenericCUDABackendWrapper<float, DNN_TARGET_CUDA>;
using CUDABackendWrapperFP16 = GenericCUDABackendWrapper<half, float, DNN_TARGET_CUDA_FP16>;
using CUDABackendWrapperFP32 = GenericCUDABackendWrapper<float, float, DNN_TARGET_CUDA>;
using CUDABackendWrapperINT32 = GenericCUDABackendWrapper<int32_t, int32_t, DNN_TARGET_CUDA>;
using CUDABackendWrapperINT64 = GenericCUDABackendWrapper<int64_t, int64_t, DNN_TARGET_CUDA>;
template <class T> struct GetCUDABackendWrapperType_ { };
template <> struct GetCUDABackendWrapperType_<half> { typedef CUDABackendWrapperFP16 type; };
template <> struct GetCUDABackendWrapperType_<float> { typedef CUDABackendWrapperFP32 type; };
template <> struct GetCUDABackendWrapperType_<int32_t> { typedef CUDABackendWrapperINT32 type; };
template <> struct GetCUDABackendWrapperType_<int64_t> { typedef CUDABackendWrapperINT64 type; };
template <class T>
using GetCUDABackendWrapperType = typename GetCUDABackendWrapperType_<T>::type;

View File

@ -1161,8 +1161,8 @@ void TFImporter::parseExpandDims(tensorflow::GraphDef& net, const tensorflow::No
// Get input shape
std::vector<MatShape> inShape_, outShape_;
int inpIdindex = layer_id.find(inpId.name)->second;
dstNet.getLayerShapes(netInputShapes, inpIdindex, inShape_, outShape_);
std::vector<MatType> netInputTypes(netInputShapes.size(), CV_32F);
dstNet.getLayerShapes(netInputShapes, netInputTypes, inpIdindex, inShape_, outShape_);
MatShape inpShape = outShape_[0];
std::vector<int> outShape = inpShape;
@ -1838,7 +1838,8 @@ void TFImporter::parseMul(tensorflow::GraphDef& net, const tensorflow::NodeDef&
// Get input shape
MatShape outShape;
std::vector<MatShape> inpShapes, outShapes;
dstNet.getLayerShapes(netInputShapes, inpId, inpShapes, outShapes);
std::vector<MatType> netInputTypes(netInputShapes.size(), CV_32F);
dstNet.getLayerShapes(netInputShapes, netInputTypes, inpId, inpShapes, outShapes);
CV_CheckGT(static_cast<int>(outShapes.size()), pin.blobIndex, "");
outShape = outShapes[pin.blobIndex];

View File

@ -80,11 +80,20 @@ Mat blobFromNPY(const std::string& path)
ifs.read(&header[0], header.size());
// Extract data type.
CV_Assert(getType(header) == "<f4");
int matType;
if (getType(header) == "<f4")
matType = CV_32F;
else if (getType(header) == "<i4")
matType = CV_32S;
else if (getType(header) == "<i8")
matType = CV_64S;
else
CV_Error(Error::BadDepth, "Unsupported numpy type");
CV_Assert(getFortranOrder(header) == "False");
std::vector<int> shape = getShape(header);
Mat blob(shape, CV_32F);
Mat blob(shape, matType);
ifs.read((char*)blob.data, blob.total() * blob.elemSize());
CV_Assert((size_t)ifs.gcount() == blob.total() * blob.elemSize());

View File

@ -209,7 +209,7 @@ TEST_P(Reproducibility_AlexNet, Accuracy)
// Test input layer size
std::vector<MatShape> inLayerShapes;
std::vector<MatShape> outLayerShapes;
net.getLayerShapes(MatShape(), 0, inLayerShapes, outLayerShapes);
net.getLayerShapes(MatShape(), CV_32F, 0, inLayerShapes, outLayerShapes);
ASSERT_FALSE(inLayerShapes.empty());
ASSERT_EQ(inLayerShapes[0].size(), 4);
ASSERT_EQ(inLayerShapes[0][0], 1);
@ -256,7 +256,7 @@ TEST(Reproducibility_FCN, Accuracy)
std::vector<int> layerIds;
std::vector<size_t> weights, blobs;
net.getMemoryConsumption(shape(1,3,227,227), layerIds, weights, blobs);
net.getMemoryConsumption(shape(1,3,227,227), CV_32F, layerIds, weights, blobs);
net.setInput(blobFromImage(sample, 1.0f, Size(500, 500), Scalar(), false), "data");
Mat out = net.forward("score");

View File

@ -0,0 +1,56 @@
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
//
// Copyright (C) 2017, Intel Corporation, all rights reserved.
// Third party copyrights are property of their respective owners.
#include "test_precomp.hpp"
namespace opencv_test { namespace {
typedef testing::TestWithParam<tuple<Backend, Target> > Test_int64_sum;
TEST_P(Test_int64_sum, basic)
{
Backend backend = get<0>(GetParam());
Target target = get<1>(GetParam());
int64_t a_value = 1000000000000000ll;
int64_t b_value = 1;
int64_t result_value = 1000000000000001ll;
EXPECT_NE(int64_t(float(a_value) + float(b_value)), result_value);
Mat a(3, 5, CV_64SC1, cv::Scalar_<int64_t>(a_value));
Mat b = Mat::ones(3, 5, CV_64S);
Net net;
LayerParams lp;
lp.type = "NaryEltwise";
lp.name = "testLayer";
lp.set("operation", "sum");
int id = net.addLayerToPrev(lp.name, lp.type, lp);
net.connect(0, 1, id, 1);
vector<String> inpNames(2);
inpNames[0] = "a";
inpNames[1] = "b";
net.setInputsNames(inpNames);
net.setInput(a, inpNames[0]);
net.setInput(b, inpNames[1]);
net.setPreferableBackend(backend);
net.setPreferableTarget(target);
Mat re;
re = net.forward();
EXPECT_EQ(re.depth(), CV_64S);
auto ptr_re = (int64_t *) re.data;
for (int i = 0; i < re.total(); i++)
ASSERT_EQ(result_value, ptr_re[i]);
}
INSTANTIATE_TEST_CASE_P(/*nothing*/, Test_int64_sum,
dnnBackendsAndTargets()
);
}} // namespace

View File

@ -1688,7 +1688,7 @@ TEST(Layer_Test_PoolingIndices, Accuracy)
Mat inp(10, 10, CV_8U);
randu(inp, 0, 255);
Mat maxValues(5, 5, CV_32F, Scalar(-1)), indices(5, 5, CV_32F, Scalar(-1));
Mat maxValues(5, 5, CV_32F, Scalar(-1)), indices(5, 5, CV_64S, Scalar(-1));
for (int y = 0; y < 10; ++y)
{
int dstY = y / 2;
@ -1699,7 +1699,7 @@ TEST(Layer_Test_PoolingIndices, Accuracy)
if ((float)inp.at<uint8_t>(y, x) > maxValues.at<float>(dstY, dstX))
{
maxValues.at<float>(dstY, dstX) = val;
indices.at<float>(dstY, dstX) = y * 10 + x;
indices.at<int64_t>(dstY, dstX) = y * 10 + x;
}
}
}

View File

@ -75,7 +75,7 @@ TEST_P(Layer_Gather_1d_Test, Accuracy) {
cv::Mat input = cv::Mat(input_shape, CV_32F, 1.0);
cv::randu(input, 0.0, 1.0);
cv::Mat indices = cv::Mat(indices_shape, CV_32F, 0.0);
cv::Mat indices = cv::Mat(indices_shape, CV_32S, 0.0);
cv::Mat output_ref = cv::Mat(output_shape, CV_32F, input(cv::Range::all(), cv::Range(0, 1)).data);
std::vector<Mat> inputs{input, indices};

View File

@ -111,10 +111,6 @@
"test_eyelike_populate_off_main_diagonal",
"test_eyelike_with_dtype",
"test_eyelike_without_dtype",
"test_gather_0",
"test_gather_1",
"test_gather_2d_indices",
"test_gather_negative_indices",
"test_gathernd_example_float32",
"test_gathernd_example_int32",
"test_gathernd_example_int32_batch_dim1",

View File

@ -44,7 +44,7 @@ public:
{
std::vector<MatShape> inLayerShapes;
std::vector<MatShape> outLayerShapes;
net.getLayerShapes(MatShape(), 0, inLayerShapes, outLayerShapes);
net.getLayerShapes(MatShape(), CV_32F, 0, inLayerShapes, outLayerShapes);
ASSERT_EQ(inLayerShapes.size(), inps.size());
for (int i = 0; i < inps.size(); ++i) {
@ -695,9 +695,6 @@ TEST_P(Test_ONNX_layers, Compare_GT)
testONNXModels("greater");
}
TEST_P(Test_ONNX_layers, Greater_input_dtype_int64) {
testONNXModels("greater_input_dtype_int64");
}
TEST_P(Test_ONNX_layers, Compare_LT)
{
@ -2167,7 +2164,7 @@ TEST_P(Test_ONNX_nets, Alexnet)
expectNoFallbacksFromIE(net);
}
TEST_P(Test_ONNX_nets, RAFT)
TEST_P(Test_ONNX_nets, DISABLED_RAFT)
{
applyTestTag(CV_TEST_TAG_LONG, CV_TEST_TAG_DEBUG_VERYLONG, CV_TEST_TAG_MEMORY_2GB);

View File

@ -31,7 +31,7 @@ public:
void testInputShapes(const Net& net, const std::vector<Mat>& inps) {
std::vector<MatShape> inLayerShapes;
std::vector<MatShape> outLayerShapes;
net.getLayerShapes(MatShape(), 0, inLayerShapes, outLayerShapes);
net.getLayerShapes(MatShape(), CV_32F, 0, inLayerShapes, outLayerShapes);
ASSERT_EQ(inLayerShapes.size(), inps.size());
for (int i = 0; i < inps.size(); ++i) {
@ -179,7 +179,7 @@ TEST_P(Test_TFLite, max_unpooling)
for (int c = 0; c < 32; ++c) {
float *poolInpData = poolInp.ptr<float>(0, c);
float *poolOutData = poolOut.ptr<float>(0, c);
float *poolIdsData = poolIds.ptr<float>(0, c);
int64_t *poolIdsData = poolIds.ptr<int64_t>(0, c);
float *unpoolInpData = unpoolInp.ptr<float>(0, c);
float *unpoolOutData = unpoolOut.ptr<float>(0, c);
for (int y = 0; y < 64; ++y) {
@ -195,7 +195,7 @@ TEST_P(Test_TFLite, max_unpooling)
}
EXPECT_EQ(poolInpData[maxIdx], poolOutData[y * 64 + x]) << errMsg;
if (backend != DNN_BACKEND_INFERENCE_ENGINE_NGRAPH) {
EXPECT_EQ(poolIdsData[y * 64 + x], (float)maxIdx) << errMsg;
EXPECT_EQ(poolIdsData[y * 64 + x], (int64_t)maxIdx) << errMsg;
}
EXPECT_EQ(unpoolOutData[maxIdx], unpoolInpData[y * 64 + x]) << errMsg;
}