// This file is part of OpenCV project. // It is subject to the license terms in the LICENSE file found in the top-level directory // of this distribution and at http://opencv.org/license.html. #ifndef OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_CELLS_HPP #define OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_CELLS_HPP #include "../../op_cuda.hpp" #include "../csl/cudnn.hpp" #include "../csl/tensor_ops.hpp" #include "../csl/cudnn/recurrent.hpp" namespace cv { namespace dnn { namespace cuda4dnn { struct RNNConfiguration { int seqLength; int numLayers; int hiddenSize; int inputSize; int miniBatch; bool bidirectional; }; template class LSTMOp final : public CUDABackendNode { public: using wrapper_type = GetCUDABackendWrapperType; LSTMOp(csl::Stream stream_, csl::cudnn::Handle handle, const Mat& filters, const Mat& h0, const Mat& c0, const RNNConfiguration& config) : stream(std::move(stream_)) { typename csl::LSTM::params_type params{ {filters.total(), 1, 1}, // reshape config.seqLength, config.numLayers, config.hiddenSize, config.inputSize, config.miniBatch, config.bidirectional, 0.0, /* dropout */ csl::cudnn::RNNDescriptor::RNNMode::LSTM }; lstm = csl::LSTM(handle, params); auto correct_shape_filters = filters.reshape(1, {static_cast(filters.total()), 1, 1}); filtersTensor = csl::makeTensorHeader(correct_shape_filters); csl::copyMatToTensor(correct_shape_filters, filtersTensor, stream); h0Tensor = csl::makeTensorHeader(h0); csl::copyMatToTensor(h0, h0Tensor, stream); c0Tensor = csl::makeTensorHeader(c0); csl::copyMatToTensor(c0, c0Tensor, stream); csl::WorkspaceBuilder builder; builder.require(lstm.get_workspace_memory_in_bytes()); } void forward(const std::vector>& inputs, const std::vector>& outputs, csl::Workspace& workspace) override { CV_Assert(inputs.size() == 1 && !outputs.empty()); auto input_wrapper = inputs[0].dynamicCast(); auto input = input_wrapper->getView(); auto y_output_wrapper = outputs[0].dynamicCast(); auto y_output = y_output_wrapper->getSpan(); Ptr yc_output_wrapper = outputs.size() == 2 ? outputs[1].dynamicCast() : Ptr(); csl::TensorSpan yc_output = yc_output_wrapper.empty() ? csl::TensorSpan() : yc_output_wrapper->getSpan(); csl::WorkspaceAllocator allocator(workspace); lstm.inference(input, y_output, yc_output, filtersTensor, h0Tensor, c0Tensor, allocator.get_instance()); } std::size_t get_workspace_memory_in_bytes() const noexcept override { return lstm.get_workspace_memory_in_bytes(); } private: csl::LSTM lstm; csl::Stream stream; csl::Tensor filtersTensor; csl::Tensor h0Tensor; csl::Tensor c0Tensor; }; }}} /* namespace cv::dnn::cuda4dnn */ #endif //OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_RECURRENT_CELLS_HPP