mirror of
https://github.com/opencv/opencv.git
synced 2025-01-18 22:44:02 +08:00
Merge pull request #20506 from JulieBar:lstm_activations
* Support activations(Sigmoid, Tanh) for LSTM * fix warning
This commit is contained in:
parent
9d3826c676
commit
cfb36443fb
@ -80,12 +80,31 @@ static void sigmoid(const Mat &src, Mat &dst)
|
||||
cv::pow(1 + dst, -1, dst);
|
||||
}
|
||||
|
||||
typedef void (*ActivationFunction)(const Mat &src, Mat &dst);
|
||||
static ActivationFunction get_activation_function(const String& activation) {
|
||||
// most used activations for PyTorch and TF : Tanh, Sigmoid
|
||||
// if you need to support more optional activations use std::map instead
|
||||
if (activation == "Tanh")
|
||||
{
|
||||
return tanh;
|
||||
}
|
||||
else if (activation == "Sigmoid")
|
||||
{
|
||||
return sigmoid;
|
||||
}
|
||||
else
|
||||
{
|
||||
CV_Error(Error::StsNotImplemented,
|
||||
cv::format("Activation function [%s] for layer LSTM is not supported", activation.c_str()));
|
||||
}
|
||||
}
|
||||
|
||||
class LSTMLayerImpl CV_FINAL : public LSTMLayer
|
||||
{
|
||||
int numTimeStamps, numSamples;
|
||||
bool allocated;
|
||||
|
||||
MatShape outTailShape; //shape of single output sample
|
||||
MatShape outTailShape; //shape of single output sample
|
||||
MatShape outTsShape; //shape of N output samples
|
||||
|
||||
bool useTimestampDim;
|
||||
@ -95,6 +114,10 @@ class LSTMLayerImpl CV_FINAL : public LSTMLayer
|
||||
bool reverse; // If true, go in negative direction along the time axis
|
||||
bool bidirectional; // If true, produces both forward and reversed directions along time axis
|
||||
|
||||
ActivationFunction f_activation;
|
||||
ActivationFunction g_activation;
|
||||
ActivationFunction h_activation;
|
||||
|
||||
public:
|
||||
|
||||
LSTMLayerImpl(const LayerParams& params)
|
||||
@ -145,6 +168,20 @@ public:
|
||||
reverse = params.get<bool>("reverse", false);
|
||||
CV_Assert(!reverse || !bidirectional);
|
||||
|
||||
// read activations
|
||||
DictValue activations = params.get<DictValue>("activations", "");
|
||||
if (activations.size() == 1) // if activations wasn't specified use default
|
||||
{
|
||||
f_activation = sigmoid;
|
||||
g_activation = tanh;
|
||||
h_activation = tanh;
|
||||
} else {
|
||||
CV_Assert(activations.size() == 3);
|
||||
f_activation = get_activation_function(activations.getStringValue(0));
|
||||
g_activation = get_activation_function(activations.getStringValue(1));
|
||||
h_activation = get_activation_function(activations.getStringValue(2));
|
||||
}
|
||||
|
||||
allocated = false;
|
||||
outTailShape.clear();
|
||||
}
|
||||
@ -339,15 +376,15 @@ public:
|
||||
Mat gatesIF = gates.colRange(0, 2*numOut);
|
||||
gemm(cInternal, blobs[5], 1, gateI, 1, gateI);
|
||||
gemm(cInternal, blobs[6], 1, gateF, 1, gateF);
|
||||
sigmoid(gatesIF, gatesIF);
|
||||
f_activation(gatesIF, gatesIF);
|
||||
}
|
||||
else
|
||||
{
|
||||
Mat gatesIFO = gates.colRange(0, 3*numOut);
|
||||
sigmoid(gatesIFO, gatesIFO);
|
||||
f_activation(gatesIFO, gatesIFO);
|
||||
}
|
||||
|
||||
tanh(gateG, gateG);
|
||||
g_activation(gateG, gateG);
|
||||
|
||||
//compute c_t
|
||||
multiply(gateF, cInternal, gateF); // f_t (*) c_{t-1}
|
||||
@ -362,11 +399,11 @@ public:
|
||||
if (usePeephole)
|
||||
{
|
||||
gemm(cInternal, blobs[7], 1, gateO, 1, gateO);
|
||||
sigmoid(gateO, gateO);
|
||||
f_activation(gateO, gateO);
|
||||
}
|
||||
|
||||
//compute h_t
|
||||
tanh(cInternal, hInternal);
|
||||
h_activation(cInternal, hInternal);
|
||||
multiply(gateO, hInternal, hInternal);
|
||||
|
||||
//save results in output blobs
|
||||
|
@ -244,6 +244,10 @@ static DictValue parse(const ::google::protobuf::RepeatedField< ::google::protob
|
||||
return DictValue::arrayInt(&dst[0], src.size());
|
||||
}
|
||||
|
||||
static DictValue parseStr(const ::google::protobuf::RepeatedPtrField< ::std::string>& src) {
|
||||
return DictValue::arrayString(src.begin(), static_cast<int>(src.size()));
|
||||
}
|
||||
|
||||
LayerParams ONNXImporter::getLayerParams(const opencv_onnx::NodeProto& node_proto)
|
||||
{
|
||||
LayerParams lp;
|
||||
@ -301,6 +305,10 @@ LayerParams ONNXImporter::getLayerParams(const opencv_onnx::NodeProto& node_prot
|
||||
CV_Assert(attribute_proto.ints_size() == 1 || attribute_proto.ints_size() == 2 || attribute_proto.ints_size() == 3);
|
||||
lp.set("dilation", parse(attribute_proto.ints()));
|
||||
}
|
||||
else if(attribute_name == "activations" && node_proto.op_type() == "LSTM")
|
||||
{
|
||||
lp.set(attribute_name, parseStr(attribute_proto.strings()));
|
||||
}
|
||||
else if (attribute_proto.has_i())
|
||||
{
|
||||
::google::protobuf::int64 src = attribute_proto.i();
|
||||
@ -997,18 +1005,32 @@ void ONNXImporter::parseLSTM(LayerParams& layerParams, const opencv_onnx::NodePr
|
||||
lstmParams.name += "/lstm";
|
||||
|
||||
// https://pytorch.org/docs/stable/nn.html#lstm
|
||||
CV_Assert(node_proto.input_size() == 7);
|
||||
CV_Assert(node_proto.input_size() >= 7);
|
||||
Mat Wx = getBlob(node_proto, 1);
|
||||
Mat Wh = getBlob(node_proto, 2);
|
||||
Mat b = getBlob(node_proto, 3);
|
||||
Mat h0 = getBlob(node_proto, 5);
|
||||
Mat c0 = getBlob(node_proto, 6);
|
||||
|
||||
b = b.reshape(1, b.size[0]);
|
||||
|
||||
const int numHidden = lstmParams.get<int>("hidden_size");
|
||||
const int numDirs = Wx.size[0]; // Is 1 for forward only and 2 for bidirectional LSTM.
|
||||
const int numFeatures = Wx.size[2];
|
||||
|
||||
Mat h0, c0;
|
||||
if (!node_proto.input(5).empty()) {
|
||||
h0 = getBlob(node_proto, 5);
|
||||
h0 = h0.reshape(1, h0.size[0] * h0.size[1]);
|
||||
} else {
|
||||
// initial_h attribute can be empty in case of keras2onnx producer. fill it with zeros
|
||||
h0 = Mat::zeros(numDirs * numFeatures, numHidden, CV_32FC1);
|
||||
}
|
||||
if (!node_proto.input(6).empty()) {
|
||||
c0 = getBlob(node_proto, 6);
|
||||
c0 = c0.reshape(1, c0.size[0] * c0.size[1]);
|
||||
} else {
|
||||
// initial_c attribute can be empty in case of keras2onnx producer. fill it with zeros
|
||||
c0 = Mat::zeros(numDirs * numFeatures, numHidden, CV_32FC1);
|
||||
}
|
||||
|
||||
b = b.reshape(1, b.size[0]);
|
||||
Mat bx = b.colRange(0, b.cols / 2);
|
||||
Mat bh = b.colRange(b.cols / 2, b.cols);
|
||||
b = bx + bh;
|
||||
@ -1036,8 +1058,7 @@ void ONNXImporter::parseLSTM(LayerParams& layerParams, const opencv_onnx::NodePr
|
||||
}
|
||||
Wx = Wx.reshape(1, Wx.size[0] * Wx.size[1]);
|
||||
Wh = Wh.reshape(1, Wh.size[0] * Wh.size[1]);
|
||||
h0 = h0.reshape(1, h0.size[0] * h0.size[1]);
|
||||
c0 = c0.reshape(1, c0.size[0] * c0.size[1]);
|
||||
|
||||
|
||||
lstmParams.blobs.resize(5);
|
||||
lstmParams.blobs[0] = Wh;
|
||||
@ -1045,6 +1066,9 @@ void ONNXImporter::parseLSTM(LayerParams& layerParams, const opencv_onnx::NodePr
|
||||
lstmParams.blobs[2] = b;
|
||||
lstmParams.blobs[3] = h0;
|
||||
lstmParams.blobs[4] = c0;
|
||||
|
||||
// read direction attribute
|
||||
lstmParams.set("reverse", lstmParams.get<String>("direction", "") == "reverse");
|
||||
lstmParams.set("bidirectional", lstmParams.get<String>("direction", "") == "bidirectional");
|
||||
|
||||
node_proto.set_output(0, lstmParams.name); // set different name so output shapes will be registered on that name
|
||||
|
@ -665,6 +665,11 @@ TEST_P(Test_ONNX_layers, Split_EltwiseMax)
|
||||
testONNXModels("split_max");
|
||||
}
|
||||
|
||||
TEST_P(Test_ONNX_layers, LSTM_Activations)
|
||||
{
|
||||
testONNXModels("lstm_cntk_tanh", pb, 0, 0, false, false);
|
||||
}
|
||||
|
||||
TEST_P(Test_ONNX_layers, LSTM)
|
||||
{
|
||||
testONNXModels("lstm", npy, 0, 0, false, false);
|
||||
|
Loading…
Reference in New Issue
Block a user