Merge branch 'main' into afpn

afpn
Glenn Jocher 1 year ago committed by GitHub
commit 7660ca7bfe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      docs/help/contributing.md
  2. 42
      docs/modes/predict.md
  3. 2
      docs/quickstart.md
  4. 71
      examples/YOLOv8-ONNXRuntime-CPP/CMakeLists.txt
  5. 23
      examples/YOLOv8-ONNXRuntime-CPP/README.md
  6. 27
      examples/YOLOv8-ONNXRuntime-CPP/inference.cpp
  7. 25
      examples/YOLOv8-ONNXRuntime-CPP/inference.h
  8. 84
      examples/YOLOv8-ONNXRuntime-CPP/main.cpp
  9. 2
      ultralytics/engine/results.py

@ -6,7 +6,7 @@ keywords: Ultralytics, YOLO, open-source, contribute, pull request, bug report,
# Contributing to Ultralytics Open-Source YOLO Repositories # Contributing to Ultralytics Open-Source YOLO Repositories
First of all, thank you for your interest in contributing to Ultralytics open-source YOLO repositories! Your contributions will help improve the project and benefit the community. This document provides guidelines and best practices for contributing to Ultralytics YOLO repositories. First of all, thank you for your interest in contributing to Ultralytics open-source YOLO repositories! Your contributions will help improve the project and benefit the community. This document provides guidelines and best practices to get you started.
## Table of Contents ## Table of Contents

@ -54,21 +54,22 @@ YOLOv8 can process different types of input sources for inference, as shown in t
Use `stream=True` for processing long videos or large datasets to efficiently manage memory. When `stream=False`, the results for all frames or data points are stored in memory, which can quickly add up and cause out-of-memory errors for large inputs. In contrast, `stream=True` utilizes a generator, which only keeps the results of the current frame or data point in memory, significantly reducing memory consumption and preventing out-of-memory issues. Use `stream=True` for processing long videos or large datasets to efficiently manage memory. When `stream=False`, the results for all frames or data points are stored in memory, which can quickly add up and cause out-of-memory errors for large inputs. In contrast, `stream=True` utilizes a generator, which only keeps the results of the current frame or data point in memory, significantly reducing memory consumption and preventing out-of-memory issues.
| Source | Argument | Type | Notes | | Source | Argument | Type | Notes |
|-------------|--------------------------------------------|---------------------------------------|----------------------------------------------------------------------------| |---------------|--------------------------------------------|-----------------|---------------------------------------------------------------------------------------------|
| image | `'image.jpg'` | `str` or `Path` | Single image file. | | image | `'image.jpg'` | `str` or `Path` | Single image file. |
| URL | `'https://ultralytics.com/images/bus.jpg'` | `str` | URL to an image. | | URL | `'https://ultralytics.com/images/bus.jpg'` | `str` | URL to an image. |
| screenshot | `'screen'` | `str` | Capture a screenshot. | | screenshot | `'screen'` | `str` | Capture a screenshot. |
| PIL | `Image.open('im.jpg')` | `PIL.Image` | HWC format with RGB channels. | | PIL | `Image.open('im.jpg')` | `PIL.Image` | HWC format with RGB channels. |
| OpenCV | `cv2.imread('im.jpg')` | `np.ndarray` of `uint8 (0-255)` | HWC format with BGR channels. | | OpenCV | `cv2.imread('im.jpg')` | `np.ndarray` | HWC format with BGR channels `uint8 (0-255)`. |
| numpy | `np.zeros((640,1280,3))` | `np.ndarray` of `uint8 (0-255)` | HWC format with BGR channels. | | numpy | `np.zeros((640,1280,3))` | `np.ndarray` | HWC format with BGR channels `uint8 (0-255)`. |
| torch | `torch.zeros(16,3,320,640)` | `torch.Tensor` of `float32 (0.0-1.0)` | BCHW format with RGB channels. | | torch | `torch.zeros(16,3,320,640)` | `torch.Tensor` | BCHW format with RGB channels `float32 (0.0-1.0)`. |
| CSV | `'sources.csv'` | `str` or `Path` | CSV file containing paths to images, videos, or directories. | | CSV | `'sources.csv'` | `str` or `Path` | CSV file containing paths to images, videos, or directories. |
| video ✅ | `'video.mp4'` | `str` or `Path` | Video file in formats like MP4, AVI, etc. | | video ✅ | `'video.mp4'` | `str` or `Path` | Video file in formats like MP4, AVI, etc. |
| directory ✅ | `'path/'` | `str` or `Path` | Path to a directory containing images or videos. | | directory ✅ | `'path/'` | `str` or `Path` | Path to a directory containing images or videos. |
| glob ✅ | `'path/*.jpg'` | `str` | Glob pattern to match multiple files. Use the `*` character as a wildcard. | | glob ✅ | `'path/*.jpg'` | `str` | Glob pattern to match multiple files. Use the `*` character as a wildcard. |
| YouTube ✅ | `'https://youtu.be/Zgi9g1ksQHc'` | `str` | URL to a YouTube video. | | YouTube ✅ | `'https://youtu.be/Zgi9g1ksQHc'` | `str` | URL to a YouTube video. |
| stream ✅ | `'rtsp://example.com/media.mp4'` | `str` | URL for streaming protocols such as RTSP, RTMP, or an IP address. | | stream ✅ | `'rtsp://example.com/media.mp4'` | `str` | URL for streaming protocols such as RTSP, RTMP, or an IP address. |
| multi-stream ✅ | `'list.streams'` | `str` or `Path` | `*.streams` text file with one stream URL per row, i.e. 8 streams will run at batch-size 8. |
Below are code examples for using each source type: Below are code examples for using each source type:
@ -262,16 +263,19 @@ Below are code examples for using each source type:
results = model(source, stream=True) # generator of Results objects results = model(source, stream=True) # generator of Results objects
``` ```
=== "Stream" === "Streams"
Run inference on remote streaming sources using RTSP, RTMP, and IP address protocols. Run inference on remote streaming sources using RTSP, RTMP, and IP address protocols. If mutliple streams are provided in a `*.streams` text file then batched inference will run, i.e. 8 streams will run at batch-size 8, otherwise single streams will run at batch-size 1.
```python ```python
from ultralytics import YOLO from ultralytics import YOLO
# Load a pretrained YOLOv8n model # Load a pretrained YOLOv8n model
model = YOLO('yolov8n.pt') model = YOLO('yolov8n.pt')
# Define source as RTSP, RTMP or IP streaming address # Single stream with batch-size 1 inference
source = 'rtsp://example.com/media.mp4' source = 'rtsp://example.com/media.mp4' # RTSP, RTMP or IP streaming address
# Multiple streams with batched inference (i.e. batch-size 8 for 8 streams)
source = 'path/to/list.streams' # *.streams text file with one streaming address per row
# Run inference on the source # Run inference on the source
results = model(source, stream=True) # generator of Results objects results = model(source, stream=True) # generator of Results objects

@ -28,7 +28,7 @@ Ultralytics provides various installation methods including pip, conda, and Dock
```bash ```bash
# Install the ultralytics package using conda # Install the ultralytics package using conda
conda install ultralytics conda install -c conda-forge ultralytics
``` ```
=== "Git clone" === "Git clone"

@ -0,0 +1,71 @@
cmake_minimum_required(VERSION 3.5)
set(PROJECT_NAME Yolov8OnnxRuntimeCPPInference)
project(${PROJECT_NAME} VERSION 0.0.1 LANGUAGES CXX)
# -------------- Support C++17 for using filesystem ------------------#
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS ON)
set(CMAKE_INCLUDE_CURRENT_DIR ON)
# OpenCV
find_package(OpenCV REQUIRED)
include_directories(${OpenCV_INCLUDE_DIRS})
# ONNXRUNTIME
# Set ONNXRUNTIME_VERSION
set(ONNXRUNTIME_VERSION 1.15.1)
if(WIN32)
# CPU
# set(ONNXRUNTIME_ROOT "${CMAKE_CURRENT_SOURCE_DIR}/onnxruntime-win-x64-${ONNXRUNTIME_VERSION}")
# GPU
set(ONNXRUNTIME_ROOT "${CMAKE_CURRENT_SOURCE_DIR}/onnxruntime-win-x64-gpu-${ONNXRUNTIME_VERSION}")
elseif(LINUX)
# CPU
# set(ONNXRUNTIME_ROOT "${CMAKE_CURRENT_SOURCE_DIR}/onnxruntime-linux-x64-${ONNXRUNTIME_VERSION}")
# GPU
set(ONNXRUNTIME_ROOT "${CMAKE_CURRENT_SOURCE_DIR}/onnxruntime-linux-x64-gpu-${ONNXRUNTIME_VERSION}")
elseif(APPLE)
set(ONNXRUNTIME_ROOT "${CMAKE_CURRENT_SOURCE_DIR}/onnxruntime-osx-arm64-${ONNXRUNTIME_VERSION}")
# Apple X64 binary
# set(ONNXRUNTIME_ROOT "${CMAKE_CURRENT_SOURCE_DIR}/onnxruntime-osx-x64-${ONNXRUNTIME_VERSION}")
# Apple Universal binary
# set(ONNXRUNTIME_ROOT "${CMAKE_CURRENT_SOURCE_DIR}/onnxruntime-osx-universal2-${ONNXRUNTIME_VERSION}")
endif()
include_directories(${PROJECT_NAME} ${ONNXRUNTIME_ROOT}/include)
set(PROJECT_SOURCES
main.cpp
inference.h
inference.cpp
)
add_executable(${PROJECT_NAME} ${PROJECT_SOURCES})
if(WIN32)
target_link_libraries(${PROJECT_NAME} ${OpenCV_LIBS} ${ONNXRUNTIME_ROOT}/lib/onnxruntime.lib)
elseif(LINUX)
target_link_libraries(${PROJECT_NAME} ${OpenCV_LIBS} ${ONNXRUNTIME_ROOT}/lib/libonnxruntime.so)
elseif(APPLE)
target_link_libraries(${PROJECT_NAME} ${OpenCV_LIBS} ${ONNXRUNTIME_ROOT}/lib/libonnxruntime.dylib)
endif()
# For windows system, copy onnxruntime.dll to the same folder of the executable file
if(WIN32)
add_custom_command(TARGET ${PROJECT_NAME} POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy_if_different
"${ONNXRUNTIME_ROOT}/lib/onnxruntime.dll"
$<TARGET_FILE_DIR:${PROJECT_NAME}>)
endif()
# Download https://raw.githubusercontent.com/ultralytics/ultralytics/main/ultralytics/cfg/datasets/coco.yaml
# and put it in the same folder of the executable file
configure_file(coco.yaml ${CMAKE_CURRENT_BINARY_DIR}/coco.yaml COPYONLY)

@ -2,8 +2,6 @@
This example demonstrates how to perform inference using YOLOv8 in C++ with ONNX Runtime and OpenCV's API. This example demonstrates how to perform inference using YOLOv8 in C++ with ONNX Runtime and OpenCV's API.
We recommend using Visual Studio to build the project.
## Benefits ## Benefits
- Friendly for deployment in the industrial sector. - Friendly for deployment in the industrial sector.
@ -25,13 +23,20 @@ model = YOLO("yolov8n.pt")
model.export(format="onnx", opset=12, simplify=True, dynamic=False, imgsz=640) model.export(format="onnx", opset=12, simplify=True, dynamic=False, imgsz=640)
``` ```
Alternatively, you can use the following command for exporting the model in the terminal
```bash
yolo export model=yolov8n.pt opset=12 simplify=True dynamic=False format=onnx imgsz=640,640
```
## Dependencies ## Dependencies
| Dependency | Version | | Dependency | Version |
| ----------------------- | -------- | | -------------------------------- | -------- |
| Onnxruntime-win-x64-gpu | >=1.14.1 | | Onnxruntime(linux,windows,macos) | >=1.14.1 |
| OpenCV | >=4.0.0 | | OpenCV | >=4.0.0 |
| C++ | >=17 | | C++ | >=17 |
| Cmake | >=3.5 |
Note: The dependency on C++17 is due to the usage of the C++17 filesystem feature. Note: The dependency on C++17 is due to the usage of the C++17 filesystem feature.
@ -39,9 +44,9 @@ Note: The dependency on C++17 is due to the usage of the C++17 filesystem featur
```c++ ```c++
// CPU inference // CPU inference
DCSP_INIT_PARAM params{ model_path, YOLO_ORIGIN_V8, {imgsz_w, imgsz_h}, class_num, 0.1, 0.5, false}; DCSP_INIT_PARAM params{ model_path, YOLO_ORIGIN_V8, {imgsz_w, imgsz_h}, 0.1, 0.5, false};
// GPU inference // GPU inference
DCSP_INIT_PARAM params{ model_path, YOLO_ORIGIN_V8, {imgsz_w, imgsz_h}, class_num, 0.1, 0.5, true}; DCSP_INIT_PARAM params{ model_path, YOLO_ORIGIN_V8, {imgsz_w, imgsz_h}, 0.1, 0.5, true};
// Load your image // Load your image
cv::Mat img = cv::imread(img_path); cv::Mat img = cv::imread(img_path);

@ -2,7 +2,6 @@
#include <regex> #include <regex>
#define benchmark #define benchmark
#define ELOG
DCSP_CORE::DCSP_CORE() DCSP_CORE::DCSP_CORE()
{ {
@ -29,7 +28,7 @@ char* BlobFromImage(cv::Mat& iImg, T& iBlob)
{ {
for (int w = 0; w < imgWidth; w++) for (int w = 0; w < imgWidth; w++)
{ {
iBlob[c * imgWidth * imgHeight + h * imgWidth + w] = (std::remove_pointer<T>::type)((iImg.at<cv::Vec3b>(h, w)[c]) / 255.0f); iBlob[c * imgWidth * imgHeight + h * imgWidth + w] = typename std::remove_pointer<T>::type((iImg.at<cv::Vec3b>(h, w)[c]) / 255.0f);
} }
} }
} }
@ -40,8 +39,8 @@ char* BlobFromImage(cv::Mat& iImg, T& iBlob)
char* PostProcess(cv::Mat& iImg, std::vector<int> iImgSize, cv::Mat& oImg) char* PostProcess(cv::Mat& iImg, std::vector<int> iImgSize, cv::Mat& oImg)
{ {
cv::Mat img = iImg.clone(); cv::Mat img = iImg.clone();
cv::resize(iImg, oImg, cv::Size(iImgSize.at(0), iImgSize.at(1))); cv::resize(iImg, oImg, cv::Size(iImgSize.at(0), iImgSize.at(1)));
if (img.channels() == 1) if (img.channels() == 1)
{ {
cv::cvtColor(oImg, oImg, cv::COLOR_GRAY2BGR); cv::cvtColor(oImg, oImg, cv::COLOR_GRAY2BGR);
} }
@ -75,17 +74,21 @@ char* DCSP_CORE::CreateSession(DCSP_INIT_PARAM &iParams)
OrtCUDAProviderOptions cudaOption; OrtCUDAProviderOptions cudaOption;
cudaOption.device_id = 0; cudaOption.device_id = 0;
sessionOption.AppendExecutionProvider_CUDA(cudaOption); sessionOption.AppendExecutionProvider_CUDA(cudaOption);
//OrtOpenVINOProviderOptions ovOption;
//sessionOption.AppendExecutionProvider_OpenVINO(ovOption);
} }
sessionOption.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL); sessionOption.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);
sessionOption.SetIntraOpNumThreads(iParams.IntraOpNumThreads); sessionOption.SetIntraOpNumThreads(iParams.IntraOpNumThreads);
sessionOption.SetLogSeverityLevel(iParams.LogSeverityLevel); sessionOption.SetLogSeverityLevel(iParams.LogSeverityLevel);
#ifdef _WIN32
int ModelPathSize = MultiByteToWideChar(CP_UTF8, 0, iParams.ModelPath.c_str(), static_cast<int>(iParams.ModelPath.length()), nullptr, 0); int ModelPathSize = MultiByteToWideChar(CP_UTF8, 0, iParams.ModelPath.c_str(), static_cast<int>(iParams.ModelPath.length()), nullptr, 0);
wchar_t* wide_cstr = new wchar_t[ModelPathSize + 1]; wchar_t* wide_cstr = new wchar_t[ModelPathSize + 1];
MultiByteToWideChar(CP_UTF8, 0, iParams.ModelPath.c_str(), static_cast<int>(iParams.ModelPath.length()), wide_cstr, ModelPathSize); MultiByteToWideChar(CP_UTF8, 0, iParams.ModelPath.c_str(), static_cast<int>(iParams.ModelPath.length()), wide_cstr, ModelPathSize);
wide_cstr[ModelPathSize] = L'\0'; wide_cstr[ModelPathSize] = L'\0';
const wchar_t* modelPath = wide_cstr; const wchar_t* modelPath = wide_cstr;
#else
const char* modelPath = iParams.ModelPath.c_str();
#endif // _WIN32
session = new Ort::Session(env, modelPath, sessionOption); session = new Ort::Session(env, modelPath, sessionOption);
Ort::AllocatorWithDefaultOptions allocator; Ort::AllocatorWithDefaultOptions allocator;
size_t inputNodesNum = session->GetInputCount(); size_t inputNodesNum = session->GetInputCount();
@ -96,7 +99,6 @@ char* DCSP_CORE::CreateSession(DCSP_INIT_PARAM &iParams)
strcpy(temp_buf, input_node_name.get()); strcpy(temp_buf, input_node_name.get());
inputNodeNames.push_back(temp_buf); inputNodeNames.push_back(temp_buf);
} }
size_t OutputNodesNum = session->GetOutputCount(); size_t OutputNodesNum = session->GetOutputCount();
for (size_t i = 0; i < OutputNodesNum; i++) for (size_t i = 0; i < OutputNodesNum; i++)
{ {
@ -151,7 +153,7 @@ char* DCSP_CORE::RunSession(cv::Mat &iImg, std::vector<DCSP_RESULT>& oResult)
template<typename N> template<typename N>
char* DCSP_CORE::TensorProcess(clock_t& starttime_1, cv::Mat& iImg, N& blob, std::vector<int64_t>& inputNodeDims, std::vector<DCSP_RESULT>& oResult) char* DCSP_CORE::TensorProcess(clock_t& starttime_1, cv::Mat& iImg, N& blob, std::vector<int64_t>& inputNodeDims, std::vector<DCSP_RESULT>& oResult)
{ {
Ort::Value inputTensor = Ort::Value::CreateTensor<std::remove_pointer<N>::type>(Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU), blob, 3 * imgSize.at(0) * imgSize.at(1), inputNodeDims.data(), inputNodeDims.size()); Ort::Value inputTensor = Ort::Value::CreateTensor<typename std::remove_pointer<N>::type>(Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU), blob, 3 * imgSize.at(0) * imgSize.at(1), inputNodeDims.data(), inputNodeDims.size());
#ifdef benchmark #ifdef benchmark
clock_t starttime_2 = clock(); clock_t starttime_2 = clock();
#endif // benchmark #endif // benchmark
@ -159,10 +161,11 @@ char* DCSP_CORE::TensorProcess(clock_t& starttime_1, cv::Mat& iImg, N& blob, std
#ifdef benchmark #ifdef benchmark
clock_t starttime_3 = clock(); clock_t starttime_3 = clock();
#endif // benchmark #endif // benchmark
Ort::TypeInfo typeInfo = outputTensor.front().GetTypeInfo(); Ort::TypeInfo typeInfo = outputTensor.front().GetTypeInfo();
auto tensor_info = typeInfo.GetTensorTypeAndShapeInfo(); auto tensor_info = typeInfo.GetTensorTypeAndShapeInfo();
std::vector<int64_t>outputNodeDims = tensor_info.GetShape(); std::vector<int64_t>outputNodeDims = tensor_info.GetShape();
std::remove_pointer<N>::type* output = outputTensor.front().GetTensorMutableData<std::remove_pointer<N>::type>(); auto output = outputTensor.front().GetTensorMutableData<typename std::remove_pointer<N>::type>();
delete blob; delete blob;
switch (modelType) switch (modelType)
{ {
@ -183,7 +186,7 @@ char* DCSP_CORE::TensorProcess(clock_t& starttime_1, cv::Mat& iImg, N& blob, std
for (int i = 0; i < strideNum; ++i) for (int i = 0; i < strideNum; ++i)
{ {
float* classesScores = data + 4; float* classesScores = data + 4;
cv::Mat scores(1, classesNum, CV_32FC1, classesScores); cv::Mat scores(1, this->classes.size(), CV_32FC1, classesScores);
cv::Point class_id; cv::Point class_id;
double maxClassScore; double maxClassScore;
cv::minMaxLoc(scores, 0, &maxClassScore, 0, &class_id); cv::minMaxLoc(scores, 0, &maxClassScore, 0, &class_id);
@ -203,13 +206,14 @@ char* DCSP_CORE::TensorProcess(clock_t& starttime_1, cv::Mat& iImg, N& blob, std
int width = int(w * x_factor); int width = int(w * x_factor);
int height = int(h * y_factor); int height = int(h * y_factor);
boxes.push_back(cv::Rect(left, top, width, height)); boxes.emplace_back(left, top, width, height);
} }
data += signalResultNum; data += signalResultNum;
} }
std::vector<int> nmsResult; std::vector<int> nmsResult;
cv::dnn::NMSBoxes(boxes, confidences, rectConfidenceThreshold, iouThreshold, nmsResult); cv::dnn::NMSBoxes(boxes, confidences, rectConfidenceThreshold, iouThreshold, nmsResult);
for (int i = 0; i < nmsResult.size(); ++i) for (int i = 0; i < nmsResult.size(); ++i)
{ {
int idx = nmsResult[i]; int idx = nmsResult[i];
@ -266,6 +270,5 @@ char* DCSP_CORE::WarmUpSession()
std::cout << "[DCSP_ONNX(CUDA)]: " << "Cuda warm-up cost " << post_process_time << " ms. " << std::endl; std::cout << "[DCSP_ONNX(CUDA)]: " << "Cuda warm-up cost " << post_process_time << " ms. " << std::endl;
} }
} }
return Ret; return Ret;
} }

@ -1,15 +1,17 @@
#pragma once #pragma once
#define _CRT_SECURE_NO_WARNINGS
#define RET_OK nullptr #define RET_OK nullptr
#ifdef _WIN32
#include <Windows.h>
#include <direct.h>
#include <io.h>
#endif
#include <string> #include <string>
#include <vector> #include <vector>
#include <stdio.h> #include <cstdio>
#include "io.h" #include <opencv2/opencv.hpp>
#include "direct.h"
#include "opencv.hpp"
#include <Windows.h>
#include "onnxruntime_cxx_api.h" #include "onnxruntime_cxx_api.h"
@ -23,13 +25,12 @@ enum MODEL_TYPE
}; };
typedef struct _DCSP_INIT_PARAM typedef struct _DCSP_INIT_PARAM
{ {
std::string ModelPath; std::string ModelPath;
MODEL_TYPE ModelType = YOLO_ORIGIN_V8; MODEL_TYPE ModelType = YOLO_ORIGIN_V8;
std::vector<int> imgSize={640, 640}; std::vector<int> imgSize={640, 640};
int classesNum=80;
float RectConfidenceThreshold = 0.6; float RectConfidenceThreshold = 0.6;
float iouThreshold = 0.5; float iouThreshold = 0.5;
bool CudaEnable = false; bool CudaEnable = false;
@ -55,16 +56,14 @@ public:
public: public:
char* CreateSession(DCSP_INIT_PARAM &iParams); char* CreateSession(DCSP_INIT_PARAM &iParams);
char* RunSession(cv::Mat &iImg, std::vector<DCSP_RESULT>& oResult); char* RunSession(cv::Mat &iImg, std::vector<DCSP_RESULT>& oResult);
char* WarmUpSession(); char* WarmUpSession();
template<typename N> template<typename N>
char* TensorProcess(clock_t& starttime_1, cv::Mat& iImg, N& blob, std::vector<int64_t>& inputNodeDims, std::vector<DCSP_RESULT>& oResult); char* TensorProcess(clock_t& starttime_1, cv::Mat& iImg, N& blob, std::vector<int64_t>& inputNodeDims, std::vector<DCSP_RESULT>& oResult);
std::vector<std::string> classes{};
private: private:
Ort::Env env; Ort::Env env;
@ -74,9 +73,7 @@ private:
std::vector<const char*> inputNodeNames; std::vector<const char*> inputNodeNames;
std::vector<const char*> outputNodeNames; std::vector<const char*> outputNodeNames;
MODEL_TYPE modelType;
int classesNum;
MODEL_TYPE modelType;
std::vector<int> imgSize; std::vector<int> imgSize;
float rectConfidenceThreshold; float rectConfidenceThreshold;
float iouThreshold; float iouThreshold;

@ -1,44 +1,94 @@
#include <iostream> #include <iostream>
#include <stdio.h>
#include "inference.h" #include "inference.h"
#include <filesystem> #include <filesystem>
#include <fstream>
void file_iterator(DCSP_CORE*& p) void file_iterator(DCSP_CORE*& p)
{ {
std::filesystem::path img_path = R"(E:\project\Project_C++\DCPS_ONNX\TEST_ORIGIN)"; std::filesystem::path current_path = std::filesystem::current_path();
int k = 0; std::filesystem::path imgs_path = current_path/"images";
for (auto& i : std::filesystem::directory_iterator(img_path)) for (auto& i : std::filesystem::directory_iterator(imgs_path))
{ {
if (i.path().extension() == ".jpg") if (i.path().extension() == ".jpg" || i.path().extension() == ".png")
{ {
std::string img_path = i.path().string(); std::string img_path = i.path().string();
//std::cout << img_path << std::endl;
cv::Mat img = cv::imread(img_path); cv::Mat img = cv::imread(img_path);
std::vector<DCSP_RESULT> res; std::vector<DCSP_RESULT> res;
char* ret = p->RunSession(img, res); p->RunSession(img, res);
for (int i = 0; i < res.size(); i++)
for (auto & re : res)
{ {
cv::rectangle(img, res.at(i).box, cv::Scalar(125, 123, 0), 3); cv::rectangle(img, re.box, cv::Scalar(0, 0 , 255), 3);
std::string label = p->classes[re.classId];
cv::putText(
img,
label,
cv::Point(re.box.x, re.box.y - 5),
cv::FONT_HERSHEY_SIMPLEX,
0.75,
cv::Scalar(255, 255, 0),
2
);
} }
cv::imshow("Result", img);
k++;
cv::imshow("TEST_ORIGIN", img);
cv::waitKey(0); cv::waitKey(0);
cv::destroyAllWindows(); cv::destroyAllWindows();
//cv::imwrite("E:\\output\\" + std::to_string(k) + ".png", img);
} }
} }
} }
int read_coco_yaml(DCSP_CORE*& p)
{
// Open the YAML file
std::ifstream file("coco.yaml");
if (!file.is_open()) {
std::cerr << "Failed to open file" << std::endl;
return 1;
}
// Read the file line by line
std::string line;
std::vector<std::string> lines;
while (std::getline(file, line)) {
lines.push_back(line);
}
// Find the start and end of the names section
std::size_t start = 0;
std::size_t end = 0;
for (std::size_t i = 0; i < lines.size(); i++) {
if (lines[i].find("names:") != std::string::npos) {
start = i + 1;
} else if (start > 0 && lines[i].find(':') == std::string::npos) {
end = i;
break;
}
}
// Extract the names
std::vector<std::string> names;
for (std::size_t i = start; i < end; i++) {
std::stringstream ss(lines[i]);
std::string name;
std::getline(ss, name, ':'); // Extract the number before the delimiter
std::getline(ss, name); // Extract the string after the delimiter
names.push_back(name);
}
p->classes = names;
return 0;
}
int main() int main()
{ {
DCSP_CORE* p1 = new DCSP_CORE; DCSP_CORE* p1 = new DCSP_CORE;
std::string model_path = "yolov8n.onnx"; std::string model_path = "yolov8n.onnx";
DCSP_INIT_PARAM params{ model_path, YOLO_ORIGIN_V8, {640, 640}, 80, 0.1, 0.5, false }; read_coco_yaml(p1);
char* ret = p1->CreateSession(params); // GPU inference
DCSP_INIT_PARAM params{ model_path, YOLO_ORIGIN_V8, {640, 640}, 0.1, 0.5, true };
// CPU inference
// DCSP_INIT_PARAM params{ model_path, YOLO_ORIGIN_V8, {640, 640}, 0.1, 0.5, false };
p1->CreateSession(params);
file_iterator(p1); file_iterator(p1);
} }

@ -215,7 +215,7 @@ class Results(SimpleClass):
``` ```
""" """
if img is None and isinstance(self.orig_img, torch.Tensor): if img is None and isinstance(self.orig_img, torch.Tensor):
img = np.ascontiguousarray(self.orig_img[0].permute(1, 2, 0).cpu().detach().numpy()) * 255 img = (self.orig_img[0].detach().permute(1, 2, 0).cpu().contiguous() * 255).to(torch.uint8).numpy()
# Deprecation warn TODO: remove in 8.2 # Deprecation warn TODO: remove in 8.2
if 'show_conf' in kwargs: if 'show_conf' in kwargs:

Loading…
Cancel
Save