mirror of
https://github.com/opencv/opencv.git
synced 2025-07-25 22:57:53 +08:00
Add global_pooling_dim flags
This commit is contained in:
parent
7eba3a7c96
commit
a33d50084d
@ -150,18 +150,12 @@ void getPoolingKernelParams(const LayerParams ¶ms, std::vector<size_t>& kern
|
|||||||
{
|
{
|
||||||
bool is_global = params.get<bool>("global_pooling", false);
|
bool is_global = params.get<bool>("global_pooling", false);
|
||||||
globalPooling = std::vector<bool>(3, is_global);
|
globalPooling = std::vector<bool>(3, is_global);
|
||||||
if (params.has("global_d"))
|
if (params.has("global_pooling_d"))
|
||||||
{
|
globalPooling[0] = params.get<bool>("global_pooling_d");
|
||||||
globalPooling[0] = params.get<bool>("global_d");
|
else if (params.has("global_pooling_h"))
|
||||||
}
|
globalPooling[1] = params.get<bool>("global_pooling_h");
|
||||||
else if (params.has("global_h"))
|
else if (params.has("global_pooling_w"))
|
||||||
{
|
globalPooling[2] = params.get<bool>("global_pooling_w");
|
||||||
globalPooling[1] = params.get<bool>("global_h");
|
|
||||||
}
|
|
||||||
else if (params.has("global_w"))
|
|
||||||
{
|
|
||||||
globalPooling[2] = params.get<bool>("global_w");
|
|
||||||
}
|
|
||||||
|
|
||||||
if (is_global)
|
if (is_global)
|
||||||
{
|
{
|
||||||
|
@ -1961,7 +1961,7 @@ void TFImporter::populateNet(Net dstNet)
|
|||||||
CV_Assert(layer_id.find(avgName) == layer_id.end());
|
CV_Assert(layer_id.find(avgName) == layer_id.end());
|
||||||
avgLp.set("pool", "ave");
|
avgLp.set("pool", "ave");
|
||||||
// pooling kernel H x 1
|
// pooling kernel H x 1
|
||||||
avgLp.set("global_h", true);
|
avgLp.set("global_pooling_h", true);
|
||||||
avgLp.set("kernel_size", 1);
|
avgLp.set("kernel_size", 1);
|
||||||
int avgId = dstNet.addLayer(avgName, "Pooling", avgLp);
|
int avgId = dstNet.addLayer(avgName, "Pooling", avgLp);
|
||||||
layer_id[avgName] = avgId;
|
layer_id[avgName] = avgId;
|
||||||
|
Loading…
Reference in New Issue
Block a user