mirror of
https://github.com/opencv/opencv.git
synced 2025-06-08 01:53:19 +08:00
Merge pull request #20442 from JulieBar:gru_layer
* Add initialization and inference for GRU layer * fix issues found on review
This commit is contained in:
parent
ba539eb9aa
commit
e1cafa3834
@ -165,6 +165,40 @@ CV__DNN_INLINE_NS_BEGIN
|
||||
int outputNameToIndex(const String& outputName) CV_OVERRIDE;
|
||||
};
|
||||
|
||||
/** @brief GRU recurrent one-layer
|
||||
*
|
||||
* Accepts input sequence and computes the final hidden state for each element in the batch.
|
||||
*
|
||||
* - input[0] containing the features of the input sequence.
|
||||
* input[0] should have shape [`T`, `N`, `data_dims`] where `T` is sequence length, `N` is batch size, `data_dims` is input size
|
||||
* - output would have shape [`T`, `N`, `D` * `hidden_size`] where `D = 2` if layer is bidirectional otherwise `D = 1`
|
||||
*
|
||||
* Depends on the following attributes:
|
||||
* - hidden_size - Number of neurons in the hidden layer
|
||||
* - direction - RNN could be bidirectional or forward
|
||||
*
|
||||
* The final hidden state @f$ h_t @f$ computes by the following formulas:
|
||||
*
|
||||
@f{eqnarray*}{
|
||||
r_t = \sigma(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\
|
||||
z_t = \sigma(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + b_{hz}) \\
|
||||
n_t = \tanh(W_{in} x_t + b_{in} + r_t \odot (W_{hn} h_{(t-1)}+ b_{hn})) \\
|
||||
h_t = (1 - z_t) \odot n_t + z_t \odot h_{(t-1)} \\
|
||||
@f}
|
||||
* Where @f$x_t@f$ is current input, @f$h_{(t-1)}@f$ is previous or initial hidden state.
|
||||
*
|
||||
* @f$W_{x?}@f$, @f$W_{h?}@f$ and @f$b_{?}@f$ are learned weights represented as matrices:
|
||||
* @f$W_{x?} \in R^{N_h \times N_x}@f$, @f$W_{h?} \in R^{N_h \times N_h}@f$, @f$b_? \in R^{N_h}@f$.
|
||||
*
|
||||
* @f$\odot@f$ is per-element multiply operation.
|
||||
*/
|
||||
class CV_EXPORTS GRULayer : public Layer
|
||||
{
|
||||
public:
|
||||
/** Creates instance of GRU layer */
|
||||
static Ptr<GRULayer> create(const LayerParams& params);
|
||||
};
|
||||
|
||||
/** @brief Classical recurrent layer
|
||||
|
||||
Accepts two inputs @f$x_t@f$ and @f$h_{t-1}@f$ and compute two outputs @f$o_t@f$ and @f$h_t@f$.
|
||||
|
@ -139,6 +139,7 @@ void initializeLayerFactory()
|
||||
CV_DNN_REGISTER_LAYER_CLASS(FlowWarp, FlowWarpLayer);
|
||||
|
||||
CV_DNN_REGISTER_LAYER_CLASS(LSTM, LSTMLayer);
|
||||
CV_DNN_REGISTER_LAYER_CLASS(GRU, GRULayer);
|
||||
}
|
||||
|
||||
CV__DNN_INLINE_NS_END
|
||||
|
@ -563,5 +563,214 @@ CV_EXPORTS_W Ptr<RNNLayer> RNNLayer::create(const LayerParams& params)
|
||||
return Ptr<RNNLayer>(new RNNLayerImpl(params));
|
||||
}
|
||||
|
||||
class GRULayerImpl CV_FINAL : public GRULayer
|
||||
{
|
||||
int numTimeStamps, numSamples;
|
||||
bool allocated;
|
||||
|
||||
MatShape outTailShape; //shape of single output sample
|
||||
MatShape outTsShape; //shape of N output samples
|
||||
bool bidirectional; // If true, produces both forward and reversed directions along time axis
|
||||
|
||||
public:
|
||||
|
||||
GRULayerImpl(const LayerParams& params) : numTimeStamps(0), numSamples(0)
|
||||
{
|
||||
setParamsFrom(params);
|
||||
|
||||
bidirectional = params.get<bool>("bidirectional", false);
|
||||
if (!blobs.empty())
|
||||
{
|
||||
CV_Assert(blobs.size() >= 3);
|
||||
|
||||
blobs[2] = blobs[2].reshape(1, 1);
|
||||
|
||||
const Mat& Wh = blobs[0];
|
||||
const Mat& Wx = blobs[1];
|
||||
const Mat& bias = blobs[2];
|
||||
const Mat& hInternal = blobs[3];
|
||||
CV_CheckEQ(Wh.dims, 2, "");
|
||||
CV_CheckEQ(Wx.dims, 2, "");
|
||||
CV_CheckEQ(Wh.rows, Wx.rows, "");
|
||||
CV_CheckEQ(Wh.rows, (1 + static_cast<int>(bidirectional)) * 3 * Wh.cols, "");
|
||||
CV_CheckEQ(Wh.rows * 2, (int)bias.total(), "");
|
||||
CV_CheckEQ(hInternal.cols, Wh.cols, "");
|
||||
CV_CheckTypeEQ(Wh.type(), Wx.type(), "");
|
||||
CV_CheckTypeEQ(Wx.type(), bias.type(), "");
|
||||
}
|
||||
|
||||
allocated = false;
|
||||
outTailShape.clear();
|
||||
}
|
||||
|
||||
bool getMemoryShapes(const std::vector<MatShape> &inputs,
|
||||
const int requiredOutputs,
|
||||
std::vector<MatShape> &outputs,
|
||||
std::vector<MatShape> &internals) const CV_OVERRIDE
|
||||
{
|
||||
CV_Assert(inputs.size() == 1);
|
||||
const MatShape& inp0 = inputs[0];
|
||||
|
||||
const Mat &Wh = blobs[0], &Wx = blobs[1];
|
||||
int _numOut = Wh.size[1];
|
||||
int _numInp = Wx.size[1];
|
||||
MatShape outTailShape_(outTailShape), outResShape;
|
||||
|
||||
if (!outTailShape_.empty())
|
||||
CV_Assert(total(outTailShape_) == _numOut);
|
||||
else
|
||||
outTailShape_.assign(1, _numOut);
|
||||
|
||||
int _numSamples;
|
||||
CV_Assert(inp0.size() >= 2 && total(inp0, 2) == _numInp);
|
||||
_numSamples = inp0[1];
|
||||
outResShape.push_back(inp0[0]);
|
||||
|
||||
outResShape.push_back(_numSamples);
|
||||
outResShape.insert(outResShape.end(), outTailShape_.begin(), outTailShape_.end());
|
||||
outResShape.back() *= (1 + static_cast<int>(bidirectional));
|
||||
|
||||
outputs.assign(1, outResShape);
|
||||
|
||||
internals.assign(1, shape(_numSamples, _numOut)); // hInternal
|
||||
internals.push_back(shape(_numSamples, 1)); // dummyOnes
|
||||
internals.push_back(shape(_numSamples, 2 * _numOut)); // gates
|
||||
internals.push_back(shape(_numSamples, 2 * _numOut)); // gates_b
|
||||
internals.push_back(shape(_numSamples, 1 * _numOut)); // h_linear
|
||||
internals.push_back(shape(_numSamples, _numOut)); // ones
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
void finalize(InputArrayOfArrays inputs_arr, OutputArrayOfArrays) CV_OVERRIDE
|
||||
{
|
||||
std::vector<Mat> input;
|
||||
inputs_arr.getMatVector(input);
|
||||
|
||||
CV_Assert(input.size() == 1);
|
||||
const Mat& inp0 = input[0];
|
||||
|
||||
Mat &Wh = blobs[0], &Wx = blobs[1];
|
||||
int numOut = Wh.size[1];
|
||||
int numInp = Wx.size[1];
|
||||
|
||||
if (!outTailShape.empty())
|
||||
CV_Assert(total(outTailShape) == numOut);
|
||||
else
|
||||
outTailShape.assign(1, numOut);
|
||||
|
||||
CV_Assert(inp0.dims >= 2 && (int)inp0.total(2) == numInp);
|
||||
numTimeStamps = inp0.size[0];
|
||||
numSamples = inp0.size[1];
|
||||
|
||||
outTsShape.clear();
|
||||
outTsShape.push_back(numSamples);
|
||||
outTsShape.insert(outTsShape.end(), outTailShape.begin(), outTailShape.end());
|
||||
outTsShape.back() *= (1 + static_cast<int>(bidirectional));
|
||||
|
||||
allocated = true;
|
||||
}
|
||||
|
||||
void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE
|
||||
{
|
||||
CV_TRACE_FUNCTION();
|
||||
CV_TRACE_ARG_VALUE(name, "name", name.c_str());
|
||||
|
||||
if (inputs_arr.depth() == CV_16S)
|
||||
{
|
||||
forward_fallback(inputs_arr, outputs_arr, internals_arr);
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<Mat> input, output, internals;
|
||||
inputs_arr.getMatVector(input);
|
||||
outputs_arr.getMatVector(output);
|
||||
internals_arr.getMatVector(internals);
|
||||
|
||||
const int numDirs = 1 + static_cast<int>(bidirectional);
|
||||
for (int i = 0; i < numDirs; ++i)
|
||||
{
|
||||
const Mat &Wh = blobs[0].rowRange(i * blobs[0].rows / numDirs, (i + 1) * blobs[0].rows / numDirs);
|
||||
const Mat &Wx = blobs[1].rowRange(i * blobs[1].rows / numDirs, (i + 1) * blobs[1].rows / numDirs);
|
||||
const Mat &bias = blobs[2].colRange(i * blobs[2].cols / numDirs, (i + 1) * blobs[2].cols / numDirs);
|
||||
const Mat &h_0 = blobs[3].rowRange(i * blobs[3].rows / numDirs, (i + 1) * blobs[3].rows / numDirs);
|
||||
|
||||
const Mat &bx = bias.colRange(0, bias.cols / 2);
|
||||
const Mat &bh = bias.colRange(bias.cols / 2, bias.cols);
|
||||
|
||||
Mat hInternal = internals[0], dummyOnes = internals[1], gates = internals[2],
|
||||
b_rz = internals[3], n_t = internals[4], ones = internals[5];
|
||||
h_0.copyTo(hInternal);
|
||||
dummyOnes.setTo(1.);
|
||||
ones.setTo(1.);
|
||||
|
||||
int numOut = Wh.size[1];
|
||||
const Mat& wx_rz = Wx.rowRange(0, 2 * numOut);
|
||||
const Mat& wh_rz = Wh.rowRange(0, 2 * numOut);
|
||||
b_rz = bx.colRange(0, 2 * numOut) + bh.colRange(0, 2 * numOut);
|
||||
const Mat& wx_n = Wx.rowRange(2 * numOut, 3 * numOut);
|
||||
const Mat& wh_n = Wh.rowRange(2 * numOut, 3 * numOut);
|
||||
const Mat& b_in = bx.colRange(2 * numOut, 3 * numOut);
|
||||
const Mat& b_hn = bh.colRange(2 * numOut, 3 * numOut);
|
||||
|
||||
int numSamplesTotal = numTimeStamps * numSamples;
|
||||
Mat xTs = input[0].reshape(1, numSamplesTotal);
|
||||
|
||||
Mat hOutTs = output[0].reshape(1, numSamplesTotal);
|
||||
hOutTs = hOutTs.colRange(i * hOutTs.cols / numDirs, (i + 1) * hOutTs.cols / numDirs);
|
||||
Mat cOutTs = Mat();
|
||||
|
||||
int tsStart, tsEnd, tsInc;
|
||||
if (i == 1) {
|
||||
tsStart = numTimeStamps - 1;
|
||||
tsEnd = -1;
|
||||
tsInc = -1;
|
||||
}
|
||||
else {
|
||||
tsStart = 0;
|
||||
tsEnd = numTimeStamps;
|
||||
tsInc = 1;
|
||||
}
|
||||
for (int ts = tsStart; ts != tsEnd; ts += tsInc)
|
||||
{
|
||||
Range curRowRange(ts * numSamples, (ts + 1) * numSamples);
|
||||
Mat xCurr = xTs.rowRange(curRowRange);
|
||||
|
||||
// calculate r_t = sigmoid(x * Wx_r + h_(t-1) * Wh_r + b_r)
|
||||
// calculate z_t = sigmoid(x * Wx_z + h_(t-1) * Wh_z + b_z)
|
||||
gemm(xCurr, wx_rz, 1, gates, 0, gates, GEMM_2_T); // x * Wx_rz
|
||||
gemm(hInternal, wh_rz, 1, gates, 1, gates, GEMM_2_T); // + h_(t-1) * Wh_rz
|
||||
gemm(dummyOnes, b_rz, 1, gates, 1, gates); // + b_rz
|
||||
sigmoid(gates, gates); // sigmoid()
|
||||
|
||||
Mat z = gates.colRange(0, gates.cols / 2);
|
||||
Mat r = gates.colRange(gates.cols / 2, gates.cols);
|
||||
|
||||
// calculate n_t = tanh(r (*) (h_(t-1) * Wh_n + b_hn) + x * Wx_n + b_in)
|
||||
gemm(hInternal, wh_n, 1, n_t, 0, n_t, GEMM_2_T); // h_(t-1) * Wh_n
|
||||
gemm(dummyOnes, b_hn, 1, n_t, 1, n_t); // + b_hn
|
||||
multiply(r, n_t, n_t); // r (*) (h_(t-1) * Wh_n + b_hn)
|
||||
|
||||
gemm(xCurr, wx_n, 1, n_t, 1, n_t, GEMM_2_T); // + x * Wx_n
|
||||
gemm(dummyOnes, b_in, 1, n_t, 1, n_t); // + b_in
|
||||
tanh(n_t, n_t); // tanh()
|
||||
|
||||
//compute next h_t = z (*) h_(t-1) + (1 - z) (*) n_t
|
||||
multiply(z, hInternal, hInternal); // z (*) h_{t-1}
|
||||
subtract(ones, z, z); // 1 - z
|
||||
multiply(z, n_t, z); // (1 - z) * n
|
||||
add(z, hInternal, hInternal); // z (*) h_(t-1) + (1 - z) (*) n_t
|
||||
|
||||
//save results in output blobs
|
||||
hInternal.copyTo(hOutTs.rowRange(curRowRange));
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Ptr<GRULayer> GRULayer::create(const LayerParams ¶ms) {
|
||||
return Ptr<GRULayer>(new GRULayerImpl(params));
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
@ -145,6 +145,7 @@ private:
|
||||
void parseNeg (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
|
||||
void parseConstant (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
|
||||
void parseLSTM (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
|
||||
void parseGRU (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
|
||||
void parseImageScaler (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
|
||||
void parseClip (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
|
||||
void parseLeakyRelu (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
|
||||
@ -582,6 +583,7 @@ const std::set<String>& ONNXImporter::getSupportedTypes()
|
||||
"Neg",
|
||||
"Constant",
|
||||
"LSTM",
|
||||
"GRU",
|
||||
"ImageScaler",
|
||||
"Clip",
|
||||
"LeakyRelu",
|
||||
@ -1239,6 +1241,46 @@ void ONNXImporter::parseLSTM(LayerParams& layerParams, const opencv_onnx::NodePr
|
||||
addLayer(layerParams, node_proto);
|
||||
}
|
||||
|
||||
void ONNXImporter::parseGRU(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto_)
|
||||
{
|
||||
opencv_onnx::NodeProto node_proto = node_proto_;
|
||||
LayerParams gruParams = layerParams;
|
||||
gruParams.name += "/gru";
|
||||
|
||||
// https://pytorch.org/docs/stable/generated/torch.nn.GRU.html?highlight=gru#
|
||||
CV_Assert(node_proto.input_size() == 6);
|
||||
Mat Wx = getBlob(node_proto, 1);
|
||||
Mat Wh = getBlob(node_proto, 2);
|
||||
Mat b = getBlob(node_proto, 3);
|
||||
Mat h0 = getBlob(node_proto, 5);
|
||||
|
||||
Wx = Wx.reshape(1, Wx.size[0] * Wx.size[1]);
|
||||
Wh = Wh.reshape(1, Wh.size[0] * Wh.size[1]);
|
||||
h0 = h0.reshape(1, h0.size[0] * h0.size[1]);
|
||||
b = b.reshape(1, b.size[0]);
|
||||
|
||||
gruParams.blobs.resize(4);
|
||||
gruParams.blobs[0] = Wh;
|
||||
gruParams.blobs[1] = Wx;
|
||||
gruParams.blobs[2] = b;
|
||||
gruParams.blobs[3] = h0;
|
||||
gruParams.set("bidirectional", gruParams.get<String>("direction", "") == "bidirectional");
|
||||
|
||||
node_proto.set_output(0, gruParams.name); // set different name so output shapes will be registered on that name
|
||||
addLayer(gruParams, node_proto);
|
||||
|
||||
MatShape gruShape = outShapes[node_proto.output(0)];
|
||||
|
||||
// Add fake 1 as it is done in ONNX
|
||||
gruShape.insert(gruShape.begin() + 1, 1);
|
||||
|
||||
layerParams.type = "Reshape";
|
||||
layerParams.set("dim", DictValue::arrayInt(&gruShape[0], gruShape.size()));
|
||||
node_proto.set_input(0, gruParams.name); // redirect input to GRU
|
||||
node_proto.set_output(0, layerParams.name); // keep origin GRU's name
|
||||
addLayer(layerParams, node_proto);
|
||||
}
|
||||
|
||||
void ONNXImporter::parseImageScaler(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
|
||||
{
|
||||
const float scale = layerParams.has("scale") ? layerParams.get<float>("scale") : 1.0f;
|
||||
@ -2358,6 +2400,7 @@ const ONNXImporter::DispatchMap ONNXImporter::buildDispatchMap()
|
||||
dispatch["Neg"] = &ONNXImporter::parseNeg;
|
||||
dispatch["Constant"] = &ONNXImporter::parseConstant;
|
||||
dispatch["LSTM"] = &ONNXImporter::parseLSTM;
|
||||
dispatch["GRU"] = &ONNXImporter::parseGRU;
|
||||
dispatch["ImageScaler"] = &ONNXImporter::parseImageScaler;
|
||||
dispatch["Clip"] = &ONNXImporter::parseClip;
|
||||
dispatch["LeakyRelu"] = &ONNXImporter::parseLeakyRelu;
|
||||
|
@ -596,6 +596,35 @@ TEST(Layer_LSTM_Test_Accuracy_with_, HiddenParams)
|
||||
normAssert(h_t_reference, outputs[0]);
|
||||
}
|
||||
|
||||
TEST(Layer_GRU_Test_Accuracy_with_, Pytorch)
|
||||
{
|
||||
Mat Wx = blobFromNPY(_tf("gru.W.npy"));
|
||||
Mat Wh = blobFromNPY(_tf("gru.R.npy"));
|
||||
Mat b = blobFromNPY(_tf("gru.B.npy"));
|
||||
Mat h0 = blobFromNPY(_tf("gru.h0.npy"));
|
||||
|
||||
Wx = Wx.reshape(1, Wx.size[0] * Wx.size[1]);
|
||||
Wh = Wh.reshape(1, Wh.size[0] * Wh.size[1]);
|
||||
h0 = h0.reshape(1, h0.size[0] * h0.size[1]);
|
||||
b = b.reshape(1, b.size[0]);
|
||||
|
||||
LayerParams gruParams;
|
||||
gruParams.blobs.resize(4);
|
||||
gruParams.blobs[0] = Wh;
|
||||
gruParams.blobs[1] = Wx;
|
||||
gruParams.blobs[2] = b;
|
||||
gruParams.blobs[3] = h0;
|
||||
gruParams.set("bidirectional", false);
|
||||
Ptr<GRULayer> layer = GRULayer::create(gruParams);
|
||||
|
||||
Mat inp = blobFromNPY(_tf("gru.input.npy"));
|
||||
std::vector<Mat> inputs(1, inp), outputs;
|
||||
runLayer(layer, inputs, outputs);
|
||||
|
||||
Mat h_t_reference = blobFromNPY(_tf("gru.output.npy"));
|
||||
normAssert(h_t_reference, outputs[0]);
|
||||
}
|
||||
|
||||
TEST(Layer_RNN_Test_Accuracy_with_, CaffeRecurrent)
|
||||
{
|
||||
Ptr<RNNLayer> layer = RNNLayer::create(LayerParams());
|
||||
|
@ -720,6 +720,16 @@ TEST_P(Test_ONNX_layers, LSTM_hidden_bidirectional)
|
||||
testONNXModels("hidden_lstm_bi", npy, 0, 0, false, false);
|
||||
}
|
||||
|
||||
TEST_P(Test_ONNX_layers, GRU)
|
||||
{
|
||||
testONNXModels("gru", npy, 0, 0, false, false);
|
||||
}
|
||||
|
||||
TEST_P(Test_ONNX_layers, GRU_bidirectional)
|
||||
{
|
||||
testONNXModels("gru_bi", npy, 0, 0, false, false);
|
||||
}
|
||||
|
||||
TEST_P(Test_ONNX_layers, Pad2d_Unfused)
|
||||
{
|
||||
testONNXModels("ReflectionPad2d");
|
||||
|
Loading…
Reference in New Issue
Block a user