Merge pull request #25613 from CNOCycle:tflite/ops

Support Global_Pool_2D ops in .tflite model #25613

### Pull Request Readiness Checklist

**Merge with extra**: https://github.com/opencv/opencv_extra/pull/1180

This PR adds support for `GlobalAveragePooling2D` and `GlobalMaxPool2D` on the TFlite backend. When the k`eep_dims` option is enabled, the output is a 2D tensor, necessitating the inclusion of an additional flatten layer. Additionally, the names of these layers have been updated to match the output tensor names generated by `generate.py` from the opencv_extra repository.

- [X] I agree to contribute to the project under Apache 2 License.
- [X] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV
- [X] The PR is proposed to the proper branch
- [ ] There is a reference to the original bug report and related work
- [X] There is accuracy test, performance test and test data in opencv_extra repository, if applicable
      Patch to opencv_extra has the same branch name.
- [X] The feature is well documented and sample code can be built with the project CMake
This commit is contained in:
CNOCycle 2024-06-01 00:31:21 +08:00 committed by GitHub
parent 29f91a08d5
commit 98b8825031
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 53 additions and 0 deletions

View File

@ -71,6 +71,7 @@ private:
void parseSoftmax(const Operator& op, const std::string& opcode, LayerParams& layerParams); void parseSoftmax(const Operator& op, const std::string& opcode, LayerParams& layerParams);
void parseCast(const Operator& op, const std::string& opcode, LayerParams& layerParams); void parseCast(const Operator& op, const std::string& opcode, LayerParams& layerParams);
void parseTranspose(const Operator& op, const std::string& opcode, LayerParams& layerParams); void parseTranspose(const Operator& op, const std::string& opcode, LayerParams& layerParams);
void parseGlobalPooling(const Operator& op, const std::string& opcode, LayerParams& layerParams);
void parseFusedActivation(const Operator& op, ActivationFunctionType activ); void parseFusedActivation(const Operator& op, ActivationFunctionType activ);
void parseActivation(const Operator& op, const std::string& opcode, LayerParams& layerParams, bool isFused); void parseActivation(const Operator& op, const std::string& opcode, LayerParams& layerParams, bool isFused);
@ -78,6 +79,8 @@ private:
int addPermuteLayer(const std::vector<int>& order, const std::string& permName, const std::pair<int, int>& inpId, int dtype); int addPermuteLayer(const std::vector<int>& order, const std::string& permName, const std::pair<int, int>& inpId, int dtype);
int addReshapeLayer(const std::vector<int>& shape, int axis, int num_axes, int addReshapeLayer(const std::vector<int>& shape, int axis, int num_axes,
const std::string& name, const std::pair<int, int>& inpId, int dtype); const std::string& name, const std::pair<int, int>& inpId, int dtype);
int addFlattenLayer(int axis, int end_axis, const std::string& name, const std::pair<int, int>& inpId, int dtype);
inline bool isInt8(const Operator& op); inline bool isInt8(const Operator& op);
inline void getQuantParams(const Operator& op, float& inpScale, int& inpZero, float& outScale, int& outZero); inline void getQuantParams(const Operator& op, float& inpScale, int& inpZero, float& outScale, int& outZero);
}; };
@ -286,6 +289,7 @@ TFLiteImporter::DispatchMap TFLiteImporter::buildDispatchMap()
dispatch["CAST"] = &TFLiteImporter::parseCast; dispatch["CAST"] = &TFLiteImporter::parseCast;
dispatch["TFLite_Detection_PostProcess"] = &TFLiteImporter::parseDetectionPostProcess; dispatch["TFLite_Detection_PostProcess"] = &TFLiteImporter::parseDetectionPostProcess;
dispatch["TRANSPOSE"] = &TFLiteImporter::parseTranspose; dispatch["TRANSPOSE"] = &TFLiteImporter::parseTranspose;
dispatch["MEAN"] = dispatch["REDUCE_MAX"] = &TFLiteImporter::parseGlobalPooling;
return dispatch; return dispatch;
} }
@ -764,6 +768,37 @@ void TFLiteImporter::parseTranspose(const Operator& op, const std::string& opcod
addLayer(layerParams, op); addLayer(layerParams, op);
} }
void TFLiteImporter::parseGlobalPooling(const Operator& op, const std::string& opcode, LayerParams& layerParams)
{
layerParams.type = "Pooling";
if(opcode == "MEAN") {
layerParams.set("pool", "ave");
}
else if (opcode == "REDUCE_MAX") {
layerParams.set("pool", "max");
}
else {
CV_Error(Error::StsNotImplemented, "Unsupported pooling " + opcode);
}
layerParams.set("global_pooling", true);
auto options = op.builtin_options_as_ReducerOptions();
bool keep_dims = options->keep_dims();
if (!keep_dims) {
const auto name = layerParams.name;
layerParams.name += "/global_pooling";
addLayer(layerParams, op);
int out = op.outputs()->Get(0);
auto outId = layerIds[out];
int flattenId = addFlattenLayer(1, -1, name, outId, isInt8(op) ? CV_8S : CV_32F);
layerIds[out] = std::make_pair(flattenId, 0);
}
else {
addLayer(layerParams, op);
}
}
int TFLiteImporter::addPermuteLayer(const std::vector<int>& order, const std::string& permName, int TFLiteImporter::addPermuteLayer(const std::vector<int>& order, const std::string& permName,
const std::pair<int, int>& inpId, int dtype) const std::pair<int, int>& inpId, int dtype)
{ {
@ -786,6 +821,16 @@ int TFLiteImporter::addReshapeLayer(const std::vector<int>& shape, int axis, int
return id; return id;
} }
int TFLiteImporter::addFlattenLayer(int axis, int end_axis, const std::string& name, const std::pair<int, int>& inpId, int dtype)
{
LayerParams lp;
lp.set("axis", axis);
lp.set("end_axis", end_axis);
int id = dstNet.addLayer(name, "Flatten", dtype, lp);
dstNet.connect(inpId.first, inpId.second, id, 0);
return id;
}
void TFLiteImporter::parseDeconvolution(const Operator& op, const std::string& opcode, LayerParams& layerParams) { void TFLiteImporter::parseDeconvolution(const Operator& op, const std::string& opcode, LayerParams& layerParams) {
layerParams.type = "Deconvolution"; layerParams.type = "Deconvolution";

View File

@ -260,6 +260,14 @@ TEST_P(Test_TFLite, permute) {
testLayer("permutation_4d_0231"); testLayer("permutation_4d_0231");
} }
TEST_P(Test_TFLite, global_average_pooling_2d) {
testLayer("global_average_pooling_2d");
}
TEST_P(Test_TFLite, global_max_pooling_2d) {
testLayer("global_max_pooling_2d");
}
INSTANTIATE_TEST_CASE_P(/**/, Test_TFLite, dnnBackendsAndTargets()); INSTANTIATE_TEST_CASE_P(/**/, Test_TFLite, dnnBackendsAndTargets());
}} // namespace }} // namespace