mirror of
https://github.com/opencv/opencv.git
synced 2025-06-08 18:13:13 +08:00
gemm support transA and transB, and first input is constance.
This commit is contained in:
parent
25ac77e010
commit
0d56524b72
@ -80,6 +80,9 @@ public:
|
|||||||
FullyConnectedLayerImpl(const LayerParams& params)
|
FullyConnectedLayerImpl(const LayerParams& params)
|
||||||
{
|
{
|
||||||
setParamsFrom(params);
|
setParamsFrom(params);
|
||||||
|
transA = params.get<bool>("transA", false);
|
||||||
|
transB = params.get<bool>("transB", false);
|
||||||
|
|
||||||
bias = params.get<bool>("bias_term", true);
|
bias = params.get<bool>("bias_term", true);
|
||||||
axis = params.get<int>("axis", 1);
|
axis = params.get<int>("axis", 1);
|
||||||
if (!blobs.empty())
|
if (!blobs.empty())
|
||||||
@ -116,30 +119,48 @@ public:
|
|||||||
std::vector<MatShape> &) const CV_OVERRIDE
|
std::vector<MatShape> &) const CV_OVERRIDE
|
||||||
{
|
{
|
||||||
int numOutput, cAxis;
|
int numOutput, cAxis;
|
||||||
|
|
||||||
|
std::vector<MatShape> inputsTmp;
|
||||||
|
inputsTmp.assign(inputs.begin(), inputs.end());
|
||||||
|
|
||||||
if (blobs.empty())
|
if (blobs.empty())
|
||||||
{
|
{
|
||||||
CV_CheckEQ(inputs.size(), (size_t)2, "");
|
CV_CheckEQ(inputsTmp.size(), (size_t)2, "");
|
||||||
numOutput = inputs[1].back();
|
|
||||||
cAxis = inputs[0].size() - 1;
|
if (transA)
|
||||||
int dims = inputs[0].size();
|
{
|
||||||
CV_CheckEQ(inputs[1].size(), (size_t)dims, "");
|
CV_CheckEQ(inputsTmp[0].size(), (size_t)2, "");
|
||||||
|
std::swap(inputsTmp[0][0], inputsTmp[0][1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (transB)
|
||||||
|
{
|
||||||
|
CV_CheckEQ(inputsTmp[1].size(), (size_t)2, "");
|
||||||
|
std::swap(inputsTmp[1][0], inputsTmp[1][1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
numOutput = inputsTmp[1].back();
|
||||||
|
cAxis = inputsTmp[0].size() - 1;
|
||||||
|
int dims = inputsTmp[0].size();
|
||||||
|
CV_CheckEQ(inputsTmp[1].size(), (size_t)dims, "");
|
||||||
CV_CheckGE(dims, 2, "");
|
CV_CheckGE(dims, 2, "");
|
||||||
for (int i = 0; i < dims - 2; i++)
|
for (int i = 0; i < dims - 2; i++)
|
||||||
CV_CheckEQ(inputs[0][i], inputs[1][i], "");
|
CV_CheckEQ(inputsTmp[0][i], inputsTmp[1][i], "");
|
||||||
CV_CheckEQ(inputs[0].back(), inputs[1][dims - 2], "");
|
CV_CheckEQ(inputsTmp[0].back(), inputsTmp[1][dims - 2], "");
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
CV_CheckEQ(inputs.size(), (size_t)1, "");
|
CV_Assert(!transA && !transB);
|
||||||
|
CV_CheckEQ(inputsTmp.size(), (size_t)1, "");
|
||||||
CV_CheckEQ(blobs[0].dims, 2, "");
|
CV_CheckEQ(blobs[0].dims, 2, "");
|
||||||
numOutput = blobs[0].size[0];
|
numOutput = blobs[0].size[0];
|
||||||
CV_Assert(!bias || (size_t)numOutput == blobs[1].total());
|
CV_Assert(!bias || (size_t)numOutput == blobs[1].total());
|
||||||
cAxis = normalize_axis(axis, inputs[0]);
|
cAxis = normalize_axis(axis, inputsTmp[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
MatShape outShape(cAxis + 1);
|
MatShape outShape(cAxis + 1);
|
||||||
for (int i = 0; i < cAxis; ++i)
|
for (int i = 0; i < cAxis; ++i)
|
||||||
outShape[i] = inputs[0][i];
|
outShape[i] = inputsTmp[0][i];
|
||||||
outShape.back() = numOutput;
|
outShape.back() = numOutput;
|
||||||
|
|
||||||
outputs.resize(1, outShape);
|
outputs.resize(1, outShape);
|
||||||
@ -148,15 +169,15 @@ public:
|
|||||||
|
|
||||||
virtual bool supportBackend(int backendId) CV_OVERRIDE
|
virtual bool supportBackend(int backendId) CV_OVERRIDE
|
||||||
{
|
{
|
||||||
|
bool tranAorB = transA || transB;
|
||||||
#ifdef HAVE_INF_ENGINE
|
#ifdef HAVE_INF_ENGINE
|
||||||
if (backendId == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH)
|
if (backendId == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH)
|
||||||
return axis == 1;
|
return axis == 1 && !tranAorB;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
return backendId == DNN_BACKEND_OPENCV ||
|
return backendId == DNN_BACKEND_OPENCV ||
|
||||||
backendId == DNN_BACKEND_CUDA ||
|
(backendId == DNN_BACKEND_CUDA && !tranAorB) ||
|
||||||
(backendId == DNN_BACKEND_HALIDE && haveHalide() && axis == 1) ||
|
(backendId == DNN_BACKEND_HALIDE && haveHalide() && axis == 1 && !tranAorB) ||
|
||||||
(backendId == DNN_BACKEND_WEBNN && axis == 1);
|
(backendId == DNN_BACKEND_WEBNN && axis == 1 && !tranAorB);
|
||||||
}
|
}
|
||||||
|
|
||||||
virtual bool setActivation(const Ptr<ActivationLayer>& layer) CV_OVERRIDE
|
virtual bool setActivation(const Ptr<ActivationLayer>& layer) CV_OVERRIDE
|
||||||
@ -497,6 +518,7 @@ public:
|
|||||||
|
|
||||||
if (!blobs.empty())
|
if (!blobs.empty())
|
||||||
{
|
{
|
||||||
|
CV_Assert(!transA && !transB);
|
||||||
int axisCan = normalize_axis(axis, input[0].dims);
|
int axisCan = normalize_axis(axis, input[0].dims);
|
||||||
int outerSize = input[0].total(0, axisCan);
|
int outerSize = input[0].total(0, axisCan);
|
||||||
|
|
||||||
@ -511,15 +533,30 @@ public:
|
|||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
float* inpData = input[0].ptr<float>();
|
Mat input0 = input[0];
|
||||||
float* weightData = input[1].ptr<float>();
|
Mat input1 = input[1];
|
||||||
|
|
||||||
|
if (transA)
|
||||||
|
{
|
||||||
|
CV_Assert(input0.dims == 2);
|
||||||
|
input0 = input0.t();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (transB)
|
||||||
|
{
|
||||||
|
CV_Assert(input1.dims == 2);
|
||||||
|
input1 = input1.t();
|
||||||
|
}
|
||||||
|
|
||||||
|
float* inpData = input0.ptr<float>();
|
||||||
|
float* weightData = input1.ptr<float>();
|
||||||
float* outData = output[0].ptr<float>();
|
float* outData = output[0].ptr<float>();
|
||||||
|
|
||||||
int dims = output[0].dims;
|
int dims = output[0].dims;
|
||||||
int numSlice = output[0].total() / output[0].total(dims - 2);
|
int numSlice = output[0].total() / output[0].total(dims - 2);
|
||||||
int m = input[0].size[dims - 2];
|
int m = input0.size[dims - 2];
|
||||||
int n = input[0].size[dims - 1];
|
int n = input0.size[dims - 1];
|
||||||
int k = input[1].size[dims - 1];
|
int k = input1.size[dims - 1];
|
||||||
for (int i = 0; i < numSlice; i++)
|
for (int i = 0; i < numSlice; i++)
|
||||||
{
|
{
|
||||||
Mat inpSlice(m, n, CV_32F, inpData);
|
Mat inpSlice(m, n, CV_32F, inpData);
|
||||||
@ -716,6 +753,7 @@ public:
|
|||||||
|
|
||||||
bool bias;
|
bool bias;
|
||||||
Mat weightsMat, biasMat;
|
Mat weightsMat, biasMat;
|
||||||
|
bool transA, transB;
|
||||||
Ptr<ActivationLayer> activ;
|
Ptr<ActivationLayer> activ;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -2000,18 +2000,9 @@ void ONNXImporter::parseGemm(LayerParams& layerParams, const opencv_onnx::NodePr
|
|||||||
{
|
{
|
||||||
CV_Assert(node_proto.input_size() >= 2);
|
CV_Assert(node_proto.input_size() >= 2);
|
||||||
layerParams.type = "InnerProduct";
|
layerParams.type = "InnerProduct";
|
||||||
Mat weights = getBlob(node_proto, 1);
|
int transA = layerParams.get<int>("transA", 0);
|
||||||
|
layerParams.set("transA", transA == 1);
|
||||||
|
|
||||||
if (!layerParams.get<int>("transB", 0))
|
|
||||||
{
|
|
||||||
transpose(weights, weights);
|
|
||||||
}
|
|
||||||
layerParams.blobs.push_back(weights);
|
|
||||||
|
|
||||||
if (node_proto.input_size() == 3) {
|
|
||||||
Mat bias = getBlob(node_proto, 2);
|
|
||||||
layerParams.blobs.push_back(bias);
|
|
||||||
}
|
|
||||||
if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
|
if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
|
||||||
{
|
{
|
||||||
Mat inputBuf = getBlob(node_proto, 0);
|
Mat inputBuf = getBlob(node_proto, 0);
|
||||||
@ -2026,7 +2017,43 @@ void ONNXImporter::parseGemm(LayerParams& layerParams, const opencv_onnx::NodePr
|
|||||||
addLayer(constParams, proto);
|
addLayer(constParams, proto);
|
||||||
}
|
}
|
||||||
|
|
||||||
layerParams.set("num_output", layerParams.blobs[0].size[0]);
|
int transB = layerParams.get<int>("transB", 0);
|
||||||
|
if (constBlobs.find(node_proto.input(1)) != constBlobs.end())
|
||||||
|
{
|
||||||
|
Mat weights = getBlob(node_proto, 1);
|
||||||
|
|
||||||
|
if (transA == 0) // optimized barnch, for now, we can only optimize the Gemm when transA = 0.
|
||||||
|
{
|
||||||
|
if (transB == 0)
|
||||||
|
{
|
||||||
|
transpose(weights, weights);
|
||||||
|
}
|
||||||
|
layerParams.set("transB", false);
|
||||||
|
layerParams.blobs.push_back(weights);
|
||||||
|
layerParams.set("num_output", layerParams.blobs[0].size[0]);
|
||||||
|
}
|
||||||
|
else // no optimized branch, TODO! optimize when the transA==1.
|
||||||
|
{
|
||||||
|
LayerParams constParams;
|
||||||
|
constParams.name = node_proto.input(1);
|
||||||
|
constParams.type = "Const";
|
||||||
|
constParams.blobs.push_back(weights);
|
||||||
|
|
||||||
|
opencv_onnx::NodeProto proto;
|
||||||
|
proto.add_output(constParams.name);
|
||||||
|
addLayer(constParams, proto);
|
||||||
|
layerParams.set("transB", transB == 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
layerParams.set("transB", transB == 1);
|
||||||
|
|
||||||
|
if (node_proto.input_size() == 3)
|
||||||
|
{
|
||||||
|
Mat bias = getBlob(node_proto, 2);
|
||||||
|
layerParams.blobs.push_back(bias);
|
||||||
|
}
|
||||||
|
|
||||||
layerParams.set("bias_term", node_proto.input_size() == 3);
|
layerParams.set("bias_term", node_proto.input_size() == 3);
|
||||||
addLayer(layerParams, node_proto);
|
addLayer(layerParams, node_proto);
|
||||||
}
|
}
|
||||||
|
@ -1829,6 +1829,7 @@ TEST_P(Test_ONNX_layers, Gemm)
|
|||||||
{
|
{
|
||||||
testONNXModels("gemm_no_transB");
|
testONNXModels("gemm_no_transB");
|
||||||
testONNXModels("gemm_transB_0");
|
testONNXModels("gemm_transB_0");
|
||||||
|
testONNXModels("gemm_first_const");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(Test_ONNX_layers, Quantized_Convolution)
|
TEST_P(Test_ONNX_layers, Quantized_Convolution)
|
||||||
|
Loading…
Reference in New Issue
Block a user