diff --git a/modules/dnn/src/int8layers/softmax_layer.cpp b/modules/dnn/src/int8layers/softmax_layer.cpp index b2caf56fb0..5096e541e6 100644 --- a/modules/dnn/src/int8layers/softmax_layer.cpp +++ b/modules/dnn/src/int8layers/softmax_layer.cpp @@ -22,15 +22,29 @@ public: SoftMaxLayerInt8Impl(const LayerParams& params) { - axisRaw = params.get("axis", 1); + setParamsFrom(params); + + axis = params.get("axis", 1); logSoftMax = params.get("log_softmax", false); + coerced_2d = params.get("coerced_2d", false); input_sc = params.get("input_scale"); input_zp = params.get("input_zeropoint"); output_sc = params.get("scales"); output_zp = params.get("zeropoints"); - setParamsFrom(params); + + if (blobs.empty()) // if no lookUpTable is found + { + Mat lookUpTable(1, 256, CV_32F); + float* table = lookUpTable.ptr(); + for (int i = -128; i < 128; i++) + { + float x = input_sc * (i - 127); // ensures exp(x) is always between (0, 1) + table[i + 128] = std::exp(x); + } + blobs.push_back(lookUpTable); + } } bool getMemoryShapes(const std::vector &inputs, @@ -40,12 +54,39 @@ public: { bool inplace = Layer::getMemoryShapes(inputs, requiredOutputs, outputs, internals); MatShape shape = inputs[0]; - int cAxis = normalize_axis(axisRaw, shape.size()); - shape[cAxis] = 1; internals.assign(1, shape); return inplace; } + virtual void finalize(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr) CV_OVERRIDE { + std::vector inputs; + inputs_arr.getMatVector(inputs); + auto src = inputs[0]; + auto dims_src = src.dims; + auto shape_src = shape(src); + axis = normalize_axis(axis, dims_src); + + if (!coerced_2d) { + is_transpose_needed = (axis == dims_src - 1) ? false : true; + if (is_transpose_needed) { + permutation.resize(dims_src); + std::iota(permutation.begin(), permutation.end(), 0); + permutation[axis] = dims_src - 1; + permutation[dims_src - 1] = axis; + + transposed_shape.resize(dims_src); + std::transform(permutation.begin(), permutation.end(), transposed_shape.begin(), [&shape_src](int axis) { return shape_src[axis]; }); + N = std::accumulate(transposed_shape.begin(), transposed_shape.end() - 1, 1, std::multiplies()); + D = transposed_shape.back(); + + return; + } + } + + N = src.total(0, axis); + D = src.total(axis); + } + virtual bool supportBackend(int backendId) CV_OVERRIDE { return backendId == DNN_BACKEND_OPENCV || @@ -80,8 +121,7 @@ public: const Mat &src = inputWrapper->getMat(); // convert axis from OpenCV NCHW toTimVX WHCN. - int axis = normalize_axis(axisRaw, src.dims); - int tvAxis = src.dims - 1 - axis; + int tvAxis = src.dims - 1 - normalize_axis(axis, src.dims); if(tvAxis < 0) tvAxis = 0; // default value is 0. @@ -154,103 +194,188 @@ public: return Ptr(); } + template + class SoftmaxInt8Invoker : public ParallelLoopBody { + public: + const Mat& src_; + Mat& dst_; + + const Mat& lookup_table_; + + int N_; + int D_; + + float y_scale_; + int y_zero_point_; + + int threads; + int cost_per_thread; + + SoftmaxInt8Invoker(const Mat& src, Mat& dst, const Mat& lookup_table, int N, int D, float y_scale, int y_zero_point) + : src_(src), dst_(dst), lookup_table_(lookup_table), N_(N), D_(D), y_scale_(1.f / y_scale), y_zero_point_(y_zero_point) { + threads = N_; + cost_per_thread = D_; + } + + static void run(const Mat& src, Mat& dst, const Mat& lookup_table, int N, int D, float y_scale, int y_zero_point) { + CV_Assert(src.isContinuous()); + CV_Assert(dst.isContinuous()); + CV_CheckTypeEQ(src.type(), CV_8S, "DNN/SoftmaxInt8: type of input must be int8"); + CV_CheckTypeEQ(dst.type(), CV_8S, "DNN/SoftmaxInt8: type of output must be int8"); + + SoftmaxInt8Invoker p(src, dst, lookup_table, N, D, y_scale, y_zero_point); + + double nstripes = ((size_t)p.threads * p.cost_per_thread) * (1 / 1024.0); + parallel_for_(Range(0, p.threads), p, nstripes); + } + + void operator()(const Range& r) const CV_OVERRIDE { + int start = r.start; + int end = r.end; + + const int8_t* p_src = src_.ptr(); + int8_t* p_dst = dst_.ptr(); + const float* table = lookup_table_.ptr(); + + for (int i = start; i < end; ++i) { + const int8_t* x = p_src + i * D_; + int8_t* y = p_dst + i * D_; + + float vsum = 0; + for (int j = 0; j < D_; ++j) { + const uint8_t idx = uint8_t((*x++) + 128); + vsum += table[idx]; + } + + // FIXME: avoid divide by vsum==0 + + x = p_src + i * D_; + if (with_log) { + for (int j = 0; j < D_; ++j) { + const uint8_t idx = uint8_t((*x++) + 128); + const float v = table[idx]; + *y++ = saturate_cast(std::nearbyintf(y_scale_ * std::log(v / vsum)) + y_zero_point_); + } + } else { + for (int j = 0; j < D_; ++j) { + const uint8_t idx = uint8_t((*x++) + 128); + const float v = table[idx]; + *y++ = saturate_cast(std::nearbyintf(y_scale_ * v / vsum) + y_zero_point_); + } + } + } + } + }; + + template + class SoftmaxInt8OutputFloatInvoker : public ParallelLoopBody { + public: + const Mat& src_; + Mat& dst_; + + const Mat& lookup_table_; + + int N_; + int D_; + + int threads; + int cost_per_thread; + + SoftmaxInt8OutputFloatInvoker(const Mat& src, Mat& dst, const Mat& lookup_table, int N, int D) + : src_(src), dst_(dst), lookup_table_(lookup_table), N_(N), D_(D) { + threads = N_; + cost_per_thread = D_; + } + + static void run(const Mat& src, Mat& dst, const Mat& lookup_table, int N, int D) { + CV_Assert(src.isContinuous()); + CV_Assert(dst.isContinuous()); + CV_CheckTypeEQ(src.type(), CV_8S, "DNN/SoftmaxInt8: type of input must be int8"); + CV_CheckTypeEQ(dst.type(), CV_32F, "DNN/SoftmaxInt8: type of input must be float32 since Dequantization is fused"); + + SoftmaxInt8OutputFloatInvoker p(src, dst, lookup_table, N, D); + + double nstripes = ((size_t)p.threads * p.cost_per_thread) * (1 / 1024.0); + parallel_for_(Range(0, p.threads), p, nstripes); + } + + void operator()(const Range& r) const CV_OVERRIDE { + int start = r.start; + int end = r.end; + + const int8_t* p_src = src_.ptr(); + float* p_dst = dst_.ptr(); + const float* table = lookup_table_.ptr(); + + for (int i = start; i < end; ++i) { + const int8_t* x = p_src + i * D_; + float* y = p_dst + i * D_; + + float vsum = 0; + for (int j = 0; j < D_; ++j) { + const uint8_t idx = uint8_t((*x++) + 128); + vsum += table[idx]; + } + + // FIXME: avoid divide by vsum==0 + + x = p_src + i * D_; + if (with_log) { + for (int j = 0; j < D_; ++j) { + const uint8_t idx = uint8_t((*x++) + 128); + const float v = table[idx]; + *y++ = std::log(v / vsum); + } + } else { + for (int j = 0; j < D_; ++j) { + const uint8_t idx = uint8_t((*x++) + 128); + const float v = table[idx]; + *y++ = v / vsum; + } + } + } + } + }; + void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE { CV_TRACE_FUNCTION(); CV_TRACE_ARG_VALUE(name, "name", name.c_str()); - std::vector inputs, outputs, internals; + std::vector inputs, outputs; inputs_arr.getMatVector(inputs); outputs_arr.getMatVector(outputs); - internals_arr.getMatVector(internals); - const Mat &src = inputs[0]; - Mat &dst = outputs[0]; + Mat src, dst; - int axis = normalize_axis(axisRaw, src.dims); - size_t outerSize = src.total(0, axis), channels = src.size[axis], - innerSize = src.total(axis + 1); - - CV_Assert(src.type() == CV_8S && (dst.type() == CV_8S || dst.type() == CV_32F)); - CV_Assert(src.isContinuous() && dst.isContinuous()); - - size_t outerStep = src.total(axis); - size_t cnStep = src.total(axis + 1); - const int8_t *srcPtr = src.ptr(); - const float *expPtr = blobs[0].ptr(); - - if (dst.type() == CV_32F) - { - float *dstPtr = dst.ptr(); - for (size_t outerDim = 0; outerDim < outerSize; outerDim++) - { - size_t srcOffset = outerDim * outerStep; - std::vector expSum(innerSize, 0.f); - - // sum exp along axis - for (size_t cnDim = 0; cnDim < channels; cnDim++) - { - const int offset = srcOffset + cnDim * cnStep; - for (size_t i = 0; i < innerSize; i++) - expSum[i] += expPtr[srcPtr[offset + i] + 128]; - } - - // divide by computed sum - for (size_t cnDim = 0; cnDim < channels; cnDim++) - { - const int offset = srcOffset + cnDim * cnStep; - for (size_t i = 0; i < innerSize; i++) - dstPtr[offset + i] = expPtr[srcPtr[offset + i] + 128]/expSum[i]; - } - - if (logSoftMax) - { - for (size_t cnDim = 0; cnDim < channels; cnDim++) - { - const int offset = srcOffset + cnDim * cnStep; - for (size_t i = 0; i < innerSize; i++) - dstPtr[offset + i] = log(dstPtr[offset + i]); - } - } - } + if (!coerced_2d && is_transpose_needed) { + transposeND(inputs[0], permutation, src); + dst = Mat::zeros(transposed_shape.size(), transposed_shape.data(), outputs[0].type()); + } else { + src = inputs[0]; + dst = outputs[0]; } - else - { - const float inv_scale = 1.f/output_sc; - int8_t *dstPtr = dst.ptr(); - for (size_t outerDim = 0; outerDim < outerSize; outerDim++) - { - size_t srcOffset = outerDim * outerStep; - std::vector expSum(innerSize, 0.f); - // sum exp along axis - for (size_t cnDim = 0; cnDim < channels; cnDim++) - { - const int offset = srcOffset + cnDim * cnStep; - for (size_t i = 0; i < innerSize; i++) - expSum[i] += expPtr[srcPtr[offset + i] + 128]; + switch (dst.type()) { + case CV_8S: { + if (logSoftMax) { + SoftmaxInt8Invoker::run(src, dst, blobs[0], N, D, output_sc, output_zp); + } else { + SoftmaxInt8Invoker::run(src, dst, blobs[0], N, D, output_sc, output_zp); } + } break; + case CV_32F: { + if (logSoftMax) { + SoftmaxInt8OutputFloatInvoker::run(src, dst, blobs[0], N, D); + } else { + SoftmaxInt8OutputFloatInvoker::run(src, dst, blobs[0], N, D); + } + } break; + default: CV_Error(cv::Error::BadDepth, "DNN/SoftmaxInt8: Unsupported output type"); + } - // divide by computed sum and quantize to int8 - if (logSoftMax) - { - for (size_t cnDim = 0; cnDim < channels; cnDim++) - { - const int offset = srcOffset + cnDim * cnStep; - for (size_t i = 0; i < innerSize; i++) - dstPtr[offset + i] = saturate_cast(output_zp + std::round(inv_scale*log(expPtr[srcPtr[offset + i] + 128]/expSum[i]))); - } - } - else - { - for (size_t cnDim = 0; cnDim < channels; cnDim++) - { - const int offset = srcOffset + cnDim * cnStep; - for (size_t i = 0; i < innerSize; i++) - dstPtr[offset + i] = saturate_cast(output_zp + std::round(inv_scale*(expPtr[srcPtr[offset + i] + 128]/expSum[i]))); - } - } - } + if (!coerced_2d && is_transpose_needed) { + transposeND(dst, permutation, outputs[0]); } } @@ -268,7 +393,14 @@ public: return flops; } - int axisRaw; +private: + int axis; + int N; + int D; + bool coerced_2d; + bool is_transpose_needed; + std::vector permutation; + std::vector transposed_shape; }; Ptr SoftmaxLayerInt8::create(const LayerParams& params) diff --git a/modules/dnn/src/onnx/onnx_importer.cpp b/modules/dnn/src/onnx/onnx_importer.cpp index 196928b3cd..508ba0299e 100644 --- a/modules/dnn/src/onnx/onnx_importer.cpp +++ b/modules/dnn/src/onnx/onnx_importer.cpp @@ -206,6 +206,7 @@ private: void parseQAvgPool (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto); void parseQConcat (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto); void parseQGemm (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto); + void parseQSoftmax (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto); // '???' domain or '???' layer type void parseCustomLayer (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto); @@ -758,6 +759,7 @@ static bool ifInt8Output(const String& layerType) "QLinearSigmoid", "QLinearConcat", "QGemm", + "QLinearSoftmax", "QLinearConv", "QLinearMatMul", "MaxPool", @@ -3929,6 +3931,29 @@ void ONNXImporter::parseQConcat(LayerParams& layerParams, const opencv_onnx::Nod addLayer(layerParams, node_proto); } +void ONNXImporter::parseQSoftmax(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto) +{ + CV_CheckEQ(node_proto.input_size(), 5, "DNN/ONNX: QLinearSoftmax requires 5 inputs, X, X_scale, X_zero_point, Y_scale, Y_zero_point"); + + int opset = layerParams.get("opset"); + if (opset < 13) { + layerParams.set("coerced_2d", true); + } + + float x_scale = getScalarFromMat(getBlob(node_proto, 1)); + int8_t x_zero_point = getScalarFromMat(getBlob(node_proto, 2)); + float y_scale = getScalarFromMat(getBlob(node_proto, 3)); + int8_t y_zero_point = getScalarFromMat(getBlob(node_proto, 4)); + + layerParams.type = "SoftmaxInt8"; + // layerParams also has "axis" and "opset" attrs + layerParams.set("input_scale", x_scale); + layerParams.set("input_zeropoint", x_zero_point); + layerParams.set("scales", y_scale); + layerParams.set("zeropoints", y_zero_point); + addLayer(layerParams, node_proto); +} + // Domain: ai.onnx (default) // URL: https://github.com/onnx/onnx/blob/master/docs/Operators.md void ONNXImporter::buildDispatchMap_ONNX_AI(int opset_version) @@ -4026,6 +4051,7 @@ void ONNXImporter::buildDispatchMap_COM_MICROSOFT(int opset_version) dispatch["QLinearSigmoid"] = &ONNXImporter::parseQSigmoid; dispatch["QLinearConcat"] = &ONNXImporter::parseQConcat; dispatch["QGemm"] = &ONNXImporter::parseQGemm; + dispatch["QLinearSoftmax"] = &ONNXImporter::parseQSoftmax; domain_dispatch_map["com.microsoft"] = dispatch; } diff --git a/modules/dnn/test/test_onnx_importer.cpp b/modules/dnn/test/test_onnx_importer.cpp index ec98b87dd2..ee97ecb2ee 100644 --- a/modules/dnn/test/test_onnx_importer.cpp +++ b/modules/dnn/test/test_onnx_importer.cpp @@ -1999,6 +1999,12 @@ TEST_P(Test_ONNX_layers, OutputRegistration) testONNXModels("output_registration", npy, 0, 0, false, true, 2); } +TEST_P(Test_ONNX_layers, QLinearSoftmax) +{ + testONNXModels("qlinearsoftmax_v11", npy, 0.002, 0.002); // 2D coerced + testONNXModels("qlinearsoftmax_v13", npy, 0.002, 0.002); +} + INSTANTIATE_TEST_CASE_P(/*nothing*/, Test_ONNX_layers, dnnBackendsAndTargets()); class Test_ONNX_nets : public Test_ONNX_layers