From a5735724c54a9f5bcb239c151fefbd1337d7123d Mon Sep 17 00:00:00 2001 From: Myyura Date: Thu, 21 Dec 2023 07:47:14 +0800 Subject: [PATCH] Add YOLOv8 LibTorch C++ inference example (#7090) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher --- examples/README.md | 1 + .../CMakeLists.txt | 47 ++++ .../YOLOv8-LibTorch-CPP-Inference/README.md | 35 +++ .../YOLOv8-LibTorch-CPP-Inference/main.cc | 259 ++++++++++++++++++ 4 files changed, 342 insertions(+) create mode 100644 examples/YOLOv8-LibTorch-CPP-Inference/CMakeLists.txt create mode 100644 examples/YOLOv8-LibTorch-CPP-Inference/README.md create mode 100644 examples/YOLOv8-LibTorch-CPP-Inference/main.cc diff --git a/examples/README.md b/examples/README.md index 90d1415c7..d49bdfe5d 100644 --- a/examples/README.md +++ b/examples/README.md @@ -17,6 +17,7 @@ This repository features a collection of real-world applications and walkthrough | [YOLOv8 SAHI Video Inference](https://github.com/RizwanMunawar/ultralytics/blob/main/examples/YOLOv8-SAHI-Inference-Video/yolov8_sahi.py) | Python | [Muhammad Rizwan Munawar](https://github.com/RizwanMunawar) | | [YOLOv8 Region Counter](https://github.com/RizwanMunawar/ultralytics/blob/main/examples/YOLOv8-Region-Counter/yolov8_region_counter.py) | Python | [Muhammad Rizwan Munawar](https://github.com/RizwanMunawar) | | [YOLOv8 Segmentation ONNXRuntime Python](./YOLOv8-Segmentation-ONNXRuntime-Python) | Python/ONNXRuntime | [jamjamjon](https://github.com/jamjamjon) | +| [YOLOv8 LibTorch CPP](./YOLOv8-LibTorch-CPP-Inference) | C++/LibTorch | [Myyura](https://github.com/Myyura) | ### How to Contribute diff --git a/examples/YOLOv8-LibTorch-CPP-Inference/CMakeLists.txt b/examples/YOLOv8-LibTorch-CPP-Inference/CMakeLists.txt new file mode 100644 index 000000000..2cbd796c4 --- /dev/null +++ b/examples/YOLOv8-LibTorch-CPP-Inference/CMakeLists.txt @@ -0,0 +1,47 @@ +cmake_minimum_required(VERSION 3.18 FATAL_ERROR) + +project(yolov8_libtorch_example) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CXX_EXTENSIONS OFF) + + +# -------------- OpenCV -------------- +set(OpenCV_DIR "/path/to/opencv/lib/cmake/opencv4") +find_package(OpenCV REQUIRED) + +message(STATUS "OpenCV library status:") +message(STATUS " config: ${OpenCV_DIR}") +message(STATUS " version: ${OpenCV_VERSION}") +message(STATUS " libraries: ${OpenCV_LIBS}") +message(STATUS " include path: ${OpenCV_INCLUDE_DIRS}") + +include_directories(${OpenCV_INCLUDE_DIRS}) + +# -------------- libtorch -------------- +list(APPEND CMAKE_PREFIX_PATH "/path/to/libtorch") +set(Torch_DIR "/path/to/libtorch/share/cmake/Torch") + +find_package(Torch REQUIRED) +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") +message("${TORCH_LIBRARIES}") +message("${TORCH_INCLUDE_DIRS}") + +# The following code block is suggested to be used on Windows. +# According to https://github.com/pytorch/pytorch/issues/25457, +# the DLLs need to be copied to avoid memory errors. +# if (MSVC) +# file(GLOB TORCH_DLLS "${TORCH_INSTALL_PREFIX}/lib/*.dll") +# add_custom_command(TARGET yolov8_libtorch_example +# POST_BUILD +# COMMAND ${CMAKE_COMMAND} -E copy_if_different +# ${TORCH_DLLS} +# $) +# endif (MSVC) + +include_directories(${TORCH_INCLUDE_DIRS}) + +add_executable(yolov8_libtorch_inference "${CMAKE_CURRENT_SOURCE_DIR}/main.cc") +target_link_libraries(yolov8_libtorch_inference ${TORCH_LIBRARIES} ${OpenCV_LIBS}) +set_property(TARGET yolov8_libtorch_inference PROPERTY CXX_STANDARD 17) diff --git a/examples/YOLOv8-LibTorch-CPP-Inference/README.md b/examples/YOLOv8-LibTorch-CPP-Inference/README.md new file mode 100644 index 000000000..930c3cd22 --- /dev/null +++ b/examples/YOLOv8-LibTorch-CPP-Inference/README.md @@ -0,0 +1,35 @@ +# YOLOv8 LibTorch Inference C++ + +This example demonstrates how to perform inference using YOLOv8 models in C++ with LibTorch API. + +## Dependencies + +| Dependency | Version | +| ------------ | -------- | +| OpenCV | >=4.0.0 | +| C++ Standard | >=17 | +| Cmake | >=3.18 | +| Libtorch | >=1.12.1 | + +## Usage + +```bash +git clone ultralytics +cd ultralytics +pip install . +cd examples/YOLOv8-LibTorch-CPP-Inference + +mkdir build +cd build +cmake .. +make +./yolov8_libtorch_inference +``` + +## Exporting YOLOv8 + +To export YOLOv8 models: + +```commandline +yolo export model=yolov8s.pt imgsz=640 format=torchscript +``` diff --git a/examples/YOLOv8-LibTorch-CPP-Inference/main.cc b/examples/YOLOv8-LibTorch-CPP-Inference/main.cc new file mode 100644 index 000000000..ebb1a7589 --- /dev/null +++ b/examples/YOLOv8-LibTorch-CPP-Inference/main.cc @@ -0,0 +1,259 @@ +#include + +#include +#include +#include +#include +#include + +using torch::indexing::Slice; +using torch::indexing::None; + + +float generate_scale(cv::Mat& image, const std::vector& target_size) { + int origin_w = image.cols; + int origin_h = image.rows; + + int target_h = target_size[0]; + int target_w = target_size[1]; + + float ratio_h = static_cast(target_h) / static_cast(origin_h); + float ratio_w = static_cast(target_w) / static_cast(origin_w); + float resize_scale = std::min(ratio_h, ratio_w); + return resize_scale; +} + + +float letterbox(cv::Mat &input_image, cv::Mat &output_image, const std::vector &target_size) { + if (input_image.cols == target_size[1] && input_image.rows == target_size[0]) { + if (input_image.data == output_image.data) { + return 1.; + } else { + output_image = input_image.clone(); + return 1.; + } + } + + float resize_scale = generate_scale(input_image, target_size); + int new_shape_w = std::round(input_image.cols * resize_scale); + int new_shape_h = std::round(input_image.rows * resize_scale); + float padw = (target_size[1] - new_shape_w) / 2.; + float padh = (target_size[0] - new_shape_h) / 2.; + + int top = std::round(padh - 0.1); + int bottom = std::round(padh + 0.1); + int left = std::round(padw - 0.1); + int right = std::round(padw + 0.1); + + cv::resize(input_image, output_image, + cv::Size(new_shape_w, new_shape_h), + 0, 0, cv::INTER_AREA); + + cv::copyMakeBorder(output_image, output_image, top, bottom, left, right, + cv::BORDER_CONSTANT, cv::Scalar(114.)); + return resize_scale; +} + + +torch::Tensor xyxy2xywh(const torch::Tensor& x) { + auto y = torch::empty_like(x); + y.index_put_({"...", 0}, (x.index({"...", 0}) + x.index({"...", 2})).div(2)); + y.index_put_({"...", 1}, (x.index({"...", 1}) + x.index({"...", 3})).div(2)); + y.index_put_({"...", 2}, x.index({"...", 2}) - x.index({"...", 0})); + y.index_put_({"...", 3}, x.index({"...", 3}) - x.index({"...", 1})); + return y; +} + + +torch::Tensor xywh2xyxy(const torch::Tensor& x) { + auto y = torch::empty_like(x); + auto dw = x.index({"...", 2}).div(2); + auto dh = x.index({"...", 3}).div(2); + y.index_put_({"...", 0}, x.index({"...", 0}) - dw); + y.index_put_({"...", 1}, x.index({"...", 1}) - dh); + y.index_put_({"...", 2}, x.index({"...", 0}) + dw); + y.index_put_({"...", 3}, x.index({"...", 1}) + dh); + return y; +} + + +// Reference: https://github.com/pytorch/vision/blob/main/torchvision/csrc/ops/cpu/nms_kernel.cpp +torch::Tensor nms(const torch::Tensor& bboxes, const torch::Tensor& scores, float iou_threshold) { + if (bboxes.numel() == 0) + return torch::empty({0}, bboxes.options().dtype(torch::kLong)); + + auto x1_t = bboxes.select(1, 0).contiguous(); + auto y1_t = bboxes.select(1, 1).contiguous(); + auto x2_t = bboxes.select(1, 2).contiguous(); + auto y2_t = bboxes.select(1, 3).contiguous(); + + torch::Tensor areas_t = (x2_t - x1_t) * (y2_t - y1_t); + + auto order_t = std::get<1>( + scores.sort(/*stable=*/true, /*dim=*/0, /* descending=*/true)); + + auto ndets = bboxes.size(0); + torch::Tensor suppressed_t = torch::zeros({ndets}, bboxes.options().dtype(torch::kByte)); + torch::Tensor keep_t = torch::zeros({ndets}, bboxes.options().dtype(torch::kLong)); + + auto suppressed = suppressed_t.data_ptr(); + auto keep = keep_t.data_ptr(); + auto order = order_t.data_ptr(); + auto x1 = x1_t.data_ptr(); + auto y1 = y1_t.data_ptr(); + auto x2 = x2_t.data_ptr(); + auto y2 = y2_t.data_ptr(); + auto areas = areas_t.data_ptr(); + + int64_t num_to_keep = 0; + + for (int64_t _i = 0; _i < ndets; _i++) { + auto i = order[_i]; + if (suppressed[i] == 1) + continue; + keep[num_to_keep++] = i; + auto ix1 = x1[i]; + auto iy1 = y1[i]; + auto ix2 = x2[i]; + auto iy2 = y2[i]; + auto iarea = areas[i]; + + for (int64_t _j = _i + 1; _j < ndets; _j++) { + auto j = order[_j]; + if (suppressed[j] == 1) + continue; + auto xx1 = std::max(ix1, x1[j]); + auto yy1 = std::max(iy1, y1[j]); + auto xx2 = std::min(ix2, x2[j]); + auto yy2 = std::min(iy2, y2[j]); + + auto w = std::max(static_cast(0), xx2 - xx1); + auto h = std::max(static_cast(0), yy2 - yy1); + auto inter = w * h; + auto ovr = inter / (iarea + areas[j] - inter); + if (ovr > iou_threshold) + suppressed[j] = 1; + } + } + return keep_t.narrow(0, 0, num_to_keep); +} + + +torch::Tensor non_max_supperession(torch::Tensor& prediction, float conf_thres = 0.25, float iou_thres = 0.45, int max_det = 300) { + auto bs = prediction.size(0); + auto nc = prediction.size(1) - 4; + auto nm = prediction.size(1) - nc - 4; + auto mi = 4 + nc; + auto xc = prediction.index({Slice(), Slice(4, mi)}).amax(1) > conf_thres; + + prediction = prediction.transpose(-1, -2); + prediction.index_put_({"...", Slice({None, 4})}, xywh2xyxy(prediction.index({"...", Slice(None, 4)}))); + + std::vector output; + for (int i = 0; i < bs; i++) { + output.push_back(torch::zeros({0, 6 + nm}, prediction.device())); + } + + for (int xi = 0; xi < prediction.size(0); xi++) { + auto x = prediction[xi]; + x = x.index({xc[xi]}); + auto x_split = x.split({4, nc, nm}, 1); + auto box = x_split[0], cls = x_split[1], mask = x_split[2]; + auto [conf, j] = cls.max(1, true); + x = torch::cat({box, conf, j.toType(torch::kFloat), mask}, 1); + x = x.index({conf.view(-1) > conf_thres}); + int n = x.size(0); + if (!n) { continue; } + + // NMS + auto c = x.index({Slice(), Slice{5, 6}}) * 7680; + auto boxes = x.index({Slice(), Slice(None, 4)}) + c; + auto scores = x.index({Slice(), 4}); + auto i = nms(boxes, scores, iou_thres); + i = i.index({Slice(None, max_det)}); + output[xi] = x.index({i}); + } + + return torch::stack(output); +} + + +torch::Tensor clip_boxes(torch::Tensor& boxes, const std::vector& shape) { + boxes.index_put_({"...", 0}, boxes.index({"...", 0}).clamp(0, shape[1])); + boxes.index_put_({"...", 1}, boxes.index({"...", 1}).clamp(0, shape[0])); + boxes.index_put_({"...", 2}, boxes.index({"...", 2}).clamp(0, shape[1])); + boxes.index_put_({"...", 3}, boxes.index({"...", 3}).clamp(0, shape[0])); + return boxes; +} + + +torch::Tensor scale_boxes(const std::vector& img1_shape, torch::Tensor& boxes, const std::vector& img0_shape) { + auto gain = (std::min)((float)img1_shape[0] / img0_shape[0], (float)img1_shape[1] / img0_shape[1]); + auto pad0 = std::round((float)(img1_shape[1] - img0_shape[1] * gain) / 2. - 0.1); + auto pad1 = std::round((float)(img1_shape[0] - img0_shape[0] * gain) / 2. - 0.1); + + boxes.index_put_({"...", 0}, boxes.index({"...", 0}) - pad0); + boxes.index_put_({"...", 2}, boxes.index({"...", 2}) - pad0); + boxes.index_put_({"...", 1}, boxes.index({"...", 1}) - pad1); + boxes.index_put_({"...", 3}, boxes.index({"...", 3}) - pad1); + boxes.index_put_({"...", Slice(None, 4)}, boxes.index({"...", Slice(None, 4)}).div(gain)); + return boxes; +} + + +int main() { + // Device + torch::Device device(torch::cuda::is_available() ? torch::kCUDA :torch::kCPU); + + // Note that in this example the classes are hard-coded + std::vector classes {"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light", "fire hydrant", + "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", + "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", + "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", + "spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", + "couch", "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", + "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush"}; + + try { + // Load the model (e.g. yolov8s.torchscript) + std::string model_path = "/path/to/yolov8s.torchscript"; + torch::jit::script::Module yolo_model; + yolo_model = torch::jit::load(model_path); + yolo_model.eval(); + yolo_model.to(device, torch::kFloat32); + + // Load image and preprocess + cv::Mat image = cv::imread("/path/to/bus.jpg"); + cv::Mat input_image; + letterbox(image, input_image, {640, 640}); + + torch::Tensor image_tensor = torch::from_blob(input_image.data, {input_image.rows, input_image.cols, 3}, torch::kByte).to(device); + image_tensor = image_tensor.toType(torch::kFloat32).div(255); + image_tensor = image_tensor.permute({2, 0, 1}); + image_tensor = image_tensor.unsqueeze(0); + std::vector inputs {image_tensor}; + + // Inference + torch::Tensor output = yolo_model.forward(inputs).toTensor().cpu(); + + // NMS + auto keep = non_max_supperession(output)[0]; + auto boxes = keep.index({Slice(), Slice(None, 4)}); + keep.index_put_({Slice(), Slice(None, 4)}, scale_boxes({input_image.rows, input_image.cols}, boxes, {image.rows, image.cols})); + + // Show the results + for (int i = 0; i < keep.size(0); i++) { + int x1 = keep[i][0].item().toFloat(); + int y1 = keep[i][1].item().toFloat(); + int x2 = keep[i][2].item().toFloat(); + int y2 = keep[i][3].item().toFloat(); + float conf = keep[i][4].item().toFloat(); + int cls = keep[i][5].item().toInt(); + std::cout << "Rect: [" << x1 << "," << y1 << "," << x2 << "," << y2 << "] Conf: " << conf << " Class: " << classes[cls] << std::endl; + } + } catch (const c10::Error& e) { + std::cout << e.msg() << std::endl; + } + + return 0; +}