support yolov8-det normal

triplemu/cpp-refine
triplemu 6 months ago
parent 58a60434b3
commit c4378eafeb
  1. 79
      csrc/detect/normal/CMakeLists.txt
  2. 138
      csrc/detect/normal/cmake/FindTensorRT.cmake
  3. 14
      csrc/detect/normal/cmake/Function.cmake
  4. 30
      csrc/detect/normal/include/common.hpp
  5. 6075
      csrc/detect/normal/include/filesystem.hpp
  6. 115
      csrc/detect/normal/include/yolov8.hpp
  7. 44
      csrc/detect/normal/main.cpp

@ -1,4 +1,4 @@
cmake_minimum_required(VERSION 3.1)
cmake_minimum_required(VERSION 3.12)
set(CMAKE_CUDA_ARCHITECTURES 60 61 62 70 72 75 86 89 90)
set(CMAKE_CUDA_COMPILER /usr/local/cuda/bin/nvcc)
@ -10,50 +10,79 @@ set(CMAKE_CXX_STANDARD 14)
set(CMAKE_BUILD_TYPE Release)
option(CUDA_USE_STATIC_CUDA_RUNTIME OFF)
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
include(Function)
# CUDA
find_package(CUDA REQUIRED)
message(STATUS "CUDA Libs: \n${CUDA_LIBRARIES}\n")
print_var(CUDA_LIBRARIES)
print_var(CUDA_INCLUDE_DIRS)
get_filename_component(CUDA_LIB_DIR ${CUDA_LIBRARIES} DIRECTORY)
message(STATUS "CUDA Headers: \n${CUDA_INCLUDE_DIRS}\n")
print_var(CUDA_LIB_DIR)
# OpenCV
find_package(OpenCV REQUIRED)
message(STATUS "OpenCV Libs: \n${OpenCV_LIBS}\n")
message(STATUS "OpenCV Libraries: \n${OpenCV_LIBRARIES}\n")
message(STATUS "OpenCV Headers: \n${OpenCV_INCLUDE_DIRS}\n")
print_var(OpenCV_LIBS)
print_var(OpenCV_LIBRARIES)
print_var(OpenCV_INCLUDE_DIRS)
# TensorRT
set(TensorRT_INCLUDE_DIRS /usr/include/x86_64-linux-gnu)
set(TensorRT_LIBRARIES /usr/lib/x86_64-linux-gnu)
message(STATUS "TensorRT Libs: \n${TensorRT_LIBRARIES}\n")
message(STATUS "TensorRT Headers: \n${TensorRT_INCLUDE_DIRS}\n")
find_package(TensorRT REQUIRED)
print_var(TensorRT_LIBRARIES)
print_var(TensorRT_INCLUDE_DIRS)
print_var(TensorRT_LIB_DIR)
if (TensorRT_VERSION_MAJOR GREATER_EQUAL 10)
message(STATUS "Build with -DTRT_10")
add_definitions(-DTRT_10)
endif ()
list(APPEND INCLUDE_DIRS
list(APPEND ALL_INCLUDE_DIRS
${CUDA_INCLUDE_DIRS}
${OpenCV_INCLUDE_DIRS}
${TensorRT_INCLUDE_DIRS}
include
)
${CMAKE_CURRENT_SOURCE_DIR}/include
)
list(APPEND ALL_LIBS
${CUDA_LIBRARIES}
${CUDA_LIB_DIR}
${OpenCV_LIBRARIES}
${TensorRT_LIBRARIES}
)
)
list(APPEND ALL_LIB_DIRS
${CUDA_LIB_DIR}
${TensorRT_LIB_DIR}
)
print_var(ALL_INCLUDE_DIRS)
print_var(ALL_LIBS)
print_var(ALL_LIB_DIRS)
add_executable(
${PROJECT_NAME}
${CMAKE_CURRENT_SOURCE_DIR}/main.cpp
${CMAKE_CURRENT_SOURCE_DIR}/include/yolov8.hpp
${CMAKE_CURRENT_SOURCE_DIR}/include/common.hpp
)
include_directories(${INCLUDE_DIRS})
target_include_directories(
${PROJECT_NAME}
PUBLIC
${ALL_INCLUDE_DIRS}
)
add_executable(${PROJECT_NAME}
main.cpp
include/yolov8.hpp
include/common.hpp
)
target_link_directories(
${PROJECT_NAME}
PUBLIC
${ALL_LIB_DIRS}
)
target_link_directories(${PROJECT_NAME} PUBLIC ${ALL_LIBS})
target_link_libraries(${PROJECT_NAME} PRIVATE nvinfer nvinfer_plugin cudart ${OpenCV_LIBS})
target_link_libraries(
${PROJECT_NAME}
PRIVATE
${ALL_LIBS}
)
if (${OpenCV_VERSION} VERSION_GREATER_EQUAL 4.7.0)
message(STATUS "Build with -DBATCHED_NMS")

@ -0,0 +1,138 @@
# This module defines the following variables:
#
# ::
#
# TensorRT_INCLUDE_DIRS
# TensorRT_LIBRARIES
# TensorRT_FOUND
#
# ::
#
# TensorRT_VERSION_STRING - version (x.y.z)
# TensorRT_VERSION_MAJOR - major version (x)
# TensorRT_VERSION_MINOR - minor version (y)
# TensorRT_VERSION_PATCH - patch version (z)
#
# Hints
# ^^^^^
# A user may set ``TensorRT_ROOT`` to an installation root to tell this module where to look.
#
set(_TensorRT_SEARCHES)
if(TensorRT_ROOT)
set(_TensorRT_SEARCH_ROOT PATHS ${TensorRT_ROOT} NO_DEFAULT_PATH)
list(APPEND _TensorRT_SEARCHES _TensorRT_SEARCH_ROOT)
endif()
# appends some common paths
set(_TensorRT_SEARCH_NORMAL
PATHS "/usr"
)
list(APPEND _TensorRT_SEARCHES _TensorRT_SEARCH_NORMAL)
# Include dir
foreach(search ${_TensorRT_SEARCHES})
find_path(TensorRT_INCLUDE_DIR NAMES NvInfer.h ${${search}} PATH_SUFFIXES include)
endforeach()
if(NOT TensorRT_LIBRARY)
foreach(search ${_TensorRT_SEARCHES})
find_library(TensorRT_LIBRARY NAMES nvinfer ${${search}} PATH_SUFFIXES lib)
if(NOT TensorRT_LIB_DIR)
get_filename_component(TensorRT_LIB_DIR ${TensorRT_LIBRARY} DIRECTORY)
endif ()
endforeach()
endif()
if(NOT TensorRT_nvinfer_plugin_LIBRARY)
foreach(search ${_TensorRT_SEARCHES})
find_library(TensorRT_nvinfer_plugin_LIBRARY NAMES nvinfer_plugin ${${search}} PATH_SUFFIXES lib)
endforeach()
endif()
mark_as_advanced(TensorRT_INCLUDE_DIR)
if(TensorRT_INCLUDE_DIR AND EXISTS "${TensorRT_INCLUDE_DIR}/NvInfer.h")
if(EXISTS "${TensorRT_INCLUDE_DIR}/NvInferVersion.h")
set(_VersionSearchFile "${TensorRT_INCLUDE_DIR}/NvInferVersion.h")
else ()
set(_VersionSearchFile "${TensorRT_INCLUDE_DIR}/NvInfer.h")
endif ()
file(STRINGS "${_VersionSearchFile}" TensorRT_MAJOR REGEX "^#define NV_TENSORRT_MAJOR [0-9]+.*$")
file(STRINGS "${_VersionSearchFile}" TensorRT_MINOR REGEX "^#define NV_TENSORRT_MINOR [0-9]+.*$")
file(STRINGS "${_VersionSearchFile}" TensorRT_PATCH REGEX "^#define NV_TENSORRT_PATCH [0-9]+.*$")
string(REGEX REPLACE "^#define NV_TENSORRT_MAJOR ([0-9]+).*$" "\\1" TensorRT_VERSION_MAJOR "${TensorRT_MAJOR}")
string(REGEX REPLACE "^#define NV_TENSORRT_MINOR ([0-9]+).*$" "\\1" TensorRT_VERSION_MINOR "${TensorRT_MINOR}")
string(REGEX REPLACE "^#define NV_TENSORRT_PATCH ([0-9]+).*$" "\\1" TensorRT_VERSION_PATCH "${TensorRT_PATCH}")
set(TensorRT_VERSION_STRING "${TensorRT_VERSION_MAJOR}.${TensorRT_VERSION_MINOR}.${TensorRT_VERSION_PATCH}")
endif()
include(FindPackageHandleStandardArgs)
FIND_PACKAGE_HANDLE_STANDARD_ARGS(TensorRT REQUIRED_VARS TensorRT_LIBRARY TensorRT_INCLUDE_DIR VERSION_VAR TensorRT_VERSION_STRING)
if(TensorRT_FOUND)
set(TensorRT_INCLUDE_DIRS ${TensorRT_INCLUDE_DIR})
if(NOT TensorRT_LIBRARIES)
set(TensorRT_LIBRARIES ${TensorRT_LIBRARY})
if (TensorRT_nvinfer_plugin_LIBRARY)
list(APPEND TensorRT_LIBRARIES ${TensorRT_nvinfer_plugin_LIBRARY})
endif()
endif()
if(NOT TARGET TensorRT::TensorRT)
add_library(TensorRT INTERFACE IMPORTED)
add_library(TensorRT::TensorRT ALIAS TensorRT)
endif()
if(NOT TARGET TensorRT::nvinfer)
add_library(TensorRT::nvinfer SHARED IMPORTED)
if (WIN32)
foreach(search ${_TensorRT_SEARCHES})
find_file(TensorRT_LIBRARY_DLL
NAMES nvinfer.dll
PATHS ${${search}}
PATH_SUFFIXES bin
)
endforeach()
set_target_properties(TensorRT::nvinfer PROPERTIES
INTERFACE_INCLUDE_DIRECTORIES "${TensorRT_INCLUDE_DIRS}"
IMPORTED_LOCATION "${TensorRT_LIBRARY_DLL}"
IMPORTED_IMPLIB "${TensorRT_LIBRARY}"
)
else()
set_target_properties(TensorRT::nvinfer PROPERTIES
INTERFACE_INCLUDE_DIRECTORIES "${TensorRT_INCLUDE_DIRS}"
IMPORTED_LOCATION "${TensorRT_LIBRARY}"
)
endif()
target_link_libraries(TensorRT INTERFACE TensorRT::nvinfer)
endif()
if(NOT TARGET TensorRT::nvinfer_plugin AND TensorRT_nvinfer_plugin_LIBRARY)
add_library(TensorRT::nvinfer_plugin SHARED IMPORTED)
if (WIN32)
foreach(search ${_TensorRT_SEARCHES})
find_file(TensorRT_nvinfer_plugin_LIBRARY_DLL
NAMES nvinfer_plugin.dll
PATHS ${${search}}
PATH_SUFFIXES bin
)
endforeach()
set_target_properties(TensorRT::nvinfer_plugin PROPERTIES
INTERFACE_INCLUDE_DIRECTORIES "${TensorRT_INCLUDE_DIRS}"
IMPORTED_LOCATION "${TensorRT_nvinfer_plugin_LIBRARY_DLL}"
IMPORTED_IMPLIB "${TensorRT_nvinfer_plugin_LIBRARY}"
)
else()
set_target_properties(TensorRT::nvinfer_plugin PROPERTIES
INTERFACE_INCLUDE_DIRECTORIES "${TensorRT_INCLUDE_DIRS}"
IMPORTED_LOCATION "${TensorRT_nvinfer_plugin_LIBRARY}"
)
endif()
target_link_libraries(TensorRT INTERFACE TensorRT::nvinfer_plugin)
endif()
endif()

@ -0,0 +1,14 @@
function(print_var var)
set(value "${${var}}")
string(LENGTH "${value}" value_length)
if(value_length GREATER 0)
math(EXPR last_index "${value_length} - 1")
string(SUBSTRING "${value}" ${last_index} ${last_index} last_char)
endif()
if(NOT "${last_char}" STREQUAL "\n")
set(value "${value}\n")
endif()
message(STATUS "${var}:\n ${value}")
endfunction()

@ -5,9 +5,8 @@
#ifndef DETECT_NORMAL_COMMON_HPP
#define DETECT_NORMAL_COMMON_HPP
#include "NvInfer.h"
#include "filesystem.hpp"
#include "opencv2/opencv.hpp"
#include <sys/stat.h>
#include <unistd.h>
#define CHECK(call) \
do { \
@ -89,33 +88,6 @@ inline static float clamp(float val, float min, float max)
return val > min ? (val < max ? val : max) : min;
}
inline bool IsPathExist(const std::string& path)
{
if (access(path.c_str(), 0) == F_OK) {
return true;
}
return false;
}
inline bool IsFile(const std::string& path)
{
if (!IsPathExist(path)) {
printf("%s:%d %s not exist\n", __FILE__, __LINE__, path.c_str());
return false;
}
struct stat buffer;
return (stat(path.c_str(), &buffer) == 0 && S_ISREG(buffer.st_mode));
}
inline bool IsFolder(const std::string& path)
{
if (!IsPathExist(path)) {
return false;
}
struct stat buffer;
return (stat(path.c_str(), &buffer) == 0 && S_ISDIR(buffer.st_mode));
}
namespace det {
struct Binding {
size_t size = 1;

File diff suppressed because it is too large Load Diff

@ -5,7 +5,8 @@
#define DETECT_NORMAL_YOLOV8_HPP
#include "NvInferPlugin.h"
#include "common.hpp"
#include "fstream"
#include <fstream>
using namespace det;
class YOLOv8 {
@ -68,28 +69,52 @@ YOLOv8::YOLOv8(const std::string& engine_file_path)
assert(this->context != nullptr);
cudaStreamCreate(&this->stream);
this->num_bindings = this->engine->getNbBindings();
#ifdef TRT_10
this->num_bindings = this->engine->getNbIOTensors();
#else
this->num_bindings = this->num_bindings = this->engine->getNbBindings();
#endif
for (int i = 0; i < this->num_bindings; ++i) {
Binding binding;
nvinfer1::Dims dims;
Binding binding;
nvinfer1::Dims dims;
#ifdef TRT_10
std::string name = this->engine->getIOTensorName(i);
nvinfer1::DataType dtype = this->engine->getTensorDataType(name.c_str());
#else
nvinfer1::DataType dtype = this->engine->getBindingDataType(i);
std::string name = this->engine->getBindingName(i);
binding.name = name;
binding.dsize = type_to_size(dtype);
#endif
binding.name = name;
binding.dsize = type_to_size(dtype);
#ifdef TRT_10
bool IsInput = engine->getTensorIOMode(name.c_str()) == nvinfer1::TensorIOMode::kINPUT;
#else
bool IsInput = engine->bindingIsInput(i);
#endif
if (IsInput) {
this->num_inputs += 1;
dims = this->engine->getProfileDimensions(i, 0, nvinfer1::OptProfileSelector::kMAX);
#ifdef TRT_10
dims = this->engine->getProfileShape(name.c_str(), 0, nvinfer1::OptProfileSelector::kMAX);
// set max opt shape
this->context->setInputShape(name.c_str(), dims);
#else
dims = this->engine->getProfileDimensions(i, 0, nvinfer1::OptProfileSelector::kMAX);
// set max opt shape
this->context->setBindingDimensions(i, dims);
#endif
binding.size = get_size_by_dims(dims);
binding.dims = dims;
this->input_bindings.push_back(binding);
// set max opt shape
this->context->setBindingDimensions(i, dims);
}
else {
dims = this->context->getBindingDimensions(i);
#ifdef TRT_10
dims = this->context->getTensorShape(name.c_str());
#else
dims = this->context->getBindingDimensions(i);
#endif
binding.size = get_size_by_dims(dims);
binding.dims = dims;
this->output_bindings.push_back(binding);
@ -100,9 +125,15 @@ YOLOv8::YOLOv8(const std::string& engine_file_path)
YOLOv8::~YOLOv8()
{
#ifdef TRT_10
delete this->context;
delete this->engine;
delete this->runtime;
#else
this->context->destroy();
this->engine->destroy();
this->runtime->destroy();
#endif
cudaStreamDestroy(this->stream);
for (auto& ptr : this->device_ptrs) {
CHECK(cudaFree(ptr));
@ -119,6 +150,12 @@ void YOLOv8::make_pipe(bool warmup)
void* d_ptr;
CHECK(cudaMallocAsync(&d_ptr, bindings.size * bindings.dsize, this->stream));
this->device_ptrs.push_back(d_ptr);
#ifdef TRT_10
auto name = bindings.name.c_str();
this->context->setInputShape(name, bindings.dims);
this->context->setTensorAddress(name, d_ptr);
#endif
}
for (auto& bindings : this->output_bindings) {
@ -128,6 +165,11 @@ void YOLOv8::make_pipe(bool warmup)
CHECK(cudaHostAlloc(&h_ptr, size, 0));
this->device_ptrs.push_back(d_ptr);
this->host_ptrs.push_back(h_ptr);
#ifdef TRT_10
auto name = bindings.name.c_str();
this->context->setTensorAddress(name, d_ptr);
#endif
}
if (warmup) {
@ -176,7 +218,19 @@ void YOLOv8::letterbox(const cv::Mat& image, cv::Mat& out, cv::Size& size)
cv::copyMakeBorder(tmp, tmp, top, bottom, left, right, cv::BORDER_CONSTANT, {114, 114, 114});
cv::dnn::blobFromImage(tmp, out, 1 / 255.f, cv::Size(), cv::Scalar(0, 0, 0), true, false, CV_32F);
out.create({1, 3, (int)inp_h, (int)inp_w}, CV_32F);
std::vector<cv::Mat> channels;
cv::split(tmp, channels);
cv::Mat c0((int)inp_h, (int)inp_w, CV_32F, (float*)out.data);
cv::Mat c1((int)inp_h, (int)inp_w, CV_32F, (float*)out.data + (int)inp_h * (int)inp_w);
cv::Mat c2((int)inp_h, (int)inp_w, CV_32F, (float*)out.data + (int)inp_h * (int)inp_w * 2);
channels[0].convertTo(c2, CV_32F, 1 / 255.f);
channels[1].convertTo(c1, CV_32F, 1 / 255.f);
channels[2].convertTo(c0, CV_32F, 1 / 255.f);
this->pparam.ratio = 1 / r;
this->pparam.dw = dw;
this->pparam.dh = dh;
@ -189,30 +243,47 @@ void YOLOv8::copy_from_Mat(const cv::Mat& image)
{
cv::Mat nchw;
auto& in_binding = this->input_bindings[0];
auto width = in_binding.dims.d[3];
auto height = in_binding.dims.d[2];
int width = in_binding.dims.d[3];
int height = in_binding.dims.d[2];
cv::Size size{width, height};
this->letterbox(image, nchw, size);
this->context->setBindingDimensions(0, nvinfer1::Dims{4, {1, 3, height, width}});
CHECK(cudaMemcpyAsync(
this->device_ptrs[0], nchw.ptr<float>(), nchw.total() * nchw.elemSize(), cudaMemcpyHostToDevice, this->stream));
#ifdef TRT_10
auto name = this->input_bindings[0].name.c_str();
this->context->setInputShape(name, nvinfer1::Dims{4, {1, 3, size.height, size.width}});
this->context->setTensorAddress(name, this->device_ptrs[0]);
#else
this->context->setBindingDimensions(0, nvinfer1::Dims{4, {1, 3, height, width}});
#endif
}
void YOLOv8::copy_from_Mat(const cv::Mat& image, cv::Size& size)
{
cv::Mat nchw;
this->letterbox(image, nchw, size);
this->context->setBindingDimensions(0, nvinfer1::Dims{4, {1, 3, size.height, size.width}});
CHECK(cudaMemcpyAsync(
this->device_ptrs[0], nchw.ptr<float>(), nchw.total() * nchw.elemSize(), cudaMemcpyHostToDevice, this->stream));
#ifdef TRT_10
auto name = this->input_bindings[0].name.c_str();
this->context->setInputShape(name, nvinfer1::Dims{4, {1, 3, size.height, size.width}});
this->context->setTensorAddress(name, this->device_ptrs[0]);
#else
this->context->setBindingDimensions(0, nvinfer1::Dims{4, {1, 3, size.height, size.width}});
#endif
}
void YOLOv8::infer()
{
#ifdef TRT_10
this->context->enqueueV3(this->stream);
#else
this->context->enqueueV2(this->device_ptrs.data(), this->stream, nullptr);
#endif
for (int i = 0; i < this->num_outputs; i++) {
size_t osize = this->output_bindings[i].size * this->output_bindings[i].dsize;
CHECK(cudaMemcpyAsync(
@ -224,8 +295,8 @@ void YOLOv8::infer()
void YOLOv8::postprocess(std::vector<Object>& objs, float score_thres, float iou_thres, int topk, int num_labels)
{
objs.clear();
auto num_channels = this->output_bindings[0].dims.d[1];
auto num_anchors = this->output_bindings[0].dims.d[2];
int num_channels = this->output_bindings[0].dims.d[1];
int num_anchors = this->output_bindings[0].dims.d[2];
auto& dw = this->pparam.dw;
auto& dh = this->pparam.dh;
@ -310,9 +381,9 @@ void YOLOv8::draw_objects(const cv::Mat& image,
int x = (int)obj.rect.x;
int y = (int)obj.rect.y + 1;
if (y > res.rows)
if (y > res.rows) {
y = res.rows;
}
cv::rectangle(res, cv::Rect(x, y, label_size.width, label_size.height + baseLine), {0, 0, 255}, -1);
cv::putText(res, text, cv::Point(x, y + label_size.height), cv::FONT_HERSHEY_SIMPLEX, 0.4, {255, 255, 255}, 1);

@ -1,9 +1,11 @@
//
// Created by ubuntu on 1/20/23.
//
#include "chrono"
#include "opencv2/opencv.hpp"
#include "yolov8.hpp"
#include <chrono>
namespace fs = ghc::filesystem;
const std::vector<std::string> CLASS_NAMES = {
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train",
@ -37,27 +39,30 @@ const std::vector<std::vector<unsigned int>> COLORS = {
int main(int argc, char** argv)
{
if (argc != 3) {
fprintf(stderr, "Usage: %s [engine_path] [image_path/image_dir/video_path]\n", argv[0]);
return -1;
}
// cuda:0
cudaSetDevice(0);
const std::string engine_file_path{argv[1]};
const std::string path{argv[2]};
const fs::path path{argv[2]};
std::vector<std::string> imagePathList;
bool isVideo{false};
assert(argc == 3);
auto yolov8 = new YOLOv8(engine_file_path);
yolov8->make_pipe(true);
if (IsFile(path)) {
std::string suffix = path.substr(path.find_last_of('.') + 1);
if (suffix == "jpg" || suffix == "jpeg" || suffix == "png") {
if (fs::exists(path)) {
std::string suffix = path.extension();
if (suffix == ".jpg" || suffix == ".jpeg" || suffix == ".png") {
imagePathList.push_back(path);
}
else if (suffix == "mp4" || suffix == "avi" || suffix == "m4v" || suffix == "mpeg" || suffix == "mov"
|| suffix == "mkv") {
else if (suffix == ".mp4" || suffix == ".avi" || suffix == ".m4v" || suffix == ".mpeg" || suffix == ".mov"
|| suffix == ".mkv") {
isVideo = true;
}
else {
@ -65,17 +70,12 @@ int main(int argc, char** argv)
std::abort();
}
}
else if (IsFolder(path)) {
cv::glob(path + "/*.jpg", imagePathList);
else if (fs::is_directory(path)) {
cv::glob(path.string() + "/*.jpg", imagePathList);
}
cv::Mat res, image;
cv::Size size = cv::Size{640, 640};
int num_labels = 80;
int topk = 100;
float score_thres = 0.25f;
float iou_thres = 0.65f;
cv::Mat res, image;
cv::Size size = cv::Size{640, 640};
std::vector<Object> objs;
cv::namedWindow("result", cv::WINDOW_AUTOSIZE);
@ -93,7 +93,7 @@ int main(int argc, char** argv)
auto start = std::chrono::system_clock::now();
yolov8->infer();
auto end = std::chrono::system_clock::now();
yolov8->postprocess(objs, score_thres, iou_thres, topk, num_labels);
yolov8->postprocess(objs);
yolov8->draw_objects(image, res, objs, CLASS_NAMES, COLORS);
auto tc = (double)std::chrono::duration_cast<std::chrono::microseconds>(end - start).count() / 1000.;
printf("cost %2.4lf ms\n", tc);
@ -104,14 +104,14 @@ int main(int argc, char** argv)
}
}
else {
for (auto& path : imagePathList) {
for (auto& p : imagePathList) {
objs.clear();
image = cv::imread(path);
image = cv::imread(p);
yolov8->copy_from_Mat(image, size);
auto start = std::chrono::system_clock::now();
yolov8->infer();
auto end = std::chrono::system_clock::now();
yolov8->postprocess(objs, score_thres, iou_thres, topk, num_labels);
yolov8->postprocess(objs);
yolov8->draw_objects(image, res, objs, CLASS_NAMES, COLORS);
auto tc = (double)std::chrono::duration_cast<std::chrono::microseconds>(end - start).count() / 1000.;
printf("cost %2.4lf ms\n", tc);

Loading…
Cancel
Save