mirror of
https://github.com/opencv/opencv.git
synced 2025-06-07 17:44:04 +08:00
Update global pooling
This commit is contained in:
parent
cf477f7e9f
commit
752653c70b
@ -251,7 +251,7 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
|
|||||||
CV_DEPRECATED_EXTERNAL Size kernel, stride, pad;
|
CV_DEPRECATED_EXTERNAL Size kernel, stride, pad;
|
||||||
CV_DEPRECATED_EXTERNAL int pad_l, pad_t, pad_r, pad_b;
|
CV_DEPRECATED_EXTERNAL int pad_l, pad_t, pad_r, pad_b;
|
||||||
bool globalPooling;
|
bool globalPooling;
|
||||||
int global_axis;
|
std::vector<bool> isGlobalPooling;
|
||||||
bool computeMaxIdx;
|
bool computeMaxIdx;
|
||||||
String padMode;
|
String padMode;
|
||||||
bool ceilMode;
|
bool ceilMode;
|
||||||
|
@ -122,9 +122,17 @@ public:
|
|||||||
}
|
}
|
||||||
else
|
else
|
||||||
CV_Error(Error::StsBadArg, "Cannot determine pooling type");
|
CV_Error(Error::StsBadArg, "Cannot determine pooling type");
|
||||||
|
|
||||||
setParamsFrom(params);
|
setParamsFrom(params);
|
||||||
ceilMode = params.get<bool>("ceil_mode", true);
|
ceilMode = params.get<bool>("ceil_mode", true);
|
||||||
global_axis = params.get<int>("global_axis", -1);
|
if (params.has("is_global_pooling"))
|
||||||
|
{
|
||||||
|
const DictValue &global_axis = params.get("is_global_pooling");
|
||||||
|
int size = global_axis.size();
|
||||||
|
isGlobalPooling.resize(size);
|
||||||
|
for (int i = 0; i < size; i++)
|
||||||
|
isGlobalPooling[i] = global_axis.get<bool>(i);
|
||||||
|
}
|
||||||
spatialScale = params.get<float>("spatial_scale", 1);
|
spatialScale = params.get<float>("spatial_scale", 1);
|
||||||
avePoolPaddedArea = params.get<bool>("ave_pool_padded_area", true);
|
avePoolPaddedArea = params.get<bool>("ave_pool_padded_area", true);
|
||||||
}
|
}
|
||||||
@ -150,8 +158,12 @@ public:
|
|||||||
if (globalPooling) {
|
if (globalPooling) {
|
||||||
kernel = Size(inp[1], inp[0]);
|
kernel = Size(inp[1], inp[0]);
|
||||||
kernel_size = std::vector<size_t>(inp.begin(), inp.end());
|
kernel_size = std::vector<size_t>(inp.begin(), inp.end());
|
||||||
} else if (global_axis != -1) {
|
} else if (!isGlobalPooling.empty()) {
|
||||||
kernel_size[global_axis] = inp[global_axis];
|
for (int i = 0; i < isGlobalPooling.size(); i++)
|
||||||
|
{
|
||||||
|
if (isGlobalPooling[i])
|
||||||
|
kernel_size[i] = inp[i];
|
||||||
|
}
|
||||||
kernel = Size(kernel_size[1], kernel_size[0]);
|
kernel = Size(kernel_size[1], kernel_size[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1041,10 +1053,14 @@ virtual Ptr<BackendNode> initNgraph(const std::vector<Ptr<BackendWrapper> >& inp
|
|||||||
outShape[0] = inputs[1][0]; // Number of proposals;
|
outShape[0] = inputs[1][0]; // Number of proposals;
|
||||||
outShape[1] = psRoiOutChannels;
|
outShape[1] = psRoiOutChannels;
|
||||||
}
|
}
|
||||||
else if (global_axis != -1)
|
else if (!isGlobalPooling.empty())
|
||||||
{
|
{
|
||||||
CV_Assert(global_axis >= 0 && global_axis < inpShape.size());
|
CV_Assert(isGlobalPooling.size() == inpShape.size());
|
||||||
outShape[2 + global_axis] = 1;
|
for (int i = 0; i < isGlobalPooling.size(); i++)
|
||||||
|
{
|
||||||
|
if (isGlobalPooling[i])
|
||||||
|
outShape[2 + i] = 1;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int numOutputs = requiredOutputs ? requiredOutputs : (type == MAX ? 2 : 1);
|
int numOutputs = requiredOutputs ? requiredOutputs : (type == MAX ? 2 : 1);
|
||||||
|
@ -1961,7 +1961,8 @@ void TFImporter::populateNet(Net dstNet)
|
|||||||
CV_Assert(layer_id.find(avgName) == layer_id.end());
|
CV_Assert(layer_id.find(avgName) == layer_id.end());
|
||||||
avgLp.set("pool", "ave");
|
avgLp.set("pool", "ave");
|
||||||
// pooling kernel H x 1
|
// pooling kernel H x 1
|
||||||
avgLp.set("global_axis", 0);
|
bool isGlobalPooling[] = {true, false};
|
||||||
|
avgLp.set("is_global_pooling", DictValue::arrayInt(&isGlobalPooling[0], 2));
|
||||||
avgLp.set("kernel_size", 1);
|
avgLp.set("kernel_size", 1);
|
||||||
int avgId = dstNet.addLayer(avgName, "Pooling", avgLp);
|
int avgId = dstNet.addLayer(avgName, "Pooling", avgLp);
|
||||||
layer_id[avgName] = avgId;
|
layer_id[avgName] = avgId;
|
||||||
|
Loading…
Reference in New Issue
Block a user