Add YOLOv8 LibTorch C++ inference example (#7090)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
pull/6971/head^2
Myyura 1 year ago committed by GitHub
parent f955978dc4
commit a5735724c5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 1
      examples/README.md
  2. 47
      examples/YOLOv8-LibTorch-CPP-Inference/CMakeLists.txt
  3. 35
      examples/YOLOv8-LibTorch-CPP-Inference/README.md
  4. 259
      examples/YOLOv8-LibTorch-CPP-Inference/main.cc

@ -17,6 +17,7 @@ This repository features a collection of real-world applications and walkthrough
| [YOLOv8 SAHI Video Inference](https://github.com/RizwanMunawar/ultralytics/blob/main/examples/YOLOv8-SAHI-Inference-Video/yolov8_sahi.py) | Python | [Muhammad Rizwan Munawar](https://github.com/RizwanMunawar) |
| [YOLOv8 Region Counter](https://github.com/RizwanMunawar/ultralytics/blob/main/examples/YOLOv8-Region-Counter/yolov8_region_counter.py) | Python | [Muhammad Rizwan Munawar](https://github.com/RizwanMunawar) |
| [YOLOv8 Segmentation ONNXRuntime Python](./YOLOv8-Segmentation-ONNXRuntime-Python) | Python/ONNXRuntime | [jamjamjon](https://github.com/jamjamjon) |
| [YOLOv8 LibTorch CPP](./YOLOv8-LibTorch-CPP-Inference) | C++/LibTorch | [Myyura](https://github.com/Myyura) |
### How to Contribute

@ -0,0 +1,47 @@
cmake_minimum_required(VERSION 3.18 FATAL_ERROR)
project(yolov8_libtorch_example)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS OFF)
# -------------- OpenCV --------------
set(OpenCV_DIR "/path/to/opencv/lib/cmake/opencv4")
find_package(OpenCV REQUIRED)
message(STATUS "OpenCV library status:")
message(STATUS " config: ${OpenCV_DIR}")
message(STATUS " version: ${OpenCV_VERSION}")
message(STATUS " libraries: ${OpenCV_LIBS}")
message(STATUS " include path: ${OpenCV_INCLUDE_DIRS}")
include_directories(${OpenCV_INCLUDE_DIRS})
# -------------- libtorch --------------
list(APPEND CMAKE_PREFIX_PATH "/path/to/libtorch")
set(Torch_DIR "/path/to/libtorch/share/cmake/Torch")
find_package(Torch REQUIRED)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
message("${TORCH_LIBRARIES}")
message("${TORCH_INCLUDE_DIRS}")
# The following code block is suggested to be used on Windows.
# According to https://github.com/pytorch/pytorch/issues/25457,
# the DLLs need to be copied to avoid memory errors.
# if (MSVC)
# file(GLOB TORCH_DLLS "${TORCH_INSTALL_PREFIX}/lib/*.dll")
# add_custom_command(TARGET yolov8_libtorch_example
# POST_BUILD
# COMMAND ${CMAKE_COMMAND} -E copy_if_different
# ${TORCH_DLLS}
# $<TARGET_FILE_DIR:yolov8_libtorch_example>)
# endif (MSVC)
include_directories(${TORCH_INCLUDE_DIRS})
add_executable(yolov8_libtorch_inference "${CMAKE_CURRENT_SOURCE_DIR}/main.cc")
target_link_libraries(yolov8_libtorch_inference ${TORCH_LIBRARIES} ${OpenCV_LIBS})
set_property(TARGET yolov8_libtorch_inference PROPERTY CXX_STANDARD 17)

@ -0,0 +1,35 @@
# YOLOv8 LibTorch Inference C++
This example demonstrates how to perform inference using YOLOv8 models in C++ with LibTorch API.
## Dependencies
| Dependency | Version |
| ------------ | -------- |
| OpenCV | >=4.0.0 |
| C++ Standard | >=17 |
| Cmake | >=3.18 |
| Libtorch | >=1.12.1 |
## Usage
```bash
git clone ultralytics
cd ultralytics
pip install .
cd examples/YOLOv8-LibTorch-CPP-Inference
mkdir build
cd build
cmake ..
make
./yolov8_libtorch_inference
```
## Exporting YOLOv8
To export YOLOv8 models:
```commandline
yolo export model=yolov8s.pt imgsz=640 format=torchscript
```

@ -0,0 +1,259 @@
#include <iostream>
#include <opencv2/core.hpp>
#include <opencv2/imgproc.hpp>
#include <opencv2/imgcodecs.hpp>
#include <torch/torch.h>
#include <torch/script.h>
using torch::indexing::Slice;
using torch::indexing::None;
float generate_scale(cv::Mat& image, const std::vector<int>& target_size) {
int origin_w = image.cols;
int origin_h = image.rows;
int target_h = target_size[0];
int target_w = target_size[1];
float ratio_h = static_cast<float>(target_h) / static_cast<float>(origin_h);
float ratio_w = static_cast<float>(target_w) / static_cast<float>(origin_w);
float resize_scale = std::min(ratio_h, ratio_w);
return resize_scale;
}
float letterbox(cv::Mat &input_image, cv::Mat &output_image, const std::vector<int> &target_size) {
if (input_image.cols == target_size[1] && input_image.rows == target_size[0]) {
if (input_image.data == output_image.data) {
return 1.;
} else {
output_image = input_image.clone();
return 1.;
}
}
float resize_scale = generate_scale(input_image, target_size);
int new_shape_w = std::round(input_image.cols * resize_scale);
int new_shape_h = std::round(input_image.rows * resize_scale);
float padw = (target_size[1] - new_shape_w) / 2.;
float padh = (target_size[0] - new_shape_h) / 2.;
int top = std::round(padh - 0.1);
int bottom = std::round(padh + 0.1);
int left = std::round(padw - 0.1);
int right = std::round(padw + 0.1);
cv::resize(input_image, output_image,
cv::Size(new_shape_w, new_shape_h),
0, 0, cv::INTER_AREA);
cv::copyMakeBorder(output_image, output_image, top, bottom, left, right,
cv::BORDER_CONSTANT, cv::Scalar(114.));
return resize_scale;
}
torch::Tensor xyxy2xywh(const torch::Tensor& x) {
auto y = torch::empty_like(x);
y.index_put_({"...", 0}, (x.index({"...", 0}) + x.index({"...", 2})).div(2));
y.index_put_({"...", 1}, (x.index({"...", 1}) + x.index({"...", 3})).div(2));
y.index_put_({"...", 2}, x.index({"...", 2}) - x.index({"...", 0}));
y.index_put_({"...", 3}, x.index({"...", 3}) - x.index({"...", 1}));
return y;
}
torch::Tensor xywh2xyxy(const torch::Tensor& x) {
auto y = torch::empty_like(x);
auto dw = x.index({"...", 2}).div(2);
auto dh = x.index({"...", 3}).div(2);
y.index_put_({"...", 0}, x.index({"...", 0}) - dw);
y.index_put_({"...", 1}, x.index({"...", 1}) - dh);
y.index_put_({"...", 2}, x.index({"...", 0}) + dw);
y.index_put_({"...", 3}, x.index({"...", 1}) + dh);
return y;
}
// Reference: https://github.com/pytorch/vision/blob/main/torchvision/csrc/ops/cpu/nms_kernel.cpp
torch::Tensor nms(const torch::Tensor& bboxes, const torch::Tensor& scores, float iou_threshold) {
if (bboxes.numel() == 0)
return torch::empty({0}, bboxes.options().dtype(torch::kLong));
auto x1_t = bboxes.select(1, 0).contiguous();
auto y1_t = bboxes.select(1, 1).contiguous();
auto x2_t = bboxes.select(1, 2).contiguous();
auto y2_t = bboxes.select(1, 3).contiguous();
torch::Tensor areas_t = (x2_t - x1_t) * (y2_t - y1_t);
auto order_t = std::get<1>(
scores.sort(/*stable=*/true, /*dim=*/0, /* descending=*/true));
auto ndets = bboxes.size(0);
torch::Tensor suppressed_t = torch::zeros({ndets}, bboxes.options().dtype(torch::kByte));
torch::Tensor keep_t = torch::zeros({ndets}, bboxes.options().dtype(torch::kLong));
auto suppressed = suppressed_t.data_ptr<uint8_t>();
auto keep = keep_t.data_ptr<int64_t>();
auto order = order_t.data_ptr<int64_t>();
auto x1 = x1_t.data_ptr<float>();
auto y1 = y1_t.data_ptr<float>();
auto x2 = x2_t.data_ptr<float>();
auto y2 = y2_t.data_ptr<float>();
auto areas = areas_t.data_ptr<float>();
int64_t num_to_keep = 0;
for (int64_t _i = 0; _i < ndets; _i++) {
auto i = order[_i];
if (suppressed[i] == 1)
continue;
keep[num_to_keep++] = i;
auto ix1 = x1[i];
auto iy1 = y1[i];
auto ix2 = x2[i];
auto iy2 = y2[i];
auto iarea = areas[i];
for (int64_t _j = _i + 1; _j < ndets; _j++) {
auto j = order[_j];
if (suppressed[j] == 1)
continue;
auto xx1 = std::max(ix1, x1[j]);
auto yy1 = std::max(iy1, y1[j]);
auto xx2 = std::min(ix2, x2[j]);
auto yy2 = std::min(iy2, y2[j]);
auto w = std::max(static_cast<float>(0), xx2 - xx1);
auto h = std::max(static_cast<float>(0), yy2 - yy1);
auto inter = w * h;
auto ovr = inter / (iarea + areas[j] - inter);
if (ovr > iou_threshold)
suppressed[j] = 1;
}
}
return keep_t.narrow(0, 0, num_to_keep);
}
torch::Tensor non_max_supperession(torch::Tensor& prediction, float conf_thres = 0.25, float iou_thres = 0.45, int max_det = 300) {
auto bs = prediction.size(0);
auto nc = prediction.size(1) - 4;
auto nm = prediction.size(1) - nc - 4;
auto mi = 4 + nc;
auto xc = prediction.index({Slice(), Slice(4, mi)}).amax(1) > conf_thres;
prediction = prediction.transpose(-1, -2);
prediction.index_put_({"...", Slice({None, 4})}, xywh2xyxy(prediction.index({"...", Slice(None, 4)})));
std::vector<torch::Tensor> output;
for (int i = 0; i < bs; i++) {
output.push_back(torch::zeros({0, 6 + nm}, prediction.device()));
}
for (int xi = 0; xi < prediction.size(0); xi++) {
auto x = prediction[xi];
x = x.index({xc[xi]});
auto x_split = x.split({4, nc, nm}, 1);
auto box = x_split[0], cls = x_split[1], mask = x_split[2];
auto [conf, j] = cls.max(1, true);
x = torch::cat({box, conf, j.toType(torch::kFloat), mask}, 1);
x = x.index({conf.view(-1) > conf_thres});
int n = x.size(0);
if (!n) { continue; }
// NMS
auto c = x.index({Slice(), Slice{5, 6}}) * 7680;
auto boxes = x.index({Slice(), Slice(None, 4)}) + c;
auto scores = x.index({Slice(), 4});
auto i = nms(boxes, scores, iou_thres);
i = i.index({Slice(None, max_det)});
output[xi] = x.index({i});
}
return torch::stack(output);
}
torch::Tensor clip_boxes(torch::Tensor& boxes, const std::vector<int>& shape) {
boxes.index_put_({"...", 0}, boxes.index({"...", 0}).clamp(0, shape[1]));
boxes.index_put_({"...", 1}, boxes.index({"...", 1}).clamp(0, shape[0]));
boxes.index_put_({"...", 2}, boxes.index({"...", 2}).clamp(0, shape[1]));
boxes.index_put_({"...", 3}, boxes.index({"...", 3}).clamp(0, shape[0]));
return boxes;
}
torch::Tensor scale_boxes(const std::vector<int>& img1_shape, torch::Tensor& boxes, const std::vector<int>& img0_shape) {
auto gain = (std::min)((float)img1_shape[0] / img0_shape[0], (float)img1_shape[1] / img0_shape[1]);
auto pad0 = std::round((float)(img1_shape[1] - img0_shape[1] * gain) / 2. - 0.1);
auto pad1 = std::round((float)(img1_shape[0] - img0_shape[0] * gain) / 2. - 0.1);
boxes.index_put_({"...", 0}, boxes.index({"...", 0}) - pad0);
boxes.index_put_({"...", 2}, boxes.index({"...", 2}) - pad0);
boxes.index_put_({"...", 1}, boxes.index({"...", 1}) - pad1);
boxes.index_put_({"...", 3}, boxes.index({"...", 3}) - pad1);
boxes.index_put_({"...", Slice(None, 4)}, boxes.index({"...", Slice(None, 4)}).div(gain));
return boxes;
}
int main() {
// Device
torch::Device device(torch::cuda::is_available() ? torch::kCUDA :torch::kCPU);
// Note that in this example the classes are hard-coded
std::vector<std::string> classes {"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light", "fire hydrant",
"stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra",
"giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite",
"baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork", "knife",
"spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair",
"couch", "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone",
"microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush"};
try {
// Load the model (e.g. yolov8s.torchscript)
std::string model_path = "/path/to/yolov8s.torchscript";
torch::jit::script::Module yolo_model;
yolo_model = torch::jit::load(model_path);
yolo_model.eval();
yolo_model.to(device, torch::kFloat32);
// Load image and preprocess
cv::Mat image = cv::imread("/path/to/bus.jpg");
cv::Mat input_image;
letterbox(image, input_image, {640, 640});
torch::Tensor image_tensor = torch::from_blob(input_image.data, {input_image.rows, input_image.cols, 3}, torch::kByte).to(device);
image_tensor = image_tensor.toType(torch::kFloat32).div(255);
image_tensor = image_tensor.permute({2, 0, 1});
image_tensor = image_tensor.unsqueeze(0);
std::vector<torch::jit::IValue> inputs {image_tensor};
// Inference
torch::Tensor output = yolo_model.forward(inputs).toTensor().cpu();
// NMS
auto keep = non_max_supperession(output)[0];
auto boxes = keep.index({Slice(), Slice(None, 4)});
keep.index_put_({Slice(), Slice(None, 4)}, scale_boxes({input_image.rows, input_image.cols}, boxes, {image.rows, image.cols}));
// Show the results
for (int i = 0; i < keep.size(0); i++) {
int x1 = keep[i][0].item().toFloat();
int y1 = keep[i][1].item().toFloat();
int x2 = keep[i][2].item().toFloat();
int y2 = keep[i][3].item().toFloat();
float conf = keep[i][4].item().toFloat();
int cls = keep[i][5].item().toInt();
std::cout << "Rect: [" << x1 << "," << y1 << "," << x2 << "," << y2 << "] Conf: " << conf << " Class: " << classes[cls] << std::endl;
}
} catch (const c10::Error& e) {
std::cout << e.msg() << std::endl;
}
return 0;
}
Loading…
Cancel
Save