mirror of
https://github.com/opencv/opencv.git
synced 2025-07-24 22:16:27 +08:00
Add global_pooling_dim flags
This commit is contained in:
parent
7eba3a7c96
commit
a33d50084d
@ -150,18 +150,12 @@ void getPoolingKernelParams(const LayerParams ¶ms, std::vector<size_t>& kern
|
||||
{
|
||||
bool is_global = params.get<bool>("global_pooling", false);
|
||||
globalPooling = std::vector<bool>(3, is_global);
|
||||
if (params.has("global_d"))
|
||||
{
|
||||
globalPooling[0] = params.get<bool>("global_d");
|
||||
}
|
||||
else if (params.has("global_h"))
|
||||
{
|
||||
globalPooling[1] = params.get<bool>("global_h");
|
||||
}
|
||||
else if (params.has("global_w"))
|
||||
{
|
||||
globalPooling[2] = params.get<bool>("global_w");
|
||||
}
|
||||
if (params.has("global_pooling_d"))
|
||||
globalPooling[0] = params.get<bool>("global_pooling_d");
|
||||
else if (params.has("global_pooling_h"))
|
||||
globalPooling[1] = params.get<bool>("global_pooling_h");
|
||||
else if (params.has("global_pooling_w"))
|
||||
globalPooling[2] = params.get<bool>("global_pooling_w");
|
||||
|
||||
if (is_global)
|
||||
{
|
||||
|
@ -1961,7 +1961,7 @@ void TFImporter::populateNet(Net dstNet)
|
||||
CV_Assert(layer_id.find(avgName) == layer_id.end());
|
||||
avgLp.set("pool", "ave");
|
||||
// pooling kernel H x 1
|
||||
avgLp.set("global_h", true);
|
||||
avgLp.set("global_pooling_h", true);
|
||||
avgLp.set("kernel_size", 1);
|
||||
int avgId = dstNet.addLayer(avgName, "Pooling", avgLp);
|
||||
layer_id[avgName] = avgId;
|
||||
|
Loading…
Reference in New Issue
Block a user