diff --git a/modules/dnn/include/opencv2/dnn/dnn.hpp b/modules/dnn/include/opencv2/dnn/dnn.hpp index 97033a313e..67042a14b7 100644 --- a/modules/dnn/include/opencv2/dnn/dnn.hpp +++ b/modules/dnn/include/opencv2/dnn/dnn.hpp @@ -1310,6 +1310,9 @@ CV__DNN_INLINE_NS_BEGIN class CV_EXPORTS_W_SIMPLE ClassificationModel : public Model { public: + CV_DEPRECATED_EXTERNAL // avoid using in C++ code, will be moved to "protected" (need to fix bindings first) + ClassificationModel(); + /** * @brief Create classification model from network represented in one of the supported formats. * An order of @p model and @p config arguments does not matter. @@ -1324,6 +1327,24 @@ CV__DNN_INLINE_NS_BEGIN */ CV_WRAP ClassificationModel(const Net& network); + /** + * @brief Set enable/disable softmax post processing option. + * + * If this option is true, softmax is applied after forward inference within the classify() function + * to convert the confidences range to [0.0-1.0]. + * This function allows you to toggle this behavior. + * Please turn true when not contain softmax layer in model. + * @param[in] enable Set enable softmax post processing within the classify() function. + */ + CV_WRAP ClassificationModel& setEnableSoftmaxPostProcessing(bool enable); + + /** + * @brief Get enable/disable softmax post processing option. + * + * This option defaults to false, softmax post processing is not applied within the classify() function. + */ + CV_WRAP bool getEnableSoftmaxPostProcessing() const; + /** @brief Given the @p input frame, create input blob, run net and return top-1 prediction. * @param[in] frame The input image. */ diff --git a/modules/dnn/src/model.cpp b/modules/dnn/src/model.cpp index bc8709d22e..22d5681d5b 100644 --- a/modules/dnn/src/model.cpp +++ b/modules/dnn/src/model.cpp @@ -197,28 +197,95 @@ void Model::predict(InputArray frame, OutputArrayOfArrays outs) const } +class ClassificationModel_Impl : public Model::Impl +{ +public: + virtual ~ClassificationModel_Impl() {} + ClassificationModel_Impl() : Impl() {} + ClassificationModel_Impl(const ClassificationModel_Impl&) = delete; + ClassificationModel_Impl(ClassificationModel_Impl&&) = delete; + + void setEnableSoftmaxPostProcessing(bool enable) + { + applySoftmax = enable; + } + + bool getEnableSoftmaxPostProcessing() const + { + return applySoftmax; + } + + std::pair classify(InputArray frame) + { + std::vector outs; + processFrame(frame, outs); + CV_Assert(outs.size() == 1); + + Mat out = outs[0].reshape(1, 1); + + if(getEnableSoftmaxPostProcessing()) + { + softmax(out, out); + } + + double conf; + Point maxLoc; + cv::minMaxLoc(out, nullptr, &conf, nullptr, &maxLoc); + return {maxLoc.x, static_cast(conf)}; + } + +protected: + void softmax(InputArray inblob, OutputArray outblob) + { + const Mat input = inblob.getMat(); + outblob.create(inblob.size(), inblob.type()); + + Mat exp; + const float max = *std::max_element(input.begin(), input.end()); + cv::exp((input - max), exp); + outblob.getMat() = exp / cv::sum(exp)[0]; + } + +protected: + bool applySoftmax = false; +}; + +ClassificationModel::ClassificationModel() + : Model() +{ + // nothing +} + ClassificationModel::ClassificationModel(const String& model, const String& config) - : Model(model, config) + : ClassificationModel(readNet(model, config)) { // nothing } ClassificationModel::ClassificationModel(const Net& network) - : Model(network) + : Model() { - // nothing + impl = makePtr(); + impl->initNet(network); +} + +ClassificationModel& ClassificationModel::setEnableSoftmaxPostProcessing(bool enable) +{ + CV_Assert(impl != nullptr && impl.dynamicCast() != nullptr); + impl.dynamicCast()->setEnableSoftmaxPostProcessing(enable); + return *this; +} + +bool ClassificationModel::getEnableSoftmaxPostProcessing() const +{ + CV_Assert(impl != nullptr && impl.dynamicCast() != nullptr); + return impl.dynamicCast()->getEnableSoftmaxPostProcessing(); } std::pair ClassificationModel::classify(InputArray frame) { - std::vector outs; - impl->processFrame(frame, outs); - CV_Assert(outs.size() == 1); - - double conf; - cv::Point maxLoc; - minMaxLoc(outs[0].reshape(1, 1), nullptr, &conf, nullptr, &maxLoc); - return {maxLoc.x, static_cast(conf)}; + CV_Assert(impl != nullptr && impl.dynamicCast() != nullptr); + return impl.dynamicCast()->classify(frame); } void ClassificationModel::classify(InputArray frame, int& classId, float& conf)