diff --git a/CMakeLists.txt b/CMakeLists.txt index 49554c15c6..49c93d2406 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -411,6 +411,9 @@ OCV_OPTION(WITH_OPENCLAMDBLAS "Include AMD OpenCL BLAS library support" ON OCV_OPTION(WITH_DIRECTX "Include DirectX support" ON VISIBLE_IF WIN32 AND NOT WINRT VERIFY HAVE_DIRECTX) +OCV_OPTION(WITH_DIRECTML "Include DirectML support" ON + VISIBLE_IF WIN32 AND NOT WINRT + VERIFY HAVE_DIRECTML) OCV_OPTION(WITH_OPENCL_D3D11_NV "Include NVIDIA OpenCL D3D11 support" WITH_DIRECTX VISIBLE_IF WIN32 AND NOT WINRT VERIFY HAVE_OPENCL_D3D11_NV) @@ -848,6 +851,10 @@ endif() if(WITH_DIRECTX) include(cmake/OpenCVDetectDirectX.cmake) endif() +# --- DirectML --- +if(WITH_DIRECTML) + include(cmake/OpenCVDetectDirectML.cmake) +endif() if(WITH_VTK) include(cmake/OpenCVDetectVTK.cmake) diff --git a/cmake/OpenCVDetectDirectML.cmake b/cmake/OpenCVDetectDirectML.cmake new file mode 100644 index 0000000000..0fc71eca03 --- /dev/null +++ b/cmake/OpenCVDetectDirectML.cmake @@ -0,0 +1,13 @@ +if(WIN32) + try_compile(__VALID_DIRECTML + "${OpenCV_BINARY_DIR}" + "${OpenCV_SOURCE_DIR}/cmake/checks/directml.cpp" + LINK_LIBRARIES d3d12 dxcore directml + OUTPUT_VARIABLE TRY_OUT + ) + if(NOT __VALID_DIRECTML) + message(STATUS "No support for DirectML (d3d12, dxcore, directml libs are required)") + return() + endif() + set(HAVE_DIRECTML ON) +endif() diff --git a/cmake/checks/directml.cpp b/cmake/checks/directml.cpp new file mode 100644 index 0000000000..1cf62b8fad --- /dev/null +++ b/cmake/checks/directml.cpp @@ -0,0 +1,38 @@ +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +int main(int /*argc*/, char** /*argv*/) +{ + IDXCoreAdapterFactory* factory; + DXCoreCreateAdapterFactory(__uuidof(IDXCoreAdapterFactory), (void**)&factory); + + IDXCoreAdapterList* adapterList; + const GUID dxGUIDs[] = { DXCORE_ADAPTER_ATTRIBUTE_D3D12_CORE_COMPUTE }; + factory->CreateAdapterList(ARRAYSIZE(dxGUIDs), dxGUIDs, __uuidof(IDXCoreAdapterList), (void**)&adapterList); + + IDXCoreAdapter* adapter; + adapterList->GetAdapter(0u, __uuidof(IDXCoreAdapter), (void**)&adapter); + + D3D_FEATURE_LEVEL d3dFeatureLevel = D3D_FEATURE_LEVEL_1_0_CORE; + ID3D12Device* d3d12Device = NULL; + D3D12CreateDevice((IUnknown*)adapter, d3dFeatureLevel, __uuidof(ID3D11Device), (void**)&d3d12Device); + + D3D12_COMMAND_LIST_TYPE commandQueueType = D3D12_COMMAND_LIST_TYPE_COMPUTE; + ID3D12CommandQueue* cmdQueue; + D3D12_COMMAND_QUEUE_DESC commandQueueDesc = {}; + commandQueueDesc.Type = commandQueueType; + + d3d12Device->CreateCommandQueue(&commandQueueDesc, __uuidof(ID3D12CommandQueue), (void**)&cmdQueue); + IDMLDevice* dmlDevice; + DMLCreateDevice(d3d12Device, DML_CREATE_DEVICE_FLAG_NONE, IID_PPV_ARGS(&dmlDevice)); + + return 0; +} \ No newline at end of file diff --git a/modules/gapi/CMakeLists.txt b/modules/gapi/CMakeLists.txt index 2caeb02ae2..a8714a99f5 100644 --- a/modules/gapi/CMakeLists.txt +++ b/modules/gapi/CMakeLists.txt @@ -367,6 +367,10 @@ if(WIN32) ocv_target_link_libraries(${the_module} PRIVATE wsock32 ws2_32) endif() +if(HAVE_DIRECTML) + ocv_target_compile_definitions(${the_module} PRIVATE HAVE_DIRECTML=1) +endif() + if(HAVE_ONNX) ocv_target_link_libraries(${the_module} PRIVATE ${ONNX_LIBRARY}) ocv_target_compile_definitions(${the_module} PRIVATE HAVE_ONNX=1) diff --git a/modules/gapi/include/opencv2/gapi/infer/onnx.hpp b/modules/gapi/include/opencv2/gapi/infer/onnx.hpp index 4efb750439..ae160ac3e5 100644 --- a/modules/gapi/include/opencv2/gapi/infer/onnx.hpp +++ b/modules/gapi/include/opencv2/gapi/infer/onnx.hpp @@ -189,7 +189,16 @@ public: GAPI_WRAP explicit DirectML(const int device_id) : ddesc(device_id) { }; - using DeviceDesc = cv::util::variant; + /** @brief Class constructor. + + Constructs DirectML parameters based on adapter name. + + @param adapter_name Target adapter_name to use. + */ + GAPI_WRAP + explicit DirectML(const std::string &adapter_name) : ddesc(adapter_name) { }; + + using DeviceDesc = cv::util::variant; DeviceDesc ddesc; }; diff --git a/modules/gapi/src/backends/onnx/dml_ep.cpp b/modules/gapi/src/backends/onnx/dml_ep.cpp index 7f59e1f3d6..671fa2dbcb 100644 --- a/modules/gapi/src/backends/onnx/dml_ep.cpp +++ b/modules/gapi/src/backends/onnx/dml_ep.cpp @@ -13,13 +13,240 @@ #ifdef HAVE_ONNX_DML #include "../providers/dml/dml_provider_factory.h" +#ifdef HAVE_DIRECTML + +#undef WINVER +#define WINVER 0x0A00 +#undef _WIN32_WINNT +#define _WIN32_WINNT 0x0A00 + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#pragma comment (lib, "d3d11.lib") +#pragma comment (lib, "d3d12.lib") +#pragma comment (lib, "dxgi.lib") +#pragma comment (lib, "dxcore.lib") +#pragma comment (lib, "directml.lib") + +#endif // HAVE_DIRECTML + +static void addDMLExecutionProviderWithAdapterName(Ort::SessionOptions *session_options, + const std::string &adapter_name); + void cv::gimpl::onnx::addDMLExecutionProvider(Ort::SessionOptions *session_options, const cv::gapi::onnx::ep::DirectML &dml_ep) { namespace ep = cv::gapi::onnx::ep; - GAPI_Assert(cv::util::holds_alternative(dml_ep.ddesc)); - const int device_id = cv::util::get(dml_ep.ddesc); + switch (dml_ep.ddesc.index()) { + case ep::DirectML::DeviceDesc::index_of(): { + const int device_id = cv::util::get(dml_ep.ddesc); + try { + OrtSessionOptionsAppendExecutionProvider_DML(*session_options, device_id); + } catch (const std::exception &e) { + std::stringstream ss; + ss << "ONNX Backend: Failed to enable DirectML" + << " Execution Provider: " << e.what(); + cv::util::throw_error(std::runtime_error(ss.str())); + } + break; + } + case ep::DirectML::DeviceDesc::index_of(): { + const std::string adapter_name = cv::util::get(dml_ep.ddesc); + addDMLExecutionProviderWithAdapterName(session_options, adapter_name); + break; + } + default: + GAPI_Assert(false && "Invalid DirectML device description"); + } +} + +#ifdef HAVE_DIRECTML + +#define THROW_IF_FAILED(hr, error_msg) \ +{ \ + if ((hr) != S_OK) \ + throw std::runtime_error(error_msg); \ +} + +template +void release(T *ptr) { + if (ptr) { + ptr->Release(); + } +} + +template +using ComPtrGuard = std::unique_ptr)>; + +template +ComPtrGuard make_com_ptr(T *ptr) { + return ComPtrGuard{ptr, &release}; +} + +struct AdapterDesc { + ComPtrGuard ptr; + std::string description; +}; + +static std::vector getAvailableAdapters() { + std::vector all_adapters; + + IDXCoreAdapterFactory* factory_ptr; + GAPI_LOG_DEBUG(nullptr, "Create IDXCoreAdapterFactory"); + THROW_IF_FAILED( + DXCoreCreateAdapterFactory( + __uuidof(IDXCoreAdapterFactory), (void**)&factory_ptr), + "Failed to create IDXCoreAdapterFactory"); + auto factory = make_com_ptr(factory_ptr); + + IDXCoreAdapterList* adapter_list_ptr; + const GUID dxGUIDs[] = { DXCORE_ADAPTER_ATTRIBUTE_D3D12_CORE_COMPUTE }; + GAPI_LOG_DEBUG(nullptr, "CreateAdapterList"); + THROW_IF_FAILED( + factory->CreateAdapterList( + ARRAYSIZE(dxGUIDs), dxGUIDs, __uuidof(IDXCoreAdapterList), (void**)&adapter_list_ptr), + "Failed to create IDXCoreAdapterList"); + auto adapter_list = make_com_ptr(adapter_list_ptr); + + for (UINT i = 0; i < adapter_list->GetAdapterCount(); i++) + { + IDXCoreAdapter* curr_adapter_ptr; + GAPI_LOG_DEBUG(nullptr, "GetAdapter"); + THROW_IF_FAILED( + adapter_list->GetAdapter( + i, __uuidof(IDXCoreAdapter), (void**)&curr_adapter_ptr), + "Failed to obtain IDXCoreAdapter" + ); + auto curr_adapter = make_com_ptr(curr_adapter_ptr); + + bool is_hardware = false; + curr_adapter->GetProperty(DXCoreAdapterProperty::IsHardware, &is_hardware); + // NB: Filter out if not hardware adapter. + if (!is_hardware) { + continue; + } + + size_t desc_size = 0u; + char description[256]; + curr_adapter->GetPropertySize(DXCoreAdapterProperty::DriverDescription, &desc_size); + curr_adapter->GetProperty(DXCoreAdapterProperty::DriverDescription, desc_size, &description); + all_adapters.push_back(AdapterDesc{std::move(curr_adapter), description}); + } + return all_adapters; +}; + +struct DMLDeviceInfo { + ComPtrGuard device; + ComPtrGuard cmd_queue; +}; + +static DMLDeviceInfo createDMLInfo(IDXCoreAdapter* adapter) { + auto pAdapter = make_com_ptr(adapter); + D3D_FEATURE_LEVEL d3dFeatureLevel = D3D_FEATURE_LEVEL_1_0_CORE; + if (adapter->IsAttributeSupported(DXCORE_ADAPTER_ATTRIBUTE_D3D12_GRAPHICS)) + { + GAPI_LOG_INFO(nullptr, "DXCORE_ADAPTER_ATTRIBUTE_D3D12_GRAPHICS is supported"); + d3dFeatureLevel = D3D_FEATURE_LEVEL::D3D_FEATURE_LEVEL_11_0; + + IDXGIFactory4* dxgiFactory4; + GAPI_LOG_DEBUG(nullptr, "CreateDXGIFactory2"); + THROW_IF_FAILED( + CreateDXGIFactory2(0, __uuidof(IDXGIFactory4), (void**)&dxgiFactory4), + "Failed to create IDXGIFactory4" + ); + // If DXGI factory creation was successful then get the IDXGIAdapter from the LUID + // acquired from the selectedAdapter + LUID adapterLuid; + IDXGIAdapter* spDxgiAdapter; + + GAPI_LOG_DEBUG(nullptr, "Get DXCoreAdapterProperty::InstanceLuid property"); + THROW_IF_FAILED( + adapter->GetProperty(DXCoreAdapterProperty::InstanceLuid, &adapterLuid), + "Failed to get DXCoreAdapterProperty::InstanceLuid property"); + + GAPI_LOG_DEBUG(nullptr, "Get IDXGIAdapter by luid"); + THROW_IF_FAILED( + dxgiFactory4->EnumAdapterByLuid( + adapterLuid, __uuidof(IDXGIAdapter), (void**)&spDxgiAdapter), + "Failed to get IDXGIAdapter"); + pAdapter = make_com_ptr(spDxgiAdapter); + } else { + GAPI_LOG_INFO(nullptr, "DXCORE_ADAPTER_ATTRIBUTE_D3D12_GRAPHICS isn't supported"); + } + + ID3D12Device* d3d12_device_ptr; + GAPI_LOG_DEBUG(nullptr, "Create D3D12Device"); + THROW_IF_FAILED( + D3D12CreateDevice( + pAdapter.get(), d3dFeatureLevel, __uuidof(ID3D12Device), (void**)&d3d12_device_ptr), + "Failed to create ID3D12Device"); + auto d3d12_device = make_com_ptr(d3d12_device_ptr); + + D3D12_COMMAND_LIST_TYPE commandQueueType = D3D12_COMMAND_LIST_TYPE_COMPUTE; + ID3D12CommandQueue* cmd_queue_ptr; + D3D12_COMMAND_QUEUE_DESC commandQueueDesc = {}; + commandQueueDesc.Type = commandQueueType; + GAPI_LOG_DEBUG(nullptr, "Create D3D12CommandQueue"); + THROW_IF_FAILED( + d3d12_device->CreateCommandQueue( + &commandQueueDesc, __uuidof(ID3D12CommandQueue), (void**)&cmd_queue_ptr), + "Failed to create D3D12CommandQueue" + ); + GAPI_LOG_DEBUG(nullptr, "Create D3D12CommandQueue - successful"); + auto cmd_queue = make_com_ptr(cmd_queue_ptr); + + IDMLDevice* dml_device_ptr; + GAPI_LOG_DEBUG(nullptr, "Create DirectML device"); + THROW_IF_FAILED( + DMLCreateDevice( + d3d12_device.get(), DML_CREATE_DEVICE_FLAG_NONE, IID_PPV_ARGS(&dml_device_ptr)), + "Failed to create IDMLDevice"); + GAPI_LOG_DEBUG(nullptr, "Create DirectML device - successful"); + auto dml_device = make_com_ptr(dml_device_ptr); + + return {std::move(dml_device), std::move(cmd_queue)}; +}; + +static void addDMLExecutionProviderWithAdapterName(Ort::SessionOptions *session_options, + const std::string &adapter_name) { + auto all_adapters = getAvailableAdapters(); + + std::vector selected_adapters; + std::stringstream log_msg; + for (auto&& adapter : all_adapters) { + log_msg << adapter.description << std::endl; + if (std::strstr(adapter.description.c_str(), adapter_name.c_str())) { + selected_adapters.emplace_back(std::move(adapter)); + } + } + GAPI_LOG_INFO(NULL, "\nAvailable DirectML adapters:\n" << log_msg.str()); + + if (selected_adapters.empty()) { + std::stringstream error_msg; + error_msg << "ONNX Backend: No DirectML adapters found match to \"" << adapter_name << "\""; + cv::util::throw_error(std::runtime_error(error_msg.str())); + } else if (selected_adapters.size() > 1) { + std::stringstream error_msg; + error_msg << "ONNX Backend: More than one adapter matches to \"" << adapter_name << "\":\n"; + for (const auto &selected_adapter : selected_adapters) { + error_msg << selected_adapter.description << "\n"; + } + cv::util::throw_error(std::runtime_error(error_msg.str())); + } + + GAPI_LOG_INFO(NULL, "Selected device: " << selected_adapters.front().description); + auto dml = createDMLInfo(selected_adapters.front().ptr.get()); try { - OrtSessionOptionsAppendExecutionProvider_DML(*session_options, device_id); + OrtSessionOptionsAppendExecutionProviderEx_DML( + *session_options, dml.device.release(), dml.cmd_queue.release()); } catch (const std::exception &e) { std::stringstream ss; ss << "ONNX Backend: Failed to enable DirectML" @@ -28,6 +255,16 @@ void cv::gimpl::onnx::addDMLExecutionProvider(Ort::SessionOptions *session_optio } } +#else // HAVE_DIRECTML + +static void addDMLExecutionProviderWithAdapterName(Ort::SessionOptions*, const std::string&) { + std::stringstream ss; + ss << "ONNX Backend: Failed to add DirectML Execution Provider with adapter name." + << " DirectML support is required."; + cv::util::throw_error(std::runtime_error(ss.str())); +} + +#endif // HAVE_DIRECTML #else // HAVE_ONNX_DML void cv::gimpl::onnx::addDMLExecutionProvider(Ort::SessionOptions*,