mirror of
https://github.com/opencv/opencv.git
synced 2025-06-10 19:24:07 +08:00
Merge pull request #25880 from Jamim:fix/cuda-no-fp16
Fix CUDA for old GPUs without FP16 support #25880
Fixes #21461
~This is a build-time solution that reflects https://github.com/opencv/opencv/blob/4.10.0/modules/dnn/src/cuda4dnn/init.hpp#L68-L82.~
~We shouldn't add an invalid target while building with `CUDA_ARCH_BIN` < 53.~
_(please see [this discussion](https://github.com/opencv/opencv/pull/25880#discussion_r1668074505))_
This is a run-time solution that basically reverts [these lines](d0fe6ad109 (diff-757c5ab6ddf2f99cdd09f851e3cf17abff203aff4107d908c7ad3d0466f39604L245-R245)
).
I've debugged these changes, [coupled with other fixes](https://github.com/gentoo/gentoo/pull/37479), on [Gentoo Linux](https://www.gentoo.org/) and [related tests passed](https://github.com/user-attachments/files/16135391/opencv-4.10.0.20240708-224733.log.gz) on my laptop with `GeForce GTX 960M`.
Alternative solution:
- #21462
_Best regards!_
### Pull Request Readiness Checklist
- [x] I agree to contribute to the project under Apache 2 License.
- [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV
- [x] The PR is proposed to the proper branch
- [x] There is a reference to the original bug report and related work
- [ ] `n/a` There is accuracy test, performance test and test data in opencv_extra repository, if applicable
- [ ] `n/a` The feature is well documented and sample code can be built with the project CMake
This commit is contained in:
parent
b964943517
commit
35ca2f78d6
@ -15,7 +15,7 @@
|
|||||||
|
|
||||||
namespace cv { namespace dnn { namespace cuda4dnn {
|
namespace cv { namespace dnn { namespace cuda4dnn {
|
||||||
|
|
||||||
void checkVersions()
|
inline void checkVersions()
|
||||||
{
|
{
|
||||||
// https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#programming-model
|
// https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#programming-model
|
||||||
// cuDNN API Compatibility
|
// cuDNN API Compatibility
|
||||||
@ -44,21 +44,23 @@ namespace cv { namespace dnn { namespace cuda4dnn {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int getDeviceCount()
|
inline int getDeviceCount()
|
||||||
{
|
{
|
||||||
return cuda::getCudaEnabledDeviceCount();
|
return cuda::getCudaEnabledDeviceCount();
|
||||||
}
|
}
|
||||||
|
|
||||||
int getDevice()
|
inline int getDevice()
|
||||||
{
|
{
|
||||||
int device_id = -1;
|
int device_id = -1;
|
||||||
CUDA4DNN_CHECK_CUDA(cudaGetDevice(&device_id));
|
CUDA4DNN_CHECK_CUDA(cudaGetDevice(&device_id));
|
||||||
return device_id;
|
return device_id;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool isDeviceCompatible()
|
inline bool isDeviceCompatible(int device_id = -1)
|
||||||
{
|
{
|
||||||
int device_id = getDevice();
|
if (device_id < 0)
|
||||||
|
device_id = getDevice();
|
||||||
|
|
||||||
if (device_id < 0)
|
if (device_id < 0)
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
@ -76,9 +78,11 @@ namespace cv { namespace dnn { namespace cuda4dnn {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool doesDeviceSupportFP16()
|
inline bool doesDeviceSupportFP16(int device_id = -1)
|
||||||
{
|
{
|
||||||
int device_id = getDevice();
|
if (device_id < 0)
|
||||||
|
device_id = getDevice();
|
||||||
|
|
||||||
if (device_id < 0)
|
if (device_id < 0)
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
@ -87,9 +91,7 @@ namespace cv { namespace dnn { namespace cuda4dnn {
|
|||||||
CUDA4DNN_CHECK_CUDA(cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device_id));
|
CUDA4DNN_CHECK_CUDA(cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device_id));
|
||||||
|
|
||||||
int version = major * 10 + minor;
|
int version = major * 10 + minor;
|
||||||
if (version < 53)
|
return (version >= 53);
|
||||||
return false;
|
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}}} /* namespace cv::dnn::cuda4dnn */
|
}}} /* namespace cv::dnn::cuda4dnn */
|
||||||
|
@ -10,6 +10,10 @@
|
|||||||
#include "backend.hpp"
|
#include "backend.hpp"
|
||||||
#include "factory.hpp"
|
#include "factory.hpp"
|
||||||
|
|
||||||
|
#ifdef HAVE_CUDA
|
||||||
|
#include "cuda4dnn/init.hpp"
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace cv {
|
namespace cv {
|
||||||
namespace dnn {
|
namespace dnn {
|
||||||
CV__DNN_INLINE_NS_BEGIN
|
CV__DNN_INLINE_NS_BEGIN
|
||||||
@ -242,6 +246,16 @@ void Net::Impl::setPreferableTarget(int targetId)
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (IS_DNN_CUDA_TARGET(targetId))
|
||||||
|
{
|
||||||
|
preferableTarget = DNN_TARGET_CPU;
|
||||||
|
#ifdef HAVE_CUDA
|
||||||
|
if (cuda4dnn::doesDeviceSupportFP16() && targetId == DNN_TARGET_CUDA_FP16)
|
||||||
|
preferableTarget = DNN_TARGET_CUDA_FP16;
|
||||||
|
else
|
||||||
|
preferableTarget = DNN_TARGET_CUDA;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
#if !defined(__arm64__) || !__arm64__
|
#if !defined(__arm64__) || !__arm64__
|
||||||
if (targetId == DNN_TARGET_CPU_FP16)
|
if (targetId == DNN_TARGET_CPU_FP16)
|
||||||
{
|
{
|
||||||
|
@ -18,6 +18,10 @@
|
|||||||
#include "backend.hpp"
|
#include "backend.hpp"
|
||||||
#include "factory.hpp"
|
#include "factory.hpp"
|
||||||
|
|
||||||
|
#ifdef HAVE_CUDA
|
||||||
|
#include "cuda4dnn/init.hpp"
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace cv {
|
namespace cv {
|
||||||
namespace dnn {
|
namespace dnn {
|
||||||
CV__DNN_INLINE_NS_BEGIN
|
CV__DNN_INLINE_NS_BEGIN
|
||||||
@ -118,10 +122,28 @@ private:
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef HAVE_CUDA
|
#ifdef HAVE_CUDA
|
||||||
if (haveCUDA())
|
cuda4dnn::checkVersions();
|
||||||
|
|
||||||
|
bool hasCudaCompatible = false;
|
||||||
|
bool hasCudaFP16 = false;
|
||||||
|
for (int i = 0; i < cuda4dnn::getDeviceCount(); i++)
|
||||||
|
{
|
||||||
|
if (cuda4dnn::isDeviceCompatible(i))
|
||||||
|
{
|
||||||
|
hasCudaCompatible = true;
|
||||||
|
if (cuda4dnn::doesDeviceSupportFP16(i))
|
||||||
|
{
|
||||||
|
hasCudaFP16 = true;
|
||||||
|
break; // we already have all we need here
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (hasCudaCompatible)
|
||||||
{
|
{
|
||||||
backends.push_back(std::make_pair(DNN_BACKEND_CUDA, DNN_TARGET_CUDA));
|
backends.push_back(std::make_pair(DNN_BACKEND_CUDA, DNN_TARGET_CUDA));
|
||||||
backends.push_back(std::make_pair(DNN_BACKEND_CUDA, DNN_TARGET_CUDA_FP16));
|
if (hasCudaFP16)
|
||||||
|
backends.push_back(std::make_pair(DNN_BACKEND_CUDA, DNN_TARGET_CUDA_FP16));
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
@ -211,7 +211,7 @@ public:
|
|||||||
if ((!l->supportBackend(backend) || l->preferableTarget != target) && !fused)
|
if ((!l->supportBackend(backend) || l->preferableTarget != target) && !fused)
|
||||||
{
|
{
|
||||||
hasFallbacks = true;
|
hasFallbacks = true;
|
||||||
std::cout << "FALLBACK: Layer [" << l->type << "]:[" << l->name << "] is expected to has backend implementation" << endl;
|
std::cout << "FALLBACK: Layer [" << l->type << "]:[" << l->name << "] is expected to have backend implementation" << endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (hasFallbacks && raiseError)
|
if (hasFallbacks && raiseError)
|
||||||
|
@ -1016,7 +1016,7 @@ public:
|
|||||||
if ((!l->supportBackend(backend) || l->preferableTarget != target) && !fused)
|
if ((!l->supportBackend(backend) || l->preferableTarget != target) && !fused)
|
||||||
{
|
{
|
||||||
hasFallbacks = true;
|
hasFallbacks = true;
|
||||||
std::cout << "FALLBACK: Layer [" << l->type << "]:[" << l->name << "] is expected to has backend implementation" << endl;
|
std::cout << "FALLBACK: Layer [" << l->type << "]:[" << l->name << "] is expected to have backend implementation" << endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return hasFallbacks;
|
return hasFallbacks;
|
||||||
|
Loading…
Reference in New Issue
Block a user