mirror of
https://github.com/opencv/opencv.git
synced 2025-06-07 09:25:45 +08:00
Update global pooling
This commit is contained in:
parent
cf477f7e9f
commit
752653c70b
@ -251,7 +251,7 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
|
||||
CV_DEPRECATED_EXTERNAL Size kernel, stride, pad;
|
||||
CV_DEPRECATED_EXTERNAL int pad_l, pad_t, pad_r, pad_b;
|
||||
bool globalPooling;
|
||||
int global_axis;
|
||||
std::vector<bool> isGlobalPooling;
|
||||
bool computeMaxIdx;
|
||||
String padMode;
|
||||
bool ceilMode;
|
||||
|
@ -122,9 +122,17 @@ public:
|
||||
}
|
||||
else
|
||||
CV_Error(Error::StsBadArg, "Cannot determine pooling type");
|
||||
|
||||
setParamsFrom(params);
|
||||
ceilMode = params.get<bool>("ceil_mode", true);
|
||||
global_axis = params.get<int>("global_axis", -1);
|
||||
if (params.has("is_global_pooling"))
|
||||
{
|
||||
const DictValue &global_axis = params.get("is_global_pooling");
|
||||
int size = global_axis.size();
|
||||
isGlobalPooling.resize(size);
|
||||
for (int i = 0; i < size; i++)
|
||||
isGlobalPooling[i] = global_axis.get<bool>(i);
|
||||
}
|
||||
spatialScale = params.get<float>("spatial_scale", 1);
|
||||
avePoolPaddedArea = params.get<bool>("ave_pool_padded_area", true);
|
||||
}
|
||||
@ -150,8 +158,12 @@ public:
|
||||
if (globalPooling) {
|
||||
kernel = Size(inp[1], inp[0]);
|
||||
kernel_size = std::vector<size_t>(inp.begin(), inp.end());
|
||||
} else if (global_axis != -1) {
|
||||
kernel_size[global_axis] = inp[global_axis];
|
||||
} else if (!isGlobalPooling.empty()) {
|
||||
for (int i = 0; i < isGlobalPooling.size(); i++)
|
||||
{
|
||||
if (isGlobalPooling[i])
|
||||
kernel_size[i] = inp[i];
|
||||
}
|
||||
kernel = Size(kernel_size[1], kernel_size[0]);
|
||||
}
|
||||
|
||||
@ -1041,10 +1053,14 @@ virtual Ptr<BackendNode> initNgraph(const std::vector<Ptr<BackendWrapper> >& inp
|
||||
outShape[0] = inputs[1][0]; // Number of proposals;
|
||||
outShape[1] = psRoiOutChannels;
|
||||
}
|
||||
else if (global_axis != -1)
|
||||
else if (!isGlobalPooling.empty())
|
||||
{
|
||||
CV_Assert(global_axis >= 0 && global_axis < inpShape.size());
|
||||
outShape[2 + global_axis] = 1;
|
||||
CV_Assert(isGlobalPooling.size() == inpShape.size());
|
||||
for (int i = 0; i < isGlobalPooling.size(); i++)
|
||||
{
|
||||
if (isGlobalPooling[i])
|
||||
outShape[2 + i] = 1;
|
||||
}
|
||||
}
|
||||
|
||||
int numOutputs = requiredOutputs ? requiredOutputs : (type == MAX ? 2 : 1);
|
||||
|
@ -1961,7 +1961,8 @@ void TFImporter::populateNet(Net dstNet)
|
||||
CV_Assert(layer_id.find(avgName) == layer_id.end());
|
||||
avgLp.set("pool", "ave");
|
||||
// pooling kernel H x 1
|
||||
avgLp.set("global_axis", 0);
|
||||
bool isGlobalPooling[] = {true, false};
|
||||
avgLp.set("is_global_pooling", DictValue::arrayInt(&isGlobalPooling[0], 2));
|
||||
avgLp.set("kernel_size", 1);
|
||||
int avgId = dstNet.addLayer(avgName, "Pooling", avgLp);
|
||||
layer_id[avgName] = avgId;
|
||||
|
Loading…
Reference in New Issue
Block a user