mirror of
https://github.com/opencv/opencv.git
synced 2025-06-30 17:02:12 +08:00
Modify nGraph's ConvolutionBackpropData and GroupConvolution
This commit is contained in:
parent
a2642d83d3
commit
fe77223dee
@ -544,6 +544,12 @@ public:
|
|||||||
const int group = inpCn / inpGroupCn;
|
const int group = inpCn / inpGroupCn;
|
||||||
|
|
||||||
std::vector<size_t> kernel_shape = getShape<size_t>(blobs[0]);
|
std::vector<size_t> kernel_shape = getShape<size_t>(blobs[0]);
|
||||||
|
if (group != 1)
|
||||||
|
{
|
||||||
|
kernel_shape[0] /= group;
|
||||||
|
kernel_shape.insert(kernel_shape.begin(), group);
|
||||||
|
}
|
||||||
|
|
||||||
auto ieWeights = std::make_shared<ngraph::op::Constant>(ngraph::element::f32, kernel_shape, blobs[0].data);
|
auto ieWeights = std::make_shared<ngraph::op::Constant>(ngraph::element::f32, kernel_shape, blobs[0].data);
|
||||||
if (fusedWeights)
|
if (fusedWeights)
|
||||||
{
|
{
|
||||||
@ -566,14 +572,12 @@ public:
|
|||||||
|
|
||||||
std::shared_ptr<ngraph::Node> conv_node;
|
std::shared_ptr<ngraph::Node> conv_node;
|
||||||
if (group != 1) {
|
if (group != 1) {
|
||||||
conv_node = std::make_shared<ngraph::op::GroupConvolution>(
|
conv_node = std::make_shared<ngraph::op::v1::GroupConvolution>(
|
||||||
ieInpNode, ieWeights,
|
ieInpNode, ieWeights,
|
||||||
ngraph::Strides(strides),
|
ngraph::Strides(strides),
|
||||||
ngraph::Strides(dilations),
|
|
||||||
ngraph::CoordinateDiff(std::vector<std::ptrdiff_t>(pads_begin.begin(), pads_begin.end())),
|
ngraph::CoordinateDiff(std::vector<std::ptrdiff_t>(pads_begin.begin(), pads_begin.end())),
|
||||||
ngraph::CoordinateDiff(std::vector<std::ptrdiff_t>(pads_end.begin(), pads_end.end())),
|
ngraph::CoordinateDiff(std::vector<std::ptrdiff_t>(pads_end.begin(), pads_end.end())),
|
||||||
ngraph::Strides{},
|
ngraph::Strides(dilations),
|
||||||
group,
|
|
||||||
pad_type);
|
pad_type);
|
||||||
} else {
|
} else {
|
||||||
conv_node = std::make_shared<ngraph::op::v1::Convolution>(
|
conv_node = std::make_shared<ngraph::op::v1::Convolution>(
|
||||||
@ -2037,37 +2041,29 @@ public:
|
|||||||
Mat newWeights = blobs[0].reshape(1, inpCn);
|
Mat newWeights = blobs[0].reshape(1, inpCn);
|
||||||
transpose(weightsMat, newWeights);
|
transpose(weightsMat, newWeights);
|
||||||
}
|
}
|
||||||
size_t batch = ieInpNode->get_shape()[0];
|
|
||||||
std::vector<size_t> out_shape = {batch, (size_t)numOutput};
|
|
||||||
std::vector<size_t> paddings_end;
|
std::vector<size_t> paddings_end;
|
||||||
std::vector<size_t> inpShape = ieInpNode->get_shape();
|
|
||||||
if (padMode.empty())
|
if (padMode.empty())
|
||||||
{
|
{
|
||||||
for (int i = 0; i < pads_end.size(); i++) {
|
for (int i = 0; i < pads_end.size(); i++) {
|
||||||
out_shape.push_back(strides[i] * (inpShape[2 + i] - 1) +
|
|
||||||
kernel_size[i] - pads_begin[i] - pads_end[i] + adjust_pads[i]);
|
|
||||||
paddings_end.push_back(pads_end[i] - adjust_pads[i]);
|
paddings_end.push_back(pads_end[i] - adjust_pads[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else if (padMode == "SAME")
|
else if (padMode == "SAME")
|
||||||
{
|
{
|
||||||
for (int i = 0; i < pads_begin.size(); i++) {
|
for (int i = 0; i < pads_begin.size(); i++) {
|
||||||
out_shape.push_back(strides[i] * (inpShape[2 + i] - 1) + 1 + adjust_pads[i]);
|
|
||||||
paddings_end.push_back(kernel_size[i] - pads_begin[i] - 1 - adjust_pads[i]);
|
paddings_end.push_back(kernel_size[i] - pads_begin[i] - 1 - adjust_pads[i]);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
paddings_end = pads_end;
|
paddings_end = pads_end;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto deconv = std::make_shared<ngraph::op::ConvolutionBackpropData>(
|
auto deconv = std::make_shared<ngraph::op::v1::ConvolutionBackpropData>(
|
||||||
ngraph::Shape{out_shape},
|
|
||||||
ieWeights,
|
|
||||||
ieInpNode,
|
ieInpNode,
|
||||||
|
ieWeights,
|
||||||
ngraph::Strides(strides),
|
ngraph::Strides(strides),
|
||||||
ngraph::Strides(dilations),
|
|
||||||
ngraph::CoordinateDiff(std::vector<std::ptrdiff_t>(pads_begin.begin(), pads_begin.end())),
|
ngraph::CoordinateDiff(std::vector<std::ptrdiff_t>(pads_begin.begin(), pads_begin.end())),
|
||||||
ngraph::CoordinateDiff(std::vector<std::ptrdiff_t>(paddings_end.begin(), paddings_end.end())),
|
ngraph::CoordinateDiff(std::vector<std::ptrdiff_t>(paddings_end.begin(), paddings_end.end())),
|
||||||
(strides.size() == 2 ? ngraph::Strides{1, 1} : ngraph::Strides{1, 1, 1}));
|
ngraph::Strides(dilations));
|
||||||
if (hasBias() || fusedBias)
|
if (hasBias() || fusedBias)
|
||||||
{
|
{
|
||||||
std::vector<size_t> shape(deconv->get_shape().size(), 1);
|
std::vector<size_t> shape(deconv->get_shape().size(), 1);
|
||||||
|
Loading…
Reference in New Issue
Block a user