mirror of
https://github.com/opencv/opencv.git
synced 2024-11-27 20:50:25 +08:00
Merge pull request #24060 from TolyaTalamanov:at/advanced-device-selection-onnxrt-directml
G-API: Advanced device selection for ONNX DirectML Execution Provider #24060 ### Overview Extend `cv::gapi::onnx::ep::DirectML` to accept `adapter name` as `ctor` parameter in order to select execution device by `name`. E.g: ``` pp.cfgAddExecutionProvider(cv::gapi::onnx::ep::DirectML("Intel Graphics")); ``` ### Pull Request Readiness Checklist See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request - [ ] I agree to contribute to the project under Apache 2 License. - [ ] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV - [ ] The PR is proposed to the proper branch - [ ] There is a reference to the original bug report and related work - [ ] There is accuracy test, performance test and test data in opencv_extra repository, if applicable Patch to opencv_extra has the same branch name. - [ ] The feature is well documented and sample code can be built with the project CMake
This commit is contained in:
parent
024dfd54af
commit
0e151e3c88
@ -411,6 +411,9 @@ OCV_OPTION(WITH_OPENCLAMDBLAS "Include AMD OpenCL BLAS library support" ON
|
||||
OCV_OPTION(WITH_DIRECTX "Include DirectX support" ON
|
||||
VISIBLE_IF WIN32 AND NOT WINRT
|
||||
VERIFY HAVE_DIRECTX)
|
||||
OCV_OPTION(WITH_DIRECTML "Include DirectML support" ON
|
||||
VISIBLE_IF WIN32 AND NOT WINRT
|
||||
VERIFY HAVE_DIRECTML)
|
||||
OCV_OPTION(WITH_OPENCL_D3D11_NV "Include NVIDIA OpenCL D3D11 support" WITH_DIRECTX
|
||||
VISIBLE_IF WIN32 AND NOT WINRT
|
||||
VERIFY HAVE_OPENCL_D3D11_NV)
|
||||
@ -848,6 +851,10 @@ endif()
|
||||
if(WITH_DIRECTX)
|
||||
include(cmake/OpenCVDetectDirectX.cmake)
|
||||
endif()
|
||||
# --- DirectML ---
|
||||
if(WITH_DIRECTML)
|
||||
include(cmake/OpenCVDetectDirectML.cmake)
|
||||
endif()
|
||||
|
||||
if(WITH_VTK)
|
||||
include(cmake/OpenCVDetectVTK.cmake)
|
||||
|
13
cmake/OpenCVDetectDirectML.cmake
Normal file
13
cmake/OpenCVDetectDirectML.cmake
Normal file
@ -0,0 +1,13 @@
|
||||
if(WIN32)
|
||||
try_compile(__VALID_DIRECTML
|
||||
"${OpenCV_BINARY_DIR}"
|
||||
"${OpenCV_SOURCE_DIR}/cmake/checks/directml.cpp"
|
||||
LINK_LIBRARIES d3d12 dxcore directml
|
||||
OUTPUT_VARIABLE TRY_OUT
|
||||
)
|
||||
if(NOT __VALID_DIRECTML)
|
||||
message(STATUS "No support for DirectML (d3d12, dxcore, directml libs are required)")
|
||||
return()
|
||||
endif()
|
||||
set(HAVE_DIRECTML ON)
|
||||
endif()
|
38
cmake/checks/directml.cpp
Normal file
38
cmake/checks/directml.cpp
Normal file
@ -0,0 +1,38 @@
|
||||
#include <initguid.h>
|
||||
|
||||
#include <d3d11.h>
|
||||
#include <dxgi1_2.h>
|
||||
#include <dxgi1_4.h>
|
||||
#include <dxgi.h>
|
||||
#include <dxcore.h>
|
||||
#include <dxcore_interface.h>
|
||||
#include <d3d12.h>
|
||||
#include <directml.h>
|
||||
|
||||
int main(int /*argc*/, char** /*argv*/)
|
||||
{
|
||||
IDXCoreAdapterFactory* factory;
|
||||
DXCoreCreateAdapterFactory(__uuidof(IDXCoreAdapterFactory), (void**)&factory);
|
||||
|
||||
IDXCoreAdapterList* adapterList;
|
||||
const GUID dxGUIDs[] = { DXCORE_ADAPTER_ATTRIBUTE_D3D12_CORE_COMPUTE };
|
||||
factory->CreateAdapterList(ARRAYSIZE(dxGUIDs), dxGUIDs, __uuidof(IDXCoreAdapterList), (void**)&adapterList);
|
||||
|
||||
IDXCoreAdapter* adapter;
|
||||
adapterList->GetAdapter(0u, __uuidof(IDXCoreAdapter), (void**)&adapter);
|
||||
|
||||
D3D_FEATURE_LEVEL d3dFeatureLevel = D3D_FEATURE_LEVEL_1_0_CORE;
|
||||
ID3D12Device* d3d12Device = NULL;
|
||||
D3D12CreateDevice((IUnknown*)adapter, d3dFeatureLevel, __uuidof(ID3D11Device), (void**)&d3d12Device);
|
||||
|
||||
D3D12_COMMAND_LIST_TYPE commandQueueType = D3D12_COMMAND_LIST_TYPE_COMPUTE;
|
||||
ID3D12CommandQueue* cmdQueue;
|
||||
D3D12_COMMAND_QUEUE_DESC commandQueueDesc = {};
|
||||
commandQueueDesc.Type = commandQueueType;
|
||||
|
||||
d3d12Device->CreateCommandQueue(&commandQueueDesc, __uuidof(ID3D12CommandQueue), (void**)&cmdQueue);
|
||||
IDMLDevice* dmlDevice;
|
||||
DMLCreateDevice(d3d12Device, DML_CREATE_DEVICE_FLAG_NONE, IID_PPV_ARGS(&dmlDevice));
|
||||
|
||||
return 0;
|
||||
}
|
@ -367,6 +367,10 @@ if(WIN32)
|
||||
ocv_target_link_libraries(${the_module} PRIVATE wsock32 ws2_32)
|
||||
endif()
|
||||
|
||||
if(HAVE_DIRECTML)
|
||||
ocv_target_compile_definitions(${the_module} PRIVATE HAVE_DIRECTML=1)
|
||||
endif()
|
||||
|
||||
if(HAVE_ONNX)
|
||||
ocv_target_link_libraries(${the_module} PRIVATE ${ONNX_LIBRARY})
|
||||
ocv_target_compile_definitions(${the_module} PRIVATE HAVE_ONNX=1)
|
||||
|
@ -189,7 +189,16 @@ public:
|
||||
GAPI_WRAP
|
||||
explicit DirectML(const int device_id) : ddesc(device_id) { };
|
||||
|
||||
using DeviceDesc = cv::util::variant<int>;
|
||||
/** @brief Class constructor.
|
||||
|
||||
Constructs DirectML parameters based on adapter name.
|
||||
|
||||
@param adapter_name Target adapter_name to use.
|
||||
*/
|
||||
GAPI_WRAP
|
||||
explicit DirectML(const std::string &adapter_name) : ddesc(adapter_name) { };
|
||||
|
||||
using DeviceDesc = cv::util::variant<int, std::string>;
|
||||
DeviceDesc ddesc;
|
||||
};
|
||||
|
||||
|
@ -13,10 +13,40 @@
|
||||
#ifdef HAVE_ONNX_DML
|
||||
#include "../providers/dml/dml_provider_factory.h"
|
||||
|
||||
#ifdef HAVE_DIRECTML
|
||||
|
||||
#undef WINVER
|
||||
#define WINVER 0x0A00
|
||||
#undef _WIN32_WINNT
|
||||
#define _WIN32_WINNT 0x0A00
|
||||
|
||||
#include <initguid.h>
|
||||
|
||||
#include <d3d11.h>
|
||||
#include <dxgi1_2.h>
|
||||
#include <dxgi1_4.h>
|
||||
#include <dxgi.h>
|
||||
#include <dxcore.h>
|
||||
#include <dxcore_interface.h>
|
||||
#include <d3d12.h>
|
||||
#include <directml.h>
|
||||
|
||||
#pragma comment (lib, "d3d11.lib")
|
||||
#pragma comment (lib, "d3d12.lib")
|
||||
#pragma comment (lib, "dxgi.lib")
|
||||
#pragma comment (lib, "dxcore.lib")
|
||||
#pragma comment (lib, "directml.lib")
|
||||
|
||||
#endif // HAVE_DIRECTML
|
||||
|
||||
static void addDMLExecutionProviderWithAdapterName(Ort::SessionOptions *session_options,
|
||||
const std::string &adapter_name);
|
||||
|
||||
void cv::gimpl::onnx::addDMLExecutionProvider(Ort::SessionOptions *session_options,
|
||||
const cv::gapi::onnx::ep::DirectML &dml_ep) {
|
||||
namespace ep = cv::gapi::onnx::ep;
|
||||
GAPI_Assert(cv::util::holds_alternative<int>(dml_ep.ddesc));
|
||||
switch (dml_ep.ddesc.index()) {
|
||||
case ep::DirectML::DeviceDesc::index_of<int>(): {
|
||||
const int device_id = cv::util::get<int>(dml_ep.ddesc);
|
||||
try {
|
||||
OrtSessionOptionsAppendExecutionProvider_DML(*session_options, device_id);
|
||||
@ -26,8 +56,215 @@ void cv::gimpl::onnx::addDMLExecutionProvider(Ort::SessionOptions *session_optio
|
||||
<< " Execution Provider: " << e.what();
|
||||
cv::util::throw_error(std::runtime_error(ss.str()));
|
||||
}
|
||||
break;
|
||||
}
|
||||
case ep::DirectML::DeviceDesc::index_of<std::string>(): {
|
||||
const std::string adapter_name = cv::util::get<std::string>(dml_ep.ddesc);
|
||||
addDMLExecutionProviderWithAdapterName(session_options, adapter_name);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
GAPI_Assert(false && "Invalid DirectML device description");
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef HAVE_DIRECTML
|
||||
|
||||
#define THROW_IF_FAILED(hr, error_msg) \
|
||||
{ \
|
||||
if ((hr) != S_OK) \
|
||||
throw std::runtime_error(error_msg); \
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void release(T *ptr) {
|
||||
if (ptr) {
|
||||
ptr->Release();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
using ComPtrGuard = std::unique_ptr<T, decltype(&release<T>)>;
|
||||
|
||||
template <typename T>
|
||||
ComPtrGuard<T> make_com_ptr(T *ptr) {
|
||||
return ComPtrGuard<T>{ptr, &release<T>};
|
||||
}
|
||||
|
||||
struct AdapterDesc {
|
||||
ComPtrGuard<IDXCoreAdapter> ptr;
|
||||
std::string description;
|
||||
};
|
||||
|
||||
static std::vector<AdapterDesc> getAvailableAdapters() {
|
||||
std::vector<AdapterDesc> all_adapters;
|
||||
|
||||
IDXCoreAdapterFactory* factory_ptr;
|
||||
GAPI_LOG_DEBUG(nullptr, "Create IDXCoreAdapterFactory");
|
||||
THROW_IF_FAILED(
|
||||
DXCoreCreateAdapterFactory(
|
||||
__uuidof(IDXCoreAdapterFactory), (void**)&factory_ptr),
|
||||
"Failed to create IDXCoreAdapterFactory");
|
||||
auto factory = make_com_ptr<IDXCoreAdapterFactory>(factory_ptr);
|
||||
|
||||
IDXCoreAdapterList* adapter_list_ptr;
|
||||
const GUID dxGUIDs[] = { DXCORE_ADAPTER_ATTRIBUTE_D3D12_CORE_COMPUTE };
|
||||
GAPI_LOG_DEBUG(nullptr, "CreateAdapterList");
|
||||
THROW_IF_FAILED(
|
||||
factory->CreateAdapterList(
|
||||
ARRAYSIZE(dxGUIDs), dxGUIDs, __uuidof(IDXCoreAdapterList), (void**)&adapter_list_ptr),
|
||||
"Failed to create IDXCoreAdapterList");
|
||||
auto adapter_list = make_com_ptr<IDXCoreAdapterList>(adapter_list_ptr);
|
||||
|
||||
for (UINT i = 0; i < adapter_list->GetAdapterCount(); i++)
|
||||
{
|
||||
IDXCoreAdapter* curr_adapter_ptr;
|
||||
GAPI_LOG_DEBUG(nullptr, "GetAdapter");
|
||||
THROW_IF_FAILED(
|
||||
adapter_list->GetAdapter(
|
||||
i, __uuidof(IDXCoreAdapter), (void**)&curr_adapter_ptr),
|
||||
"Failed to obtain IDXCoreAdapter"
|
||||
);
|
||||
auto curr_adapter = make_com_ptr<IDXCoreAdapter>(curr_adapter_ptr);
|
||||
|
||||
bool is_hardware = false;
|
||||
curr_adapter->GetProperty(DXCoreAdapterProperty::IsHardware, &is_hardware);
|
||||
// NB: Filter out if not hardware adapter.
|
||||
if (!is_hardware) {
|
||||
continue;
|
||||
}
|
||||
|
||||
size_t desc_size = 0u;
|
||||
char description[256];
|
||||
curr_adapter->GetPropertySize(DXCoreAdapterProperty::DriverDescription, &desc_size);
|
||||
curr_adapter->GetProperty(DXCoreAdapterProperty::DriverDescription, desc_size, &description);
|
||||
all_adapters.push_back(AdapterDesc{std::move(curr_adapter), description});
|
||||
}
|
||||
return all_adapters;
|
||||
};
|
||||
|
||||
struct DMLDeviceInfo {
|
||||
ComPtrGuard<IDMLDevice> device;
|
||||
ComPtrGuard<ID3D12CommandQueue> cmd_queue;
|
||||
};
|
||||
|
||||
static DMLDeviceInfo createDMLInfo(IDXCoreAdapter* adapter) {
|
||||
auto pAdapter = make_com_ptr<IUnknown>(adapter);
|
||||
D3D_FEATURE_LEVEL d3dFeatureLevel = D3D_FEATURE_LEVEL_1_0_CORE;
|
||||
if (adapter->IsAttributeSupported(DXCORE_ADAPTER_ATTRIBUTE_D3D12_GRAPHICS))
|
||||
{
|
||||
GAPI_LOG_INFO(nullptr, "DXCORE_ADAPTER_ATTRIBUTE_D3D12_GRAPHICS is supported");
|
||||
d3dFeatureLevel = D3D_FEATURE_LEVEL::D3D_FEATURE_LEVEL_11_0;
|
||||
|
||||
IDXGIFactory4* dxgiFactory4;
|
||||
GAPI_LOG_DEBUG(nullptr, "CreateDXGIFactory2");
|
||||
THROW_IF_FAILED(
|
||||
CreateDXGIFactory2(0, __uuidof(IDXGIFactory4), (void**)&dxgiFactory4),
|
||||
"Failed to create IDXGIFactory4"
|
||||
);
|
||||
// If DXGI factory creation was successful then get the IDXGIAdapter from the LUID
|
||||
// acquired from the selectedAdapter
|
||||
LUID adapterLuid;
|
||||
IDXGIAdapter* spDxgiAdapter;
|
||||
|
||||
GAPI_LOG_DEBUG(nullptr, "Get DXCoreAdapterProperty::InstanceLuid property");
|
||||
THROW_IF_FAILED(
|
||||
adapter->GetProperty(DXCoreAdapterProperty::InstanceLuid, &adapterLuid),
|
||||
"Failed to get DXCoreAdapterProperty::InstanceLuid property");
|
||||
|
||||
GAPI_LOG_DEBUG(nullptr, "Get IDXGIAdapter by luid");
|
||||
THROW_IF_FAILED(
|
||||
dxgiFactory4->EnumAdapterByLuid(
|
||||
adapterLuid, __uuidof(IDXGIAdapter), (void**)&spDxgiAdapter),
|
||||
"Failed to get IDXGIAdapter");
|
||||
pAdapter = make_com_ptr<IUnknown>(spDxgiAdapter);
|
||||
} else {
|
||||
GAPI_LOG_INFO(nullptr, "DXCORE_ADAPTER_ATTRIBUTE_D3D12_GRAPHICS isn't supported");
|
||||
}
|
||||
|
||||
ID3D12Device* d3d12_device_ptr;
|
||||
GAPI_LOG_DEBUG(nullptr, "Create D3D12Device");
|
||||
THROW_IF_FAILED(
|
||||
D3D12CreateDevice(
|
||||
pAdapter.get(), d3dFeatureLevel, __uuidof(ID3D12Device), (void**)&d3d12_device_ptr),
|
||||
"Failed to create ID3D12Device");
|
||||
auto d3d12_device = make_com_ptr<ID3D12Device>(d3d12_device_ptr);
|
||||
|
||||
D3D12_COMMAND_LIST_TYPE commandQueueType = D3D12_COMMAND_LIST_TYPE_COMPUTE;
|
||||
ID3D12CommandQueue* cmd_queue_ptr;
|
||||
D3D12_COMMAND_QUEUE_DESC commandQueueDesc = {};
|
||||
commandQueueDesc.Type = commandQueueType;
|
||||
GAPI_LOG_DEBUG(nullptr, "Create D3D12CommandQueue");
|
||||
THROW_IF_FAILED(
|
||||
d3d12_device->CreateCommandQueue(
|
||||
&commandQueueDesc, __uuidof(ID3D12CommandQueue), (void**)&cmd_queue_ptr),
|
||||
"Failed to create D3D12CommandQueue"
|
||||
);
|
||||
GAPI_LOG_DEBUG(nullptr, "Create D3D12CommandQueue - successful");
|
||||
auto cmd_queue = make_com_ptr<ID3D12CommandQueue>(cmd_queue_ptr);
|
||||
|
||||
IDMLDevice* dml_device_ptr;
|
||||
GAPI_LOG_DEBUG(nullptr, "Create DirectML device");
|
||||
THROW_IF_FAILED(
|
||||
DMLCreateDevice(
|
||||
d3d12_device.get(), DML_CREATE_DEVICE_FLAG_NONE, IID_PPV_ARGS(&dml_device_ptr)),
|
||||
"Failed to create IDMLDevice");
|
||||
GAPI_LOG_DEBUG(nullptr, "Create DirectML device - successful");
|
||||
auto dml_device = make_com_ptr<IDMLDevice>(dml_device_ptr);
|
||||
|
||||
return {std::move(dml_device), std::move(cmd_queue)};
|
||||
};
|
||||
|
||||
static void addDMLExecutionProviderWithAdapterName(Ort::SessionOptions *session_options,
|
||||
const std::string &adapter_name) {
|
||||
auto all_adapters = getAvailableAdapters();
|
||||
|
||||
std::vector<AdapterDesc> selected_adapters;
|
||||
std::stringstream log_msg;
|
||||
for (auto&& adapter : all_adapters) {
|
||||
log_msg << adapter.description << std::endl;
|
||||
if (std::strstr(adapter.description.c_str(), adapter_name.c_str())) {
|
||||
selected_adapters.emplace_back(std::move(adapter));
|
||||
}
|
||||
}
|
||||
GAPI_LOG_INFO(NULL, "\nAvailable DirectML adapters:\n" << log_msg.str());
|
||||
|
||||
if (selected_adapters.empty()) {
|
||||
std::stringstream error_msg;
|
||||
error_msg << "ONNX Backend: No DirectML adapters found match to \"" << adapter_name << "\"";
|
||||
cv::util::throw_error(std::runtime_error(error_msg.str()));
|
||||
} else if (selected_adapters.size() > 1) {
|
||||
std::stringstream error_msg;
|
||||
error_msg << "ONNX Backend: More than one adapter matches to \"" << adapter_name << "\":\n";
|
||||
for (const auto &selected_adapter : selected_adapters) {
|
||||
error_msg << selected_adapter.description << "\n";
|
||||
}
|
||||
cv::util::throw_error(std::runtime_error(error_msg.str()));
|
||||
}
|
||||
|
||||
GAPI_LOG_INFO(NULL, "Selected device: " << selected_adapters.front().description);
|
||||
auto dml = createDMLInfo(selected_adapters.front().ptr.get());
|
||||
try {
|
||||
OrtSessionOptionsAppendExecutionProviderEx_DML(
|
||||
*session_options, dml.device.release(), dml.cmd_queue.release());
|
||||
} catch (const std::exception &e) {
|
||||
std::stringstream ss;
|
||||
ss << "ONNX Backend: Failed to enable DirectML"
|
||||
<< " Execution Provider: " << e.what();
|
||||
cv::util::throw_error(std::runtime_error(ss.str()));
|
||||
}
|
||||
}
|
||||
|
||||
#else // HAVE_DIRECTML
|
||||
|
||||
static void addDMLExecutionProviderWithAdapterName(Ort::SessionOptions*, const std::string&) {
|
||||
std::stringstream ss;
|
||||
ss << "ONNX Backend: Failed to add DirectML Execution Provider with adapter name."
|
||||
<< " DirectML support is required.";
|
||||
cv::util::throw_error(std::runtime_error(ss.str()));
|
||||
}
|
||||
|
||||
#endif // HAVE_DIRECTML
|
||||
#else // HAVE_ONNX_DML
|
||||
|
||||
void cv::gimpl::onnx::addDMLExecutionProvider(Ort::SessionOptions*,
|
||||
|
Loading…
Reference in New Issue
Block a user