mirror of
https://github.com/opencv/opencv.git
synced 2025-06-07 09:25:45 +08:00
Merge pull request #15082 from dvd42:segmentation-module
Segmentation module (#15082)
This commit is contained in:
parent
2ad0487cec
commit
f7f2438478
@ -1109,6 +1109,36 @@ CV__DNN_INLINE_NS_BEGIN
|
||||
CV_WRAP void classify(InputArray frame, CV_OUT int& classId, CV_OUT float& conf);
|
||||
};
|
||||
|
||||
/** @brief This class represents high-level API for segmentation models
|
||||
*
|
||||
* SegmentationModel allows to set params for preprocessing input image.
|
||||
* SegmentationModel creates net from file with trained weights and config,
|
||||
* sets preprocessing input, runs forward pass and returns the class prediction for each pixel.
|
||||
*/
|
||||
class CV_EXPORTS_W SegmentationModel: public Model
|
||||
{
|
||||
public:
|
||||
/**
|
||||
* @brief Create segmentation model from network represented in one of the supported formats.
|
||||
* An order of @p model and @p config arguments does not matter.
|
||||
* @param[in] model Binary file contains trained weights.
|
||||
* @param[in] config Text file contains network configuration.
|
||||
*/
|
||||
CV_WRAP SegmentationModel(const String& model, const String& config = "");
|
||||
|
||||
/**
|
||||
* @brief Create model from deep learning network.
|
||||
* @param[in] network Net object.
|
||||
*/
|
||||
CV_WRAP SegmentationModel(const Net& network);
|
||||
|
||||
/** @brief Given the @p input frame, create input blob, run net
|
||||
* @param[in] frame The input image.
|
||||
* @param[out] mask Allocated class prediction for each pixel
|
||||
*/
|
||||
CV_WRAP void segment(InputArray frame, OutputArray mask);
|
||||
};
|
||||
|
||||
/** @brief This class represents high-level API for object detection networks.
|
||||
*
|
||||
* DetectionModel allows to set params for preprocessing input image.
|
||||
|
@ -137,6 +137,47 @@ void ClassificationModel::classify(InputArray frame, int& classId, float& conf)
|
||||
std::tie(classId, conf) = classify(frame);
|
||||
}
|
||||
|
||||
SegmentationModel::SegmentationModel(const String& model, const String& config)
|
||||
: Model(model, config) {};
|
||||
|
||||
SegmentationModel::SegmentationModel(const Net& network) : Model(network) {};
|
||||
|
||||
void SegmentationModel::segment(InputArray frame, OutputArray mask)
|
||||
{
|
||||
|
||||
std::vector<Mat> outs;
|
||||
impl->predict(*this, frame.getMat(), outs);
|
||||
CV_Assert(outs.size() == 1);
|
||||
Mat score = outs[0];
|
||||
|
||||
const int chns = score.size[1];
|
||||
const int rows = score.size[2];
|
||||
const int cols = score.size[3];
|
||||
|
||||
mask.create(rows, cols, CV_8U);
|
||||
Mat classIds = mask.getMat();
|
||||
classIds.setTo(0);
|
||||
Mat maxVal(rows, cols, CV_32F, score.data);
|
||||
|
||||
for (int ch = 1; ch < chns; ch++)
|
||||
{
|
||||
for (int row = 0; row < rows; row++)
|
||||
{
|
||||
const float *ptrScore = score.ptr<float>(0, ch, row);
|
||||
uint8_t *ptrMaxCl = classIds.ptr<uint8_t>(row);
|
||||
float *ptrMaxVal = maxVal.ptr<float>(row);
|
||||
for (int col = 0; col < cols; col++)
|
||||
{
|
||||
if (ptrScore[col] > ptrMaxVal[col])
|
||||
{
|
||||
ptrMaxVal[col] = ptrScore[col];
|
||||
ptrMaxCl[col] = ch;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
DetectionModel::DetectionModel(const String& model, const String& config)
|
||||
: Model(model, config) {};
|
||||
|
||||
|
@ -69,6 +69,25 @@ public:
|
||||
EXPECT_EQ(prediction.first, ref.first);
|
||||
ASSERT_NEAR(prediction.second, ref.second, norm);
|
||||
}
|
||||
|
||||
void testSegmentationModel(const std::string& weights_file, const std::string& config_file,
|
||||
const std::string& inImgPath, const std::string& outImgPath,
|
||||
float norm, const Size& size = {-1, -1}, Scalar mean = Scalar(),
|
||||
double scale = 1.0, bool swapRB = false, bool crop = false)
|
||||
{
|
||||
checkBackend();
|
||||
|
||||
Mat frame = imread(inImgPath);
|
||||
Mat mask;
|
||||
Mat exp = imread(outImgPath, 0);
|
||||
|
||||
SegmentationModel model(weights_file, config_file);
|
||||
model.setInputSize(size).setInputMean(mean).setInputScale(scale)
|
||||
.setInputSwapRB(swapRB).setInputCrop(crop);
|
||||
|
||||
model.segment(frame, mask);
|
||||
normAssert(mask, exp, "", norm, norm);
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(Test_Model, Classify)
|
||||
@ -202,6 +221,22 @@ TEST_P(Test_Model, DetectionMobilenetSSD)
|
||||
scoreDiff, iouDiff, confThreshold, nmsThreshold, size, mean, scale);
|
||||
}
|
||||
|
||||
TEST_P(Test_Model, Segmentation)
|
||||
{
|
||||
std::string inp = _tf("dog416.png");
|
||||
std::string weights_file = _tf("fcn8s-heavy-pascal.prototxt");
|
||||
std::string config_file = _tf("fcn8s-heavy-pascal.caffemodel");
|
||||
std::string exp = _tf("segmentation_exp.png");
|
||||
|
||||
Size size{128, 128};
|
||||
float norm = 0;
|
||||
double scale = 1.0;
|
||||
Scalar mean = Scalar();
|
||||
bool swapRB = false;
|
||||
|
||||
testSegmentationModel(weights_file, config_file, inp, exp, norm, size, mean, scale, swapRB);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(/**/, Test_Model, dnnBackendsAndTargets());
|
||||
|
||||
}} // namespace
|
||||
|
Loading…
Reference in New Issue
Block a user