mirror of
https://github.com/opencv/opencv.git
synced 2025-08-06 06:26:29 +08:00
Merge pull request #12243 from dkurt:dnn_tf_mask_rcnn
* Support Mask-RCNN from TensorFlow * Fix a sample
This commit is contained in:
parent
4f360f8b1a
commit
472b71ecef
@ -99,6 +99,13 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (boxes.rows < out.size[0])
|
||||||
|
{
|
||||||
|
// left = top = right = bottom = 0
|
||||||
|
std::vector<cv::Range> dstRanges(4, Range::all());
|
||||||
|
dstRanges[0] = Range(boxes.rows, out.size[0]);
|
||||||
|
out(dstRanges).setTo(inp.ptr<float>(0, 0, 0)[0]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -115,6 +115,7 @@ public:
|
|||||||
// It's true whenever predicted bounding boxes and proposals are normalized to [0, 1].
|
// It's true whenever predicted bounding boxes and proposals are normalized to [0, 1].
|
||||||
bool _bboxesNormalized;
|
bool _bboxesNormalized;
|
||||||
bool _clip;
|
bool _clip;
|
||||||
|
bool _groupByClasses;
|
||||||
|
|
||||||
enum { _numAxes = 4 };
|
enum { _numAxes = 4 };
|
||||||
static const std::string _layerName;
|
static const std::string _layerName;
|
||||||
@ -183,6 +184,7 @@ public:
|
|||||||
_locPredTransposed = getParameter<bool>(params, "loc_pred_transposed", 0, false, false);
|
_locPredTransposed = getParameter<bool>(params, "loc_pred_transposed", 0, false, false);
|
||||||
_bboxesNormalized = getParameter<bool>(params, "normalized_bbox", 0, false, true);
|
_bboxesNormalized = getParameter<bool>(params, "normalized_bbox", 0, false, true);
|
||||||
_clip = getParameter<bool>(params, "clip", 0, false, false);
|
_clip = getParameter<bool>(params, "clip", 0, false, false);
|
||||||
|
_groupByClasses = getParameter<bool>(params, "group_by_classes", 0, false, true);
|
||||||
|
|
||||||
getCodeType(params);
|
getCodeType(params);
|
||||||
|
|
||||||
@ -381,7 +383,7 @@ public:
|
|||||||
{
|
{
|
||||||
count += outputDetections_(i, &outputsData[count * 7],
|
count += outputDetections_(i, &outputsData[count * 7],
|
||||||
allDecodedBBoxes[i], allConfidenceScores[i],
|
allDecodedBBoxes[i], allConfidenceScores[i],
|
||||||
allIndices[i]);
|
allIndices[i], _groupByClasses);
|
||||||
}
|
}
|
||||||
CV_Assert(count == numKept);
|
CV_Assert(count == numKept);
|
||||||
}
|
}
|
||||||
@ -497,7 +499,7 @@ public:
|
|||||||
{
|
{
|
||||||
count += outputDetections_(i, &outputsData[count * 7],
|
count += outputDetections_(i, &outputsData[count * 7],
|
||||||
allDecodedBBoxes[i], allConfidenceScores[i],
|
allDecodedBBoxes[i], allConfidenceScores[i],
|
||||||
allIndices[i]);
|
allIndices[i], _groupByClasses);
|
||||||
}
|
}
|
||||||
CV_Assert(count == numKept);
|
CV_Assert(count == numKept);
|
||||||
}
|
}
|
||||||
@ -505,9 +507,36 @@ public:
|
|||||||
size_t outputDetections_(
|
size_t outputDetections_(
|
||||||
const int i, float* outputsData,
|
const int i, float* outputsData,
|
||||||
const LabelBBox& decodeBBoxes, Mat& confidenceScores,
|
const LabelBBox& decodeBBoxes, Mat& confidenceScores,
|
||||||
const std::map<int, std::vector<int> >& indicesMap
|
const std::map<int, std::vector<int> >& indicesMap,
|
||||||
|
bool groupByClasses
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
|
std::vector<int> dstIndices;
|
||||||
|
std::vector<std::pair<float, int> > allScores;
|
||||||
|
for (std::map<int, std::vector<int> >::const_iterator it = indicesMap.begin(); it != indicesMap.end(); ++it)
|
||||||
|
{
|
||||||
|
int label = it->first;
|
||||||
|
if (confidenceScores.rows <= label)
|
||||||
|
CV_Error_(cv::Error::StsError, ("Could not find confidence predictions for label %d", label));
|
||||||
|
const std::vector<float>& scores = confidenceScores.row(label);
|
||||||
|
const std::vector<int>& indices = it->second;
|
||||||
|
|
||||||
|
const int numAllScores = allScores.size();
|
||||||
|
allScores.reserve(numAllScores + indices.size());
|
||||||
|
for (size_t j = 0; j < indices.size(); ++j)
|
||||||
|
{
|
||||||
|
allScores.push_back(std::make_pair(scores[indices[j]], numAllScores + j));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!groupByClasses)
|
||||||
|
std::sort(allScores.begin(), allScores.end(), util::SortScorePairDescend<int>);
|
||||||
|
|
||||||
|
dstIndices.resize(allScores.size());
|
||||||
|
for (size_t j = 0; j < dstIndices.size(); ++j)
|
||||||
|
{
|
||||||
|
dstIndices[allScores[j].second] = j;
|
||||||
|
}
|
||||||
|
|
||||||
size_t count = 0;
|
size_t count = 0;
|
||||||
for (std::map<int, std::vector<int> >::const_iterator it = indicesMap.begin(); it != indicesMap.end(); ++it)
|
for (std::map<int, std::vector<int> >::const_iterator it = indicesMap.begin(); it != indicesMap.end(); ++it)
|
||||||
{
|
{
|
||||||
@ -524,14 +553,15 @@ public:
|
|||||||
for (size_t j = 0; j < indices.size(); ++j, ++count)
|
for (size_t j = 0; j < indices.size(); ++j, ++count)
|
||||||
{
|
{
|
||||||
int idx = indices[j];
|
int idx = indices[j];
|
||||||
|
int dstIdx = dstIndices[count];
|
||||||
const util::NormalizedBBox& decode_bbox = label_bboxes->second[idx];
|
const util::NormalizedBBox& decode_bbox = label_bboxes->second[idx];
|
||||||
outputsData[count * 7] = i;
|
outputsData[dstIdx * 7] = i;
|
||||||
outputsData[count * 7 + 1] = label;
|
outputsData[dstIdx * 7 + 1] = label;
|
||||||
outputsData[count * 7 + 2] = scores[idx];
|
outputsData[dstIdx * 7 + 2] = scores[idx];
|
||||||
outputsData[count * 7 + 3] = decode_bbox.xmin;
|
outputsData[dstIdx * 7 + 3] = decode_bbox.xmin;
|
||||||
outputsData[count * 7 + 4] = decode_bbox.ymin;
|
outputsData[dstIdx * 7 + 4] = decode_bbox.ymin;
|
||||||
outputsData[count * 7 + 5] = decode_bbox.xmax;
|
outputsData[dstIdx * 7 + 5] = decode_bbox.xmax;
|
||||||
outputsData[count * 7 + 6] = decode_bbox.ymax;
|
outputsData[dstIdx * 7 + 6] = decode_bbox.ymax;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return count;
|
return count;
|
||||||
|
@ -33,9 +33,7 @@ public:
|
|||||||
interpolation = params.get<String>("interpolation");
|
interpolation = params.get<String>("interpolation");
|
||||||
CV_Assert(interpolation == "nearest" || interpolation == "bilinear");
|
CV_Assert(interpolation == "nearest" || interpolation == "bilinear");
|
||||||
|
|
||||||
bool alignCorners = params.get<bool>("align_corners", false);
|
alignCorners = params.get<bool>("align_corners", false);
|
||||||
if (alignCorners)
|
|
||||||
CV_Error(Error::StsNotImplemented, "Resize with align_corners=true is not implemented");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool getMemoryShapes(const std::vector<MatShape> &inputs,
|
bool getMemoryShapes(const std::vector<MatShape> &inputs,
|
||||||
@ -66,8 +64,15 @@ public:
|
|||||||
outHeight = outputs[0].size[2];
|
outHeight = outputs[0].size[2];
|
||||||
outWidth = outputs[0].size[3];
|
outWidth = outputs[0].size[3];
|
||||||
}
|
}
|
||||||
scaleHeight = static_cast<float>(inputs[0]->size[2]) / outHeight;
|
if (alignCorners && outHeight > 1)
|
||||||
scaleWidth = static_cast<float>(inputs[0]->size[3]) / outWidth;
|
scaleHeight = static_cast<float>(inputs[0]->size[2] - 1) / (outHeight - 1);
|
||||||
|
else
|
||||||
|
scaleHeight = static_cast<float>(inputs[0]->size[2]) / outHeight;
|
||||||
|
|
||||||
|
if (alignCorners && outWidth > 1)
|
||||||
|
scaleWidth = static_cast<float>(inputs[0]->size[3] - 1) / (outWidth - 1);
|
||||||
|
else
|
||||||
|
scaleWidth = static_cast<float>(inputs[0]->size[3]) / outWidth;
|
||||||
}
|
}
|
||||||
|
|
||||||
void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE
|
void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE
|
||||||
@ -166,6 +171,7 @@ protected:
|
|||||||
int outWidth, outHeight, zoomFactorWidth, zoomFactorHeight;
|
int outWidth, outHeight, zoomFactorWidth, zoomFactorHeight;
|
||||||
String interpolation;
|
String interpolation;
|
||||||
float scaleWidth, scaleHeight;
|
float scaleWidth, scaleHeight;
|
||||||
|
bool alignCorners;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
@ -537,4 +537,56 @@ TEST(Test_TensorFlow, two_inputs)
|
|||||||
normAssert(out, firstInput + secondInput);
|
normAssert(out, firstInput + secondInput);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(Test_TensorFlow, Mask_RCNN)
|
||||||
|
{
|
||||||
|
std::string proto = findDataFile("dnn/mask_rcnn_inception_v2_coco_2018_01_28.pbtxt", false);
|
||||||
|
std::string model = findDataFile("dnn/mask_rcnn_inception_v2_coco_2018_01_28.pb", false);
|
||||||
|
|
||||||
|
Net net = readNetFromTensorflow(model, proto);
|
||||||
|
Mat img = imread(findDataFile("dnn/street.png", false));
|
||||||
|
Mat refDetections = blobFromNPY(path("mask_rcnn_inception_v2_coco_2018_01_28.detection_out.npy"));
|
||||||
|
Mat refMasks = blobFromNPY(path("mask_rcnn_inception_v2_coco_2018_01_28.detection_masks.npy"));
|
||||||
|
Mat blob = blobFromImage(img, 1.0f, Size(800, 800), Scalar(), true, false);
|
||||||
|
|
||||||
|
net.setPreferableBackend(DNN_BACKEND_OPENCV);
|
||||||
|
|
||||||
|
net.setInput(blob);
|
||||||
|
|
||||||
|
// Mask-RCNN predicts bounding boxes and segmentation masks.
|
||||||
|
std::vector<String> outNames(2);
|
||||||
|
outNames[0] = "detection_out_final";
|
||||||
|
outNames[1] = "detection_masks";
|
||||||
|
|
||||||
|
std::vector<Mat> outs;
|
||||||
|
net.forward(outs, outNames);
|
||||||
|
|
||||||
|
Mat outDetections = outs[0];
|
||||||
|
Mat outMasks = outs[1];
|
||||||
|
normAssertDetections(refDetections, outDetections, "", /*threshold for zero confidence*/1e-5);
|
||||||
|
|
||||||
|
// Output size of masks is NxCxHxW where
|
||||||
|
// N - number of detected boxes
|
||||||
|
// C - number of classes (excluding background)
|
||||||
|
// HxW - segmentation shape
|
||||||
|
const int numDetections = outDetections.size[2];
|
||||||
|
|
||||||
|
int masksSize[] = {1, numDetections, outMasks.size[2], outMasks.size[3]};
|
||||||
|
Mat masks(4, &masksSize[0], CV_32F);
|
||||||
|
|
||||||
|
std::vector<cv::Range> srcRanges(4, cv::Range::all());
|
||||||
|
std::vector<cv::Range> dstRanges(4, cv::Range::all());
|
||||||
|
|
||||||
|
outDetections = outDetections.reshape(1, outDetections.total() / 7);
|
||||||
|
for (int i = 0; i < numDetections; ++i)
|
||||||
|
{
|
||||||
|
// Get a class id for this bounding box and copy mask only for that class.
|
||||||
|
int classId = static_cast<int>(outDetections.at<float>(i, 1));
|
||||||
|
srcRanges[0] = dstRanges[1] = cv::Range(i, i + 1);
|
||||||
|
srcRanges[1] = cv::Range(classId, classId + 1);
|
||||||
|
outMasks(srcRanges).copyTo(masks(dstRanges));
|
||||||
|
}
|
||||||
|
cv::Range topRefMasks[] = {Range::all(), Range(0, numDetections), Range::all(), Range::all()};
|
||||||
|
normAssert(masks, refMasks(&topRefMasks[0]));
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
143
samples/dnn/mask_rcnn.py
Normal file
143
samples/dnn/mask_rcnn.py
Normal file
@ -0,0 +1,143 @@
|
|||||||
|
import cv2 as cv
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description=
|
||||||
|
'Use this script to run Mask-RCNN object detection and semantic '
|
||||||
|
'segmentation network from TensorFlow Object Detection API.')
|
||||||
|
parser.add_argument('--input', help='Path to input image or video file. Skip this argument to capture frames from a camera.')
|
||||||
|
parser.add_argument('--model', required=True, help='Path to a .pb file with weights.')
|
||||||
|
parser.add_argument('--config', required=True, help='Path to a .pxtxt file contains network configuration.')
|
||||||
|
parser.add_argument('--classes', help='Optional path to a text file with names of classes.')
|
||||||
|
parser.add_argument('--colors', help='Optional path to a text file with colors for an every class. '
|
||||||
|
'An every color is represented with three values from 0 to 255 in BGR channels order.')
|
||||||
|
parser.add_argument('--width', type=int, default=800,
|
||||||
|
help='Preprocess input image by resizing to a specific width.')
|
||||||
|
parser.add_argument('--height', type=int, default=800,
|
||||||
|
help='Preprocess input image by resizing to a specific height.')
|
||||||
|
parser.add_argument('--thr', type=float, default=0.5, help='Confidence threshold')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
np.random.seed(324)
|
||||||
|
|
||||||
|
# Load names of classes
|
||||||
|
classes = None
|
||||||
|
if args.classes:
|
||||||
|
with open(args.classes, 'rt') as f:
|
||||||
|
classes = f.read().rstrip('\n').split('\n')
|
||||||
|
|
||||||
|
# Load colors
|
||||||
|
colors = None
|
||||||
|
if args.colors:
|
||||||
|
with open(args.colors, 'rt') as f:
|
||||||
|
colors = [np.array(color.split(' '), np.uint8) for color in f.read().rstrip('\n').split('\n')]
|
||||||
|
|
||||||
|
legend = None
|
||||||
|
def showLegend(classes):
|
||||||
|
global legend
|
||||||
|
if not classes is None and legend is None:
|
||||||
|
blockHeight = 30
|
||||||
|
assert(len(classes) == len(colors))
|
||||||
|
|
||||||
|
legend = np.zeros((blockHeight * len(colors), 200, 3), np.uint8)
|
||||||
|
for i in range(len(classes)):
|
||||||
|
block = legend[i * blockHeight:(i + 1) * blockHeight]
|
||||||
|
block[:,:] = colors[i]
|
||||||
|
cv.putText(block, classes[i], (0, blockHeight/2), cv.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255))
|
||||||
|
|
||||||
|
cv.namedWindow('Legend', cv.WINDOW_NORMAL)
|
||||||
|
cv.imshow('Legend', legend)
|
||||||
|
classes = None
|
||||||
|
|
||||||
|
|
||||||
|
def drawBox(frame, classId, conf, left, top, right, bottom):
|
||||||
|
# Draw a bounding box.
|
||||||
|
cv.rectangle(frame, (left, top), (right, bottom), (0, 255, 0))
|
||||||
|
|
||||||
|
label = '%.2f' % conf
|
||||||
|
|
||||||
|
# Print a label of class.
|
||||||
|
if classes:
|
||||||
|
assert(classId < len(classes))
|
||||||
|
label = '%s: %s' % (classes[classId], label)
|
||||||
|
|
||||||
|
labelSize, baseLine = cv.getTextSize(label, cv.FONT_HERSHEY_SIMPLEX, 0.5, 1)
|
||||||
|
top = max(top, labelSize[1])
|
||||||
|
cv.rectangle(frame, (left, top - labelSize[1]), (left + labelSize[0], top + baseLine), (255, 255, 255), cv.FILLED)
|
||||||
|
cv.putText(frame, label, (left, top), cv.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0))
|
||||||
|
|
||||||
|
|
||||||
|
# Load a network
|
||||||
|
net = cv.dnn.readNet(args.model, args.config)
|
||||||
|
net.setPreferableBackend(cv.dnn.DNN_BACKEND_OPENCV)
|
||||||
|
|
||||||
|
winName = 'Mask-RCNN in OpenCV'
|
||||||
|
cv.namedWindow(winName, cv.WINDOW_NORMAL)
|
||||||
|
|
||||||
|
cap = cv.VideoCapture(args.input if args.input else 0)
|
||||||
|
legend = None
|
||||||
|
while cv.waitKey(1) < 0:
|
||||||
|
hasFrame, frame = cap.read()
|
||||||
|
if not hasFrame:
|
||||||
|
cv.waitKey()
|
||||||
|
break
|
||||||
|
|
||||||
|
frameH = frame.shape[0]
|
||||||
|
frameW = frame.shape[1]
|
||||||
|
|
||||||
|
# Create a 4D blob from a frame.
|
||||||
|
blob = cv.dnn.blobFromImage(frame, size=(args.width, args.height), swapRB=True, crop=False)
|
||||||
|
|
||||||
|
# Run a model
|
||||||
|
net.setInput(blob)
|
||||||
|
|
||||||
|
boxes, masks = net.forward(['detection_out_final', 'detection_masks'])
|
||||||
|
|
||||||
|
numClasses = masks.shape[1]
|
||||||
|
numDetections = boxes.shape[2]
|
||||||
|
|
||||||
|
# Draw segmentation
|
||||||
|
if not colors:
|
||||||
|
# Generate colors
|
||||||
|
colors = [np.array([0, 0, 0], np.uint8)]
|
||||||
|
for i in range(1, numClasses + 1):
|
||||||
|
colors.append((colors[i - 1] + np.random.randint(0, 256, [3], np.uint8)) / 2)
|
||||||
|
del colors[0]
|
||||||
|
|
||||||
|
boxesToDraw = []
|
||||||
|
for i in range(numDetections):
|
||||||
|
box = boxes[0, 0, i]
|
||||||
|
mask = masks[i]
|
||||||
|
score = box[2]
|
||||||
|
if score > args.thr:
|
||||||
|
classId = int(box[1])
|
||||||
|
left = int(frameW * box[3])
|
||||||
|
top = int(frameH * box[4])
|
||||||
|
right = int(frameW * box[5])
|
||||||
|
bottom = int(frameH * box[6])
|
||||||
|
|
||||||
|
left = max(0, min(left, frameW - 1))
|
||||||
|
top = max(0, min(top, frameH - 1))
|
||||||
|
right = max(0, min(right, frameW - 1))
|
||||||
|
bottom = max(0, min(bottom, frameH - 1))
|
||||||
|
|
||||||
|
boxesToDraw.append([frame, classId, score, left, top, right, bottom])
|
||||||
|
|
||||||
|
classMask = mask[classId]
|
||||||
|
classMask = cv.resize(classMask, (right - left + 1, bottom - top + 1))
|
||||||
|
mask = (classMask > 0.5)
|
||||||
|
|
||||||
|
roi = frame[top:bottom+1, left:right+1][mask]
|
||||||
|
frame[top:bottom+1, left:right+1][mask] = (0.7 * colors[classId] + 0.3 * roi).astype(np.uint8)
|
||||||
|
|
||||||
|
for box in boxesToDraw:
|
||||||
|
drawBox(*box)
|
||||||
|
|
||||||
|
# Put efficiency information.
|
||||||
|
t, _ = net.getPerfProfile()
|
||||||
|
label = 'Inference time: %.2f ms' % (t * 1000.0 / cv.getTickFrequency())
|
||||||
|
cv.putText(frame, label, (0, 15), cv.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0))
|
||||||
|
|
||||||
|
showLegend(classes)
|
||||||
|
|
||||||
|
cv.imshow(winName, frame)
|
@ -23,3 +23,98 @@ def addConstNode(name, values, graph_def):
|
|||||||
node.op = 'Const'
|
node.op = 'Const'
|
||||||
text_format.Merge(tensorMsg(values), node.attr["value"])
|
text_format.Merge(tensorMsg(values), node.attr["value"])
|
||||||
graph_def.node.extend([node])
|
graph_def.node.extend([node])
|
||||||
|
|
||||||
|
|
||||||
|
def addSlice(inp, out, begins, sizes, graph_def):
|
||||||
|
beginsNode = NodeDef()
|
||||||
|
beginsNode.name = out + '/begins'
|
||||||
|
beginsNode.op = 'Const'
|
||||||
|
text_format.Merge(tensorMsg(begins), beginsNode.attr["value"])
|
||||||
|
graph_def.node.extend([beginsNode])
|
||||||
|
|
||||||
|
sizesNode = NodeDef()
|
||||||
|
sizesNode.name = out + '/sizes'
|
||||||
|
sizesNode.op = 'Const'
|
||||||
|
text_format.Merge(tensorMsg(sizes), sizesNode.attr["value"])
|
||||||
|
graph_def.node.extend([sizesNode])
|
||||||
|
|
||||||
|
sliced = NodeDef()
|
||||||
|
sliced.name = out
|
||||||
|
sliced.op = 'Slice'
|
||||||
|
sliced.input.append(inp)
|
||||||
|
sliced.input.append(beginsNode.name)
|
||||||
|
sliced.input.append(sizesNode.name)
|
||||||
|
graph_def.node.extend([sliced])
|
||||||
|
|
||||||
|
|
||||||
|
def addReshape(inp, out, shape, graph_def):
|
||||||
|
shapeNode = NodeDef()
|
||||||
|
shapeNode.name = out + '/shape'
|
||||||
|
shapeNode.op = 'Const'
|
||||||
|
text_format.Merge(tensorMsg(shape), shapeNode.attr["value"])
|
||||||
|
graph_def.node.extend([shapeNode])
|
||||||
|
|
||||||
|
reshape = NodeDef()
|
||||||
|
reshape.name = out
|
||||||
|
reshape.op = 'Reshape'
|
||||||
|
reshape.input.append(inp)
|
||||||
|
reshape.input.append(shapeNode.name)
|
||||||
|
graph_def.node.extend([reshape])
|
||||||
|
|
||||||
|
|
||||||
|
def addSoftMax(inp, out, graph_def):
|
||||||
|
softmax = NodeDef()
|
||||||
|
softmax.name = out
|
||||||
|
softmax.op = 'Softmax'
|
||||||
|
text_format.Merge('i: -1', softmax.attr['axis'])
|
||||||
|
softmax.input.append(inp)
|
||||||
|
graph_def.node.extend([softmax])
|
||||||
|
|
||||||
|
|
||||||
|
def addFlatten(inp, out, graph_def):
|
||||||
|
flatten = NodeDef()
|
||||||
|
flatten.name = out
|
||||||
|
flatten.op = 'Flatten'
|
||||||
|
flatten.input.append(inp)
|
||||||
|
graph_def.node.extend([flatten])
|
||||||
|
|
||||||
|
|
||||||
|
# Removes Identity nodes
|
||||||
|
def removeIdentity(graph_def):
|
||||||
|
identities = {}
|
||||||
|
for node in graph_def.node:
|
||||||
|
if node.op == 'Identity':
|
||||||
|
identities[node.name] = node.input[0]
|
||||||
|
graph_def.node.remove(node)
|
||||||
|
|
||||||
|
for node in graph_def.node:
|
||||||
|
for i in range(len(node.input)):
|
||||||
|
if node.input[i] in identities:
|
||||||
|
node.input[i] = identities[node.input[i]]
|
||||||
|
|
||||||
|
|
||||||
|
def removeUnusedNodesAndAttrs(to_remove, graph_def):
|
||||||
|
unusedAttrs = ['T', 'Tshape', 'N', 'Tidx', 'Tdim', 'use_cudnn_on_gpu',
|
||||||
|
'Index', 'Tperm', 'is_training', 'Tpaddings']
|
||||||
|
|
||||||
|
removedNodes = []
|
||||||
|
|
||||||
|
for i in reversed(range(len(graph_def.node))):
|
||||||
|
op = graph_def.node[i].op
|
||||||
|
name = graph_def.node[i].name
|
||||||
|
|
||||||
|
if op == 'Const' or to_remove(name, op):
|
||||||
|
if op != 'Const':
|
||||||
|
removedNodes.append(name)
|
||||||
|
|
||||||
|
del graph_def.node[i]
|
||||||
|
else:
|
||||||
|
for attr in unusedAttrs:
|
||||||
|
if attr in graph_def.node[i].attr:
|
||||||
|
del graph_def.node[i].attr[attr]
|
||||||
|
|
||||||
|
# Remove references to removed nodes except Const nodes.
|
||||||
|
for node in graph_def.node:
|
||||||
|
for i in reversed(range(len(node.input))):
|
||||||
|
if node.input[i] in removedNodes:
|
||||||
|
del node.input[i]
|
||||||
|
@ -6,7 +6,7 @@ from tensorflow.core.framework.node_def_pb2 import NodeDef
|
|||||||
from tensorflow.tools.graph_transforms import TransformGraph
|
from tensorflow.tools.graph_transforms import TransformGraph
|
||||||
from google.protobuf import text_format
|
from google.protobuf import text_format
|
||||||
|
|
||||||
from tf_text_graph_common import tensorMsg, addConstNode
|
from tf_text_graph_common import *
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='Run this script to get a text graph of '
|
parser = argparse.ArgumentParser(description='Run this script to get a text graph of '
|
||||||
'SSD model from TensorFlow Object Detection API. '
|
'SSD model from TensorFlow Object Detection API. '
|
||||||
@ -37,50 +37,17 @@ scopesToIgnore = ('FirstStageFeatureExtractor/Assert',
|
|||||||
'FirstStageFeatureExtractor/GreaterEqual',
|
'FirstStageFeatureExtractor/GreaterEqual',
|
||||||
'FirstStageFeatureExtractor/LogicalAnd')
|
'FirstStageFeatureExtractor/LogicalAnd')
|
||||||
|
|
||||||
unusedAttrs = ['T', 'Tshape', 'N', 'Tidx', 'Tdim', 'use_cudnn_on_gpu',
|
|
||||||
'Index', 'Tperm', 'is_training', 'Tpaddings']
|
|
||||||
|
|
||||||
# Read the graph.
|
# Read the graph.
|
||||||
with tf.gfile.FastGFile(args.input, 'rb') as f:
|
with tf.gfile.FastGFile(args.input, 'rb') as f:
|
||||||
graph_def = tf.GraphDef()
|
graph_def = tf.GraphDef()
|
||||||
graph_def.ParseFromString(f.read())
|
graph_def.ParseFromString(f.read())
|
||||||
|
|
||||||
# Removes Identity nodes
|
removeIdentity(graph_def)
|
||||||
def removeIdentity():
|
|
||||||
identities = {}
|
|
||||||
for node in graph_def.node:
|
|
||||||
if node.op == 'Identity':
|
|
||||||
identities[node.name] = node.input[0]
|
|
||||||
graph_def.node.remove(node)
|
|
||||||
|
|
||||||
for node in graph_def.node:
|
def to_remove(name, op):
|
||||||
for i in range(len(node.input)):
|
return name.startswith(scopesToIgnore) or not name.startswith(scopesToKeep)
|
||||||
if node.input[i] in identities:
|
|
||||||
node.input[i] = identities[node.input[i]]
|
|
||||||
|
|
||||||
removeIdentity()
|
removeUnusedNodesAndAttrs(to_remove, graph_def)
|
||||||
|
|
||||||
removedNodes = []
|
|
||||||
|
|
||||||
for i in reversed(range(len(graph_def.node))):
|
|
||||||
op = graph_def.node[i].op
|
|
||||||
name = graph_def.node[i].name
|
|
||||||
|
|
||||||
if op == 'Const' or name.startswith(scopesToIgnore) or not name.startswith(scopesToKeep):
|
|
||||||
if op != 'Const':
|
|
||||||
removedNodes.append(name)
|
|
||||||
|
|
||||||
del graph_def.node[i]
|
|
||||||
else:
|
|
||||||
for attr in unusedAttrs:
|
|
||||||
if attr in graph_def.node[i].attr:
|
|
||||||
del graph_def.node[i].attr[attr]
|
|
||||||
|
|
||||||
# Remove references to removed nodes except Const nodes.
|
|
||||||
for node in graph_def.node:
|
|
||||||
for i in reversed(range(len(node.input))):
|
|
||||||
if node.input[i] in removedNodes:
|
|
||||||
del node.input[i]
|
|
||||||
|
|
||||||
|
|
||||||
# Connect input node to the first layer
|
# Connect input node to the first layer
|
||||||
@ -95,68 +62,18 @@ while True:
|
|||||||
if node.op == 'CropAndResize':
|
if node.op == 'CropAndResize':
|
||||||
break
|
break
|
||||||
|
|
||||||
def addSlice(inp, out, begins, sizes):
|
|
||||||
beginsNode = NodeDef()
|
|
||||||
beginsNode.name = out + '/begins'
|
|
||||||
beginsNode.op = 'Const'
|
|
||||||
text_format.Merge(tensorMsg(begins), beginsNode.attr["value"])
|
|
||||||
graph_def.node.extend([beginsNode])
|
|
||||||
|
|
||||||
sizesNode = NodeDef()
|
|
||||||
sizesNode.name = out + '/sizes'
|
|
||||||
sizesNode.op = 'Const'
|
|
||||||
text_format.Merge(tensorMsg(sizes), sizesNode.attr["value"])
|
|
||||||
graph_def.node.extend([sizesNode])
|
|
||||||
|
|
||||||
sliced = NodeDef()
|
|
||||||
sliced.name = out
|
|
||||||
sliced.op = 'Slice'
|
|
||||||
sliced.input.append(inp)
|
|
||||||
sliced.input.append(beginsNode.name)
|
|
||||||
sliced.input.append(sizesNode.name)
|
|
||||||
graph_def.node.extend([sliced])
|
|
||||||
|
|
||||||
def addReshape(inp, out, shape):
|
|
||||||
shapeNode = NodeDef()
|
|
||||||
shapeNode.name = out + '/shape'
|
|
||||||
shapeNode.op = 'Const'
|
|
||||||
text_format.Merge(tensorMsg(shape), shapeNode.attr["value"])
|
|
||||||
graph_def.node.extend([shapeNode])
|
|
||||||
|
|
||||||
reshape = NodeDef()
|
|
||||||
reshape.name = out
|
|
||||||
reshape.op = 'Reshape'
|
|
||||||
reshape.input.append(inp)
|
|
||||||
reshape.input.append(shapeNode.name)
|
|
||||||
graph_def.node.extend([reshape])
|
|
||||||
|
|
||||||
def addSoftMax(inp, out):
|
|
||||||
softmax = NodeDef()
|
|
||||||
softmax.name = out
|
|
||||||
softmax.op = 'Softmax'
|
|
||||||
text_format.Merge('i: -1', softmax.attr['axis'])
|
|
||||||
softmax.input.append(inp)
|
|
||||||
graph_def.node.extend([softmax])
|
|
||||||
|
|
||||||
def addFlatten(inp, out):
|
|
||||||
flatten = NodeDef()
|
|
||||||
flatten.name = out
|
|
||||||
flatten.op = 'Flatten'
|
|
||||||
flatten.input.append(inp)
|
|
||||||
graph_def.node.extend([flatten])
|
|
||||||
|
|
||||||
addReshape('FirstStageBoxPredictor/ClassPredictor/BiasAdd',
|
addReshape('FirstStageBoxPredictor/ClassPredictor/BiasAdd',
|
||||||
'FirstStageBoxPredictor/ClassPredictor/reshape_1', [0, -1, 2])
|
'FirstStageBoxPredictor/ClassPredictor/reshape_1', [0, -1, 2], graph_def)
|
||||||
|
|
||||||
addSoftMax('FirstStageBoxPredictor/ClassPredictor/reshape_1',
|
addSoftMax('FirstStageBoxPredictor/ClassPredictor/reshape_1',
|
||||||
'FirstStageBoxPredictor/ClassPredictor/softmax') # Compare with Reshape_4
|
'FirstStageBoxPredictor/ClassPredictor/softmax', graph_def) # Compare with Reshape_4
|
||||||
|
|
||||||
addFlatten('FirstStageBoxPredictor/ClassPredictor/softmax',
|
addFlatten('FirstStageBoxPredictor/ClassPredictor/softmax',
|
||||||
'FirstStageBoxPredictor/ClassPredictor/softmax/flatten')
|
'FirstStageBoxPredictor/ClassPredictor/softmax/flatten', graph_def)
|
||||||
|
|
||||||
# Compare with FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd
|
# Compare with FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd
|
||||||
addFlatten('FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd',
|
addFlatten('FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd',
|
||||||
'FirstStageBoxPredictor/BoxEncodingPredictor/flatten')
|
'FirstStageBoxPredictor/BoxEncodingPredictor/flatten', graph_def)
|
||||||
|
|
||||||
proposals = NodeDef()
|
proposals = NodeDef()
|
||||||
proposals.name = 'proposals' # Compare with ClipToWindow/Gather/Gather (NOTE: normalized)
|
proposals.name = 'proposals' # Compare with ClipToWindow/Gather/Gather (NOTE: normalized)
|
||||||
@ -218,14 +135,14 @@ graph_def.node.extend([clipByValueNode])
|
|||||||
for node in reversed(topNodes):
|
for node in reversed(topNodes):
|
||||||
graph_def.node.extend([node])
|
graph_def.node.extend([node])
|
||||||
|
|
||||||
addSoftMax('SecondStageBoxPredictor/Reshape_1', 'SecondStageBoxPredictor/Reshape_1/softmax')
|
addSoftMax('SecondStageBoxPredictor/Reshape_1', 'SecondStageBoxPredictor/Reshape_1/softmax', graph_def)
|
||||||
|
|
||||||
addSlice('SecondStageBoxPredictor/Reshape_1/softmax',
|
addSlice('SecondStageBoxPredictor/Reshape_1/softmax',
|
||||||
'SecondStageBoxPredictor/Reshape_1/slice',
|
'SecondStageBoxPredictor/Reshape_1/slice',
|
||||||
[0, 0, 1], [-1, -1, -1])
|
[0, 0, 1], [-1, -1, -1], graph_def)
|
||||||
|
|
||||||
addReshape('SecondStageBoxPredictor/Reshape_1/slice',
|
addReshape('SecondStageBoxPredictor/Reshape_1/slice',
|
||||||
'SecondStageBoxPredictor/Reshape_1/Reshape', [1, -1])
|
'SecondStageBoxPredictor/Reshape_1/Reshape', [1, -1], graph_def)
|
||||||
|
|
||||||
# Replace Flatten subgraph onto a single node.
|
# Replace Flatten subgraph onto a single node.
|
||||||
for i in reversed(range(len(graph_def.node))):
|
for i in reversed(range(len(graph_def.node))):
|
||||||
@ -255,7 +172,7 @@ for node in graph_def.node:
|
|||||||
################################################################################
|
################################################################################
|
||||||
### Postprocessing
|
### Postprocessing
|
||||||
################################################################################
|
################################################################################
|
||||||
addSlice('detection_out/clip_by_value', 'detection_out/slice', [0, 0, 0, 3], [-1, -1, -1, 4])
|
addSlice('detection_out/clip_by_value', 'detection_out/slice', [0, 0, 0, 3], [-1, -1, -1, 4], graph_def)
|
||||||
|
|
||||||
variance = NodeDef()
|
variance = NodeDef()
|
||||||
variance.name = 'proposals/variance'
|
variance.name = 'proposals/variance'
|
||||||
@ -271,8 +188,8 @@ varianceEncoder.input.append(variance.name)
|
|||||||
text_format.Merge('i: 2', varianceEncoder.attr["axis"])
|
text_format.Merge('i: 2', varianceEncoder.attr["axis"])
|
||||||
graph_def.node.extend([varianceEncoder])
|
graph_def.node.extend([varianceEncoder])
|
||||||
|
|
||||||
addReshape('detection_out/slice', 'detection_out/slice/reshape', [1, 1, -1])
|
addReshape('detection_out/slice', 'detection_out/slice/reshape', [1, 1, -1], graph_def)
|
||||||
addFlatten('variance_encoded', 'variance_encoded/flatten')
|
addFlatten('variance_encoded', 'variance_encoded/flatten', graph_def)
|
||||||
|
|
||||||
detectionOut = NodeDef()
|
detectionOut = NodeDef()
|
||||||
detectionOut.name = 'detection_out_final'
|
detectionOut.name = 'detection_out_final'
|
||||||
|
230
samples/dnn/tf_text_graph_mask_rcnn.py
Normal file
230
samples/dnn/tf_text_graph_mask_rcnn.py
Normal file
@ -0,0 +1,230 @@
|
|||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from tensorflow.core.framework.node_def_pb2 import NodeDef
|
||||||
|
from tensorflow.tools.graph_transforms import TransformGraph
|
||||||
|
from google.protobuf import text_format
|
||||||
|
|
||||||
|
from tf_text_graph_common import *
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='Run this script to get a text graph of '
|
||||||
|
'Mask-RCNN model from TensorFlow Object Detection API. '
|
||||||
|
'Then pass it with .pb file to cv::dnn::readNetFromTensorflow function.')
|
||||||
|
parser.add_argument('--input', required=True, help='Path to frozen TensorFlow graph.')
|
||||||
|
parser.add_argument('--output', required=True, help='Path to output text graph.')
|
||||||
|
parser.add_argument('--num_classes', default=90, type=int, help='Number of trained classes.')
|
||||||
|
parser.add_argument('--scales', default=[0.25, 0.5, 1.0, 2.0], type=float, nargs='+',
|
||||||
|
help='Hyper-parameter of grid_anchor_generator from a config file.')
|
||||||
|
parser.add_argument('--aspect_ratios', default=[0.5, 1.0, 2.0], type=float, nargs='+',
|
||||||
|
help='Hyper-parameter of grid_anchor_generator from a config file.')
|
||||||
|
parser.add_argument('--features_stride', default=16, type=float, nargs='+',
|
||||||
|
help='Hyper-parameter from a config file.')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
scopesToKeep = ('FirstStageFeatureExtractor', 'Conv',
|
||||||
|
'FirstStageBoxPredictor/BoxEncodingPredictor',
|
||||||
|
'FirstStageBoxPredictor/ClassPredictor',
|
||||||
|
'CropAndResize',
|
||||||
|
'MaxPool2D',
|
||||||
|
'SecondStageFeatureExtractor',
|
||||||
|
'SecondStageBoxPredictor',
|
||||||
|
'Preprocessor/sub',
|
||||||
|
'Preprocessor/mul',
|
||||||
|
'image_tensor')
|
||||||
|
|
||||||
|
scopesToIgnore = ('FirstStageFeatureExtractor/Assert',
|
||||||
|
'FirstStageFeatureExtractor/Shape',
|
||||||
|
'FirstStageFeatureExtractor/strided_slice',
|
||||||
|
'FirstStageFeatureExtractor/GreaterEqual',
|
||||||
|
'FirstStageFeatureExtractor/LogicalAnd')
|
||||||
|
|
||||||
|
|
||||||
|
# Read the graph.
|
||||||
|
with tf.gfile.FastGFile(args.input, 'rb') as f:
|
||||||
|
graph_def = tf.GraphDef()
|
||||||
|
graph_def.ParseFromString(f.read())
|
||||||
|
|
||||||
|
removeIdentity(graph_def)
|
||||||
|
|
||||||
|
def to_remove(name, op):
|
||||||
|
return name.startswith(scopesToIgnore) or not name.startswith(scopesToKeep)
|
||||||
|
|
||||||
|
removeUnusedNodesAndAttrs(to_remove, graph_def)
|
||||||
|
|
||||||
|
|
||||||
|
# Connect input node to the first layer
|
||||||
|
assert(graph_def.node[0].op == 'Placeholder')
|
||||||
|
graph_def.node[1].input.insert(0, graph_def.node[0].name)
|
||||||
|
|
||||||
|
# Temporarily remove top nodes.
|
||||||
|
topNodes = []
|
||||||
|
numCropAndResize = 0
|
||||||
|
while True:
|
||||||
|
node = graph_def.node.pop()
|
||||||
|
topNodes.append(node)
|
||||||
|
if node.op == 'CropAndResize':
|
||||||
|
numCropAndResize += 1
|
||||||
|
if numCropAndResize == 2:
|
||||||
|
break
|
||||||
|
|
||||||
|
addReshape('FirstStageBoxPredictor/ClassPredictor/BiasAdd',
|
||||||
|
'FirstStageBoxPredictor/ClassPredictor/reshape_1', [0, -1, 2], graph_def)
|
||||||
|
|
||||||
|
addSoftMax('FirstStageBoxPredictor/ClassPredictor/reshape_1',
|
||||||
|
'FirstStageBoxPredictor/ClassPredictor/softmax', graph_def) # Compare with Reshape_4
|
||||||
|
|
||||||
|
addFlatten('FirstStageBoxPredictor/ClassPredictor/softmax',
|
||||||
|
'FirstStageBoxPredictor/ClassPredictor/softmax/flatten', graph_def)
|
||||||
|
|
||||||
|
# Compare with FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd
|
||||||
|
addFlatten('FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd',
|
||||||
|
'FirstStageBoxPredictor/BoxEncodingPredictor/flatten', graph_def)
|
||||||
|
|
||||||
|
proposals = NodeDef()
|
||||||
|
proposals.name = 'proposals' # Compare with ClipToWindow/Gather/Gather (NOTE: normalized)
|
||||||
|
proposals.op = 'PriorBox'
|
||||||
|
proposals.input.append('FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd')
|
||||||
|
proposals.input.append(graph_def.node[0].name) # image_tensor
|
||||||
|
|
||||||
|
text_format.Merge('b: false', proposals.attr["flip"])
|
||||||
|
text_format.Merge('b: true', proposals.attr["clip"])
|
||||||
|
text_format.Merge('f: %f' % args.features_stride, proposals.attr["step"])
|
||||||
|
text_format.Merge('f: 0.0', proposals.attr["offset"])
|
||||||
|
text_format.Merge(tensorMsg([0.1, 0.1, 0.2, 0.2]), proposals.attr["variance"])
|
||||||
|
|
||||||
|
widths = []
|
||||||
|
heights = []
|
||||||
|
for a in args.aspect_ratios:
|
||||||
|
for s in args.scales:
|
||||||
|
ar = np.sqrt(a)
|
||||||
|
heights.append((args.features_stride**2) * s / ar)
|
||||||
|
widths.append((args.features_stride**2) * s * ar)
|
||||||
|
|
||||||
|
text_format.Merge(tensorMsg(widths), proposals.attr["width"])
|
||||||
|
text_format.Merge(tensorMsg(heights), proposals.attr["height"])
|
||||||
|
|
||||||
|
graph_def.node.extend([proposals])
|
||||||
|
|
||||||
|
# Compare with Reshape_5
|
||||||
|
detectionOut = NodeDef()
|
||||||
|
detectionOut.name = 'detection_out'
|
||||||
|
detectionOut.op = 'DetectionOutput'
|
||||||
|
|
||||||
|
detectionOut.input.append('FirstStageBoxPredictor/BoxEncodingPredictor/flatten')
|
||||||
|
detectionOut.input.append('FirstStageBoxPredictor/ClassPredictor/softmax/flatten')
|
||||||
|
detectionOut.input.append('proposals')
|
||||||
|
|
||||||
|
text_format.Merge('i: 2', detectionOut.attr['num_classes'])
|
||||||
|
text_format.Merge('b: true', detectionOut.attr['share_location'])
|
||||||
|
text_format.Merge('i: 0', detectionOut.attr['background_label_id'])
|
||||||
|
text_format.Merge('f: 0.7', detectionOut.attr['nms_threshold'])
|
||||||
|
text_format.Merge('i: 6000', detectionOut.attr['top_k'])
|
||||||
|
text_format.Merge('s: "CENTER_SIZE"', detectionOut.attr['code_type'])
|
||||||
|
text_format.Merge('i: 100', detectionOut.attr['keep_top_k'])
|
||||||
|
text_format.Merge('b: true', detectionOut.attr['clip'])
|
||||||
|
|
||||||
|
graph_def.node.extend([detectionOut])
|
||||||
|
|
||||||
|
# Save as text.
|
||||||
|
for node in reversed(topNodes):
|
||||||
|
if node.op != 'CropAndResize':
|
||||||
|
graph_def.node.extend([node])
|
||||||
|
topNodes.pop()
|
||||||
|
else:
|
||||||
|
if numCropAndResize == 1:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
graph_def.node.extend([node])
|
||||||
|
topNodes.pop()
|
||||||
|
numCropAndResize -= 1
|
||||||
|
|
||||||
|
addSoftMax('SecondStageBoxPredictor/Reshape_1', 'SecondStageBoxPredictor/Reshape_1/softmax', graph_def)
|
||||||
|
|
||||||
|
addSlice('SecondStageBoxPredictor/Reshape_1/softmax',
|
||||||
|
'SecondStageBoxPredictor/Reshape_1/slice',
|
||||||
|
[0, 0, 1], [-1, -1, -1], graph_def)
|
||||||
|
|
||||||
|
addReshape('SecondStageBoxPredictor/Reshape_1/slice',
|
||||||
|
'SecondStageBoxPredictor/Reshape_1/Reshape', [1, -1], graph_def)
|
||||||
|
|
||||||
|
# Replace Flatten subgraph onto a single node.
|
||||||
|
for i in reversed(range(len(graph_def.node))):
|
||||||
|
if graph_def.node[i].op == 'CropAndResize':
|
||||||
|
graph_def.node[i].input.insert(1, 'detection_out')
|
||||||
|
|
||||||
|
if graph_def.node[i].name == 'SecondStageBoxPredictor/Reshape':
|
||||||
|
addConstNode('SecondStageBoxPredictor/Reshape/shape2', [1, -1, 4], graph_def)
|
||||||
|
|
||||||
|
graph_def.node[i].input.pop()
|
||||||
|
graph_def.node[i].input.append('SecondStageBoxPredictor/Reshape/shape2')
|
||||||
|
|
||||||
|
if graph_def.node[i].name in ['SecondStageBoxPredictor/Flatten/flatten/Shape',
|
||||||
|
'SecondStageBoxPredictor/Flatten/flatten/strided_slice',
|
||||||
|
'SecondStageBoxPredictor/Flatten/flatten/Reshape/shape']:
|
||||||
|
del graph_def.node[i]
|
||||||
|
|
||||||
|
for node in graph_def.node:
|
||||||
|
if node.name == 'SecondStageBoxPredictor/Flatten/flatten/Reshape':
|
||||||
|
node.op = 'Flatten'
|
||||||
|
node.input.pop()
|
||||||
|
|
||||||
|
if node.name in ['FirstStageBoxPredictor/BoxEncodingPredictor/Conv2D',
|
||||||
|
'SecondStageBoxPredictor/BoxEncodingPredictor/MatMul']:
|
||||||
|
text_format.Merge('b: true', node.attr["loc_pred_transposed"])
|
||||||
|
|
||||||
|
################################################################################
|
||||||
|
### Postprocessing
|
||||||
|
################################################################################
|
||||||
|
addSlice('detection_out', 'detection_out/slice', [0, 0, 0, 3], [-1, -1, -1, 4], graph_def)
|
||||||
|
|
||||||
|
variance = NodeDef()
|
||||||
|
variance.name = 'proposals/variance'
|
||||||
|
variance.op = 'Const'
|
||||||
|
text_format.Merge(tensorMsg([0.1, 0.1, 0.2, 0.2]), variance.attr["value"])
|
||||||
|
graph_def.node.extend([variance])
|
||||||
|
|
||||||
|
varianceEncoder = NodeDef()
|
||||||
|
varianceEncoder.name = 'variance_encoded'
|
||||||
|
varianceEncoder.op = 'Mul'
|
||||||
|
varianceEncoder.input.append('SecondStageBoxPredictor/Reshape')
|
||||||
|
varianceEncoder.input.append(variance.name)
|
||||||
|
text_format.Merge('i: 2', varianceEncoder.attr["axis"])
|
||||||
|
graph_def.node.extend([varianceEncoder])
|
||||||
|
|
||||||
|
addReshape('detection_out/slice', 'detection_out/slice/reshape', [1, 1, -1], graph_def)
|
||||||
|
addFlatten('variance_encoded', 'variance_encoded/flatten', graph_def)
|
||||||
|
|
||||||
|
detectionOut = NodeDef()
|
||||||
|
detectionOut.name = 'detection_out_final'
|
||||||
|
detectionOut.op = 'DetectionOutput'
|
||||||
|
|
||||||
|
detectionOut.input.append('variance_encoded/flatten')
|
||||||
|
detectionOut.input.append('SecondStageBoxPredictor/Reshape_1/Reshape')
|
||||||
|
detectionOut.input.append('detection_out/slice/reshape')
|
||||||
|
|
||||||
|
text_format.Merge('i: %d' % args.num_classes, detectionOut.attr['num_classes'])
|
||||||
|
text_format.Merge('b: false', detectionOut.attr['share_location'])
|
||||||
|
text_format.Merge('i: %d' % (args.num_classes + 1), detectionOut.attr['background_label_id'])
|
||||||
|
text_format.Merge('f: 0.6', detectionOut.attr['nms_threshold'])
|
||||||
|
text_format.Merge('s: "CENTER_SIZE"', detectionOut.attr['code_type'])
|
||||||
|
text_format.Merge('i: 100', detectionOut.attr['keep_top_k'])
|
||||||
|
text_format.Merge('b: true', detectionOut.attr['clip'])
|
||||||
|
text_format.Merge('b: true', detectionOut.attr['variance_encoded_in_target'])
|
||||||
|
text_format.Merge('f: 0.3', detectionOut.attr['confidence_threshold'])
|
||||||
|
text_format.Merge('b: false', detectionOut.attr['group_by_classes'])
|
||||||
|
graph_def.node.extend([detectionOut])
|
||||||
|
|
||||||
|
for node in reversed(topNodes):
|
||||||
|
graph_def.node.extend([node])
|
||||||
|
|
||||||
|
for i in reversed(range(len(graph_def.node))):
|
||||||
|
if graph_def.node[i].op == 'CropAndResize':
|
||||||
|
graph_def.node[i].input.insert(1, 'detection_out_final')
|
||||||
|
break
|
||||||
|
|
||||||
|
graph_def.node[-1].name = 'detection_masks'
|
||||||
|
graph_def.node[-1].op = 'Sigmoid'
|
||||||
|
graph_def.node[-1].input.pop()
|
||||||
|
|
||||||
|
tf.train.write_graph(graph_def, "", args.output, as_text=True)
|
@ -15,7 +15,7 @@ from math import sqrt
|
|||||||
from tensorflow.core.framework.node_def_pb2 import NodeDef
|
from tensorflow.core.framework.node_def_pb2 import NodeDef
|
||||||
from tensorflow.tools.graph_transforms import TransformGraph
|
from tensorflow.tools.graph_transforms import TransformGraph
|
||||||
from google.protobuf import text_format
|
from google.protobuf import text_format
|
||||||
from tf_text_graph_common import tensorMsg, addConstNode
|
from tf_text_graph_common import *
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='Run this script to get a text graph of '
|
parser = argparse.ArgumentParser(description='Run this script to get a text graph of '
|
||||||
'SSD model from TensorFlow Object Detection API. '
|
'SSD model from TensorFlow Object Detection API. '
|
||||||
@ -41,10 +41,6 @@ args = parser.parse_args()
|
|||||||
keepOps = ['Conv2D', 'BiasAdd', 'Add', 'Relu6', 'Placeholder', 'FusedBatchNorm',
|
keepOps = ['Conv2D', 'BiasAdd', 'Add', 'Relu6', 'Placeholder', 'FusedBatchNorm',
|
||||||
'DepthwiseConv2dNative', 'ConcatV2', 'Mul', 'MaxPool', 'AvgPool', 'Identity']
|
'DepthwiseConv2dNative', 'ConcatV2', 'Mul', 'MaxPool', 'AvgPool', 'Identity']
|
||||||
|
|
||||||
# Nodes attributes that could be removed because they are not used during import.
|
|
||||||
unusedAttrs = ['T', 'data_format', 'Tshape', 'N', 'Tidx', 'Tdim', 'use_cudnn_on_gpu',
|
|
||||||
'Index', 'Tperm', 'is_training', 'Tpaddings']
|
|
||||||
|
|
||||||
# Node with which prefixes should be removed
|
# Node with which prefixes should be removed
|
||||||
prefixesToRemove = ('MultipleGridAnchorGenerator/', 'Postprocessor/', 'Preprocessor/')
|
prefixesToRemove = ('MultipleGridAnchorGenerator/', 'Postprocessor/', 'Preprocessor/')
|
||||||
|
|
||||||
@ -66,7 +62,6 @@ def getUnconnectedNodes():
|
|||||||
unconnected.remove(inp)
|
unconnected.remove(inp)
|
||||||
return unconnected
|
return unconnected
|
||||||
|
|
||||||
removedNodes = []
|
|
||||||
|
|
||||||
# Detect unfused batch normalization nodes and fuse them.
|
# Detect unfused batch normalization nodes and fuse them.
|
||||||
def fuse_batch_normalization():
|
def fuse_batch_normalization():
|
||||||
@ -118,41 +113,13 @@ def fuse_batch_normalization():
|
|||||||
|
|
||||||
fuse_batch_normalization()
|
fuse_batch_normalization()
|
||||||
|
|
||||||
# Removes Identity nodes
|
removeIdentity(graph_def)
|
||||||
def removeIdentity():
|
|
||||||
identities = {}
|
|
||||||
for node in graph_def.node:
|
|
||||||
if node.op == 'Identity':
|
|
||||||
identities[node.name] = node.input[0]
|
|
||||||
graph_def.node.remove(node)
|
|
||||||
|
|
||||||
for node in graph_def.node:
|
def to_remove(name, op):
|
||||||
for i in range(len(node.input)):
|
return (not op in keepOps) or name.startswith(prefixesToRemove)
|
||||||
if node.input[i] in identities:
|
|
||||||
node.input[i] = identities[node.input[i]]
|
|
||||||
|
|
||||||
removeIdentity()
|
removeUnusedNodesAndAttrs(to_remove, graph_def)
|
||||||
|
|
||||||
# Remove extra nodes and attributes.
|
|
||||||
for i in reversed(range(len(graph_def.node))):
|
|
||||||
op = graph_def.node[i].op
|
|
||||||
name = graph_def.node[i].name
|
|
||||||
|
|
||||||
if (not op in keepOps) or name.startswith(prefixesToRemove):
|
|
||||||
if op != 'Const':
|
|
||||||
removedNodes.append(name)
|
|
||||||
|
|
||||||
del graph_def.node[i]
|
|
||||||
else:
|
|
||||||
for attr in unusedAttrs:
|
|
||||||
if attr in graph_def.node[i].attr:
|
|
||||||
del graph_def.node[i].attr[attr]
|
|
||||||
|
|
||||||
# Remove references to removed nodes except Const nodes.
|
|
||||||
for node in graph_def.node:
|
|
||||||
for i in reversed(range(len(node.input))):
|
|
||||||
if node.input[i] in removedNodes:
|
|
||||||
del node.input[i]
|
|
||||||
|
|
||||||
# Connect input node to the first layer
|
# Connect input node to the first layer
|
||||||
assert(graph_def.node[0].op == 'Placeholder')
|
assert(graph_def.node[0].op == 'Placeholder')
|
||||||
@ -175,8 +142,8 @@ def addConcatNode(name, inputs, axisNodeName):
|
|||||||
concat.input.append(axisNodeName)
|
concat.input.append(axisNodeName)
|
||||||
graph_def.node.extend([concat])
|
graph_def.node.extend([concat])
|
||||||
|
|
||||||
addConstNode('concat/axis_flatten', [-1])
|
addConstNode('concat/axis_flatten', [-1], graph_def)
|
||||||
addConstNode('PriorBox/concat/axis', [-2])
|
addConstNode('PriorBox/concat/axis', [-2], graph_def)
|
||||||
|
|
||||||
for label in ['ClassPredictor', 'BoxEncodingPredictor' if args.box_predictor is 'convolutional' else 'BoxPredictor']:
|
for label in ['ClassPredictor', 'BoxEncodingPredictor' if args.box_predictor is 'convolutional' else 'BoxPredictor']:
|
||||||
concatInputs = []
|
concatInputs = []
|
||||||
|
Loading…
Reference in New Issue
Block a user