diff --git a/modules/dnn/src/cuda/eltwise_ops.cu b/modules/dnn/src/cuda/eltwise_ops.cu index 16f6cccf6b..8a861b3067 100644 --- a/modules/dnn/src/cuda/eltwise_ops.cu +++ b/modules/dnn/src/cuda/eltwise_ops.cu @@ -324,7 +324,19 @@ void eltwise_sub_2(const Stream& stream, TensorSpan output, TensorView x, eltwise_op>(stream, output, x, y); } +template +void eltwise_mod_2(const Stream& stream, TensorSpan output, TensorView x, TensorView y) { + eltwise_op>(stream, output, x, y); +} + +template +void eltwise_fmod_2(const Stream& stream, TensorSpan output, TensorView x, TensorView y) { + eltwise_op>(stream, output, x, y); +} + #if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530) + template void eltwise_mod_2(const Stream& stream, TensorSpan<__half> output, TensorView<__half> x, TensorView<__half> y); + template void eltwise_fmod_2(const Stream& stream, TensorSpan<__half> output, TensorView<__half> x, TensorView<__half> y); template void eltwise_sub_2(const Stream& stream, TensorSpan<__half> output, TensorView<__half> x, TensorView<__half> y); template void eltwise_div_2(const Stream& stream, TensorSpan<__half> output, TensorView<__half> x, TensorView<__half> y); template void eltwise_prod_2(const Stream& stream, TensorSpan<__half> output, TensorView<__half> x, TensorView<__half> y); @@ -333,6 +345,8 @@ void eltwise_sub_2(const Stream& stream, TensorSpan output, TensorView x, template void eltwise_max_2(const Stream& stream, TensorSpan<__half> output, TensorView<__half> x, TensorView<__half> y); template void eltwise_min_2(const Stream& stream, TensorSpan<__half> output, TensorView<__half> x, TensorView<__half> y); #endif + template void eltwise_mod_2(const Stream& stream, TensorSpan output, TensorView x, TensorView y); + template void eltwise_fmod_2(const Stream& stream, TensorSpan output, TensorView x, TensorView y); template void eltwise_sub_2(const Stream& stream, TensorSpan output, TensorView x, TensorView y); template void eltwise_div_2(const Stream& stream, TensorSpan output, TensorView x, TensorView y); template void eltwise_prod_2(const Stream& stream, TensorSpan output, TensorView x, TensorView y); diff --git a/modules/dnn/src/cuda/functors.hpp b/modules/dnn/src/cuda/functors.hpp index 2df32030f0..cada43387e 100644 --- a/modules/dnn/src/cuda/functors.hpp +++ b/modules/dnn/src/cuda/functors.hpp @@ -799,6 +799,40 @@ struct ReciprocalFunctor { } }; +template +struct ModFunctor { + struct Params { + CUDA4DNN_HOST_DEVICE Params() {} + }; + + CUDA4DNN_DEVICE ModFunctor() { } + CUDA4DNN_DEVICE ModFunctor(const Params& params) { } + + CUDA4DNN_DEVICE T operator()(T x, T y) { + int res = (int)x % (int)y; + T zero = T(0); + if ((res > (int)zero && y < zero) || (res < (int)zero && y > zero)) { + res += (int)y; + } + return res; + } +}; + +template +struct FModFunctor { + struct Params { + CUDA4DNN_HOST_DEVICE Params() {} + }; + + CUDA4DNN_DEVICE FModFunctor() { } + CUDA4DNN_DEVICE FModFunctor(const Params& params) { } + + CUDA4DNN_DEVICE T operator()(T x, T y) { + using csl::device::fmod; + return fmod(x, y); + } +}; + }}}} /* namespace cv::dnn::cuda4dnn::kernels */ #endif /* OPENCV_DNN_SRC_CUDA_FUNCTORS_HPP */ diff --git a/modules/dnn/src/cuda/math.hpp b/modules/dnn/src/cuda/math.hpp index 0a312a250d..8e4f091f4f 100644 --- a/modules/dnn/src/cuda/math.hpp +++ b/modules/dnn/src/cuda/math.hpp @@ -36,6 +36,13 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace de template <> inline __device__ float min(float x, float y) { return fminf(x, y); } template <> inline __device__ double min(double x, double y) { return fmin(x, y); } + template __device__ T fmod(T x, T y) { return x % y; } + template <> inline __device__ float fmod(float x, float y) { return fmodf(x, y); } + template <> inline __device__ double fmod(double x, double y) { return fmod(x, y); } +#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530) + template <> inline __device__ half fmod(half x, half y) { return fmodf((float)x, (float)y); } +#endif + template __device__ T log1p(T val); #if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530) template <> inline __device__ __half log1p(__half val) { return hlog(__half(1) + val); } diff --git a/modules/dnn/src/cuda4dnn/kernels/eltwise_ops.hpp b/modules/dnn/src/cuda4dnn/kernels/eltwise_ops.hpp index 3dc3355b3b..e80db943ae 100644 --- a/modules/dnn/src/cuda4dnn/kernels/eltwise_ops.hpp +++ b/modules/dnn/src/cuda4dnn/kernels/eltwise_ops.hpp @@ -33,6 +33,12 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels { template void eltwise_sub_2(const csl::Stream& stream, csl::TensorSpan output, csl::TensorView x, csl::TensorView y); + template + void eltwise_mod_2(const csl::Stream& stream, csl::TensorSpan output, csl::TensorView x, csl::TensorView y); + + template + void eltwise_fmod_2(const csl::Stream& stream, csl::TensorSpan output, csl::TensorView x, csl::TensorView y); + }}}} /* namespace cv::dnn::cuda4dnn::kernels */ #endif /* OPENCV_DNN_SRC_CUDA4DNN_KERNELS_ELTWISE_OPS_HPP */ diff --git a/modules/dnn/src/cuda4dnn/primitives/eltwise.hpp b/modules/dnn/src/cuda4dnn/primitives/eltwise.hpp index 05bca83820..5822f48061 100644 --- a/modules/dnn/src/cuda4dnn/primitives/eltwise.hpp +++ b/modules/dnn/src/cuda4dnn/primitives/eltwise.hpp @@ -28,6 +28,8 @@ namespace cv { namespace dnn { namespace cuda4dnn { DIV, MIN, SUB, + MOD, + FMOD, }; class EltwiseOpBase : public CUDABackendNode { @@ -90,6 +92,8 @@ namespace cv { namespace dnn { namespace cuda4dnn { kernels::eltwise_sum_coeff_2(stream, output, coeffs[0], input_x, coeffs[1], input_y); break; case EltwiseOpType::SUB: kernels::eltwise_sub_2(stream, output, input_x, input_y); break; + case EltwiseOpType::MOD: kernels::eltwise_mod_2(stream, output, input_x, input_y); break; + case EltwiseOpType::FMOD: kernels::eltwise_fmod_2(stream, output, input_x, input_y); break; } } else @@ -122,6 +126,8 @@ namespace cv { namespace dnn { namespace cuda4dnn { } break; case EltwiseOpType::SUB: kernels::eltwise_sub_2(stream, output, output, input); break; + case EltwiseOpType::MOD: kernels::eltwise_mod_2(stream, output, output, input); break; + case EltwiseOpType::FMOD: kernels::eltwise_fmod_2(stream, output, output, input); break; } } } diff --git a/modules/dnn/src/layers/nary_eltwise_layers.cpp b/modules/dnn/src/layers/nary_eltwise_layers.cpp index c988ec69f2..661861cbe3 100644 --- a/modules/dnn/src/layers/nary_eltwise_layers.cpp +++ b/modules/dnn/src/layers/nary_eltwise_layers.cpp @@ -24,6 +24,16 @@ namespace cv namespace dnn { +namespace { +static int _mod(int x, int y) { + int res = x % y; + if ((res < 0 && y > 0) || (res > 0 && y < 0)) { + res += y; + } + return res; +} +} + class NaryEltwiseLayerImpl CV_FINAL : public NaryEltwiseLayer { public: @@ -42,7 +52,8 @@ public: MAX, MEAN, MIN, - MOD, + MOD, // Integer Mod. Reminder's sign = Divisor's sign. + FMOD, // Floating-point Mod. Reminder's sign = Dividend's sign. PROD, SUB, SUM, @@ -79,6 +90,8 @@ public: op = OPERATION::MIN; else if (operation == "mod") op = OPERATION::MOD; + else if (operation == "fmod") + op = OPERATION::FMOD; else if (operation == "mul") op = OPERATION::PROD; else if (operation == "sub") @@ -106,18 +119,21 @@ public: #ifdef HAVE_CANN if (backendId == DNN_BACKEND_CANN) return op == OPERATION::ADD || op == OPERATION::PROD || op == OPERATION::SUB || - op == OPERATION::DIV || op == OPERATION::MAX || op == OPERATION::MIN; + op == OPERATION::DIV || op == OPERATION::MAX || op == OPERATION::MIN || + op == OPERATION::MOD || op == OPERATION::FMOD; #endif if (backendId == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH) return (op == OPERATION::ADD || op == OPERATION::PROD || op == OPERATION::GREATER_EQUAL || - op == OPERATION::LESS_EQUAL + op == OPERATION::LESS_EQUAL || + op == OPERATION::MOD || + op == OPERATION::FMOD ); if (backendId == DNN_BACKEND_CUDA) { - return op == OPERATION::MAX || op == OPERATION::MIN || op == OPERATION::SUM || - op == OPERATION::PROD || op == OPERATION::DIV || op == OPERATION::ADD || - op == OPERATION::SUB; + return op == OPERATION::MAX || op == OPERATION::MIN || op == OPERATION::SUM || + op == OPERATION::PROD || op == OPERATION::DIV || op == OPERATION::ADD || + op == OPERATION::SUB || op == OPERATION::MOD || op == OPERATION::FMOD; } return backendId == DNN_BACKEND_OPENCV; } @@ -703,10 +719,16 @@ public: } case OPERATION::MOD: { - auto mod = [](const uint8_t &a, const uint8_t &b) { return a % b; }; + auto mod = [] (const T &a, const T &b) { return static_cast(_mod(int(a), int(b))); }; binary_forward(mod, std::forward(args)...); break; } + case OPERATION::FMOD: + { + auto fmod = [](const T &a, const T &b) { return std::fmod(a, b); }; + binary_forward(fmod, std::forward(args)...); + break; + } case OPERATION::PROD: { auto prod = [](const T &a, const T &b) { return a * b; }; @@ -778,9 +800,8 @@ public: opDispatch(std::forward(args)...); break; case CV_32F: - CV_Assert(op != OPERATION::BITSHIFT && op != OPERATION::MOD && - op != OPERATION::AND && op != OPERATION::OR && - op != OPERATION::XOR); + CV_Assert(op != OPERATION::BITSHIFT && op != OPERATION::AND && + op != OPERATION::OR && op != OPERATION::XOR); opDispatch(std::forward(args)...); break; default: @@ -833,6 +854,12 @@ public: case OPERATION::SUB: op_ = cuda4dnn::EltwiseOpType::SUB; break; + case OPERATION::MOD: + op_ = cuda4dnn::EltwiseOpType::MOD; + break; + case OPERATION::FMOD: + op_ = cuda4dnn::EltwiseOpType::FMOD; + break; default: return Ptr(); // return empty cuda_node if the EltwiseOpType is unsupported type. }; @@ -877,6 +904,8 @@ public: BUILD_CANN_ELTWISE_OP(OPERATION::DIV, Xdivy, name); BUILD_CANN_ELTWISE_OP(OPERATION::MAX, Maximum, name); BUILD_CANN_ELTWISE_OP(OPERATION::MIN, Minimum, name); + BUILD_CANN_ELTWISE_OP(OPERATION::MOD, Mod, name); + BUILD_CANN_ELTWISE_OP(OPERATION::FMOD, Mod, name); #undef BUILD_CANN_ELTWISE_OP default: CV_Error(Error::StsNotImplemented, "Unsupported eltwise operation"); } @@ -923,6 +952,16 @@ public: node = std::make_shared(inp0, inp1); else if (op == OPERATION::LESS_EQUAL) node = std::make_shared(inp0, inp1); + // Ideally we should do this but int32 internal blobs are converted to float32 data type in inference. + // TODO: Remove data type convertion when we have type inference. + else if (op == OPERATION::MOD) { + auto inp0_i64 = std::make_shared(inp0, ngraph::element::i64); + auto inp1_i64 = std::make_shared(inp1, ngraph::element::i64); + auto mod = std::make_shared(inp0_i64, inp1_i64); + node = std::make_shared(mod, ngraph::element::f32); + } + else if (op == OPERATION::FMOD) + node = std::make_shared(inp0, inp1); else CV_Error(Error::StsNotImplemented, "Operation is not implemented for nGraph backend"); return Ptr(new InfEngineNgraphNode(node)); diff --git a/modules/dnn/src/onnx/onnx_importer.cpp b/modules/dnn/src/onnx/onnx_importer.cpp index 115738999a..f0b33d111b 100644 --- a/modules/dnn/src/onnx/onnx_importer.cpp +++ b/modules/dnn/src/onnx/onnx_importer.cpp @@ -2830,6 +2830,11 @@ void ONNXImporter::parseElementWise(LayerParams& layerParams, const opencv_onnx: layerParams.type = "NaryEltwise"; layerParams.set("operation", toLowerCase(node_proto.op_type())); + if (node_proto.op_type() == "Mod") { + if (layerParams.get("fmod", 0)) { + layerParams.set("operation", "fmod"); + }; + } // element-wise layers that can have >=1 inputs but actually have one input if (node_proto.input_size() == 1 && (op_type == "max" || op_type == "min" || op_type == "mean" || op_type == "sum")) @@ -4006,7 +4011,7 @@ void ONNXImporter::buildDispatchMap_ONNX_AI(int opset_version) dispatch["Equal"] = dispatch["Greater"] = dispatch["Less"] = dispatch["Pow"] = dispatch["Add"] = dispatch["Sub"] = dispatch["Mul"] = dispatch["Div"] = dispatch["GreaterOrEqual"] = - dispatch["LessOrEqual"] = &ONNXImporter::parseElementWise; + dispatch["LessOrEqual"] = dispatch["Mod"] = &ONNXImporter::parseElementWise; dispatch["Sum"] = dispatch["Min"] = dispatch["Max"] = &ONNXImporter::parseElementWise; dispatch["Where"] = &ONNXImporter::parseElementWise; diff --git a/modules/dnn/test/test_onnx_conformance_layer_filter__openvino.inl.hpp b/modules/dnn/test/test_onnx_conformance_layer_filter__openvino.inl.hpp index 17d561d64b..199bfdcd18 100644 --- a/modules/dnn/test/test_onnx_conformance_layer_filter__openvino.inl.hpp +++ b/modules/dnn/test/test_onnx_conformance_layer_filter__openvino.inl.hpp @@ -1056,10 +1056,25 @@ CASE(test_mod_int64_fmod) // no filter CASE(test_mod_mixed_sign_float16) // no filter + if (target == DNN_TARGET_OPENCL) + { + default_l1 = 0.0011; // Expected: (normL1) <= (l1), actual: 0.00104141 vs 1e-05 + default_lInf = 0.0016; // Expected: (normInf) <= (lInf), actual: 0.00156212 vs 0.0001 + } CASE(test_mod_mixed_sign_float32) // no filter + if (target == DNN_TARGET_OPENCL) + { + default_l1 = 0.0011; // Expected: (normL1) <= (l1), actual: 0.00104141 vs 1e-05 + default_lInf = 0.0016; // Expected: (normInf) <= (lInf), actual: 0.00156212 vs 0.0001 + } CASE(test_mod_mixed_sign_float64) // no filter + if (target == DNN_TARGET_OPENCL) + { + default_l1 = 0.0011; // Expected: (normL1) <= (l1), actual: 0.00104167 vs 1e-05 + default_lInf = 0.0016; // Expected: (normInf) <= (lInf), actual: 0.00156251 vs 0.0001 + } CASE(test_mod_mixed_sign_int16) // no filter CASE(test_mod_mixed_sign_int32) diff --git a/modules/dnn/test/test_onnx_conformance_layer_parser_denylist.inl.hpp b/modules/dnn/test/test_onnx_conformance_layer_parser_denylist.inl.hpp index be60c38b86..68f49e5fa4 100644 --- a/modules/dnn/test/test_onnx_conformance_layer_parser_denylist.inl.hpp +++ b/modules/dnn/test/test_onnx_conformance_layer_parser_denylist.inl.hpp @@ -210,9 +210,6 @@ "test_min_uint8", "test_mod_broadcast", "test_mod_int64_fmod", -"test_mod_mixed_sign_float16", -"test_mod_mixed_sign_float32", -"test_mod_mixed_sign_float64", "test_mod_mixed_sign_int16", "test_mod_mixed_sign_int32", "test_mod_mixed_sign_int64",