mirror of
https://github.com/opencv/opencv.git
synced 2025-01-18 22:44:02 +08:00
Fix slice layer from TensorFlow
This commit is contained in:
parent
f57630d92b
commit
184862582c
@ -91,7 +91,7 @@ public:
|
||||
{
|
||||
int size = sizeOrEnd;
|
||||
CV_Assert(size == -1 || size > 0); // -1 value means range [start, axis_size).
|
||||
sliceRanges[0][i].end = start > 0 ? start + size : -1; // We'll finalize a negative value later.
|
||||
sliceRanges[0][i].end = size > 0 ? (start + size) : -1; // We'll finalize a negative value later.
|
||||
}
|
||||
else
|
||||
{
|
||||
|
@ -1119,21 +1119,21 @@ void TFImporter::populateNet(Net dstNet)
|
||||
// input: "Slice/begin"
|
||||
// input: "Slice/size"
|
||||
CV_Assert(layer.input_size() == 3);
|
||||
Mat begins = getTensorContent(getConstBlob(layer, value_id, 1));
|
||||
Mat sizes = getTensorContent(getConstBlob(layer, value_id, 2));
|
||||
CV_Assert(!begins.empty(), !sizes.empty(), begins.type() == CV_32SC1,
|
||||
sizes.type() == CV_32SC1);
|
||||
|
||||
const tensorflow::TensorProto begins = getConstBlob(layer, value_id, 1);
|
||||
const tensorflow::TensorProto sizes = getConstBlob(layer, value_id, 2);
|
||||
std::string beginsData = begins.tensor_content();
|
||||
std::string sizesData = sizes.tensor_content();
|
||||
CV_Assert(begins.dtype() == tensorflow::DT_INT32);
|
||||
CV_Assert(sizes.dtype() == tensorflow::DT_INT32);
|
||||
CV_Assert(!beginsData.empty());
|
||||
CV_Assert(!sizesData.empty());
|
||||
CV_Assert(beginsData.size() == sizesData.size());
|
||||
|
||||
layerParams.set("begin", DictValue::arrayInt((int*)beginsData.c_str(),
|
||||
beginsData.size() / 4));
|
||||
layerParams.set("size", DictValue::arrayInt((int*)sizesData.c_str(),
|
||||
sizesData.size() / 4));
|
||||
if (begins.total() == 4)
|
||||
{
|
||||
// Perhabs, we have an NHWC order. Swap it to NCHW.
|
||||
std::swap(*begins.ptr<int32_t>(0, 2), *begins.ptr<int32_t>(0, 3));
|
||||
std::swap(*begins.ptr<int32_t>(0, 1), *begins.ptr<int32_t>(0, 2));
|
||||
std::swap(*sizes.ptr<int32_t>(0, 2), *sizes.ptr<int32_t>(0, 3));
|
||||
std::swap(*sizes.ptr<int32_t>(0, 1), *sizes.ptr<int32_t>(0, 2));
|
||||
}
|
||||
layerParams.set("begin", DictValue::arrayInt((int*)begins.data, begins.total()));
|
||||
layerParams.set("size", DictValue::arrayInt((int*)sizes.data, sizes.total()));
|
||||
|
||||
int id = dstNet.addLayer(name, "Slice", layerParams);
|
||||
layer_id[name] = id;
|
||||
|
@ -301,6 +301,11 @@ TEST(Test_TensorFlow, resize_nearest_neighbor)
|
||||
runTensorFlowNet("resize_nearest_neighbor");
|
||||
}
|
||||
|
||||
TEST(Test_TensorFlow, slice)
|
||||
{
|
||||
runTensorFlowNet("slice_4d");
|
||||
}
|
||||
|
||||
TEST(Test_TensorFlow, memory_read)
|
||||
{
|
||||
double l1 = 1e-5;
|
||||
|
Loading…
Reference in New Issue
Block a user