Merge pull request #9285 from arrybn:issue_9223

This commit is contained in:
Alexander Alekhin 2017-08-02 17:51:46 +00:00
commit 1ce9ffcc7f
5 changed files with 69 additions and 0 deletions

View File

@ -349,6 +349,12 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
static Ptr<ChannelsPReLULayer> create(const LayerParams& params);
};
class CV_EXPORTS ELULayer : public ActivationLayer
{
public:
static Ptr<ELULayer> create(const LayerParams &params);
};
class CV_EXPORTS TanHLayer : public ActivationLayer
{
public:

View File

@ -96,6 +96,7 @@ void initializeLayerFactory()
CV_DNN_REGISTER_LAYER_CLASS(ChannelsPReLU, ChannelsPReLULayer);
CV_DNN_REGISTER_LAYER_CLASS(Sigmoid, SigmoidLayer);
CV_DNN_REGISTER_LAYER_CLASS(TanH, TanHLayer);
CV_DNN_REGISTER_LAYER_CLASS(ELU, ELULayer);
CV_DNN_REGISTER_LAYER_CLASS(BNLL, BNLLLayer);
CV_DNN_REGISTER_LAYER_CLASS(AbsVal, AbsLayer);
CV_DNN_REGISTER_LAYER_CLASS(Power, PowerLayer);

View File

@ -302,6 +302,35 @@ struct SigmoidFunctor
int64 getFLOPSPerElement() const { return 3; }
};
struct ELUFunctor
{
typedef ELULayer Layer;
explicit ELUFunctor() {}
void apply(const float* srcptr, float* dstptr, int len, size_t planeSize, int cn0, int cn1) const
{
for( int cn = cn0; cn < cn1; cn++, srcptr += planeSize, dstptr += planeSize )
{
for(int i = 0; i < len; i++ )
{
float x = srcptr[i];
dstptr[i] = x >= 0.f ? x : exp(x) - 1;
}
}
}
#ifdef HAVE_HALIDE
void attachHalide(const Halide::Expr& input, Halide::Func& top)
{
Halide::Var x("x"), y("y"), c("c"), n("n");
top(x, y, c, n) = select(input >= 0.0f, input, exp(input) - 1);
}
#endif // HAVE_HALIDE
int64 getFLOPSPerElement() const { return 2; }
};
struct AbsValFunctor
{
typedef AbsLayer Layer;
@ -504,6 +533,14 @@ Ptr<SigmoidLayer> SigmoidLayer::create(const LayerParams& params)
return l;
}
Ptr<ELULayer> ELULayer::create(const LayerParams& params)
{
Ptr<ELULayer> l(new ElementWiseLayer<ELUFunctor>(ELUFunctor()));
l->setParamsFrom(params);
return l;
}
Ptr<AbsLayer> AbsLayer::create(const LayerParams& params)
{
Ptr<AbsLayer> l(new ElementWiseLayer<AbsValFunctor>());

View File

@ -677,6 +677,13 @@ void TFImporter::populateNet(Net dstNet)
connectToAllBlobs(layer_id, dstNet, parsePin(layer.input(0)), id, layer.input_size());
}
else if (type == "Elu")
{
int id = dstNet.addLayer(name, "ELU", layerParams);
layer_id[name] = id;
connectToAllBlobs(layer_id, dstNet, parsePin(layer.input(0)), id, layer.input_size());
}
else if (type == "MaxPool")
{
layerParams.set("pool", "max");

View File

@ -268,11 +268,29 @@ static void test_Reshape_Split_Slice_layers()
normAssert(input, output);
}
TEST(Layer_Test_Reshape_Split_Slice, Accuracy)
{
test_Reshape_Split_Slice_layers();
}
TEST(Layer_Conv_Elu, Accuracy)
{
Net net;
{
Ptr<Importer> importer = createTensorflowImporter(_tf("layer_elu_model.pb"));
ASSERT_TRUE(importer != NULL);
importer->populateNet(net);
}
Mat inp = blobFromNPY(_tf("layer_elu_in.npy"));
Mat ref = blobFromNPY(_tf("layer_elu_out.npy"));
net.setInput(inp, "input");
Mat out = net.forward();
normAssert(ref, out);
}
class Layer_LSTM_Test : public ::testing::Test
{
public: