// This file is part of OpenCV project. // It is subject to the license terms in the LICENSE file found in the top-level directory // of this distribution and at http://opencv.org/license.html. #ifndef OPENCV_DNN_SRC_CUDA_KERNEL_DISPATCHER_HPP #define OPENCV_DNN_SRC_CUDA_KERNEL_DISPATCHER_HPP #include #include /* The performance of many kernels are highly dependent on the tensor rank. Instead of having * one kernel which can work with the maximally ranked tensors, we make one kernel for each supported * tensor rank. This is to ensure that the requirements of the maximally ranked tensors do not take a * toll on the performance of the operation for low ranked tensors. Hence, many kernels take the tensor * rank as a template parameter. * * The kernel is a template and we have different instantiations for each rank. This causes the following pattern * to arise frequently: * * if(rank == 3) * kernel(); * else if(rank == 2) * kernel(); * else * kernel(); * * The rank is a runtime variable. To facilitate creation of such structures, we use GENERATE_KERNEL_DISPATCHER. * This macro creates a function which selects the correct kernel instantiation at runtime. * * Example: * * // function which setups the kernel and launches it * template * void launch_some_kernel(...); * * // creates the dispatcher named "some_dispatcher" which invokves the correct instantiation of "launch_some_kernel" * GENERATE_KERNEL_DISPATCHER(some_dispatcher, launch_some_kernel); * * // internal API function * template * void some(...) { * // ... * auto rank = input.rank(); * some_dispatcher(rank, ...); * } */ /* * name name of the dispatcher function that is generated * func template function that requires runtime selection * * T first template parameter to `func` * start starting rank * end ending rank (inclusive) * * Executes func based on runtime `selector` argument given `selector` lies * within the range [start, end]. If outside the range, no instantiation of `func` is executed. */ #define GENERATE_KERNEL_DISPATCHER(name,func); \ template static \ typename std::enable_if \ ::type name(int selector, Args&& ...args) { \ if(selector == start) \ func(std::forward(args)...); \ } \ \ template static \ typename std::enable_if \ ::type name(int selector, Args&& ...args) { \ if(selector == start) \ func(std::forward(args)...); \ else \ name(selector, std::forward(args)...); \ } #endif /* OPENCV_DNN_SRC_CUDA_KERNEL_DISPATCHER_HPP */