Merge pull request #21692 from UnaNancyOwen:add_softmax

* add apply softmax option to ClassificationModel

* remove default arguments of ClassificationModel::setSoftMax()

* fix build for python

* fix docs warning for setSoftMax()

* add impl for ClassficationModel()

* fix failed build for docs by trailing whitespace

* move to implement classify() to ClassificationModel_Impl

* move to implement softmax() to ClassificationModel_Impl

* remove softmax from public method in ClassificationModel
This commit is contained in:
Tsukasa Sugiura 2022-03-08 05:26:15 +09:00 committed by GitHub
parent 901e0ddfe4
commit 8db7d435b9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 99 additions and 11 deletions

View File

@ -1310,6 +1310,9 @@ CV__DNN_INLINE_NS_BEGIN
class CV_EXPORTS_W_SIMPLE ClassificationModel : public Model class CV_EXPORTS_W_SIMPLE ClassificationModel : public Model
{ {
public: public:
CV_DEPRECATED_EXTERNAL // avoid using in C++ code, will be moved to "protected" (need to fix bindings first)
ClassificationModel();
/** /**
* @brief Create classification model from network represented in one of the supported formats. * @brief Create classification model from network represented in one of the supported formats.
* An order of @p model and @p config arguments does not matter. * An order of @p model and @p config arguments does not matter.
@ -1324,6 +1327,24 @@ CV__DNN_INLINE_NS_BEGIN
*/ */
CV_WRAP ClassificationModel(const Net& network); CV_WRAP ClassificationModel(const Net& network);
/**
* @brief Set enable/disable softmax post processing option.
*
* If this option is true, softmax is applied after forward inference within the classify() function
* to convert the confidences range to [0.0-1.0].
* This function allows you to toggle this behavior.
* Please turn true when not contain softmax layer in model.
* @param[in] enable Set enable softmax post processing within the classify() function.
*/
CV_WRAP ClassificationModel& setEnableSoftmaxPostProcessing(bool enable);
/**
* @brief Get enable/disable softmax post processing option.
*
* This option defaults to false, softmax post processing is not applied within the classify() function.
*/
CV_WRAP bool getEnableSoftmaxPostProcessing() const;
/** @brief Given the @p input frame, create input blob, run net and return top-1 prediction. /** @brief Given the @p input frame, create input blob, run net and return top-1 prediction.
* @param[in] frame The input image. * @param[in] frame The input image.
*/ */

View File

@ -197,28 +197,95 @@ void Model::predict(InputArray frame, OutputArrayOfArrays outs) const
} }
class ClassificationModel_Impl : public Model::Impl
{
public:
virtual ~ClassificationModel_Impl() {}
ClassificationModel_Impl() : Impl() {}
ClassificationModel_Impl(const ClassificationModel_Impl&) = delete;
ClassificationModel_Impl(ClassificationModel_Impl&&) = delete;
void setEnableSoftmaxPostProcessing(bool enable)
{
applySoftmax = enable;
}
bool getEnableSoftmaxPostProcessing() const
{
return applySoftmax;
}
std::pair<int, float> classify(InputArray frame)
{
std::vector<Mat> outs;
processFrame(frame, outs);
CV_Assert(outs.size() == 1);
Mat out = outs[0].reshape(1, 1);
if(getEnableSoftmaxPostProcessing())
{
softmax(out, out);
}
double conf;
Point maxLoc;
cv::minMaxLoc(out, nullptr, &conf, nullptr, &maxLoc);
return {maxLoc.x, static_cast<float>(conf)};
}
protected:
void softmax(InputArray inblob, OutputArray outblob)
{
const Mat input = inblob.getMat();
outblob.create(inblob.size(), inblob.type());
Mat exp;
const float max = *std::max_element(input.begin<float>(), input.end<float>());
cv::exp((input - max), exp);
outblob.getMat() = exp / cv::sum(exp)[0];
}
protected:
bool applySoftmax = false;
};
ClassificationModel::ClassificationModel()
: Model()
{
// nothing
}
ClassificationModel::ClassificationModel(const String& model, const String& config) ClassificationModel::ClassificationModel(const String& model, const String& config)
: Model(model, config) : ClassificationModel(readNet(model, config))
{ {
// nothing // nothing
} }
ClassificationModel::ClassificationModel(const Net& network) ClassificationModel::ClassificationModel(const Net& network)
: Model(network) : Model()
{ {
// nothing impl = makePtr<ClassificationModel_Impl>();
impl->initNet(network);
}
ClassificationModel& ClassificationModel::setEnableSoftmaxPostProcessing(bool enable)
{
CV_Assert(impl != nullptr && impl.dynamicCast<ClassificationModel_Impl>() != nullptr);
impl.dynamicCast<ClassificationModel_Impl>()->setEnableSoftmaxPostProcessing(enable);
return *this;
}
bool ClassificationModel::getEnableSoftmaxPostProcessing() const
{
CV_Assert(impl != nullptr && impl.dynamicCast<ClassificationModel_Impl>() != nullptr);
return impl.dynamicCast<ClassificationModel_Impl>()->getEnableSoftmaxPostProcessing();
} }
std::pair<int, float> ClassificationModel::classify(InputArray frame) std::pair<int, float> ClassificationModel::classify(InputArray frame)
{ {
std::vector<Mat> outs; CV_Assert(impl != nullptr && impl.dynamicCast<ClassificationModel_Impl>() != nullptr);
impl->processFrame(frame, outs); return impl.dynamicCast<ClassificationModel_Impl>()->classify(frame);
CV_Assert(outs.size() == 1);
double conf;
cv::Point maxLoc;
minMaxLoc(outs[0].reshape(1, 1), nullptr, &conf, nullptr, &maxLoc);
return {maxLoc.x, static_cast<float>(conf)};
} }
void ClassificationModel::classify(InputArray frame, int& classId, float& conf) void ClassificationModel::classify(InputArray frame, int& classId, float& conf)