mirror of
https://github.com/opencv/opencv.git
synced 2025-01-18 22:44:02 +08:00
Merge pull request #9750 from dkurt:feature_dnn_tf_text_graph
This commit is contained in:
commit
046045239c
@ -629,7 +629,7 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
|
||||
/** @brief Reads a network model stored in Tensorflow model file.
|
||||
* @details This is shortcut consisting from createTensorflowImporter and Net::populateNet calls.
|
||||
*/
|
||||
CV_EXPORTS_W Net readNetFromTensorflow(const String &model);
|
||||
CV_EXPORTS_W Net readNetFromTensorflow(const String &model, const String &config = String());
|
||||
|
||||
/** @brief Reads a network model stored in Torch model file.
|
||||
* @details This is shortcut consisting from createTorchImporter and Net::populateNet calls.
|
||||
|
@ -81,6 +81,8 @@ public:
|
||||
|
||||
float _nmsThreshold;
|
||||
int _topK;
|
||||
// Whenever predicted bounding boxes are respresented in YXHW instead of XYWH layout.
|
||||
bool _locPredTransposed;
|
||||
|
||||
enum { _numAxes = 4 };
|
||||
static const std::string _layerName;
|
||||
@ -148,6 +150,7 @@ public:
|
||||
_keepTopK = getParameter<int>(params, "keep_top_k");
|
||||
_confidenceThreshold = getParameter<float>(params, "confidence_threshold", 0, false, -FLT_MAX);
|
||||
_topK = getParameter<int>(params, "top_k", 0, false, -1);
|
||||
_locPredTransposed = getParameter<bool>(params, "loc_pred_transposed", 0, false, false);
|
||||
|
||||
getCodeType(params);
|
||||
|
||||
@ -209,7 +212,7 @@ public:
|
||||
// Retrieve all location predictions
|
||||
std::vector<LabelBBox> allLocationPredictions;
|
||||
GetLocPredictions(locationData, num, numPriors, _numLocClasses,
|
||||
_shareLocation, allLocationPredictions);
|
||||
_shareLocation, _locPredTransposed, allLocationPredictions);
|
||||
|
||||
// Retrieve all confidences
|
||||
GetConfidenceScores(confidenceData, num, numPriors, _numClasses, allConfidenceScores);
|
||||
@ -540,11 +543,14 @@ public:
|
||||
// num_loc_classes: number of location classes. It is 1 if share_location is
|
||||
// true; and is equal to number of classes needed to predict otherwise.
|
||||
// share_location: if true, all classes share the same location prediction.
|
||||
// loc_pred_transposed: if true, represent four bounding box values as
|
||||
// [y,x,height,width] or [x,y,width,height] otherwise.
|
||||
// loc_preds: stores the location prediction, where each item contains
|
||||
// location prediction for an image.
|
||||
static void GetLocPredictions(const float* locData, const int num,
|
||||
const int numPredsPerClass, const int numLocClasses,
|
||||
const bool shareLocation, std::vector<LabelBBox>& locPreds)
|
||||
const bool shareLocation, const bool locPredTransposed,
|
||||
std::vector<LabelBBox>& locPreds)
|
||||
{
|
||||
locPreds.clear();
|
||||
if (shareLocation)
|
||||
@ -566,10 +572,20 @@ public:
|
||||
labelBBox[label].resize(numPredsPerClass);
|
||||
}
|
||||
caffe::NormalizedBBox& bbox = labelBBox[label][p];
|
||||
bbox.set_xmin(locData[startIdx + c * 4]);
|
||||
bbox.set_ymin(locData[startIdx + c * 4 + 1]);
|
||||
bbox.set_xmax(locData[startIdx + c * 4 + 2]);
|
||||
bbox.set_ymax(locData[startIdx + c * 4 + 3]);
|
||||
if (locPredTransposed)
|
||||
{
|
||||
bbox.set_ymin(locData[startIdx + c * 4]);
|
||||
bbox.set_xmin(locData[startIdx + c * 4 + 1]);
|
||||
bbox.set_ymax(locData[startIdx + c * 4 + 2]);
|
||||
bbox.set_xmax(locData[startIdx + c * 4 + 3]);
|
||||
}
|
||||
else
|
||||
{
|
||||
bbox.set_xmin(locData[startIdx + c * 4]);
|
||||
bbox.set_ymin(locData[startIdx + c * 4 + 1]);
|
||||
bbox.set_xmax(locData[startIdx + c * 4 + 2]);
|
||||
bbox.set_ymax(locData[startIdx + c * 4 + 3]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -124,6 +124,20 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
void getScales(const LayerParams ¶ms)
|
||||
{
|
||||
DictValue scalesParameter;
|
||||
bool scalesRetieved = getParameterDict(params, "scales", scalesParameter);
|
||||
if (scalesRetieved)
|
||||
{
|
||||
_scales.resize(scalesParameter.size());
|
||||
for (int i = 0; i < scalesParameter.size(); ++i)
|
||||
{
|
||||
_scales[i] = scalesParameter.get<float>(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void getVariance(const LayerParams ¶ms)
|
||||
{
|
||||
DictValue varianceParameter;
|
||||
@ -169,13 +183,14 @@ public:
|
||||
_flip = getParameter<bool>(params, "flip");
|
||||
_clip = getParameter<bool>(params, "clip");
|
||||
|
||||
_scales.clear();
|
||||
_aspectRatios.clear();
|
||||
_aspectRatios.push_back(1.);
|
||||
|
||||
getAspectRatios(params);
|
||||
getVariance(params);
|
||||
getScales(params);
|
||||
|
||||
_numPriors = _aspectRatios.size();
|
||||
_numPriors = _aspectRatios.size() + 1; // + 1 for an aspect ratio 1.0
|
||||
|
||||
_maxSize = -1;
|
||||
if (params.has("max_size"))
|
||||
@ -231,6 +246,11 @@ public:
|
||||
CV_TRACE_FUNCTION();
|
||||
CV_TRACE_ARG_VALUE(name, "name", name.c_str());
|
||||
|
||||
if (_scales.empty())
|
||||
_scales.resize(_numPriors, 1.0f);
|
||||
else
|
||||
CV_Assert(_scales.size() == _numPriors);
|
||||
|
||||
int _layerWidth = inputs[0]->size[3];
|
||||
int _layerHeight = inputs[0]->size[2];
|
||||
|
||||
@ -256,7 +276,7 @@ public:
|
||||
{
|
||||
for (size_t w = 0; w < _layerWidth; ++w)
|
||||
{
|
||||
_boxWidth = _boxHeight = _minSize;
|
||||
_boxWidth = _boxHeight = _minSize * _scales[0];
|
||||
|
||||
float center_x = (w + 0.5) * stepX;
|
||||
float center_y = (h + 0.5) * stepY;
|
||||
@ -272,7 +292,7 @@ public:
|
||||
if (_maxSize > 0)
|
||||
{
|
||||
// second prior: aspect_ratio = 1, size = sqrt(min_size * max_size)
|
||||
_boxWidth = _boxHeight = sqrt(_minSize * _maxSize);
|
||||
_boxWidth = _boxHeight = sqrt(_minSize * _maxSize) * _scales[1];
|
||||
// xmin
|
||||
outputPtr[idx++] = (center_x - _boxWidth / 2.) / _imageWidth;
|
||||
// ymin
|
||||
@ -284,15 +304,13 @@ public:
|
||||
}
|
||||
|
||||
// rest of priors
|
||||
CV_Assert((_maxSize > 0 ? 2 : 1) + _aspectRatios.size() == _scales.size());
|
||||
for (size_t r = 0; r < _aspectRatios.size(); ++r)
|
||||
{
|
||||
float ar = _aspectRatios[r];
|
||||
if (fabs(ar - 1.) < 1e-6)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
_boxWidth = _minSize * sqrt(ar);
|
||||
_boxHeight = _minSize / sqrt(ar);
|
||||
float scale = _scales[(_maxSize > 0 ? 2 : 1) + r];
|
||||
_boxWidth = _minSize * sqrt(ar) * scale;
|
||||
_boxHeight = _minSize / sqrt(ar) * scale;
|
||||
// xmin
|
||||
outputPtr[idx++] = (center_x - _boxWidth / 2.) / _imageWidth;
|
||||
// ymin
|
||||
@ -363,6 +381,7 @@ public:
|
||||
|
||||
std::vector<float> _aspectRatios;
|
||||
std::vector<float> _variance;
|
||||
std::vector<float> _scales;
|
||||
|
||||
bool _flip;
|
||||
bool _clip;
|
||||
|
@ -321,10 +321,10 @@ DictValue parseDims(const tensorflow::TensorProto &tensor) {
|
||||
CV_Assert(tensor.dtype() == tensorflow::DT_INT32);
|
||||
CV_Assert(dims == 1);
|
||||
|
||||
int size = tensor.tensor_content().size() / sizeof(int);
|
||||
const int *data = reinterpret_cast<const int*>(tensor.tensor_content().c_str());
|
||||
Mat values = getTensorContent(tensor);
|
||||
CV_Assert(values.type() == CV_32SC1);
|
||||
// TODO: add reordering shape if dims == 4
|
||||
return DictValue::arrayInt(data, size);
|
||||
return DictValue::arrayInt((int*)values.data, values.total());
|
||||
}
|
||||
|
||||
void setKSize(LayerParams &layerParams, const tensorflow::NodeDef &layer)
|
||||
@ -448,7 +448,7 @@ void ExcludeLayer(tensorflow::GraphDef& net, const int layer_index, const int in
|
||||
|
||||
class TFImporter : public Importer {
|
||||
public:
|
||||
TFImporter(const char *model);
|
||||
TFImporter(const char *model, const char *config = NULL);
|
||||
void populateNet(Net dstNet);
|
||||
~TFImporter() {}
|
||||
|
||||
@ -463,13 +463,20 @@ private:
|
||||
int input_blob_index = -1, int* actual_inp_blob_idx = 0);
|
||||
|
||||
|
||||
tensorflow::GraphDef net;
|
||||
// Binary serialized TensorFlow graph includes weights.
|
||||
tensorflow::GraphDef netBin;
|
||||
// Optional text definition of TensorFlow graph. More flexible than binary format
|
||||
// and may be used to build the network using binary format only as a weights storage.
|
||||
// This approach is similar to Caffe's `.prorotxt` and `.caffemodel`.
|
||||
tensorflow::GraphDef netTxt;
|
||||
};
|
||||
|
||||
TFImporter::TFImporter(const char *model)
|
||||
TFImporter::TFImporter(const char *model, const char *config)
|
||||
{
|
||||
if (model && model[0])
|
||||
ReadTFNetParamsFromBinaryFileOrDie(model, &net);
|
||||
ReadTFNetParamsFromBinaryFileOrDie(model, &netBin);
|
||||
if (config && config[0])
|
||||
ReadTFNetParamsFromTextFileOrDie(config, &netTxt);
|
||||
}
|
||||
|
||||
void TFImporter::kernelFromTensor(const tensorflow::TensorProto &tensor, Mat &dstBlob)
|
||||
@ -557,21 +564,23 @@ const tensorflow::TensorProto& TFImporter::getConstBlob(const tensorflow::NodeDe
|
||||
*actual_inp_blob_idx = input_blob_index;
|
||||
}
|
||||
|
||||
return net.node(const_layers.at(kernel_inp.name)).attr().at("value").tensor();
|
||||
int nodeIdx = const_layers.at(kernel_inp.name);
|
||||
if (nodeIdx < netBin.node_size() && netBin.node(nodeIdx).name() == kernel_inp.name)
|
||||
{
|
||||
return netBin.node(nodeIdx).attr().at("value").tensor();
|
||||
}
|
||||
else
|
||||
{
|
||||
CV_Assert(nodeIdx < netTxt.node_size(),
|
||||
netTxt.node(nodeIdx).name() == kernel_inp.name);
|
||||
return netTxt.node(nodeIdx).attr().at("value").tensor();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void TFImporter::populateNet(Net dstNet)
|
||||
static void addConstNodes(const tensorflow::GraphDef& net, std::map<String, int>& const_layers,
|
||||
std::set<String>& layers_to_ignore)
|
||||
{
|
||||
RemoveIdentityOps(net);
|
||||
|
||||
std::map<int, String> layers_to_ignore;
|
||||
|
||||
int layersSize = net.node_size();
|
||||
|
||||
// find all Const layers for params
|
||||
std::map<String, int> value_id;
|
||||
for (int li = 0; li < layersSize; li++)
|
||||
for (int li = 0; li < net.node_size(); li++)
|
||||
{
|
||||
const tensorflow::NodeDef &layer = net.node(li);
|
||||
String name = layer.name();
|
||||
@ -582,11 +591,27 @@ void TFImporter::populateNet(Net dstNet)
|
||||
|
||||
if (layer.attr().find("value") != layer.attr().end())
|
||||
{
|
||||
value_id.insert(std::make_pair(name, li));
|
||||
CV_Assert(const_layers.insert(std::make_pair(name, li)).second);
|
||||
}
|
||||
|
||||
layers_to_ignore[li] = name;
|
||||
layers_to_ignore.insert(name);
|
||||
}
|
||||
}
|
||||
|
||||
void TFImporter::populateNet(Net dstNet)
|
||||
{
|
||||
RemoveIdentityOps(netBin);
|
||||
RemoveIdentityOps(netTxt);
|
||||
|
||||
std::set<String> layers_to_ignore;
|
||||
|
||||
tensorflow::GraphDef& net = netTxt.ByteSize() != 0 ? netTxt : netBin;
|
||||
|
||||
int layersSize = net.node_size();
|
||||
|
||||
// find all Const layers for params
|
||||
std::map<String, int> value_id;
|
||||
addConstNodes(netBin, value_id, layers_to_ignore);
|
||||
addConstNodes(netTxt, value_id, layers_to_ignore);
|
||||
|
||||
std::map<String, int> layer_id;
|
||||
|
||||
@ -597,7 +622,7 @@ void TFImporter::populateNet(Net dstNet)
|
||||
String type = layer.op();
|
||||
LayerParams layerParams;
|
||||
|
||||
if(layers_to_ignore.find(li) != layers_to_ignore.end())
|
||||
if(layers_to_ignore.find(name) != layers_to_ignore.end())
|
||||
continue;
|
||||
|
||||
if (type == "Conv2D" || type == "SpaceToBatchND" || type == "DepthwiseConv2dNative")
|
||||
@ -627,7 +652,7 @@ void TFImporter::populateNet(Net dstNet)
|
||||
StrIntVector next_layers = getNextLayers(net, name, "Conv2D");
|
||||
CV_Assert(next_layers.size() == 1);
|
||||
layer = net.node(next_layers[0].second);
|
||||
layers_to_ignore[next_layers[0].second] = next_layers[0].first;
|
||||
layers_to_ignore.insert(next_layers[0].first);
|
||||
name = layer.name();
|
||||
type = layer.op();
|
||||
}
|
||||
@ -644,7 +669,7 @@ void TFImporter::populateNet(Net dstNet)
|
||||
|
||||
blobFromTensor(getConstBlob(net.node(weights_layer_index), value_id), layerParams.blobs[1]);
|
||||
ExcludeLayer(net, weights_layer_index, 0, false);
|
||||
layers_to_ignore[weights_layer_index] = next_layers[0].first;
|
||||
layers_to_ignore.insert(next_layers[0].first);
|
||||
}
|
||||
|
||||
kernelFromTensor(getConstBlob(layer, value_id), layerParams.blobs[0]);
|
||||
@ -684,7 +709,7 @@ void TFImporter::populateNet(Net dstNet)
|
||||
layerParams.set("pad_mode", ""); // We use padding values.
|
||||
CV_Assert(next_layers.size() == 1);
|
||||
ExcludeLayer(net, next_layers[0].second, 0, false);
|
||||
layers_to_ignore[next_layers[0].second] = next_layers[0].first;
|
||||
layers_to_ignore.insert(next_layers[0].first);
|
||||
}
|
||||
|
||||
int id = dstNet.addLayer(name, "Convolution", layerParams);
|
||||
@ -748,7 +773,7 @@ void TFImporter::populateNet(Net dstNet)
|
||||
int weights_layer_index = next_layers[0].second;
|
||||
blobFromTensor(getConstBlob(net.node(weights_layer_index), value_id), layerParams.blobs[1]);
|
||||
ExcludeLayer(net, weights_layer_index, 0, false);
|
||||
layers_to_ignore[weights_layer_index] = next_layers[0].first;
|
||||
layers_to_ignore.insert(next_layers[0].first);
|
||||
}
|
||||
|
||||
int kernel_blob_index = -1;
|
||||
@ -778,6 +803,30 @@ void TFImporter::populateNet(Net dstNet)
|
||||
// one input only
|
||||
connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0);
|
||||
}
|
||||
else if (type == "Flatten")
|
||||
{
|
||||
int id = dstNet.addLayer(name, "Flatten", layerParams);
|
||||
layer_id[name] = id;
|
||||
connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0);
|
||||
}
|
||||
else if (type == "Transpose")
|
||||
{
|
||||
Mat perm = getTensorContent(getConstBlob(layer, value_id, 1));
|
||||
CV_Assert(perm.type() == CV_32SC1);
|
||||
int* permData = (int*)perm.data;
|
||||
if (perm.total() == 4)
|
||||
{
|
||||
for (int i = 0; i < 4; ++i)
|
||||
permData[i] = toNCHW[permData[i]];
|
||||
}
|
||||
layerParams.set("order", DictValue::arrayInt<int*>(permData, perm.total()));
|
||||
|
||||
int id = dstNet.addLayer(name, "Permute", layerParams);
|
||||
layer_id[name] = id;
|
||||
|
||||
// one input only
|
||||
connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0);
|
||||
}
|
||||
else if (type == "Const")
|
||||
{
|
||||
}
|
||||
@ -807,7 +856,7 @@ void TFImporter::populateNet(Net dstNet)
|
||||
{
|
||||
int axisId = (type == "Concat" ? 0 : layer.input_size() - 1);
|
||||
int axis = getConstBlob(layer, value_id, axisId).int_val().Get(0);
|
||||
layerParams.set("axis", toNCHW[axis]);
|
||||
layerParams.set("axis", 0 <= axis && axis < 4 ? toNCHW[axis] : axis);
|
||||
|
||||
int id = dstNet.addLayer(name, "Concat", layerParams);
|
||||
layer_id[name] = id;
|
||||
@ -929,6 +978,19 @@ void TFImporter::populateNet(Net dstNet)
|
||||
else // is a vector
|
||||
{
|
||||
layerParams.blobs.resize(1, scaleMat);
|
||||
|
||||
StrIntVector next_layers = getNextLayers(net, name, "Add");
|
||||
if (!next_layers.empty())
|
||||
{
|
||||
layerParams.set("bias_term", true);
|
||||
layerParams.blobs.resize(2);
|
||||
|
||||
int weights_layer_index = next_layers[0].second;
|
||||
blobFromTensor(getConstBlob(net.node(weights_layer_index), value_id), layerParams.blobs.back());
|
||||
ExcludeLayer(net, weights_layer_index, 0, false);
|
||||
layers_to_ignore.insert(next_layers[0].first);
|
||||
}
|
||||
|
||||
id = dstNet.addLayer(name, "Scale", layerParams);
|
||||
}
|
||||
layer_id[name] = id;
|
||||
@ -1037,7 +1099,7 @@ void TFImporter::populateNet(Net dstNet)
|
||||
|
||||
blobFromTensor(getConstBlob(net.node(weights_layer_index), value_id), layerParams.blobs[1]);
|
||||
ExcludeLayer(net, weights_layer_index, 0, false);
|
||||
layers_to_ignore[weights_layer_index] = next_layers[0].first;
|
||||
layers_to_ignore.insert(next_layers[0].first);
|
||||
}
|
||||
|
||||
kernelFromTensor(getConstBlob(layer, value_id, 1), layerParams.blobs[0]);
|
||||
@ -1148,6 +1210,71 @@ void TFImporter::populateNet(Net dstNet)
|
||||
|
||||
connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0);
|
||||
}
|
||||
else if (type == "PriorBox")
|
||||
{
|
||||
if (hasLayerAttr(layer, "min_size"))
|
||||
layerParams.set("min_size", getLayerAttr(layer, "min_size").i());
|
||||
if (hasLayerAttr(layer, "max_size"))
|
||||
layerParams.set("max_size", getLayerAttr(layer, "max_size").i());
|
||||
if (hasLayerAttr(layer, "flip"))
|
||||
layerParams.set("flip", getLayerAttr(layer, "flip").b());
|
||||
if (hasLayerAttr(layer, "clip"))
|
||||
layerParams.set("clip", getLayerAttr(layer, "clip").b());
|
||||
if (hasLayerAttr(layer, "offset"))
|
||||
layerParams.set("offset", getLayerAttr(layer, "offset").f());
|
||||
if (hasLayerAttr(layer, "variance"))
|
||||
{
|
||||
Mat variance = getTensorContent(getLayerAttr(layer, "variance").tensor());
|
||||
layerParams.set("variance",
|
||||
DictValue::arrayReal<float*>((float*)variance.data, variance.total()));
|
||||
}
|
||||
if (hasLayerAttr(layer, "aspect_ratio"))
|
||||
{
|
||||
Mat aspectRatios = getTensorContent(getLayerAttr(layer, "aspect_ratio").tensor());
|
||||
layerParams.set("aspect_ratio",
|
||||
DictValue::arrayReal<float*>((float*)aspectRatios.data, aspectRatios.total()));
|
||||
}
|
||||
if (hasLayerAttr(layer, "scales"))
|
||||
{
|
||||
Mat scales = getTensorContent(getLayerAttr(layer, "scales").tensor());
|
||||
layerParams.set("scales",
|
||||
DictValue::arrayReal<float*>((float*)scales.data, scales.total()));
|
||||
}
|
||||
int id = dstNet.addLayer(name, "PriorBox", layerParams);
|
||||
layer_id[name] = id;
|
||||
connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0);
|
||||
connect(layer_id, dstNet, parsePin(layer.input(1)), id, 1);
|
||||
}
|
||||
else if (type == "DetectionOutput")
|
||||
{
|
||||
// op: "DetectionOutput"
|
||||
// input_0: "locations"
|
||||
// input_1: "classifications"
|
||||
// input_2: "prior_boxes"
|
||||
if (hasLayerAttr(layer, "num_classes"))
|
||||
layerParams.set("num_classes", getLayerAttr(layer, "num_classes").i());
|
||||
if (hasLayerAttr(layer, "share_location"))
|
||||
layerParams.set("share_location", getLayerAttr(layer, "share_location").b());
|
||||
if (hasLayerAttr(layer, "background_label_id"))
|
||||
layerParams.set("background_label_id", getLayerAttr(layer, "background_label_id").i());
|
||||
if (hasLayerAttr(layer, "nms_threshold"))
|
||||
layerParams.set("nms_threshold", getLayerAttr(layer, "nms_threshold").f());
|
||||
if (hasLayerAttr(layer, "top_k"))
|
||||
layerParams.set("top_k", getLayerAttr(layer, "top_k").i());
|
||||
if (hasLayerAttr(layer, "code_type"))
|
||||
layerParams.set("code_type", getLayerAttr(layer, "code_type").s());
|
||||
if (hasLayerAttr(layer, "keep_top_k"))
|
||||
layerParams.set("keep_top_k", getLayerAttr(layer, "keep_top_k").i());
|
||||
if (hasLayerAttr(layer, "confidence_threshold"))
|
||||
layerParams.set("confidence_threshold", getLayerAttr(layer, "confidence_threshold").f());
|
||||
if (hasLayerAttr(layer, "loc_pred_transposed"))
|
||||
layerParams.set("loc_pred_transposed", getLayerAttr(layer, "loc_pred_transposed").b());
|
||||
|
||||
int id = dstNet.addLayer(name, "DetectionOutput", layerParams);
|
||||
layer_id[name] = id;
|
||||
for (int i = 0; i < 3; ++i)
|
||||
connect(layer_id, dstNet, parsePin(layer.input(i)), id, i);
|
||||
}
|
||||
else if (type == "Abs" || type == "Tanh" || type == "Sigmoid" ||
|
||||
type == "Relu" || type == "Elu" || type == "Softmax" ||
|
||||
type == "Identity" || type == "Relu6")
|
||||
@ -1188,9 +1315,9 @@ Ptr<Importer> createTensorflowImporter(const String&)
|
||||
|
||||
#endif //HAVE_PROTOBUF
|
||||
|
||||
Net readNetFromTensorflow(const String &model)
|
||||
Net readNetFromTensorflow(const String &model, const String &config)
|
||||
{
|
||||
TFImporter importer(model.c_str());
|
||||
TFImporter importer(model.c_str(), config.c_str());
|
||||
Net net;
|
||||
importer.populateNet(net);
|
||||
return net;
|
||||
|
@ -52,12 +52,27 @@ bool ReadProtoFromBinaryFileTF(const char* filename, Message* proto) {
|
||||
return success;
|
||||
}
|
||||
|
||||
bool ReadProtoFromTextFileTF(const char* filename, Message* proto) {
|
||||
std::ifstream fs(filename, std::ifstream::in);
|
||||
CHECK(fs.is_open()) << "Can't open \"" << filename << "\"";
|
||||
IstreamInputStream input(&fs);
|
||||
bool success = google::protobuf::TextFormat::Parse(&input, proto);
|
||||
fs.close();
|
||||
return success;
|
||||
}
|
||||
|
||||
void ReadTFNetParamsFromBinaryFileOrDie(const char* param_file,
|
||||
tensorflow::GraphDef* param) {
|
||||
CHECK(ReadProtoFromBinaryFileTF(param_file, param))
|
||||
<< "Failed to parse GraphDef file: " << param_file;
|
||||
}
|
||||
|
||||
void ReadTFNetParamsFromTextFileOrDie(const char* param_file,
|
||||
tensorflow::GraphDef* param) {
|
||||
CHECK(ReadProtoFromTextFileTF(param_file, param))
|
||||
<< "Failed to parse GraphDef file: " << param_file;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
@ -22,6 +22,9 @@ namespace dnn {
|
||||
void ReadTFNetParamsFromBinaryFileOrDie(const char* param_file,
|
||||
tensorflow::GraphDef* param);
|
||||
|
||||
void ReadTFNetParamsFromTextFileOrDie(const char* param_file,
|
||||
tensorflow::GraphDef* param);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -74,14 +74,15 @@ static std::string path(const std::string& file)
|
||||
return findDataFile("dnn/tensorflow/" + file, false);
|
||||
}
|
||||
|
||||
static void runTensorFlowNet(const std::string& prefix,
|
||||
static void runTensorFlowNet(const std::string& prefix, bool hasText = false,
|
||||
double l1 = 1e-5, double lInf = 1e-4)
|
||||
{
|
||||
std::string netPath = path(prefix + "_net.pb");
|
||||
std::string netConfig = (hasText ? path(prefix + "_net.pbtxt") : "");
|
||||
std::string inpPath = path(prefix + "_in.npy");
|
||||
std::string outPath = path(prefix + "_out.npy");
|
||||
|
||||
Net net = readNetFromTensorflow(netPath);
|
||||
Net net = readNetFromTensorflow(netPath, netConfig);
|
||||
|
||||
cv::Mat input = blobFromNPY(inpPath);
|
||||
cv::Mat target = blobFromNPY(outPath);
|
||||
@ -120,6 +121,7 @@ TEST(Test_TensorFlow, batch_norm)
|
||||
{
|
||||
runTensorFlowNet("batch_norm");
|
||||
runTensorFlowNet("fused_batch_norm");
|
||||
runTensorFlowNet("batch_norm_text", true);
|
||||
}
|
||||
|
||||
TEST(Test_TensorFlow, pooling)
|
||||
@ -148,26 +150,60 @@ TEST(Test_TensorFlow, reshape)
|
||||
{
|
||||
runTensorFlowNet("shift_reshape_no_reorder");
|
||||
runTensorFlowNet("reshape_reduce");
|
||||
runTensorFlowNet("flatten", true);
|
||||
}
|
||||
|
||||
TEST(Test_TensorFlow, fp16)
|
||||
{
|
||||
const float l1 = 1e-3;
|
||||
const float lInf = 1e-2;
|
||||
runTensorFlowNet("fp16_single_conv", l1, lInf);
|
||||
runTensorFlowNet("fp16_deconvolution", l1, lInf);
|
||||
runTensorFlowNet("fp16_max_pool_odd_same", l1, lInf);
|
||||
runTensorFlowNet("fp16_padding_valid", l1, lInf);
|
||||
runTensorFlowNet("fp16_eltwise_add_mul", l1, lInf);
|
||||
runTensorFlowNet("fp16_max_pool_odd_valid", l1, lInf);
|
||||
runTensorFlowNet("fp16_pad_and_concat", l1, lInf);
|
||||
runTensorFlowNet("fp16_max_pool_even", l1, lInf);
|
||||
runTensorFlowNet("fp16_padding_same", l1, lInf);
|
||||
runTensorFlowNet("fp16_single_conv", false, l1, lInf);
|
||||
runTensorFlowNet("fp16_deconvolution", false, l1, lInf);
|
||||
runTensorFlowNet("fp16_max_pool_odd_same", false, l1, lInf);
|
||||
runTensorFlowNet("fp16_padding_valid", false, l1, lInf);
|
||||
runTensorFlowNet("fp16_eltwise_add_mul", false, l1, lInf);
|
||||
runTensorFlowNet("fp16_max_pool_odd_valid", false, l1, lInf);
|
||||
runTensorFlowNet("fp16_pad_and_concat", false, l1, lInf);
|
||||
runTensorFlowNet("fp16_max_pool_even", false, l1, lInf);
|
||||
runTensorFlowNet("fp16_padding_same", false, l1, lInf);
|
||||
}
|
||||
|
||||
TEST(Test_TensorFlow, MobileNet_SSD)
|
||||
{
|
||||
std::string netPath = findDataFile("dnn/ssd_mobilenet_v1_coco.pb", false);
|
||||
std::string netConfig = findDataFile("dnn/ssd_mobilenet_v1_coco.pbtxt", false);
|
||||
std::string imgPath = findDataFile("dnn/street.png", false);
|
||||
|
||||
Mat inp;
|
||||
resize(imread(imgPath), inp, Size(300, 300));
|
||||
inp = blobFromImage(inp, 1.0f / 127.5, Size(), Scalar(127.5, 127.5, 127.5), true);
|
||||
|
||||
std::vector<String> outNames(3);
|
||||
outNames[0] = "concat";
|
||||
outNames[1] = "concat_1";
|
||||
outNames[2] = "detection_out";
|
||||
|
||||
std::vector<Mat> target(outNames.size());
|
||||
for (int i = 0; i < outNames.size(); ++i)
|
||||
{
|
||||
std::string path = findDataFile("dnn/tensorflow/ssd_mobilenet_v1_coco." + outNames[i] + ".npy", false);
|
||||
target[i] = blobFromNPY(path);
|
||||
}
|
||||
|
||||
Net net = readNetFromTensorflow(netPath, netConfig);
|
||||
net.setInput(inp);
|
||||
|
||||
std::vector<Mat> output;
|
||||
net.forward(output, outNames);
|
||||
|
||||
normAssert(target[0].reshape(1, 1), output[0].reshape(1, 1));
|
||||
normAssert(target[1].reshape(1, 1), output[1].reshape(1, 1), "", 1e-5, 2e-4);
|
||||
normAssert(target[2].reshape(1, 1), output[2].reshape(1, 1), "", 4e-5, 1e-2);
|
||||
}
|
||||
|
||||
TEST(Test_TensorFlow, lstm)
|
||||
{
|
||||
runTensorFlowNet("lstm");
|
||||
runTensorFlowNet("lstm", true);
|
||||
}
|
||||
|
||||
TEST(Test_TensorFlow, split)
|
||||
|
131
samples/dnn/mobilenet_ssd_accuracy.py
Normal file
131
samples/dnn/mobilenet_ssd_accuracy.py
Normal file
@ -0,0 +1,131 @@
|
||||
# Script to evaluate MobileNet-SSD object detection model trained in TensorFlow
|
||||
# using both TensorFlow and OpenCV. Example:
|
||||
#
|
||||
# python mobilenet_ssd_accuracy.py \
|
||||
# --weights=frozen_inference_graph.pb \
|
||||
# --prototxt=ssd_mobilenet_v1_coco.pbtxt \
|
||||
# --images=val2017 \
|
||||
# --annotations=annotations/instances_val2017.json
|
||||
#
|
||||
# Tested on COCO 2017 object detection dataset, http://cocodataset.org/#download
|
||||
import os
|
||||
import cv2 as cv
|
||||
import json
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Evaluate MobileNet-SSD model using both TensorFlow and OpenCV. '
|
||||
'COCO evaluation framework is required: http://cocodataset.org')
|
||||
parser.add_argument('--weights', required=True,
|
||||
help='Path to frozen_inference_graph.pb of MobileNet-SSD model. '
|
||||
'Download it at https://github.com/tensorflow/models/tree/master/research/object_detection')
|
||||
parser.add_argument('--prototxt', help='Path to ssd_mobilenet_v1_coco.pbtxt from opencv_extra.', required=True)
|
||||
parser.add_argument('--images', help='Path to COCO validation images directory.', required=True)
|
||||
parser.add_argument('--annotations', help='Path to COCO annotations file.', required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
### Get OpenCV predictions #####################################################
|
||||
net = cv.dnn.readNetFromTensorflow(args.weights, args.prototxt)
|
||||
|
||||
detections = []
|
||||
for imgName in os.listdir(args.images):
|
||||
inp = cv.imread(os.path.join(args.images, imgName))
|
||||
rows = inp.shape[0]
|
||||
cols = inp.shape[1]
|
||||
inp = cv.resize(inp, (300, 300))
|
||||
|
||||
net.setInput(cv.dnn.blobFromImage(inp, 1.0/127.5, (300, 300), (127.5, 127.5, 127.5), True))
|
||||
out = net.forward()
|
||||
|
||||
for i in range(out.shape[2]):
|
||||
score = float(out[0, 0, i, 2])
|
||||
# Confidence threshold is in prototxt.
|
||||
classId = int(out[0, 0, i, 1])
|
||||
|
||||
x = out[0, 0, i, 3] * cols
|
||||
y = out[0, 0, i, 4] * rows
|
||||
w = out[0, 0, i, 5] * cols - x
|
||||
h = out[0, 0, i, 6] * rows - y
|
||||
detections.append({
|
||||
"image_id": int(imgName.rstrip('0')[:imgName.rfind('.')]),
|
||||
"category_id": classId,
|
||||
"bbox": [x, y, w, h],
|
||||
"score": score
|
||||
})
|
||||
|
||||
with open('cv_result.json', 'wt') as f:
|
||||
json.dump(detections, f)
|
||||
|
||||
### Get TensorFlow predictions #################################################
|
||||
import tensorflow as tf
|
||||
|
||||
with tf.gfile.FastGFile(args.weights) as f:
|
||||
# Load the model
|
||||
graph_def = tf.GraphDef()
|
||||
graph_def.ParseFromString(f.read())
|
||||
|
||||
with tf.Session() as sess:
|
||||
# Restore session
|
||||
sess.graph.as_default()
|
||||
tf.import_graph_def(graph_def, name='')
|
||||
|
||||
detections = []
|
||||
for imgName in os.listdir(args.images):
|
||||
inp = cv.imread(os.path.join(args.images, imgName))
|
||||
rows = inp.shape[0]
|
||||
cols = inp.shape[1]
|
||||
inp = cv.resize(inp, (300, 300))
|
||||
inp = inp[:, :, [2, 1, 0]] # BGR2RGB
|
||||
out = sess.run([sess.graph.get_tensor_by_name('num_detections:0'),
|
||||
sess.graph.get_tensor_by_name('detection_scores:0'),
|
||||
sess.graph.get_tensor_by_name('detection_boxes:0'),
|
||||
sess.graph.get_tensor_by_name('detection_classes:0')],
|
||||
feed_dict={'image_tensor:0': inp.reshape(1, inp.shape[0], inp.shape[1], 3)})
|
||||
num_detections = int(out[0][0])
|
||||
for i in range(num_detections):
|
||||
classId = int(out[3][0][i])
|
||||
score = float(out[1][0][i])
|
||||
bbox = [float(v) for v in out[2][0][i]]
|
||||
if score > 0.01:
|
||||
x = bbox[1] * cols
|
||||
y = bbox[0] * rows
|
||||
w = bbox[3] * cols - x
|
||||
h = bbox[2] * rows - y
|
||||
detections.append({
|
||||
"image_id": int(imgName.rstrip('0')[:imgName.rfind('.')]),
|
||||
"category_id": classId,
|
||||
"bbox": [x, y, w, h],
|
||||
"score": score
|
||||
})
|
||||
|
||||
with open('tf_result.json', 'wt') as f:
|
||||
json.dump(detections, f)
|
||||
|
||||
### Evaluation part ############################################################
|
||||
|
||||
# %matplotlib inline
|
||||
import matplotlib.pyplot as plt
|
||||
from pycocotools.coco import COCO
|
||||
from pycocotools.cocoeval import COCOeval
|
||||
import numpy as np
|
||||
import skimage.io as io
|
||||
import pylab
|
||||
pylab.rcParams['figure.figsize'] = (10.0, 8.0)
|
||||
|
||||
annType = ['segm','bbox','keypoints']
|
||||
annType = annType[1] #specify type here
|
||||
prefix = 'person_keypoints' if annType=='keypoints' else 'instances'
|
||||
print 'Running demo for *%s* results.'%(annType)
|
||||
|
||||
#initialize COCO ground truth api
|
||||
cocoGt=COCO(args.annotations)
|
||||
|
||||
#initialize COCO detections api
|
||||
for resFile in ['tf_result.json', 'cv_result.json']:
|
||||
print resFile
|
||||
cocoDt=cocoGt.loadRes(resFile)
|
||||
|
||||
cocoEval = COCOeval(cocoGt,cocoDt,annType)
|
||||
cocoEval.evaluate()
|
||||
cocoEval.accumulate()
|
||||
cocoEval.summarize()
|
@ -1,3 +1,14 @@
|
||||
# This script is used to demonstrate MobileNet-SSD network using OpenCV deep learning module.
|
||||
#
|
||||
# It works with model taken from https://github.com/chuanqi305/MobileNet-SSD/ that
|
||||
# was trained in Caffe-SSD framework, https://github.com/weiliu89/caffe/tree/ssd.
|
||||
# Model detects objects from 20 classes.
|
||||
#
|
||||
# Also TensorFlow model from TensorFlow object detection model zoo may be used to
|
||||
# detect objects from 90 classes:
|
||||
# https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md
|
||||
# Text graph definition must be taken from opencv_extra:
|
||||
# https://github.com/opencv/opencv_extra/tree/master/testdata/dnn/ssd_mobilenet_v1_coco.pbtxt
|
||||
import numpy as np
|
||||
import argparse
|
||||
|
||||
@ -13,27 +24,58 @@ WHRatio = inWidth / float(inHeight)
|
||||
inScaleFactor = 0.007843
|
||||
meanVal = 127.5
|
||||
|
||||
classNames = ('background',
|
||||
'aeroplane', 'bicycle', 'bird', 'boat',
|
||||
'bottle', 'bus', 'car', 'cat', 'chair',
|
||||
'cow', 'diningtable', 'dog', 'horse',
|
||||
'motorbike', 'person', 'pottedplant',
|
||||
'sheep', 'sofa', 'train', 'tvmonitor')
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Script to run MobileNet-SSD object detection network '
|
||||
'trained either in Caffe or TensorFlow frameworks.')
|
||||
parser.add_argument("--video", help="path to video file. If empty, camera's stream will be used")
|
||||
parser.add_argument("--prototxt", default="MobileNetSSD_deploy.prototxt",
|
||||
help="path to caffe prototxt")
|
||||
parser.add_argument("-c", "--caffemodel", default="MobileNetSSD_deploy.caffemodel",
|
||||
help="path to caffemodel file, download it here: "
|
||||
"https://github.com/chuanqi305/MobileNet-SSD/")
|
||||
parser.add_argument("--thr", default=0.2, help="confidence threshold to filter out weak detections")
|
||||
help='Path to text network file: '
|
||||
'MobileNetSSD_deploy.prototxt for Caffe model or '
|
||||
'ssd_mobilenet_v1_coco.pbtxt from opencv_extra for TensorFlow model')
|
||||
parser.add_argument("--weights", default="MobileNetSSD_deploy.caffemodel",
|
||||
help='Path to weights: '
|
||||
'MobileNetSSD_deploy.caffemodel for Caffe model or '
|
||||
'frozen_inference_graph.pb from TensorFlow.')
|
||||
parser.add_argument("--num_classes", default=20, type=int,
|
||||
help="Number of classes. It's 20 for Caffe model from "
|
||||
"https://github.com/chuanqi305/MobileNet-SSD/ and 90 for "
|
||||
"TensorFlow model from https://github.com/tensorflow/models/tree/master/research/object_detection")
|
||||
parser.add_argument("--thr", default=0.2, type=float, help="confidence threshold to filter out weak detections")
|
||||
args = parser.parse_args()
|
||||
|
||||
net = cv.dnn.readNetFromCaffe(args.prototxt, args.caffemodel)
|
||||
if args.num_classes == 20:
|
||||
net = cv.dnn.readNetFromCaffe(args.prototxt, args.weights)
|
||||
swapRB = False
|
||||
classNames = { 0: 'background',
|
||||
1: 'aeroplane', 2: 'bicycle', 3: 'bird', 4: 'boat',
|
||||
5: 'bottle', 6: 'bus', 7: 'car', 8: 'cat', 9: 'chair',
|
||||
10: 'cow', 11: 'diningtable', 12: 'dog', 13: 'horse',
|
||||
14: 'motorbike', 15: 'person', 16: 'pottedplant',
|
||||
17: 'sheep', 18: 'sofa', 19: 'train', 20: 'tvmonitor' }
|
||||
else:
|
||||
assert(args.num_classes == 90)
|
||||
net = cv.dnn.readNetFromTensorflow(args.weights, args.prototxt)
|
||||
swapRB = True
|
||||
classNames = { 0: 'background',
|
||||
1: 'person', 2: 'bicycle', 3: 'car', 4: 'motorcycle', 5: 'airplane', 6: 'bus',
|
||||
7: 'train', 8: 'truck', 9: 'boat', 10: 'traffic light', 11: 'fire hydrant',
|
||||
13: 'stop sign', 14: 'parking meter', 15: 'bench', 16: 'bird', 17: 'cat',
|
||||
18: 'dog', 19: 'horse', 20: 'sheep', 21: 'cow', 22: 'elephant', 23: 'bear',
|
||||
24: 'zebra', 25: 'giraffe', 27: 'backpack', 28: 'umbrella', 31: 'handbag',
|
||||
32: 'tie', 33: 'suitcase', 34: 'frisbee', 35: 'skis', 36: 'snowboard',
|
||||
37: 'sports ball', 38: 'kite', 39: 'baseball bat', 40: 'baseball glove',
|
||||
41: 'skateboard', 42: 'surfboard', 43: 'tennis racket', 44: 'bottle',
|
||||
46: 'wine glass', 47: 'cup', 48: 'fork', 49: 'knife', 50: 'spoon',
|
||||
51: 'bowl', 52: 'banana', 53: 'apple', 54: 'sandwich', 55: 'orange',
|
||||
56: 'broccoli', 57: 'carrot', 58: 'hot dog', 59: 'pizza', 60: 'donut',
|
||||
61: 'cake', 62: 'chair', 63: 'couch', 64: 'potted plant', 65: 'bed',
|
||||
67: 'dining table', 70: 'toilet', 72: 'tv', 73: 'laptop', 74: 'mouse',
|
||||
75: 'remote', 76: 'keyboard', 77: 'cell phone', 78: 'microwave', 79: 'oven',
|
||||
80: 'toaster', 81: 'sink', 82: 'refrigerator', 84: 'book', 85: 'clock',
|
||||
86: 'vase', 87: 'scissors', 88: 'teddy bear', 89: 'hair drier', 90: 'toothbrush' }
|
||||
|
||||
if len(args.video):
|
||||
if args.video:
|
||||
cap = cv.VideoCapture(args.video)
|
||||
else:
|
||||
cap = cv.VideoCapture(0)
|
||||
@ -41,7 +83,7 @@ if __name__ == "__main__":
|
||||
while True:
|
||||
# Capture frame-by-frame
|
||||
ret, frame = cap.read()
|
||||
blob = cv.dnn.blobFromImage(frame, inScaleFactor, (inWidth, inHeight), meanVal, False)
|
||||
blob = cv.dnn.blobFromImage(frame, inScaleFactor, (inWidth, inHeight), (meanVal, meanVal, meanVal), swapRB)
|
||||
net.setInput(blob)
|
||||
detections = net.forward()
|
||||
|
||||
@ -74,14 +116,16 @@ if __name__ == "__main__":
|
||||
|
||||
cv.rectangle(frame, (xLeftBottom, yLeftBottom), (xRightTop, yRightTop),
|
||||
(0, 255, 0))
|
||||
label = classNames[class_id] + ": " + str(confidence)
|
||||
labelSize, baseLine = cv.getTextSize(label, cv.FONT_HERSHEY_SIMPLEX, 0.5, 1)
|
||||
if class_id in classNames:
|
||||
label = classNames[class_id] + ": " + str(confidence)
|
||||
labelSize, baseLine = cv.getTextSize(label, cv.FONT_HERSHEY_SIMPLEX, 0.5, 1)
|
||||
|
||||
cv.rectangle(frame, (xLeftBottom, yLeftBottom - labelSize[1]),
|
||||
(xLeftBottom + labelSize[0], yLeftBottom + baseLine),
|
||||
(255, 255, 255), cv.FILLED)
|
||||
cv.putText(frame, label, (xLeftBottom, yLeftBottom),
|
||||
cv.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0))
|
||||
yLeftBottom = max(yLeftBottom, labelSize[1])
|
||||
cv.rectangle(frame, (xLeftBottom, yLeftBottom - labelSize[1]),
|
||||
(xLeftBottom + labelSize[0], yLeftBottom + baseLine),
|
||||
(255, 255, 255), cv.FILLED)
|
||||
cv.putText(frame, label, (xLeftBottom, yLeftBottom),
|
||||
cv.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0))
|
||||
|
||||
cv.imshow("detections", frame)
|
||||
if cv.waitKey(1) >= 0:
|
||||
|
62
samples/dnn/shrink_tf_graph_weights.py
Normal file
62
samples/dnn/shrink_tf_graph_weights.py
Normal file
@ -0,0 +1,62 @@
|
||||
# This file is part of OpenCV project.
|
||||
# It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
# of this distribution and at http://opencv.org/license.html.
|
||||
#
|
||||
# Copyright (C) 2017, Intel Corporation, all rights reserved.
|
||||
# Third party copyrights are property of their respective owners.
|
||||
import tensorflow as tf
|
||||
import struct
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
parser = argparse.ArgumentParser(description='Convert weights of a frozen TensorFlow graph to fp16.')
|
||||
parser.add_argument('--input', required=True, help='Path to frozen graph.')
|
||||
parser.add_argument('--output', required=True, help='Path to output graph.')
|
||||
parser.add_argument('--ops', default=['Conv2D', 'MatMul'], nargs='+',
|
||||
help='List of ops which weights are converted.')
|
||||
args = parser.parse_args()
|
||||
|
||||
DT_FLOAT = 1
|
||||
DT_HALF = 19
|
||||
|
||||
# For the frozen graphs, an every node that uses weights connected to Const nodes
|
||||
# through an Identity node. Usually they're called in the same way with '/read' suffix.
|
||||
# We'll replace all of them to Cast nodes.
|
||||
|
||||
# Load the model
|
||||
with tf.gfile.FastGFile(args.input) as f:
|
||||
graph_def = tf.GraphDef()
|
||||
graph_def.ParseFromString(f.read())
|
||||
|
||||
# Set of all inputs from desired nodes.
|
||||
inputs = []
|
||||
for node in graph_def.node:
|
||||
if node.op in args.ops:
|
||||
inputs += node.input
|
||||
|
||||
weightsNodes = []
|
||||
for node in graph_def.node:
|
||||
# From the whole inputs we need to keep only an Identity nodes.
|
||||
if node.name in inputs and node.op == 'Identity' and node.attr['T'].type == DT_FLOAT:
|
||||
weightsNodes.append(node.input[0])
|
||||
|
||||
# Replace Identity to Cast.
|
||||
node.op = 'Cast'
|
||||
node.attr['DstT'].type = DT_FLOAT
|
||||
node.attr['SrcT'].type = DT_HALF
|
||||
del node.attr['T']
|
||||
del node.attr['_class']
|
||||
|
||||
# Convert weights to halfs.
|
||||
for node in graph_def.node:
|
||||
if node.name in weightsNodes:
|
||||
node.attr['dtype'].type = DT_HALF
|
||||
node.attr['value'].tensor.dtype = DT_HALF
|
||||
|
||||
floats = node.attr['value'].tensor.tensor_content
|
||||
|
||||
floats = struct.unpack('f' * (len(floats) / 4), floats)
|
||||
halfs = np.array(floats).astype(np.float16).view(np.uint16)
|
||||
node.attr['value'].tensor.tensor_content = struct.pack('H' * len(halfs), *halfs)
|
||||
|
||||
tf.train.write_graph(graph_def, "", args.output, as_text=False)
|
Loading…
Reference in New Issue
Block a user