add obb-cpp support

pull/270/head
triplemu 1 month ago
parent c1e76b6dd5
commit 1104e99629
  1. 84
      csrc/obb/normal/CMakeLists.txt
  2. 138
      csrc/obb/normal/cmake/FindTensorRT.cmake
  3. 14
      csrc/obb/normal/cmake/Function.cmake
  4. 113
      csrc/obb/normal/include/common.hpp
  5. 6075
      csrc/obb/normal/include/filesystem.hpp
  6. 339
      csrc/obb/normal/include/yolov8-obb.hpp
  7. 134
      csrc/obb/normal/main.cpp
  8. 6
      csrc/pose/normal/main.cpp
  9. 6
      csrc/segment/normal/main.cpp
  10. 6
      csrc/segment/simple/main.cpp
  11. 7
      docs/Obb.md

@ -0,0 +1,84 @@
cmake_minimum_required(VERSION 3.1)
set(CMAKE_CUDA_ARCHITECTURES 60 61 62 70 72 75 86 89 90)
set(CMAKE_CUDA_COMPILER /usr/local/cuda/bin/nvcc)
project(yolov8-obb LANGUAGES CXX CUDA)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++14 -O3")
set(CMAKE_CXX_STANDARD 14)
set(CMAKE_BUILD_TYPE Release)
option(CUDA_USE_STATIC_CUDA_RUNTIME OFF)
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
include(Function)
# CUDA
find_package(CUDA REQUIRED)
print_var(CUDA_LIBRARIES)
print_var(CUDA_INCLUDE_DIRS)
get_filename_component(CUDA_LIB_DIR ${CUDA_LIBRARIES} DIRECTORY)
print_var(CUDA_LIB_DIR)
# OpenCV
find_package(OpenCV REQUIRED)
print_var(OpenCV_LIBS)
print_var(OpenCV_LIBRARIES)
print_var(OpenCV_INCLUDE_DIRS)
# TensorRT
find_package(TensorRT REQUIRED)
print_var(TensorRT_LIBRARIES)
print_var(TensorRT_INCLUDE_DIRS)
print_var(TensorRT_LIB_DIR)
if (TensorRT_VERSION_MAJOR GREATER_EQUAL 10)
message(STATUS "Build with -DTRT_10")
add_definitions(-DTRT_10)
endif ()
list(APPEND ALL_INCLUDE_DIRS
${CUDA_INCLUDE_DIRS}
${OpenCV_INCLUDE_DIRS}
${TensorRT_INCLUDE_DIRS}
${CMAKE_CURRENT_SOURCE_DIR}/include
)
list(APPEND ALL_LIBS
${CUDA_LIBRARIES}
${OpenCV_LIBRARIES}
${TensorRT_LIBRARIES}
)
list(APPEND ALL_LIB_DIRS
${CUDA_LIB_DIR}
${TensorRT_LIB_DIR}
)
print_var(ALL_INCLUDE_DIRS)
print_var(ALL_LIBS)
print_var(ALL_LIB_DIRS)
add_executable(${PROJECT_NAME}
${CMAKE_CURRENT_SOURCE_DIR}/main.cpp
${CMAKE_CURRENT_SOURCE_DIR}/include/yolov8-obb.hpp
${CMAKE_CURRENT_SOURCE_DIR}/include/common.hpp
)
target_include_directories(
${PROJECT_NAME}
PUBLIC
${ALL_INCLUDE_DIRS}
)
target_link_directories(
${PROJECT_NAME}
PUBLIC
${ALL_LIB_DIRS}
)
target_link_libraries(
${PROJECT_NAME}
PRIVATE
${ALL_LIBS}
)

@ -0,0 +1,138 @@
# This module defines the following variables:
#
# ::
#
# TensorRT_INCLUDE_DIRS
# TensorRT_LIBRARIES
# TensorRT_FOUND
#
# ::
#
# TensorRT_VERSION_STRING - version (x.y.z)
# TensorRT_VERSION_MAJOR - major version (x)
# TensorRT_VERSION_MINOR - minor version (y)
# TensorRT_VERSION_PATCH - patch version (z)
#
# Hints
# ^^^^^
# A user may set ``TensorRT_ROOT`` to an installation root to tell this module where to look.
#
set(_TensorRT_SEARCHES)
if(TensorRT_ROOT)
set(_TensorRT_SEARCH_ROOT PATHS ${TensorRT_ROOT} NO_DEFAULT_PATH)
list(APPEND _TensorRT_SEARCHES _TensorRT_SEARCH_ROOT)
endif()
# appends some common paths
set(_TensorRT_SEARCH_NORMAL
PATHS "/usr"
)
list(APPEND _TensorRT_SEARCHES _TensorRT_SEARCH_NORMAL)
# Include dir
foreach(search ${_TensorRT_SEARCHES})
find_path(TensorRT_INCLUDE_DIR NAMES NvInfer.h ${${search}} PATH_SUFFIXES include)
endforeach()
if(NOT TensorRT_LIBRARY)
foreach(search ${_TensorRT_SEARCHES})
find_library(TensorRT_LIBRARY NAMES nvinfer ${${search}} PATH_SUFFIXES lib)
if(NOT TensorRT_LIB_DIR)
get_filename_component(TensorRT_LIB_DIR ${TensorRT_LIBRARY} DIRECTORY)
endif ()
endforeach()
endif()
if(NOT TensorRT_nvinfer_plugin_LIBRARY)
foreach(search ${_TensorRT_SEARCHES})
find_library(TensorRT_nvinfer_plugin_LIBRARY NAMES nvinfer_plugin ${${search}} PATH_SUFFIXES lib)
endforeach()
endif()
mark_as_advanced(TensorRT_INCLUDE_DIR)
if(TensorRT_INCLUDE_DIR AND EXISTS "${TensorRT_INCLUDE_DIR}/NvInfer.h")
if(EXISTS "${TensorRT_INCLUDE_DIR}/NvInferVersion.h")
set(_VersionSearchFile "${TensorRT_INCLUDE_DIR}/NvInferVersion.h")
else ()
set(_VersionSearchFile "${TensorRT_INCLUDE_DIR}/NvInfer.h")
endif ()
file(STRINGS "${_VersionSearchFile}" TensorRT_MAJOR REGEX "^#define NV_TENSORRT_MAJOR [0-9]+.*$")
file(STRINGS "${_VersionSearchFile}" TensorRT_MINOR REGEX "^#define NV_TENSORRT_MINOR [0-9]+.*$")
file(STRINGS "${_VersionSearchFile}" TensorRT_PATCH REGEX "^#define NV_TENSORRT_PATCH [0-9]+.*$")
string(REGEX REPLACE "^#define NV_TENSORRT_MAJOR ([0-9]+).*$" "\\1" TensorRT_VERSION_MAJOR "${TensorRT_MAJOR}")
string(REGEX REPLACE "^#define NV_TENSORRT_MINOR ([0-9]+).*$" "\\1" TensorRT_VERSION_MINOR "${TensorRT_MINOR}")
string(REGEX REPLACE "^#define NV_TENSORRT_PATCH ([0-9]+).*$" "\\1" TensorRT_VERSION_PATCH "${TensorRT_PATCH}")
set(TensorRT_VERSION_STRING "${TensorRT_VERSION_MAJOR}.${TensorRT_VERSION_MINOR}.${TensorRT_VERSION_PATCH}")
endif()
include(FindPackageHandleStandardArgs)
FIND_PACKAGE_HANDLE_STANDARD_ARGS(TensorRT REQUIRED_VARS TensorRT_LIBRARY TensorRT_INCLUDE_DIR VERSION_VAR TensorRT_VERSION_STRING)
if(TensorRT_FOUND)
set(TensorRT_INCLUDE_DIRS ${TensorRT_INCLUDE_DIR})
if(NOT TensorRT_LIBRARIES)
set(TensorRT_LIBRARIES ${TensorRT_LIBRARY})
if (TensorRT_nvinfer_plugin_LIBRARY)
list(APPEND TensorRT_LIBRARIES ${TensorRT_nvinfer_plugin_LIBRARY})
endif()
endif()
if(NOT TARGET TensorRT::TensorRT)
add_library(TensorRT INTERFACE IMPORTED)
add_library(TensorRT::TensorRT ALIAS TensorRT)
endif()
if(NOT TARGET TensorRT::nvinfer)
add_library(TensorRT::nvinfer SHARED IMPORTED)
if (WIN32)
foreach(search ${_TensorRT_SEARCHES})
find_file(TensorRT_LIBRARY_DLL
NAMES nvinfer.dll
PATHS ${${search}}
PATH_SUFFIXES bin
)
endforeach()
set_target_properties(TensorRT::nvinfer PROPERTIES
INTERFACE_INCLUDE_DIRECTORIES "${TensorRT_INCLUDE_DIRS}"
IMPORTED_LOCATION "${TensorRT_LIBRARY_DLL}"
IMPORTED_IMPLIB "${TensorRT_LIBRARY}"
)
else()
set_target_properties(TensorRT::nvinfer PROPERTIES
INTERFACE_INCLUDE_DIRECTORIES "${TensorRT_INCLUDE_DIRS}"
IMPORTED_LOCATION "${TensorRT_LIBRARY}"
)
endif()
target_link_libraries(TensorRT INTERFACE TensorRT::nvinfer)
endif()
if(NOT TARGET TensorRT::nvinfer_plugin AND TensorRT_nvinfer_plugin_LIBRARY)
add_library(TensorRT::nvinfer_plugin SHARED IMPORTED)
if (WIN32)
foreach(search ${_TensorRT_SEARCHES})
find_file(TensorRT_nvinfer_plugin_LIBRARY_DLL
NAMES nvinfer_plugin.dll
PATHS ${${search}}
PATH_SUFFIXES bin
)
endforeach()
set_target_properties(TensorRT::nvinfer_plugin PROPERTIES
INTERFACE_INCLUDE_DIRECTORIES "${TensorRT_INCLUDE_DIRS}"
IMPORTED_LOCATION "${TensorRT_nvinfer_plugin_LIBRARY_DLL}"
IMPORTED_IMPLIB "${TensorRT_nvinfer_plugin_LIBRARY}"
)
else()
set_target_properties(TensorRT::nvinfer_plugin PROPERTIES
INTERFACE_INCLUDE_DIRECTORIES "${TensorRT_INCLUDE_DIRS}"
IMPORTED_LOCATION "${TensorRT_nvinfer_plugin_LIBRARY}"
)
endif()
target_link_libraries(TensorRT INTERFACE TensorRT::nvinfer_plugin)
endif()
endif()

@ -0,0 +1,14 @@
function(print_var var)
set(value "${${var}}")
string(LENGTH "${value}" value_length)
if(value_length GREATER 0)
math(EXPR last_index "${value_length} - 1")
string(SUBSTRING "${value}" ${last_index} ${last_index} last_char)
endif()
if(NOT "${last_char}" STREQUAL "\n")
set(value "${value}\n")
endif()
message(STATUS "${var}:\n ${value}")
endfunction()

@ -0,0 +1,113 @@
//
// Created by ubuntu on 4/7/23.
//
#ifndef POSE_NORMAL_COMMON_HPP
#define POSE_NORMAL_COMMON_HPP
#include "NvInfer.h"
#include "filesystem.hpp"
#include "opencv2/opencv.hpp"
#define CHECK(call) \
do { \
const cudaError_t error_code = call; \
if (error_code != cudaSuccess) { \
printf("CUDA Error:\n"); \
printf(" File: %s\n", __FILE__); \
printf(" Line: %d\n", __LINE__); \
printf(" Error code: %d\n", error_code); \
printf(" Error text: %s\n", cudaGetErrorString(error_code)); \
exit(1); \
} \
} while (0)
class Logger: public nvinfer1::ILogger {
public:
nvinfer1::ILogger::Severity reportableSeverity;
explicit Logger(nvinfer1::ILogger::Severity severity = nvinfer1::ILogger::Severity::kINFO):
reportableSeverity(severity)
{
}
void log(nvinfer1::ILogger::Severity severity, const char* msg) noexcept override
{
if (severity > reportableSeverity) {
return;
}
switch (severity) {
case nvinfer1::ILogger::Severity::kINTERNAL_ERROR:
std::cerr << "INTERNAL_ERROR: ";
break;
case nvinfer1::ILogger::Severity::kERROR:
std::cerr << "ERROR: ";
break;
case nvinfer1::ILogger::Severity::kWARNING:
std::cerr << "WARNING: ";
break;
case nvinfer1::ILogger::Severity::kINFO:
std::cerr << "INFO: ";
break;
default:
std::cerr << "VERBOSE: ";
break;
}
std::cerr << msg << std::endl;
}
};
inline int get_size_by_dims(const nvinfer1::Dims& dims)
{
int size = 1;
for (int i = 0; i < dims.nbDims; i++) {
size *= dims.d[i];
}
return size;
}
inline int type_to_size(const nvinfer1::DataType& dataType)
{
switch (dataType) {
case nvinfer1::DataType::kFLOAT:
return 4;
case nvinfer1::DataType::kHALF:
return 2;
case nvinfer1::DataType::kINT32:
return 4;
case nvinfer1::DataType::kINT8:
return 1;
case nvinfer1::DataType::kBOOL:
return 1;
default:
return 4;
}
}
inline static float clamp(float val, float min, float max)
{
return val > min ? (val < max ? val : max) : min;
}
namespace obb {
struct Binding {
size_t size = 1;
size_t dsize = 1;
nvinfer1::Dims dims;
std::string name;
};
struct Object {
cv::RotatedRect rect;
int label = 0;
float prob = 0.0;
};
struct PreParam {
float ratio = 1.0f;
float dw = 0.0f;
float dh = 0.0f;
float height = 0;
float width = 0;
};
} // namespace obb
#endif // POSE_NORMAL_COMMON_HPP

File diff suppressed because it is too large Load Diff

@ -0,0 +1,339 @@
//
// Created by ubuntu on 10/25/24.
//
#ifndef OBB_NORMAL_YOLOv8_obb_HPP
#define OBB_NORMAL_YOLOv8_obb_HPP
#include "NvInferPlugin.h"
#include "common.hpp"
#include <fstream>
using namespace obb;
class YOLOv8_obb {
public:
explicit YOLOv8_obb(const std::string& engine_file_path);
~YOLOv8_obb();
void make_pipe(bool warmup = true);
void copy_from_Mat(const cv::Mat& image);
void copy_from_Mat(const cv::Mat& image, cv::Size& size);
void letterbox(const cv::Mat& image, cv::Mat& out, cv::Size& size);
void infer();
void postprocess(std::vector<Object>& objs,
float score_thres = 0.25f,
float iou_thres = 0.65f,
int topk = 100,
int num_labels = 15);
static void draw_objects(const cv::Mat& image,
cv::Mat& res,
const std::vector<Object>& objs,
const std::vector<std::string>& CLASS_NAMES,
const std::vector<std::vector<unsigned int>>& COLORS);
int num_bindings;
int num_inputs = 0;
int num_outputs = 0;
std::vector<Binding> input_bindings;
std::vector<Binding> output_bindings;
std::vector<void*> host_ptrs;
std::vector<void*> device_ptrs;
PreParam pparam;
private:
nvinfer1::ICudaEngine* engine = nullptr;
nvinfer1::IRuntime* runtime = nullptr;
nvinfer1::IExecutionContext* context = nullptr;
cudaStream_t stream = nullptr;
Logger gLogger{nvinfer1::ILogger::Severity::kERROR};
};
YOLOv8_obb::YOLOv8_obb(const std::string& engine_file_path)
{
std::ifstream file(engine_file_path, std::ios::binary);
assert(file.good());
file.seekg(0, std::ios::end);
auto size = file.tellg();
file.seekg(0, std::ios::beg);
char* trtModelStream = new char[size];
assert(trtModelStream);
file.read(trtModelStream, size);
file.close();
initLibNvInferPlugins(&this->gLogger, "");
this->runtime = nvinfer1::createInferRuntime(this->gLogger);
assert(this->runtime != nullptr);
this->engine = this->runtime->deserializeCudaEngine(trtModelStream, size);
assert(this->engine != nullptr);
delete[] trtModelStream;
this->context = this->engine->createExecutionContext();
assert(this->context != nullptr);
cudaStreamCreate(&this->stream);
this->num_bindings = this->engine->getNbBindings();
for (int i = 0; i < this->num_bindings; ++i) {
Binding binding;
nvinfer1::Dims dims;
nvinfer1::DataType dtype = this->engine->getBindingDataType(i);
std::string name = this->engine->getBindingName(i);
binding.name = name;
binding.dsize = type_to_size(dtype);
bool IsInput = engine->bindingIsInput(i);
if (IsInput) {
this->num_inputs += 1;
dims = this->engine->getProfileDimensions(i, 0, nvinfer1::OptProfileSelector::kMAX);
binding.size = get_size_by_dims(dims);
binding.dims = dims;
this->input_bindings.push_back(binding);
// set max opt shape
this->context->setBindingDimensions(i, dims);
}
else {
dims = this->context->getBindingDimensions(i);
binding.size = get_size_by_dims(dims);
binding.dims = dims;
this->output_bindings.push_back(binding);
this->num_outputs += 1;
}
}
}
YOLOv8_obb::~YOLOv8_obb()
{
this->context->destroy();
this->engine->destroy();
this->runtime->destroy();
cudaStreamDestroy(this->stream);
for (auto& ptr : this->device_ptrs) {
CHECK(cudaFree(ptr));
}
for (auto& ptr : this->host_ptrs) {
CHECK(cudaFreeHost(ptr));
}
}
void YOLOv8_obb::make_pipe(bool warmup)
{
for (auto& bindings : this->input_bindings) {
void* d_ptr;
CHECK(cudaMallocAsync(&d_ptr, bindings.size * bindings.dsize, this->stream));
this->device_ptrs.push_back(d_ptr);
}
for (auto& bindings : this->output_bindings) {
void * d_ptr, *h_ptr;
size_t size = bindings.size * bindings.dsize;
CHECK(cudaMallocAsync(&d_ptr, size, this->stream));
CHECK(cudaHostAlloc(&h_ptr, size, 0));
this->device_ptrs.push_back(d_ptr);
this->host_ptrs.push_back(h_ptr);
}
if (warmup) {
for (int i = 0; i < 10; i++) {
for (auto& bindings : this->input_bindings) {
size_t size = bindings.size * bindings.dsize;
void* h_ptr = malloc(size);
memset(h_ptr, 0, size);
CHECK(cudaMemcpyAsync(this->device_ptrs[0], h_ptr, size, cudaMemcpyHostToDevice, this->stream));
free(h_ptr);
}
this->infer();
}
printf("model warmup 10 times\n");
}
}
void YOLOv8_obb::letterbox(const cv::Mat& image, cv::Mat& out, cv::Size& size)
{
const float inp_h = size.height;
const float inp_w = size.width;
float height = image.rows;
float width = image.cols;
float r = std::min(inp_h / height, inp_w / width);
int padw = std::round(width * r);
int padh = std::round(height * r);
cv::Mat tmp;
if ((int)width != padw || (int)height != padh) {
cv::resize(image, tmp, cv::Size(padw, padh));
}
else {
tmp = image.clone();
}
float dw = inp_w - padw;
float dh = inp_h - padh;
dw /= 2.0f;
dh /= 2.0f;
int top = int(std::round(dh - 0.1f));
int bottom = int(std::round(dh + 0.1f));
int left = int(std::round(dw - 0.1f));
int right = int(std::round(dw + 0.1f));
cv::copyMakeBorder(tmp, tmp, top, bottom, left, right, cv::BORDER_CONSTANT, {114, 114, 114});
cv::dnn::blobFromImage(tmp, out, 1 / 255.f, cv::Size(), cv::Scalar(0, 0, 0), true, false, CV_32F);
this->pparam.ratio = 1 / r;
this->pparam.dw = dw;
this->pparam.dh = dh;
this->pparam.height = height;
this->pparam.width = width;
;
}
void YOLOv8_obb::copy_from_Mat(const cv::Mat& image)
{
cv::Mat nchw;
auto& in_binding = this->input_bindings[0];
auto width = in_binding.dims.d[3];
auto height = in_binding.dims.d[2];
cv::Size size{width, height};
this->letterbox(image, nchw, size);
this->context->setBindingDimensions(0, nvinfer1::Dims{4, {1, 3, height, width}});
CHECK(cudaMemcpyAsync(
this->device_ptrs[0], nchw.ptr<float>(), nchw.total() * nchw.elemSize(), cudaMemcpyHostToDevice, this->stream));
}
void YOLOv8_obb::copy_from_Mat(const cv::Mat& image, cv::Size& size)
{
cv::Mat nchw;
this->letterbox(image, nchw, size);
this->context->setBindingDimensions(0, nvinfer1::Dims{4, {1, 3, size.height, size.width}});
CHECK(cudaMemcpyAsync(
this->device_ptrs[0], nchw.ptr<float>(), nchw.total() * nchw.elemSize(), cudaMemcpyHostToDevice, this->stream));
}
void YOLOv8_obb::infer()
{
this->context->enqueueV2(this->device_ptrs.data(), this->stream, nullptr);
for (int i = 0; i < this->num_outputs; i++) {
size_t osize = this->output_bindings[i].size * this->output_bindings[i].dsize;
CHECK(cudaMemcpyAsync(
this->host_ptrs[i], this->device_ptrs[i + this->num_inputs], osize, cudaMemcpyDeviceToHost, this->stream));
}
cudaStreamSynchronize(this->stream);
}
void YOLOv8_obb::postprocess(std::vector<Object>& objs, float score_thres, float iou_thres, int topk, int num_labels)
{
objs.clear();
auto num_channels = this->output_bindings[0].dims.d[1];
auto num_anchors = this->output_bindings[0].dims.d[2];
auto& dw = this->pparam.dw;
auto& dh = this->pparam.dh;
auto& width = this->pparam.width;
auto& height = this->pparam.height;
auto& ratio = this->pparam.ratio;
std::vector<cv::RotatedRect> bboxes;
std::vector<float> scores;
std::vector<int> labels;
std::vector<int> indices;
cv::Mat output = cv::Mat(num_channels, num_anchors, CV_32F, static_cast<float*>(this->host_ptrs[0]));
output = output.t();
for (int i = 0; i < num_anchors; i++) {
auto row_ptr = output.row(i).ptr<float>();
auto bboxes_ptr = row_ptr;
auto scores_ptr = row_ptr + 4;
auto max_s_ptr = std::max_element(scores_ptr, scores_ptr + num_labels);
auto angle_ptr = row_ptr + 4 + num_labels;
float score = *max_s_ptr;
if (score > score_thres) {
float x = (*bboxes_ptr++ - dw) * ratio;
float y = (*bboxes_ptr++ - dh) * ratio;
float w = (*bboxes_ptr++) * ratio;
float h = (*bboxes_ptr) * ratio;
if (w < 1.f || h < 1.f) {
continue;
}
x = clamp(x, 0.f, width);
y = clamp(y, 0.f, height);
w = clamp(w, 0.f, width);
h = clamp(h, 0.f, height);
float angle = *angle_ptr / CV_PI * 180.f;
cv::RotatedRect bbox;
bbox.center.x = x;
bbox.center.y = y;
bbox.size.width = w;
bbox.size.height = h;
bbox.angle = angle;
bboxes.push_back(bbox);
labels.push_back(std::distance(scores_ptr, max_s_ptr));
scores.push_back(score);
}
}
cv::dnn::NMSBoxes(bboxes, scores, score_thres, iou_thres, indices);
int cnt = 0;
for (auto& i : indices) {
if (cnt >= topk) {
break;
}
Object obj;
obj.rect = bboxes[i];
obj.prob = scores[i];
obj.label = labels[i];
objs.push_back(obj);
cnt += 1;
}
}
void YOLOv8_obb::draw_objects(const cv::Mat& image,
cv::Mat& res,
const std::vector<Object>& objs,
const std::vector<std::string>& CLASS_NAMES,
const std::vector<std::vector<unsigned int>>& COLORS)
{
res = image.clone();
for (auto& obj : objs) {
cv::Mat points;
cv::boxPoints(obj.rect, points);
cv::Scalar color = cv::Scalar(COLORS[obj.label][0], COLORS[obj.label][1], COLORS[obj.label][2]);
points.convertTo(points, CV_32S);
cv::polylines(res, points, true, color, 2);
char text[256];
sprintf(text, "person %.1f%%", obj.prob * 100);
int baseLine = 0;
cv::Size label_size = cv::getTextSize(text, cv::FONT_HERSHEY_SIMPLEX, 0.4, 1, &baseLine);
int x = (int)obj.rect.center.x;
int y = (int)obj.rect.center.y + 1;
if (y > res.rows)
y = res.rows;
cv::rectangle(res, cv::Rect(x, y, label_size.width, label_size.height + baseLine), {0, 0, 255}, -1);
cv::putText(res, text, cv::Point(x, y + label_size.height), cv::FONT_HERSHEY_SIMPLEX, 0.4, {255, 255, 255}, 1);
}
}
#endif // OBB_NORMAL_YOLOv8_obb_HPP

@ -0,0 +1,134 @@
//
// Created by ubuntu on 4/7/23.
//
#include "opencv2/opencv.hpp"
#include "yolov8-obb.hpp"
#include <chrono>
namespace fs = ghc::filesystem;
const std::vector<std::string> CLASS_NAMES = {"plane",
"ship",
"storage tank",
"baseball diamond",
"tennis court",
"basketball court",
"ground track field",
"harbor",
"bridge",
"large vehicle",
"small vehicle",
"helicopter",
"roundabout",
"soccer ball field",
"swimming pool"};
const std::vector<std::vector<unsigned int>> COLORS = {{0, 114, 189},
{217, 83, 25},
{237, 177, 32},
{126, 47, 142},
{119, 172, 48},
{77, 190, 238},
{162, 20, 47},
{76, 76, 76},
{153, 153, 153},
{255, 0, 0},
{255, 128, 0},
{191, 191, 0},
{0, 255, 0},
{0, 0, 255},
{170, 0, 255}};
int main(int argc, char** argv)
{
if (argc != 3) {
fprintf(stderr, "Usage: %s [engine_path] [image_path/image_dir/video_path]\n", argv[0]);
return -1;
}
// cuda:0
cudaSetDevice(0);
const std::string engine_file_path{argv[1]};
const fs::path path{argv[2]};
std::vector<std::string> imagePathList;
bool isVideo{false};
assert(argc == 3);
auto yolov8_obb = new YOLOv8_obb(engine_file_path);
yolov8_obb->make_pipe(true);
if (fs::exists(path)) {
std::string suffix = path.extension();
if (suffix == ".jpg" || suffix == ".jpeg" || suffix == ".png") {
imagePathList.push_back(path);
}
else if (suffix == ".mp4" || suffix == ".avi" || suffix == ".m4v" || suffix == ".mpeg" || suffix == ".mov"
|| suffix == ".mkv") {
isVideo = true;
}
else {
printf("suffix %s is wrong !!!\n", suffix.c_str());
std::abort();
}
}
else if (fs::is_directory(path)) {
cv::glob(path.string() + "/*.jpg", imagePathList);
}
cv::Mat res, image;
cv::Size size = cv::Size{1024, 1024};
int num_labels = 15;
int topk = 100;
float score_thres = 0.25f;
float iou_thres = 0.65f;
std::vector<Object> objs;
cv::namedWindow("result", cv::WINDOW_AUTOSIZE);
if (isVideo) {
cv::VideoCapture cap(path);
if (!cap.isOpened()) {
printf("can not open %s\n", path.c_str());
return -1;
}
while (cap.read(image)) {
objs.clear();
yolov8_obb->copy_from_Mat(image, size);
auto start = std::chrono::system_clock::now();
yolov8_obb->infer();
auto end = std::chrono::system_clock::now();
yolov8_obb->postprocess(objs, score_thres, iou_thres, topk, num_labels);
yolov8_obb->draw_objects(image, res, objs, CLASS_NAMES, COLORS);
auto tc = (double)std::chrono::duration_cast<std::chrono::microseconds>(end - start).count() / 1000.;
printf("cost %2.4lf ms\n", tc);
cv::imshow("result", res);
if (cv::waitKey(10) == 'q') {
break;
}
}
}
else {
for (auto& p : imagePathList) {
objs.clear();
image = cv::imread(p);
yolov8_obb->copy_from_Mat(image, size);
auto start = std::chrono::system_clock::now();
yolov8_obb->infer();
auto end = std::chrono::system_clock::now();
yolov8_obb->postprocess(objs, score_thres, iou_thres, topk, num_labels);
yolov8_obb->draw_objects(image, res, objs, CLASS_NAMES, COLORS);
auto tc = (double)std::chrono::duration_cast<std::chrono::microseconds>(end - start).count() / 1000.;
printf("cost %2.4lf ms\n", tc);
cv::imshow("result", res);
cv::waitKey(0);
}
}
cv::destroyAllWindows();
delete yolov8_obb;
return 0;
}

@ -88,11 +88,11 @@ int main(int argc, char** argv)
if (fs::exists(path)) {
std::string suffix = path.extension();
if (suffix == "jpg" || suffix == "jpeg" || suffix == "png") {
if (suffix == ".jpg" || suffix == ".jpeg" || suffix == ".png") {
imagePathList.push_back(path);
}
else if (suffix == "mp4" || suffix == "avi" || suffix == "m4v" || suffix == "mpeg" || suffix == "mov"
|| suffix == "mkv") {
else if (suffix == ".mp4" || suffix == ".avi" || suffix == ".m4v" || suffix == ".mpeg" || suffix == ".mov"
|| suffix == ".mkv") {
isVideo = true;
}
else {

@ -58,11 +58,11 @@ int main(int argc, char** argv)
if (IsFile(path)) {
std::string suffix = path.substr(path.find_last_of('.') + 1);
if (suffix == "jpg" || suffix == "jpeg" || suffix == "png") {
if (suffix == ".jpg" || suffix == ".jpeg" || suffix == ".png") {
imagePathList.push_back(path);
}
else if (suffix == "mp4" || suffix == "avi" || suffix == "m4v" || suffix == "mpeg" || suffix == "mov"
|| suffix == "mkv") {
else if (suffix == ".mp4" || suffix == ".avi" || suffix == ".m4v" || suffix == ".mpeg" || suffix == ".mov"
|| suffix == ".mkv") {
isVideo = true;
}
else {

@ -58,11 +58,11 @@ int main(int argc, char** argv)
if (IsFile(path)) {
std::string suffix = path.substr(path.find_last_of('.') + 1);
if (suffix == "jpg" || suffix == "jpeg" || suffix == "png") {
if (suffix == ".jpg" || suffix == ".jpeg" || suffix == ".png") {
imagePathList.push_back(path);
}
else if (suffix == "mp4" || suffix == "avi" || suffix == "m4v" || suffix == "mpeg" || suffix == "mov"
|| suffix == "mkv") {
else if (suffix == ".mp4" || suffix == ".avi" || suffix == ".m4v" || suffix == ".mpeg" || suffix == ".mov"
|| suffix == ".mkv") {
isVideo = true;
}
else {

@ -87,13 +87,16 @@ python3 infer-obb.py \
- `--out-dir` : Where to save detection results images. It will not work when use `--show` flag.
- `--device` : The CUDA deivce you use.
## Inference with c++ (Under Construction)
## Inference with c++
You can infer with c++ in [`csrc/obb/normal`](../csrc/obb/normal) .
### Build:
Please set you own librarys in [`CMakeLists.txt`](../csrc/obb/normal/CMakeLists.txt) and modify `CLASS_NAMES` in [`main.cpp`](../csrc/obb/normal/main.cpp).
Please set you own librarys in [`CMakeLists.txt`](../csrc/obb/normal/CMakeLists.txt) and modify `CLASS_NAMES`
and `COLORS` in [`main.cpp`](../csrc/obb/normal/main.cpp).
Besides, you can modify the postprocess parameters such as `num_labels` and `score_thres` and `iou_thres` and `topk`
And build:

Loading…
Cancel
Save