mirror of
https://github.com/opencv/opencv.git
synced 2025-06-13 04:52:53 +08:00
Add CUDA support for LSTM.
Co-authored-by: Julia Bareeva <jbareeva@gmail.com>
This commit is contained in:
parent
be38d4ea93
commit
abebbf04b1
@ -287,6 +287,51 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace cu
|
|||||||
cudnnTensorDescriptor_t descriptor;
|
cudnnTensorDescriptor_t descriptor;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/** An array of number fully packed tensor descriptors
|
||||||
|
*
|
||||||
|
* @tparam T type of elements in the tensor
|
||||||
|
*/
|
||||||
|
template<class T>
|
||||||
|
class TensorDescriptorsArray
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
TensorDescriptorsArray() noexcept = default;
|
||||||
|
TensorDescriptorsArray(const TensorDescriptorsArray&) = delete;
|
||||||
|
TensorDescriptorsArray(TensorDescriptorsArray&& other) noexcept
|
||||||
|
: descriptors{std::move(other.descriptors)} {}
|
||||||
|
|
||||||
|
TensorDescriptorsArray(int seqLength, std::array<int, 3> dims)
|
||||||
|
{
|
||||||
|
for (int i = 0; i < seqLength; ++i)
|
||||||
|
{
|
||||||
|
descriptors.emplace_back(dims);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
~TensorDescriptorsArray() noexcept = default;
|
||||||
|
|
||||||
|
TensorDescriptorsArray& operator=(const TensorDescriptorsArray&) = delete;
|
||||||
|
TensorDescriptorsArray& operator=(TensorDescriptorsArray&& other) noexcept
|
||||||
|
{
|
||||||
|
descriptors = std::move(other.descriptors);
|
||||||
|
return *this;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<cudnnTensorDescriptor_t> get() const noexcept
|
||||||
|
{
|
||||||
|
std::vector<cudnnTensorDescriptor_t> descPtrs;
|
||||||
|
descPtrs.reserve(descriptors.size());
|
||||||
|
for (auto& desc : descriptors)
|
||||||
|
{
|
||||||
|
descPtrs.push_back(desc.get());
|
||||||
|
}
|
||||||
|
return descPtrs;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::vector<TensorDescriptor<T>> descriptors;
|
||||||
|
};
|
||||||
|
|
||||||
}}}}} /* namespace cv::dnn::cuda4dnn::csl::cudnn */
|
}}}}} /* namespace cv::dnn::cuda4dnn::csl::cudnn */
|
||||||
|
|
||||||
#endif /* OPENCV_DNN_CUDA4DNN_CSL_CUDNN_HPP */
|
#endif /* OPENCV_DNN_CUDA4DNN_CSL_CUDNN_HPP */
|
||||||
|
195
modules/dnn/src/cuda4dnn/csl/cudnn/recurrent.hpp
Normal file
195
modules/dnn/src/cuda4dnn/csl/cudnn/recurrent.hpp
Normal file
@ -0,0 +1,195 @@
|
|||||||
|
// This file is part of OpenCV project.
|
||||||
|
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||||
|
// of this distribution and at http://opencv.org/license.html.
|
||||||
|
|
||||||
|
#ifndef OPENCV_DNN_CUDA4DNN_CSL_CUDNN_RECURRENT_HPP
|
||||||
|
#define OPENCV_DNN_CUDA4DNN_CSL_CUDNN_RECURRENT_HPP
|
||||||
|
|
||||||
|
#include "cudnn.hpp"
|
||||||
|
#include <cudnn.h>
|
||||||
|
|
||||||
|
|
||||||
|
namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace cudnn {
|
||||||
|
|
||||||
|
/**
|
||||||
|
*/
|
||||||
|
class DropoutDescriptor
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
DropoutDescriptor() noexcept = default;
|
||||||
|
DropoutDescriptor(const DropoutDescriptor &) = delete;
|
||||||
|
DropoutDescriptor(DropoutDescriptor &&other) noexcept : descriptor{other.descriptor}
|
||||||
|
{
|
||||||
|
states = std::move(other.states);
|
||||||
|
other.descriptor = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
*/
|
||||||
|
DropoutDescriptor(const Handle &handle, float dropout)
|
||||||
|
{
|
||||||
|
CUDA4DNN_CHECK_CUDNN(cudnnCreateDropoutDescriptor(&descriptor));
|
||||||
|
|
||||||
|
// we need additional memory for dropout descriptor
|
||||||
|
size_t stateSize;
|
||||||
|
CUDA4DNN_CHECK_CUDNN(cudnnDropoutGetStatesSize(handle.get(), &stateSize));
|
||||||
|
states.reset(stateSize);
|
||||||
|
|
||||||
|
try
|
||||||
|
{
|
||||||
|
auto seed = 1234ull; // Pick a seed.
|
||||||
|
CUDA4DNN_CHECK_CUDNN(cudnnSetDropoutDescriptor(descriptor, handle.get(), dropout,
|
||||||
|
states.get().get(), stateSize, seed));
|
||||||
|
}
|
||||||
|
catch (...)
|
||||||
|
{
|
||||||
|
CUDA4DNN_CHECK_CUDNN(cudnnDestroyDropoutDescriptor(descriptor));
|
||||||
|
throw;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
~DropoutDescriptor() noexcept
|
||||||
|
{
|
||||||
|
if (descriptor)
|
||||||
|
{
|
||||||
|
CUDA4DNN_CHECK_CUDNN(cudnnDestroyDropoutDescriptor(descriptor));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
DropoutDescriptor &operator=(const DropoutDescriptor &) = delete;
|
||||||
|
DropoutDescriptor &operator=(DropoutDescriptor &&other) noexcept
|
||||||
|
{
|
||||||
|
descriptor = other.descriptor;
|
||||||
|
states = std::move(other.states);
|
||||||
|
other.descriptor = nullptr;
|
||||||
|
return *this;
|
||||||
|
};
|
||||||
|
|
||||||
|
cudnnDropoutDescriptor_t get() const noexcept { return descriptor; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
cudnnDropoutDescriptor_t descriptor{nullptr};
|
||||||
|
|
||||||
|
using value_type = typename ManagedPtr<char>::element_type;
|
||||||
|
ManagedPtr<value_type> states;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
*/
|
||||||
|
template<class T>
|
||||||
|
class RNNDescriptor
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
enum class RNNMode
|
||||||
|
{
|
||||||
|
RNN_RELU,
|
||||||
|
RNN_TANH,
|
||||||
|
LSTM,
|
||||||
|
GRU
|
||||||
|
};
|
||||||
|
|
||||||
|
RNNDescriptor() noexcept = default;
|
||||||
|
RNNDescriptor(const RNNDescriptor &) = delete;
|
||||||
|
RNNDescriptor(RNNDescriptor &&other) noexcept : descriptor{other.descriptor}
|
||||||
|
{
|
||||||
|
other.descriptor = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
*/
|
||||||
|
RNNDescriptor(const Handle &handle, RNNMode mode, int hidden_size, int num_layers,
|
||||||
|
bool bidirectional, const DropoutDescriptor &dropoutDesc)
|
||||||
|
{
|
||||||
|
CUDA4DNN_CHECK_CUDNN(cudnnCreateRNNDescriptor(&descriptor));
|
||||||
|
const auto rnn_mode = [mode] {
|
||||||
|
switch (mode)
|
||||||
|
{
|
||||||
|
case RNNMode::RNN_RELU:
|
||||||
|
return CUDNN_RNN_RELU;
|
||||||
|
case RNNMode::RNN_TANH:
|
||||||
|
return CUDNN_RNN_TANH;
|
||||||
|
case RNNMode::LSTM:
|
||||||
|
return CUDNN_LSTM;
|
||||||
|
case RNNMode::GRU:
|
||||||
|
return CUDNN_GRU;
|
||||||
|
default:
|
||||||
|
return CUDNN_LSTM;
|
||||||
|
}
|
||||||
|
}();
|
||||||
|
|
||||||
|
try
|
||||||
|
{
|
||||||
|
CUDA4DNN_CHECK_CUDNN(cudnnSetRNNDescriptor_v6(
|
||||||
|
handle.get(), descriptor, hidden_size, num_layers, dropoutDesc.get(),
|
||||||
|
CUDNN_LINEAR_INPUT, bidirectional ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL,
|
||||||
|
rnn_mode,
|
||||||
|
algo, //CUDNN_RNN_ALGO_STANDARD,
|
||||||
|
detail::get_data_type<T>()));
|
||||||
|
}
|
||||||
|
catch (...)
|
||||||
|
{
|
||||||
|
CUDA4DNN_CHECK_CUDNN(cudnnDestroyRNNDescriptor(descriptor));
|
||||||
|
throw;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
~RNNDescriptor() noexcept
|
||||||
|
{
|
||||||
|
if (descriptor)
|
||||||
|
{
|
||||||
|
CUDA4DNN_CHECK_CUDNN(cudnnDestroyRNNDescriptor(descriptor));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RNNDescriptor &operator=(const RNNDescriptor &) = delete;
|
||||||
|
RNNDescriptor &operator=(RNNDescriptor &&other) noexcept
|
||||||
|
{
|
||||||
|
descriptor = other.descriptor;
|
||||||
|
other.descriptor = nullptr;
|
||||||
|
return *this;
|
||||||
|
};
|
||||||
|
|
||||||
|
cudnnRNNDescriptor_t get() const noexcept { return descriptor; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
cudnnRNNDescriptor_t descriptor{nullptr};
|
||||||
|
cudnnRNNMode_t mode{CUDNN_LSTM};
|
||||||
|
// support only one algo for a while
|
||||||
|
cudnnRNNAlgo_t algo{CUDNN_RNN_ALGO_STANDARD};
|
||||||
|
};
|
||||||
|
|
||||||
|
template<class T>
|
||||||
|
size_t getRNNWorkspaceSize(const Handle &handle, const RNNDescriptor<T> &rnnDesc,
|
||||||
|
const int seqLength, const TensorDescriptorsArray<T> &inputDesc)
|
||||||
|
{
|
||||||
|
size_t workSize;
|
||||||
|
CUDA4DNN_CHECK_CUDNN(cudnnGetRNNWorkspaceSize(handle.get(), rnnDesc.get(), seqLength,
|
||||||
|
inputDesc.get().data(), &workSize));
|
||||||
|
return workSize;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<class T>
|
||||||
|
void LSTMForward(const Handle &handle, const RNNDescriptor<T> &rnnDesc,
|
||||||
|
const FilterDescriptor<T> &filterDesc, DevicePtr<const T> filterPtr,
|
||||||
|
const TensorDescriptorsArray<T> &inputDesc, DevicePtr<const T> inputPtr,
|
||||||
|
const TensorDescriptor<T> &initialHDesc, DevicePtr<const T> initialH,
|
||||||
|
const TensorDescriptor<T> &initialCDesc, DevicePtr<const T> initialC,
|
||||||
|
const int seqLength, const TensorDescriptorsArray<T> &outputDesc,
|
||||||
|
DevicePtr<T> yOutputPtr, DevicePtr<T> ycOutputPtr, WorkspaceInstance workspace)
|
||||||
|
{
|
||||||
|
CV_Assert(handle);
|
||||||
|
|
||||||
|
CUDA4DNN_CHECK_CUDNN(cudnnRNNForwardInference(handle.get(), rnnDesc.get(), seqLength,
|
||||||
|
inputDesc.get().data(), inputPtr.get(), // input sequence
|
||||||
|
initialHDesc.get(), initialH.get(),
|
||||||
|
initialCDesc.get(), initialC.get(), // hidden
|
||||||
|
filterDesc.get(), filterPtr.get(), // weights
|
||||||
|
outputDesc.get().data(), yOutputPtr.get(), // output
|
||||||
|
nullptr, nullptr,
|
||||||
|
initialCDesc.get(), ycOutputPtr.get(),
|
||||||
|
static_cast<void*>(workspace.get()), workspace.size_in_bytes()));
|
||||||
|
}
|
||||||
|
|
||||||
|
}}}}} /* namespace cv::dnn::cuda4dnn::csl::cudnn */
|
||||||
|
|
||||||
|
#endif //OPENCV_DNN_CUDA4DNN_CSL_CUDNN_RECURRENT_HPP
|
@ -18,6 +18,7 @@
|
|||||||
#include "cudnn/softmax.hpp"
|
#include "cudnn/softmax.hpp"
|
||||||
#include "cudnn/transform.hpp"
|
#include "cudnn/transform.hpp"
|
||||||
#include "cudnn/transpose_convolution.hpp"
|
#include "cudnn/transpose_convolution.hpp"
|
||||||
|
#include "cudnn/recurrent.hpp"
|
||||||
|
|
||||||
#include <opencv2/core.hpp>
|
#include <opencv2/core.hpp>
|
||||||
|
|
||||||
@ -472,6 +473,90 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl {
|
|||||||
TensorTransformDescriptor transDesc;
|
TensorTransformDescriptor transDesc;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template<class T>
|
||||||
|
class LSTM
|
||||||
|
{
|
||||||
|
using TensorDescriptor = cudnn::TensorDescriptor<T>;
|
||||||
|
using DropoutDescriptor = cudnn::DropoutDescriptor;
|
||||||
|
using RNNDescriptor = cudnn::RNNDescriptor<T>;
|
||||||
|
using FilterDescriptor = cudnn::FilterDescriptor<T>;
|
||||||
|
using TensorDescriptorsArray = cudnn::TensorDescriptorsArray<T>;
|
||||||
|
|
||||||
|
public:
|
||||||
|
using RNNMode = typename RNNDescriptor::RNNMode;
|
||||||
|
|
||||||
|
struct params_type
|
||||||
|
{
|
||||||
|
std::vector<std::size_t> weights_shape;
|
||||||
|
|
||||||
|
int seqLength;
|
||||||
|
int numLayers;
|
||||||
|
int hiddenSize;
|
||||||
|
int inputSize;
|
||||||
|
int miniBatch;
|
||||||
|
bool bidirectional;
|
||||||
|
|
||||||
|
float dropout;
|
||||||
|
RNNMode type;
|
||||||
|
};
|
||||||
|
|
||||||
|
LSTM() = default;
|
||||||
|
LSTM(const LSTM&) = delete;
|
||||||
|
LSTM(LSTM&&) = default;
|
||||||
|
LSTM(cudnn::Handle handle, const params_type& params)
|
||||||
|
: cudnnHandle(std::move(handle)), seqLength{params.seqLength},
|
||||||
|
inputDesc(seqLength, {params.miniBatch, params.inputSize, 1}),
|
||||||
|
outputDesc(seqLength,
|
||||||
|
{params.miniBatch,
|
||||||
|
params.bidirectional ? params.hiddenSize * 2 : params.hiddenSize,
|
||||||
|
1})
|
||||||
|
{
|
||||||
|
dropoutDesc = DropoutDescriptor(cudnnHandle, params.dropout);
|
||||||
|
filterDesc = FilterDescriptor(params.weights_shape);
|
||||||
|
rnnDesc = RNNDescriptor(cudnnHandle, params.type, params.hiddenSize,
|
||||||
|
params.numLayers, params.bidirectional, dropoutDesc);
|
||||||
|
|
||||||
|
int num_direction = params.bidirectional ? 2 : 1;
|
||||||
|
h0TensorDesc = TensorDescriptor(
|
||||||
|
{num_direction, params.miniBatch, params.hiddenSize});
|
||||||
|
c0TensorDesc = TensorDescriptor(
|
||||||
|
{num_direction, params.miniBatch, params.hiddenSize});
|
||||||
|
|
||||||
|
// Get amount of work space required to execute the RNN described by rnnDesc
|
||||||
|
// with input dimensions defined by inputDesc
|
||||||
|
csl::WorkspaceBuilder builder;
|
||||||
|
builder.require(cudnn::getRNNWorkspaceSize<T>(cudnnHandle, rnnDesc, seqLength, inputDesc));
|
||||||
|
scratch_mem_in_bytes = builder.required_workspace_size();
|
||||||
|
}
|
||||||
|
|
||||||
|
LSTM& operator=(const LSTM&) = delete;
|
||||||
|
LSTM& operator=(LSTM&&) = default;
|
||||||
|
|
||||||
|
void inference(TensorView<T> input, TensorSpan<T> y_output, TensorSpan<T> yc_output, TensorView<T> filters,
|
||||||
|
TensorView<T> h0, TensorView<T> c0, WorkspaceInstance workspace)
|
||||||
|
{
|
||||||
|
cudnn::LSTMForward<T>(cudnnHandle, rnnDesc, filterDesc, filters.get(), inputDesc,
|
||||||
|
input.get(), h0TensorDesc, h0.get(), c0TensorDesc, c0.get(),
|
||||||
|
seqLength, outputDesc, y_output.get(), yc_output.get(), workspace);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::size_t get_workspace_memory_in_bytes() const noexcept { return scratch_mem_in_bytes; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
cudnn::Handle cudnnHandle;
|
||||||
|
std::size_t scratch_mem_in_bytes{0};
|
||||||
|
int seqLength;
|
||||||
|
|
||||||
|
RNNDescriptor rnnDesc;
|
||||||
|
DropoutDescriptor dropoutDesc;
|
||||||
|
|
||||||
|
FilterDescriptor filterDesc;
|
||||||
|
TensorDescriptor h0TensorDesc, c0TensorDesc;
|
||||||
|
|
||||||
|
TensorDescriptorsArray inputDesc;
|
||||||
|
TensorDescriptorsArray outputDesc;
|
||||||
|
};
|
||||||
|
|
||||||
}}}} /* namespace cv::dnn::cuda4dnn::csl */
|
}}}} /* namespace cv::dnn::cuda4dnn::csl */
|
||||||
|
|
||||||
#endif /* OPENCV_DNN_SRC_CUDA4DNN_CSL_TENSOR_OPS_HPP */
|
#endif /* OPENCV_DNN_SRC_CUDA4DNN_CSL_TENSOR_OPS_HPP */
|
||||||
|
97
modules/dnn/src/cuda4dnn/primitives/recurrent_cells.hpp
Normal file
97
modules/dnn/src/cuda4dnn/primitives/recurrent_cells.hpp
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
// This file is part of OpenCV project.
|
||||||
|
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||||
|
// of this distribution and at http://opencv.org/license.html.
|
||||||
|
|
||||||
|
#ifndef OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_CELLS_HPP
|
||||||
|
#define OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_CELLS_HPP
|
||||||
|
|
||||||
|
#include "../../op_cuda.hpp"
|
||||||
|
|
||||||
|
#include "../csl/cudnn.hpp"
|
||||||
|
#include "../csl/tensor_ops.hpp"
|
||||||
|
#include "../csl/cudnn/recurrent.hpp"
|
||||||
|
|
||||||
|
namespace cv { namespace dnn { namespace cuda4dnn {
|
||||||
|
|
||||||
|
struct RNNConfiguration
|
||||||
|
{
|
||||||
|
int seqLength;
|
||||||
|
int numLayers;
|
||||||
|
int hiddenSize;
|
||||||
|
int inputSize;
|
||||||
|
int miniBatch;
|
||||||
|
bool bidirectional;
|
||||||
|
};
|
||||||
|
|
||||||
|
template<class T>
|
||||||
|
class LSTMOp final : public CUDABackendNode
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
using wrapper_type = GetCUDABackendWrapperType<T>;
|
||||||
|
|
||||||
|
LSTMOp(csl::Stream stream_, csl::cudnn::Handle handle, const Mat& filters, const Mat& h0,
|
||||||
|
const Mat& c0, const RNNConfiguration& config)
|
||||||
|
: stream(std::move(stream_))
|
||||||
|
{
|
||||||
|
typename csl::LSTM<T>::params_type params{
|
||||||
|
{filters.total(), 1, 1}, // reshape
|
||||||
|
config.seqLength,
|
||||||
|
config.numLayers,
|
||||||
|
config.hiddenSize,
|
||||||
|
config.inputSize,
|
||||||
|
config.miniBatch,
|
||||||
|
config.bidirectional,
|
||||||
|
0.0, /* dropout */
|
||||||
|
csl::cudnn::RNNDescriptor<T>::RNNMode::LSTM
|
||||||
|
};
|
||||||
|
|
||||||
|
lstm = csl::LSTM<T>(handle, params);
|
||||||
|
auto correct_shape_filters = filters.reshape(1, {static_cast<int>(filters.total()), 1, 1});
|
||||||
|
filtersTensor = csl::makeTensorHeader<T>(correct_shape_filters);
|
||||||
|
csl::copyMatToTensor<T>(correct_shape_filters, filtersTensor, stream);
|
||||||
|
|
||||||
|
h0Tensor = csl::makeTensorHeader<T>(h0);
|
||||||
|
csl::copyMatToTensor<T>(h0, h0Tensor, stream);
|
||||||
|
|
||||||
|
c0Tensor = csl::makeTensorHeader<T>(c0);
|
||||||
|
csl::copyMatToTensor<T>(c0, c0Tensor, stream);
|
||||||
|
|
||||||
|
csl::WorkspaceBuilder builder;
|
||||||
|
builder.require<T>(lstm.get_workspace_memory_in_bytes());
|
||||||
|
}
|
||||||
|
|
||||||
|
void forward(const std::vector<cv::Ptr<BackendWrapper>>& inputs,
|
||||||
|
const std::vector<cv::Ptr<BackendWrapper>>& outputs,
|
||||||
|
csl::Workspace& workspace) override
|
||||||
|
{
|
||||||
|
CV_Assert(inputs.size() == 1 && !outputs.empty());
|
||||||
|
|
||||||
|
auto input_wrapper = inputs[0].dynamicCast<wrapper_type>();
|
||||||
|
auto input = input_wrapper->getView();
|
||||||
|
|
||||||
|
auto y_output_wrapper = outputs[0].dynamicCast<wrapper_type>();
|
||||||
|
auto y_output = y_output_wrapper->getSpan();
|
||||||
|
|
||||||
|
Ptr<wrapper_type> yc_output_wrapper = outputs.size() == 2 ? outputs[1].dynamicCast<wrapper_type>() : Ptr<wrapper_type>();
|
||||||
|
csl::TensorSpan<T> yc_output = yc_output_wrapper.empty() ? csl::TensorSpan<T>() : yc_output_wrapper->getSpan();
|
||||||
|
|
||||||
|
csl::WorkspaceAllocator allocator(workspace);
|
||||||
|
lstm.inference(input, y_output, yc_output, filtersTensor, h0Tensor, c0Tensor, allocator.get_instance());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::size_t get_workspace_memory_in_bytes() const noexcept override
|
||||||
|
{
|
||||||
|
return lstm.get_workspace_memory_in_bytes();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
csl::LSTM<T> lstm;
|
||||||
|
csl::Stream stream;
|
||||||
|
csl::Tensor<T> filtersTensor;
|
||||||
|
csl::Tensor<T> h0Tensor;
|
||||||
|
csl::Tensor<T> c0Tensor;
|
||||||
|
};
|
||||||
|
|
||||||
|
}}} /* namespace cv::dnn::cuda4dnn */
|
||||||
|
|
||||||
|
#endif //OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_RECURRENT_CELLS_HPP
|
@ -42,10 +42,14 @@
|
|||||||
|
|
||||||
#include "../precomp.hpp"
|
#include "../precomp.hpp"
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <iterator>
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <opencv2/dnn/shape_utils.hpp>
|
#include <opencv2/dnn/shape_utils.hpp>
|
||||||
|
|
||||||
|
#ifdef HAVE_CUDA
|
||||||
|
#include "../cuda4dnn/primitives/recurrent_cells.hpp"
|
||||||
|
using namespace cv::dnn::cuda4dnn;
|
||||||
|
#endif
|
||||||
|
|
||||||
#include "layers_common.hpp"
|
#include "layers_common.hpp"
|
||||||
|
|
||||||
namespace cv
|
namespace cv
|
||||||
@ -119,6 +123,7 @@ class LSTMLayerImpl CV_FINAL : public LSTMLayer
|
|||||||
ActivationFunction f_activation;
|
ActivationFunction f_activation;
|
||||||
ActivationFunction g_activation;
|
ActivationFunction g_activation;
|
||||||
ActivationFunction h_activation;
|
ActivationFunction h_activation;
|
||||||
|
bool isDefaultActivations{true};
|
||||||
|
|
||||||
#if CV_TRY_AVX
|
#if CV_TRY_AVX
|
||||||
bool useAVX;
|
bool useAVX;
|
||||||
@ -202,11 +207,15 @@ public:
|
|||||||
f_activation = sigmoid;
|
f_activation = sigmoid;
|
||||||
g_activation = tanh;
|
g_activation = tanh;
|
||||||
h_activation = tanh;
|
h_activation = tanh;
|
||||||
|
isDefaultActivations = true;
|
||||||
} else {
|
} else {
|
||||||
CV_Assert(activations.size() == 3);
|
CV_Assert(activations.size() == 3);
|
||||||
f_activation = get_activation_function(activations.getStringValue(0));
|
f_activation = get_activation_function(activations.getStringValue(0));
|
||||||
g_activation = get_activation_function(activations.getStringValue(1));
|
g_activation = get_activation_function(activations.getStringValue(1));
|
||||||
h_activation = get_activation_function(activations.getStringValue(2));
|
h_activation = get_activation_function(activations.getStringValue(2));
|
||||||
|
isDefaultActivations = activations.getStringValue(0) == "Sigmoid"
|
||||||
|
&& activations.getStringValue(1) == "Tanh"
|
||||||
|
&& activations.getStringValue(2) == "Tanh";
|
||||||
}
|
}
|
||||||
|
|
||||||
allocated = false;
|
allocated = false;
|
||||||
@ -245,6 +254,12 @@ public:
|
|||||||
blobs[2] = Mat(bias.clone()).reshape(1, 1);
|
blobs[2] = Mat(bias.clone()).reshape(1, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool supportBackend(int backendId) CV_OVERRIDE
|
||||||
|
{
|
||||||
|
return backendId == DNN_BACKEND_OPENCV
|
||||||
|
|| (backendId == DNN_BACKEND_CUDA && isDefaultActivations && !reverse && !usePeephole);
|
||||||
|
}
|
||||||
|
|
||||||
bool getMemoryShapes(const std::vector<MatShape> &inputs,
|
bool getMemoryShapes(const std::vector<MatShape> &inputs,
|
||||||
const int requiredOutputs,
|
const int requiredOutputs,
|
||||||
std::vector<MatShape> &outputs,
|
std::vector<MatShape> &outputs,
|
||||||
@ -582,29 +597,8 @@ public:
|
|||||||
cOut = cOut.reshape(1, sizeof(shp)/sizeof(shp[0]), shp);
|
cOut = cOut.reshape(1, sizeof(shp)/sizeof(shp[0]), shp);
|
||||||
|
|
||||||
// permute to {0, 2, 1, 3};
|
// permute to {0, 2, 1, 3};
|
||||||
std::vector<int> newShape = shape(cOut);
|
cv::Mat newCellState;
|
||||||
std::swap(newShape[1], newShape[2]);
|
cv::transposeND(cOut, {0, 2, 1, 3}, newCellState);
|
||||||
cv::Mat newCellState(newShape, CV_32FC1);
|
|
||||||
const float* src = cOut.ptr<const float>();
|
|
||||||
float* dst = newCellState.ptr<float>();
|
|
||||||
size_t sj = newCellState.size[3];
|
|
||||||
size_t sk = newCellState.size[2] * sj;
|
|
||||||
size_t si = newCellState.size[1] * sk;
|
|
||||||
for (size_t i = 0; i < newCellState.size[0]; i++)
|
|
||||||
{
|
|
||||||
for (size_t j = 0; j < newCellState.size[2]; j++)
|
|
||||||
{
|
|
||||||
for (size_t k = 0; k < newCellState.size[1]; k++)
|
|
||||||
{
|
|
||||||
std::memcpy(dst, src, sizeof(float) * newCellState.size[3]);
|
|
||||||
src += cOut.size[3];
|
|
||||||
dst += sk;
|
|
||||||
}
|
|
||||||
dst = dst + sj - si;
|
|
||||||
}
|
|
||||||
dst = dst + si - sk;
|
|
||||||
}
|
|
||||||
|
|
||||||
cOut = newCellState;
|
cOut = newCellState;
|
||||||
|
|
||||||
if (numDirs == 1)
|
if (numDirs == 1)
|
||||||
@ -637,6 +631,77 @@ public:
|
|||||||
cOut = cOut.reshape(1, sizeof(finalShape)/sizeof(finalShape[0]), finalShape);
|
cOut = cOut.reshape(1, sizeof(finalShape)/sizeof(finalShape[0]), finalShape);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef HAVE_CUDA
|
||||||
|
Ptr<BackendNode> initCUDA(void *context_, const std::vector<Ptr<BackendWrapper>> &inputs,
|
||||||
|
const std::vector<Ptr<BackendWrapper>> &outputs) override
|
||||||
|
{
|
||||||
|
const int numDirs = 1 + static_cast<int>(bidirectional);
|
||||||
|
auto toIFCO = [numDirs] (Mat& in) {
|
||||||
|
int first = in.size[0];
|
||||||
|
int rest = in.total() / first / 4;
|
||||||
|
// every weight blob contains weights for Input, Output, Forget and Cell gates
|
||||||
|
Mat m = in.reshape(1, {first, 4, rest});
|
||||||
|
Mat outputGate = m.col(1);
|
||||||
|
Mat forgetGate = m.col(2);
|
||||||
|
Mat cellGate = m.col(3);
|
||||||
|
// IOFC -> IFOC
|
||||||
|
std::swap_ranges(outputGate.begin<float>(), outputGate.end<float>(), forgetGate.begin<float>());
|
||||||
|
std::swap(outputGate, forgetGate);
|
||||||
|
// IFOC -> IFCO
|
||||||
|
std::swap_ranges(outputGate.begin<float>(), outputGate.end<float>(), cellGate.begin<float>());
|
||||||
|
in = in.reshape(1, numDirs);
|
||||||
|
};
|
||||||
|
|
||||||
|
Mat& b = originalBlobs[2];
|
||||||
|
// B is a concatenation of biases for Wh and Wx
|
||||||
|
b = b.reshape(1, originalBlobs[2].size[0]*2);
|
||||||
|
|
||||||
|
for (auto& m : originalBlobs)
|
||||||
|
{
|
||||||
|
toIFCO(m);
|
||||||
|
}
|
||||||
|
|
||||||
|
b = b.reshape(1, static_cast<int>(b.total()));
|
||||||
|
|
||||||
|
Mat ordered_weights;
|
||||||
|
// Wx_f, Wh_f, [Wx_b, Wh_b,] b
|
||||||
|
for (int i = 0; i < numDirs; ++i)
|
||||||
|
{
|
||||||
|
for (size_t j = 0; j < 2; ++j) // Wx, Wh
|
||||||
|
{
|
||||||
|
Mat oneDirection = originalBlobs[j].row(i);
|
||||||
|
ordered_weights.push_back(oneDirection.reshape(1, static_cast<int>(oneDirection.total())));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ordered_weights.push_back(b);
|
||||||
|
|
||||||
|
// Pass hidden states as is
|
||||||
|
Mat h0 = blobs[3];
|
||||||
|
Mat c0 = blobs[4];
|
||||||
|
|
||||||
|
CV_Assert(!inputs.empty());
|
||||||
|
auto input_wrapper = inputs[0].dynamicCast<CUDABackendWrapper>();
|
||||||
|
auto input_shape = input_wrapper->getShape();
|
||||||
|
|
||||||
|
RNNConfiguration config
|
||||||
|
{
|
||||||
|
input_shape[0], // seqLength;
|
||||||
|
1, // numLayers;
|
||||||
|
numHidden, // hiddenSize;
|
||||||
|
input_shape[2], // inputSize;
|
||||||
|
input_shape[1], // miniBatch;
|
||||||
|
bidirectional
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
auto *context = reinterpret_cast<cuda4dnn::csl::CSLContext *>(context_);
|
||||||
|
return make_cuda_node<cuda4dnn::LSTMOp>(preferableTarget, std::move(context->stream),
|
||||||
|
std::move(context->cudnn_handle),
|
||||||
|
ordered_weights, h0, c0,
|
||||||
|
config);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
};
|
};
|
||||||
|
|
||||||
Ptr<LSTMLayer> LSTMLayer::create(const LayerParams& params)
|
Ptr<LSTMLayer> LSTMLayer::create(const LayerParams& params)
|
||||||
|
@ -1574,8 +1574,6 @@ void transformBlobs(std::vector<Mat>& blobs)
|
|||||||
cudaWorkaround.push_back(b.clone());
|
cudaWorkaround.push_back(b.clone());
|
||||||
|
|
||||||
const int numHidden = Wh.size[2];
|
const int numHidden = Wh.size[2];
|
||||||
const int numDirs = Wx.size[0]; // Is 1 for forward only and 2 for bidirectional LSTM.
|
|
||||||
const int numFeatures = Wx.size[2];
|
|
||||||
|
|
||||||
Mat h0 = blobs[3];
|
Mat h0 = blobs[3];
|
||||||
h0 = h0.reshape(1, h0.size[0] * h0.size[1]);
|
h0 = h0.reshape(1, h0.size[0] * h0.size[1]);
|
||||||
@ -1587,30 +1585,20 @@ void transformBlobs(std::vector<Mat>& blobs)
|
|||||||
Mat bh = b.colRange(b.cols / 2, b.cols);
|
Mat bh = b.colRange(b.cols / 2, b.cols);
|
||||||
b = bx + bh;
|
b = bx + bh;
|
||||||
|
|
||||||
// b is numDirs X numHidden*3
|
auto toIFOC = [] (Mat& in) {
|
||||||
CV_CheckLE(numHidden * 3, b.cols, "Bias data should have at least 3x hidden_size columns");
|
int first = in.size[0];
|
||||||
|
int rest = in.total() / first / 4;
|
||||||
|
// every weight blob contains weights for Input, Output, Forget and Cell gates
|
||||||
|
Mat m = in.reshape(1, {first, 4, rest});
|
||||||
|
Mat outputGate = m.col(1);
|
||||||
|
Mat forgetGate = m.col(2);
|
||||||
|
std::swap_ranges(outputGate.begin<float>(), outputGate.end<float>(), forgetGate.begin<float>());
|
||||||
|
};
|
||||||
|
|
||||||
|
toIFOC(Wx);
|
||||||
|
toIFOC(Wh);
|
||||||
|
toIFOC(b);
|
||||||
|
|
||||||
// IFGO->IGFO
|
|
||||||
for (int k = 0; k < numDirs; ++k)
|
|
||||||
{
|
|
||||||
float* WxData = Wx.ptr<float>(k);
|
|
||||||
float* WhData = Wh.ptr<float>(k);
|
|
||||||
float* biasData = b.ptr<float>(k);
|
|
||||||
for (int j = 0; j < numHidden; ++j)
|
|
||||||
{
|
|
||||||
for (int i = 0; i < numFeatures; ++i)
|
|
||||||
{
|
|
||||||
std::swap(WxData[(numHidden + j) * numFeatures + i],
|
|
||||||
WxData[(numHidden * 2 + j) * numFeatures + i]);
|
|
||||||
}
|
|
||||||
for (int i = 0; i < numHidden; ++i)
|
|
||||||
{
|
|
||||||
std::swap(WhData[(numHidden + j) * numHidden + i],
|
|
||||||
WhData[(numHidden * 2 + j) * numHidden + i]);
|
|
||||||
}
|
|
||||||
std::swap(biasData[numHidden + j], biasData[numHidden * 2 + j]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Wx = Wx.reshape(1, Wx.size[0] * Wx.size[1]);
|
Wx = Wx.reshape(1, Wx.size[0] * Wx.size[1]);
|
||||||
Wh = Wh.reshape(1, Wh.size[0] * Wh.size[1]);
|
Wh = Wh.reshape(1, Wh.size[0] * Wh.size[1]);
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user