Merge pull request #24060 from TolyaTalamanov:at/advanced-device-selection-onnxrt-directml

G-API: Advanced device selection for ONNX DirectML Execution Provider #24060

### Overview
Extend `cv::gapi::onnx::ep::DirectML` to accept `adapter name` as `ctor` parameter in order to select execution device by `name`.
E.g:
```
pp.cfgAddExecutionProvider(cv::gapi::onnx::ep::DirectML("Intel Graphics"));
```

### Pull Request Readiness Checklist

See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request

- [ ] I agree to contribute to the project under Apache 2 License.
- [ ] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV
- [ ] The PR is proposed to the proper branch
- [ ] There is a reference to the original bug report and related work
- [ ] There is accuracy test, performance test and test data in opencv_extra repository, if applicable
      Patch to opencv_extra has the same branch name.
- [ ] The feature is well documented and sample code can be built with the project CMake
pull/24562/head
Anatoliy Talamanov 1 year ago committed by GitHub
parent 024dfd54af
commit 0e151e3c88
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 7
      CMakeLists.txt
  2. 13
      cmake/OpenCVDetectDirectML.cmake
  3. 38
      cmake/checks/directml.cpp
  4. 4
      modules/gapi/CMakeLists.txt
  5. 11
      modules/gapi/include/opencv2/gapi/infer/onnx.hpp
  6. 243
      modules/gapi/src/backends/onnx/dml_ep.cpp

@ -411,6 +411,9 @@ OCV_OPTION(WITH_OPENCLAMDBLAS "Include AMD OpenCL BLAS library support" ON
OCV_OPTION(WITH_DIRECTX "Include DirectX support" ON OCV_OPTION(WITH_DIRECTX "Include DirectX support" ON
VISIBLE_IF WIN32 AND NOT WINRT VISIBLE_IF WIN32 AND NOT WINRT
VERIFY HAVE_DIRECTX) VERIFY HAVE_DIRECTX)
OCV_OPTION(WITH_DIRECTML "Include DirectML support" ON
VISIBLE_IF WIN32 AND NOT WINRT
VERIFY HAVE_DIRECTML)
OCV_OPTION(WITH_OPENCL_D3D11_NV "Include NVIDIA OpenCL D3D11 support" WITH_DIRECTX OCV_OPTION(WITH_OPENCL_D3D11_NV "Include NVIDIA OpenCL D3D11 support" WITH_DIRECTX
VISIBLE_IF WIN32 AND NOT WINRT VISIBLE_IF WIN32 AND NOT WINRT
VERIFY HAVE_OPENCL_D3D11_NV) VERIFY HAVE_OPENCL_D3D11_NV)
@ -848,6 +851,10 @@ endif()
if(WITH_DIRECTX) if(WITH_DIRECTX)
include(cmake/OpenCVDetectDirectX.cmake) include(cmake/OpenCVDetectDirectX.cmake)
endif() endif()
# --- DirectML ---
if(WITH_DIRECTML)
include(cmake/OpenCVDetectDirectML.cmake)
endif()
if(WITH_VTK) if(WITH_VTK)
include(cmake/OpenCVDetectVTK.cmake) include(cmake/OpenCVDetectVTK.cmake)

@ -0,0 +1,13 @@
if(WIN32)
try_compile(__VALID_DIRECTML
"${OpenCV_BINARY_DIR}"
"${OpenCV_SOURCE_DIR}/cmake/checks/directml.cpp"
LINK_LIBRARIES d3d12 dxcore directml
OUTPUT_VARIABLE TRY_OUT
)
if(NOT __VALID_DIRECTML)
message(STATUS "No support for DirectML (d3d12, dxcore, directml libs are required)")
return()
endif()
set(HAVE_DIRECTML ON)
endif()

@ -0,0 +1,38 @@
#include <initguid.h>
#include <d3d11.h>
#include <dxgi1_2.h>
#include <dxgi1_4.h>
#include <dxgi.h>
#include <dxcore.h>
#include <dxcore_interface.h>
#include <d3d12.h>
#include <directml.h>
int main(int /*argc*/, char** /*argv*/)
{
IDXCoreAdapterFactory* factory;
DXCoreCreateAdapterFactory(__uuidof(IDXCoreAdapterFactory), (void**)&factory);
IDXCoreAdapterList* adapterList;
const GUID dxGUIDs[] = { DXCORE_ADAPTER_ATTRIBUTE_D3D12_CORE_COMPUTE };
factory->CreateAdapterList(ARRAYSIZE(dxGUIDs), dxGUIDs, __uuidof(IDXCoreAdapterList), (void**)&adapterList);
IDXCoreAdapter* adapter;
adapterList->GetAdapter(0u, __uuidof(IDXCoreAdapter), (void**)&adapter);
D3D_FEATURE_LEVEL d3dFeatureLevel = D3D_FEATURE_LEVEL_1_0_CORE;
ID3D12Device* d3d12Device = NULL;
D3D12CreateDevice((IUnknown*)adapter, d3dFeatureLevel, __uuidof(ID3D11Device), (void**)&d3d12Device);
D3D12_COMMAND_LIST_TYPE commandQueueType = D3D12_COMMAND_LIST_TYPE_COMPUTE;
ID3D12CommandQueue* cmdQueue;
D3D12_COMMAND_QUEUE_DESC commandQueueDesc = {};
commandQueueDesc.Type = commandQueueType;
d3d12Device->CreateCommandQueue(&commandQueueDesc, __uuidof(ID3D12CommandQueue), (void**)&cmdQueue);
IDMLDevice* dmlDevice;
DMLCreateDevice(d3d12Device, DML_CREATE_DEVICE_FLAG_NONE, IID_PPV_ARGS(&dmlDevice));
return 0;
}

@ -367,6 +367,10 @@ if(WIN32)
ocv_target_link_libraries(${the_module} PRIVATE wsock32 ws2_32) ocv_target_link_libraries(${the_module} PRIVATE wsock32 ws2_32)
endif() endif()
if(HAVE_DIRECTML)
ocv_target_compile_definitions(${the_module} PRIVATE HAVE_DIRECTML=1)
endif()
if(HAVE_ONNX) if(HAVE_ONNX)
ocv_target_link_libraries(${the_module} PRIVATE ${ONNX_LIBRARY}) ocv_target_link_libraries(${the_module} PRIVATE ${ONNX_LIBRARY})
ocv_target_compile_definitions(${the_module} PRIVATE HAVE_ONNX=1) ocv_target_compile_definitions(${the_module} PRIVATE HAVE_ONNX=1)

@ -189,7 +189,16 @@ public:
GAPI_WRAP GAPI_WRAP
explicit DirectML(const int device_id) : ddesc(device_id) { }; explicit DirectML(const int device_id) : ddesc(device_id) { };
using DeviceDesc = cv::util::variant<int>; /** @brief Class constructor.
Constructs DirectML parameters based on adapter name.
@param adapter_name Target adapter_name to use.
*/
GAPI_WRAP
explicit DirectML(const std::string &adapter_name) : ddesc(adapter_name) { };
using DeviceDesc = cv::util::variant<int, std::string>;
DeviceDesc ddesc; DeviceDesc ddesc;
}; };

@ -13,13 +13,240 @@
#ifdef HAVE_ONNX_DML #ifdef HAVE_ONNX_DML
#include "../providers/dml/dml_provider_factory.h" #include "../providers/dml/dml_provider_factory.h"
#ifdef HAVE_DIRECTML
#undef WINVER
#define WINVER 0x0A00
#undef _WIN32_WINNT
#define _WIN32_WINNT 0x0A00
#include <initguid.h>
#include <d3d11.h>
#include <dxgi1_2.h>
#include <dxgi1_4.h>
#include <dxgi.h>
#include <dxcore.h>
#include <dxcore_interface.h>
#include <d3d12.h>
#include <directml.h>
#pragma comment (lib, "d3d11.lib")
#pragma comment (lib, "d3d12.lib")
#pragma comment (lib, "dxgi.lib")
#pragma comment (lib, "dxcore.lib")
#pragma comment (lib, "directml.lib")
#endif // HAVE_DIRECTML
static void addDMLExecutionProviderWithAdapterName(Ort::SessionOptions *session_options,
const std::string &adapter_name);
void cv::gimpl::onnx::addDMLExecutionProvider(Ort::SessionOptions *session_options, void cv::gimpl::onnx::addDMLExecutionProvider(Ort::SessionOptions *session_options,
const cv::gapi::onnx::ep::DirectML &dml_ep) { const cv::gapi::onnx::ep::DirectML &dml_ep) {
namespace ep = cv::gapi::onnx::ep; namespace ep = cv::gapi::onnx::ep;
GAPI_Assert(cv::util::holds_alternative<int>(dml_ep.ddesc)); switch (dml_ep.ddesc.index()) {
const int device_id = cv::util::get<int>(dml_ep.ddesc); case ep::DirectML::DeviceDesc::index_of<int>(): {
const int device_id = cv::util::get<int>(dml_ep.ddesc);
try {
OrtSessionOptionsAppendExecutionProvider_DML(*session_options, device_id);
} catch (const std::exception &e) {
std::stringstream ss;
ss << "ONNX Backend: Failed to enable DirectML"
<< " Execution Provider: " << e.what();
cv::util::throw_error(std::runtime_error(ss.str()));
}
break;
}
case ep::DirectML::DeviceDesc::index_of<std::string>(): {
const std::string adapter_name = cv::util::get<std::string>(dml_ep.ddesc);
addDMLExecutionProviderWithAdapterName(session_options, adapter_name);
break;
}
default:
GAPI_Assert(false && "Invalid DirectML device description");
}
}
#ifdef HAVE_DIRECTML
#define THROW_IF_FAILED(hr, error_msg) \
{ \
if ((hr) != S_OK) \
throw std::runtime_error(error_msg); \
}
template <typename T>
void release(T *ptr) {
if (ptr) {
ptr->Release();
}
}
template <typename T>
using ComPtrGuard = std::unique_ptr<T, decltype(&release<T>)>;
template <typename T>
ComPtrGuard<T> make_com_ptr(T *ptr) {
return ComPtrGuard<T>{ptr, &release<T>};
}
struct AdapterDesc {
ComPtrGuard<IDXCoreAdapter> ptr;
std::string description;
};
static std::vector<AdapterDesc> getAvailableAdapters() {
std::vector<AdapterDesc> all_adapters;
IDXCoreAdapterFactory* factory_ptr;
GAPI_LOG_DEBUG(nullptr, "Create IDXCoreAdapterFactory");
THROW_IF_FAILED(
DXCoreCreateAdapterFactory(
__uuidof(IDXCoreAdapterFactory), (void**)&factory_ptr),
"Failed to create IDXCoreAdapterFactory");
auto factory = make_com_ptr<IDXCoreAdapterFactory>(factory_ptr);
IDXCoreAdapterList* adapter_list_ptr;
const GUID dxGUIDs[] = { DXCORE_ADAPTER_ATTRIBUTE_D3D12_CORE_COMPUTE };
GAPI_LOG_DEBUG(nullptr, "CreateAdapterList");
THROW_IF_FAILED(
factory->CreateAdapterList(
ARRAYSIZE(dxGUIDs), dxGUIDs, __uuidof(IDXCoreAdapterList), (void**)&adapter_list_ptr),
"Failed to create IDXCoreAdapterList");
auto adapter_list = make_com_ptr<IDXCoreAdapterList>(adapter_list_ptr);
for (UINT i = 0; i < adapter_list->GetAdapterCount(); i++)
{
IDXCoreAdapter* curr_adapter_ptr;
GAPI_LOG_DEBUG(nullptr, "GetAdapter");
THROW_IF_FAILED(
adapter_list->GetAdapter(
i, __uuidof(IDXCoreAdapter), (void**)&curr_adapter_ptr),
"Failed to obtain IDXCoreAdapter"
);
auto curr_adapter = make_com_ptr<IDXCoreAdapter>(curr_adapter_ptr);
bool is_hardware = false;
curr_adapter->GetProperty(DXCoreAdapterProperty::IsHardware, &is_hardware);
// NB: Filter out if not hardware adapter.
if (!is_hardware) {
continue;
}
size_t desc_size = 0u;
char description[256];
curr_adapter->GetPropertySize(DXCoreAdapterProperty::DriverDescription, &desc_size);
curr_adapter->GetProperty(DXCoreAdapterProperty::DriverDescription, desc_size, &description);
all_adapters.push_back(AdapterDesc{std::move(curr_adapter), description});
}
return all_adapters;
};
struct DMLDeviceInfo {
ComPtrGuard<IDMLDevice> device;
ComPtrGuard<ID3D12CommandQueue> cmd_queue;
};
static DMLDeviceInfo createDMLInfo(IDXCoreAdapter* adapter) {
auto pAdapter = make_com_ptr<IUnknown>(adapter);
D3D_FEATURE_LEVEL d3dFeatureLevel = D3D_FEATURE_LEVEL_1_0_CORE;
if (adapter->IsAttributeSupported(DXCORE_ADAPTER_ATTRIBUTE_D3D12_GRAPHICS))
{
GAPI_LOG_INFO(nullptr, "DXCORE_ADAPTER_ATTRIBUTE_D3D12_GRAPHICS is supported");
d3dFeatureLevel = D3D_FEATURE_LEVEL::D3D_FEATURE_LEVEL_11_0;
IDXGIFactory4* dxgiFactory4;
GAPI_LOG_DEBUG(nullptr, "CreateDXGIFactory2");
THROW_IF_FAILED(
CreateDXGIFactory2(0, __uuidof(IDXGIFactory4), (void**)&dxgiFactory4),
"Failed to create IDXGIFactory4"
);
// If DXGI factory creation was successful then get the IDXGIAdapter from the LUID
// acquired from the selectedAdapter
LUID adapterLuid;
IDXGIAdapter* spDxgiAdapter;
GAPI_LOG_DEBUG(nullptr, "Get DXCoreAdapterProperty::InstanceLuid property");
THROW_IF_FAILED(
adapter->GetProperty(DXCoreAdapterProperty::InstanceLuid, &adapterLuid),
"Failed to get DXCoreAdapterProperty::InstanceLuid property");
GAPI_LOG_DEBUG(nullptr, "Get IDXGIAdapter by luid");
THROW_IF_FAILED(
dxgiFactory4->EnumAdapterByLuid(
adapterLuid, __uuidof(IDXGIAdapter), (void**)&spDxgiAdapter),
"Failed to get IDXGIAdapter");
pAdapter = make_com_ptr<IUnknown>(spDxgiAdapter);
} else {
GAPI_LOG_INFO(nullptr, "DXCORE_ADAPTER_ATTRIBUTE_D3D12_GRAPHICS isn't supported");
}
ID3D12Device* d3d12_device_ptr;
GAPI_LOG_DEBUG(nullptr, "Create D3D12Device");
THROW_IF_FAILED(
D3D12CreateDevice(
pAdapter.get(), d3dFeatureLevel, __uuidof(ID3D12Device), (void**)&d3d12_device_ptr),
"Failed to create ID3D12Device");
auto d3d12_device = make_com_ptr<ID3D12Device>(d3d12_device_ptr);
D3D12_COMMAND_LIST_TYPE commandQueueType = D3D12_COMMAND_LIST_TYPE_COMPUTE;
ID3D12CommandQueue* cmd_queue_ptr;
D3D12_COMMAND_QUEUE_DESC commandQueueDesc = {};
commandQueueDesc.Type = commandQueueType;
GAPI_LOG_DEBUG(nullptr, "Create D3D12CommandQueue");
THROW_IF_FAILED(
d3d12_device->CreateCommandQueue(
&commandQueueDesc, __uuidof(ID3D12CommandQueue), (void**)&cmd_queue_ptr),
"Failed to create D3D12CommandQueue"
);
GAPI_LOG_DEBUG(nullptr, "Create D3D12CommandQueue - successful");
auto cmd_queue = make_com_ptr<ID3D12CommandQueue>(cmd_queue_ptr);
IDMLDevice* dml_device_ptr;
GAPI_LOG_DEBUG(nullptr, "Create DirectML device");
THROW_IF_FAILED(
DMLCreateDevice(
d3d12_device.get(), DML_CREATE_DEVICE_FLAG_NONE, IID_PPV_ARGS(&dml_device_ptr)),
"Failed to create IDMLDevice");
GAPI_LOG_DEBUG(nullptr, "Create DirectML device - successful");
auto dml_device = make_com_ptr<IDMLDevice>(dml_device_ptr);
return {std::move(dml_device), std::move(cmd_queue)};
};
static void addDMLExecutionProviderWithAdapterName(Ort::SessionOptions *session_options,
const std::string &adapter_name) {
auto all_adapters = getAvailableAdapters();
std::vector<AdapterDesc> selected_adapters;
std::stringstream log_msg;
for (auto&& adapter : all_adapters) {
log_msg << adapter.description << std::endl;
if (std::strstr(adapter.description.c_str(), adapter_name.c_str())) {
selected_adapters.emplace_back(std::move(adapter));
}
}
GAPI_LOG_INFO(NULL, "\nAvailable DirectML adapters:\n" << log_msg.str());
if (selected_adapters.empty()) {
std::stringstream error_msg;
error_msg << "ONNX Backend: No DirectML adapters found match to \"" << adapter_name << "\"";
cv::util::throw_error(std::runtime_error(error_msg.str()));
} else if (selected_adapters.size() > 1) {
std::stringstream error_msg;
error_msg << "ONNX Backend: More than one adapter matches to \"" << adapter_name << "\":\n";
for (const auto &selected_adapter : selected_adapters) {
error_msg << selected_adapter.description << "\n";
}
cv::util::throw_error(std::runtime_error(error_msg.str()));
}
GAPI_LOG_INFO(NULL, "Selected device: " << selected_adapters.front().description);
auto dml = createDMLInfo(selected_adapters.front().ptr.get());
try { try {
OrtSessionOptionsAppendExecutionProvider_DML(*session_options, device_id); OrtSessionOptionsAppendExecutionProviderEx_DML(
*session_options, dml.device.release(), dml.cmd_queue.release());
} catch (const std::exception &e) { } catch (const std::exception &e) {
std::stringstream ss; std::stringstream ss;
ss << "ONNX Backend: Failed to enable DirectML" ss << "ONNX Backend: Failed to enable DirectML"
@ -28,6 +255,16 @@ void cv::gimpl::onnx::addDMLExecutionProvider(Ort::SessionOptions *session_optio
} }
} }
#else // HAVE_DIRECTML
static void addDMLExecutionProviderWithAdapterName(Ort::SessionOptions*, const std::string&) {
std::stringstream ss;
ss << "ONNX Backend: Failed to add DirectML Execution Provider with adapter name."
<< " DirectML support is required.";
cv::util::throw_error(std::runtime_error(ss.str()));
}
#endif // HAVE_DIRECTML
#else // HAVE_ONNX_DML #else // HAVE_ONNX_DML
void cv::gimpl::onnx::addDMLExecutionProvider(Ort::SessionOptions*, void cv::gimpl::onnx::addDMLExecutionProvider(Ort::SessionOptions*,

Loading…
Cancel
Save