Merge pull request #20506 from JulieBar:lstm_activations

* Support activations(Sigmoid, Tanh) for LSTM

* fix warning
This commit is contained in:
Julia Bareeva 2021-08-13 15:41:00 +03:00 committed by GitHub
parent 9d3826c676
commit cfb36443fb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 79 additions and 13 deletions

View File

@ -80,12 +80,31 @@ static void sigmoid(const Mat &src, Mat &dst)
cv::pow(1 + dst, -1, dst);
}
typedef void (*ActivationFunction)(const Mat &src, Mat &dst);
static ActivationFunction get_activation_function(const String& activation) {
// most used activations for PyTorch and TF : Tanh, Sigmoid
// if you need to support more optional activations use std::map instead
if (activation == "Tanh")
{
return tanh;
}
else if (activation == "Sigmoid")
{
return sigmoid;
}
else
{
CV_Error(Error::StsNotImplemented,
cv::format("Activation function [%s] for layer LSTM is not supported", activation.c_str()));
}
}
class LSTMLayerImpl CV_FINAL : public LSTMLayer
{
int numTimeStamps, numSamples;
bool allocated;
MatShape outTailShape; //shape of single output sample
MatShape outTailShape; //shape of single output sample
MatShape outTsShape; //shape of N output samples
bool useTimestampDim;
@ -95,6 +114,10 @@ class LSTMLayerImpl CV_FINAL : public LSTMLayer
bool reverse; // If true, go in negative direction along the time axis
bool bidirectional; // If true, produces both forward and reversed directions along time axis
ActivationFunction f_activation;
ActivationFunction g_activation;
ActivationFunction h_activation;
public:
LSTMLayerImpl(const LayerParams& params)
@ -145,6 +168,20 @@ public:
reverse = params.get<bool>("reverse", false);
CV_Assert(!reverse || !bidirectional);
// read activations
DictValue activations = params.get<DictValue>("activations", "");
if (activations.size() == 1) // if activations wasn't specified use default
{
f_activation = sigmoid;
g_activation = tanh;
h_activation = tanh;
} else {
CV_Assert(activations.size() == 3);
f_activation = get_activation_function(activations.getStringValue(0));
g_activation = get_activation_function(activations.getStringValue(1));
h_activation = get_activation_function(activations.getStringValue(2));
}
allocated = false;
outTailShape.clear();
}
@ -339,15 +376,15 @@ public:
Mat gatesIF = gates.colRange(0, 2*numOut);
gemm(cInternal, blobs[5], 1, gateI, 1, gateI);
gemm(cInternal, blobs[6], 1, gateF, 1, gateF);
sigmoid(gatesIF, gatesIF);
f_activation(gatesIF, gatesIF);
}
else
{
Mat gatesIFO = gates.colRange(0, 3*numOut);
sigmoid(gatesIFO, gatesIFO);
f_activation(gatesIFO, gatesIFO);
}
tanh(gateG, gateG);
g_activation(gateG, gateG);
//compute c_t
multiply(gateF, cInternal, gateF); // f_t (*) c_{t-1}
@ -362,11 +399,11 @@ public:
if (usePeephole)
{
gemm(cInternal, blobs[7], 1, gateO, 1, gateO);
sigmoid(gateO, gateO);
f_activation(gateO, gateO);
}
//compute h_t
tanh(cInternal, hInternal);
h_activation(cInternal, hInternal);
multiply(gateO, hInternal, hInternal);
//save results in output blobs

View File

@ -244,6 +244,10 @@ static DictValue parse(const ::google::protobuf::RepeatedField< ::google::protob
return DictValue::arrayInt(&dst[0], src.size());
}
static DictValue parseStr(const ::google::protobuf::RepeatedPtrField< ::std::string>& src) {
return DictValue::arrayString(src.begin(), static_cast<int>(src.size()));
}
LayerParams ONNXImporter::getLayerParams(const opencv_onnx::NodeProto& node_proto)
{
LayerParams lp;
@ -301,6 +305,10 @@ LayerParams ONNXImporter::getLayerParams(const opencv_onnx::NodeProto& node_prot
CV_Assert(attribute_proto.ints_size() == 1 || attribute_proto.ints_size() == 2 || attribute_proto.ints_size() == 3);
lp.set("dilation", parse(attribute_proto.ints()));
}
else if(attribute_name == "activations" && node_proto.op_type() == "LSTM")
{
lp.set(attribute_name, parseStr(attribute_proto.strings()));
}
else if (attribute_proto.has_i())
{
::google::protobuf::int64 src = attribute_proto.i();
@ -997,18 +1005,32 @@ void ONNXImporter::parseLSTM(LayerParams& layerParams, const opencv_onnx::NodePr
lstmParams.name += "/lstm";
// https://pytorch.org/docs/stable/nn.html#lstm
CV_Assert(node_proto.input_size() == 7);
CV_Assert(node_proto.input_size() >= 7);
Mat Wx = getBlob(node_proto, 1);
Mat Wh = getBlob(node_proto, 2);
Mat b = getBlob(node_proto, 3);
Mat h0 = getBlob(node_proto, 5);
Mat c0 = getBlob(node_proto, 6);
b = b.reshape(1, b.size[0]);
const int numHidden = lstmParams.get<int>("hidden_size");
const int numDirs = Wx.size[0]; // Is 1 for forward only and 2 for bidirectional LSTM.
const int numFeatures = Wx.size[2];
Mat h0, c0;
if (!node_proto.input(5).empty()) {
h0 = getBlob(node_proto, 5);
h0 = h0.reshape(1, h0.size[0] * h0.size[1]);
} else {
// initial_h attribute can be empty in case of keras2onnx producer. fill it with zeros
h0 = Mat::zeros(numDirs * numFeatures, numHidden, CV_32FC1);
}
if (!node_proto.input(6).empty()) {
c0 = getBlob(node_proto, 6);
c0 = c0.reshape(1, c0.size[0] * c0.size[1]);
} else {
// initial_c attribute can be empty in case of keras2onnx producer. fill it with zeros
c0 = Mat::zeros(numDirs * numFeatures, numHidden, CV_32FC1);
}
b = b.reshape(1, b.size[0]);
Mat bx = b.colRange(0, b.cols / 2);
Mat bh = b.colRange(b.cols / 2, b.cols);
b = bx + bh;
@ -1036,8 +1058,7 @@ void ONNXImporter::parseLSTM(LayerParams& layerParams, const opencv_onnx::NodePr
}
Wx = Wx.reshape(1, Wx.size[0] * Wx.size[1]);
Wh = Wh.reshape(1, Wh.size[0] * Wh.size[1]);
h0 = h0.reshape(1, h0.size[0] * h0.size[1]);
c0 = c0.reshape(1, c0.size[0] * c0.size[1]);
lstmParams.blobs.resize(5);
lstmParams.blobs[0] = Wh;
@ -1045,6 +1066,9 @@ void ONNXImporter::parseLSTM(LayerParams& layerParams, const opencv_onnx::NodePr
lstmParams.blobs[2] = b;
lstmParams.blobs[3] = h0;
lstmParams.blobs[4] = c0;
// read direction attribute
lstmParams.set("reverse", lstmParams.get<String>("direction", "") == "reverse");
lstmParams.set("bidirectional", lstmParams.get<String>("direction", "") == "bidirectional");
node_proto.set_output(0, lstmParams.name); // set different name so output shapes will be registered on that name

View File

@ -665,6 +665,11 @@ TEST_P(Test_ONNX_layers, Split_EltwiseMax)
testONNXModels("split_max");
}
TEST_P(Test_ONNX_layers, LSTM_Activations)
{
testONNXModels("lstm_cntk_tanh", pb, 0, 0, false, false);
}
TEST_P(Test_ONNX_layers, LSTM)
{
testONNXModels("lstm", npy, 0, 0, false, false);