mirror of
https://github.com/opencv/opencv.git
synced 2025-06-13 04:52:53 +08:00
Added shared weights for MatMul
This commit is contained in:
parent
0f968e3b6d
commit
38a49f92ab
@ -1228,8 +1228,18 @@ void TFImporter::parseNode(const tensorflow::NodeDef& layer_)
|
|||||||
|
|
||||||
int kernel_blob_index = -1;
|
int kernel_blob_index = -1;
|
||||||
const tensorflow::TensorProto& kernelTensor = getConstBlob(layer, value_id, -1, &kernel_blob_index);
|
const tensorflow::TensorProto& kernelTensor = getConstBlob(layer, value_id, -1, &kernel_blob_index);
|
||||||
blobFromTensor(kernelTensor, layerParams.blobs[0]);
|
const String kernelTensorName = layer.input(kernel_blob_index);
|
||||||
releaseTensor(const_cast<tensorflow::TensorProto*>(&kernelTensor));
|
std::map<String, Mat>::iterator sharedWeightsIt = sharedWeights.find(kernelTensorName);
|
||||||
|
if (sharedWeightsIt == sharedWeights.end())
|
||||||
|
{
|
||||||
|
blobFromTensor(kernelTensor, layerParams.blobs[0]);
|
||||||
|
releaseTensor(const_cast<tensorflow::TensorProto*>(&kernelTensor));
|
||||||
|
sharedWeights[kernelTensorName] = layerParams.blobs[0];
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
layerParams.blobs[0] = sharedWeightsIt->second;
|
||||||
|
}
|
||||||
|
|
||||||
if (kernel_blob_index == 1) { // In this case output is computed by x*W formula - W should be transposed
|
if (kernel_blob_index == 1) { // In this case output is computed by x*W formula - W should be transposed
|
||||||
Mat data = layerParams.blobs[0].t();
|
Mat data = layerParams.blobs[0].t();
|
||||||
|
Loading…
Reference in New Issue
Block a user