Merge pull request #126 from triple-Mu/dev

Support YOLOv8 pose model inference with Python
dev
triple Mu 1 year ago committed by GitHub
commit 86f44fb997
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 15
      config.py
  2. 26
      csrc/deepstream/custom_bbox_parser/nvdsparsebbox_yoloV8.cpp
  3. 49
      csrc/detect/end2end/include/common.hpp
  4. 247
      csrc/detect/end2end/include/yolov8.hpp
  5. 123
      csrc/detect/end2end/main.cpp
  6. 49
      csrc/detect/normal/include/common.hpp
  7. 300
      csrc/detect/normal/include/yolov8.hpp
  8. 123
      csrc/detect/normal/main.cpp
  9. 49
      csrc/jetson/detect/include/common.hpp
  10. 245
      csrc/jetson/detect/include/yolov8.hpp
  11. 123
      csrc/jetson/detect/main.cpp
  12. 49
      csrc/jetson/pose/include/common.hpp
  13. 263
      csrc/jetson/pose/include/yolov8-pose.hpp
  14. 41
      csrc/jetson/pose/main.cpp
  15. 49
      csrc/jetson/segment/include/common.hpp
  16. 331
      csrc/jetson/segment/include/yolov8-seg.hpp
  17. 134
      csrc/jetson/segment/main.cpp
  18. 49
      csrc/pose/normal/include/common.hpp
  19. 264
      csrc/pose/normal/include/yolov8-pose.hpp
  20. 41
      csrc/pose/normal/main.cpp
  21. 49
      csrc/segment/normal/include/common.hpp
  22. 344
      csrc/segment/normal/include/yolov8-seg.hpp
  23. 134
      csrc/segment/normal/main.cpp
  24. 49
      csrc/segment/simple/include/common.hpp
  25. 333
      csrc/segment/simple/include/yolov8-seg.hpp
  26. 134
      csrc/segment/simple/main.cpp
  27. 18
      docs/Jetson.md
  28. 48
      docs/Normal.md
  29. 68
      docs/Pose.md
  30. 53
      docs/Segment.md
  31. 4
      infer-det-without-torch.py
  32. 4
      infer-det.py
  33. 116
      infer-pose-without-torch.py
  34. 112
      infer-pose.py
  35. 4
      infer-seg-without-torch.py
  36. 7
      infer-seg.py
  37. 38
      models/torch_utils.py
  38. 36
      models/utils.py

@ -36,5 +36,20 @@ MASK_COLORS = np.array([(255, 56, 56), (255, 157, 151), (255, 112, 31),
(255, 149, 200), (255, 55, 199)],
dtype=np.float32) / 255.
KPS_COLORS = [[0, 255, 0], [0, 255, 0], [0, 255, 0], [0, 255, 0], [0, 255, 0],
[255, 128, 0], [255, 128, 0], [255, 128, 0], [255, 128, 0],
[255, 128, 0], [255, 128, 0], [51, 153, 255], [51, 153, 255],
[51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255]]
SKELETON = [[16, 14], [14, 12], [17, 15], [15, 13], [12, 13], [6, 12], [7, 13],
[6, 7], [6, 8], [7, 9], [8, 10], [9, 11], [2, 3], [1, 2], [1, 3],
[2, 4], [3, 5], [4, 6], [5, 7]]
LIMB_COLORS = [[51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255],
[255, 51, 255], [255, 51, 255], [255, 51, 255], [255, 128, 0],
[255, 128, 0], [255, 128, 0], [255, 128, 0], [255, 128, 0],
[0, 255, 0], [0, 255, 0], [0, 255, 0], [0, 255, 0], [0, 255, 0],
[0, 255, 0], [0, 255, 0]]
# alpha for segment masks
ALPHA = 0.5

@ -18,8 +18,7 @@
*/
// This is just the function prototype. The definition is written at the end of the file.
extern "C" bool NvDsInferParseCustomYoloV8(
std::vector<NvDsInferLayerInfo> const& outputLayersInfo,
extern "C" bool NvDsInferParseCustomYoloV8(std::vector<NvDsInferLayerInfo> const& outputLayersInfo,
NvDsInferNetworkInfo const& networkInfo,
NvDsInferParseDetectionParams const& detectionParams,
std::vector<NvDsInferParseObjectInfo>& objectList);
@ -30,19 +29,16 @@ static __inline__ float bbox_clip(const float& val, const float& minVal = 0.f, c
return std::max(std::min(val, (maxVal - 1)), minVal);
}
static std::vector<NvDsInferParseObjectInfo> decodeYoloV8Tensor(
const int* num_dets,
static std::vector<NvDsInferParseObjectInfo> decodeYoloV8Tensor(const int* num_dets,
const float* bboxes,
const float* scores,
const int* labels,
const unsigned int& img_w,
const unsigned int& img_h
)
const unsigned int& img_h)
{
std::vector<NvDsInferParseObjectInfo> bboxInfo;
size_t nums = num_dets[0];
for (size_t i = 0; i < nums; i++)
{
for (size_t i = 0; i < nums; i++) {
float x0 = (bboxes[i * 4]);
float y0 = (bboxes[i * 4 + 1]);
float x1 = (bboxes[i * 4 + 2]);
@ -65,8 +61,7 @@ static std::vector<NvDsInferParseObjectInfo> decodeYoloV8Tensor(
}
/* C-linkage to prevent name-mangling */
extern "C" bool NvDsInferParseCustomYoloV8(
std::vector<NvDsInferLayerInfo> const& outputLayersInfo,
extern "C" bool NvDsInferParseCustomYoloV8(std::vector<NvDsInferLayerInfo> const& outputLayersInfo,
NvDsInferNetworkInfo const& networkInfo,
NvDsInferParseDetectionParams const& detectionParams,
std::vector<NvDsInferParseObjectInfo>& objectList)
@ -74,8 +69,7 @@ extern "C" bool NvDsInferParseCustomYoloV8(
// Some assertions and error checking.
if (outputLayersInfo.empty() || outputLayersInfo.size() != 4)
{
if (outputLayersInfo.empty() || outputLayersInfo.size() != 4) {
std::cerr << "Could not find output layer in bbox parsing" << std::endl;
return false;
}
@ -92,17 +86,13 @@ extern "C" bool NvDsInferParseCustomYoloV8(
assert(scores.dims.numDims == 2);
assert(labels.dims.numDims == 2);
// Decoding the output tensor of YOLOv8 to the NvDsInferParseObjectInfo format.
std::vector<NvDsInferParseObjectInfo> objects =
decodeYoloV8Tensor(
(const int*)(num_dets.buffer),
std::vector<NvDsInferParseObjectInfo> objects = decodeYoloV8Tensor((const int*)(num_dets.buffer),
(const float*)(bboxes.buffer),
(const float*)(scores.buffer),
(const int*)(labels.buffer),
networkInfo.width,
networkInfo.height
);
networkInfo.height);
objectList.clear();
objectList = objects;

@ -4,29 +4,25 @@
#ifndef DETECT_END2END_COMMON_HPP
#define DETECT_END2END_COMMON_HPP
#include "NvInfer.h"
#include "opencv2/opencv.hpp"
#include <sys/stat.h>
#include <unistd.h>
#include "NvInfer.h"
#define CHECK(call) \
do \
{ \
do { \
const cudaError_t error_code = call; \
if (error_code != cudaSuccess) \
{ \
if (error_code != cudaSuccess) { \
printf("CUDA Error:\n"); \
printf(" File: %s\n", __FILE__); \
printf(" Line: %d\n", __LINE__); \
printf(" Error code: %d\n", error_code); \
printf(" Error text: %s\n", \
cudaGetErrorString(error_code)); \
printf(" Error text: %s\n", cudaGetErrorString(error_code)); \
exit(1); \
} \
} while (0)
class Logger : public nvinfer1::ILogger
{
class Logger: public nvinfer1::ILogger {
public:
nvinfer1::ILogger::Severity reportableSeverity;
@ -37,12 +33,10 @@ public:
void log(nvinfer1::ILogger::Severity severity, const char* msg) noexcept override
{
if (severity > reportableSeverity)
{
if (severity > reportableSeverity) {
return;
}
switch (severity)
{
switch (severity) {
case nvinfer1::ILogger::Severity::kINTERNAL_ERROR:
std::cerr << "INTERNAL_ERROR: ";
break;
@ -66,8 +60,7 @@ public:
inline int get_size_by_dims(const nvinfer1::Dims& dims)
{
int size = 1;
for (int i = 0; i < dims.nbDims; i++)
{
for (int i = 0; i < dims.nbDims; i++) {
size *= dims.d[i];
}
return size;
@ -75,8 +68,7 @@ inline int get_size_by_dims(const nvinfer1::Dims& dims)
inline int type_to_size(const nvinfer1::DataType& dataType)
{
switch (dataType)
{
switch (dataType) {
case nvinfer1::DataType::kFLOAT:
return 4;
case nvinfer1::DataType::kHALF:
@ -99,8 +91,7 @@ inline static float clamp(float val, float min, float max)
inline bool IsPathExist(const std::string& path)
{
if (access(path.c_str(), 0) == F_OK)
{
if (access(path.c_str(), 0) == F_OK) {
return true;
}
return false;
@ -108,8 +99,7 @@ inline bool IsPathExist(const std::string& path)
inline bool IsFile(const std::string& path)
{
if (!IsPathExist(path))
{
if (!IsPathExist(path)) {
printf("%s:%d %s not exist\n", __FILE__, __LINE__, path.c_str());
return false;
}
@ -119,38 +109,33 @@ inline bool IsFile(const std::string& path)
inline bool IsFolder(const std::string& path)
{
if (!IsPathExist(path))
{
if (!IsPathExist(path)) {
return false;
}
struct stat buffer;
return (stat(path.c_str(), &buffer) == 0 && S_ISDIR(buffer.st_mode));
}
namespace det
{
struct Binding
{
namespace det {
struct Binding {
size_t size = 1;
size_t dsize = 1;
nvinfer1::Dims dims;
std::string name;
};
struct Object
{
struct Object {
cv::Rect_<float> rect;
int label = 0;
float prob = 0.0;
};
struct PreParam
{
struct PreParam {
float ratio = 1.0f;
float dw = 0.0f;
float dh = 0.0f;
float height = 0;
float width = 0;
};
}
} // namespace det
#endif // DETECT_END2END_COMMON_HPP

@ -3,13 +3,12 @@
//
#ifndef DETECT_END2END_YOLOV8_HPP
#define DETECT_END2END_YOLOV8_HPP
#include "fstream"
#include "common.hpp"
#include "NvInferPlugin.h"
#include "common.hpp"
#include "fstream"
using namespace det;
class YOLOv8
{
class YOLOv8 {
public:
explicit YOLOv8(const std::string& engine_file_path);
~YOLOv8();
@ -17,20 +16,14 @@ public:
void make_pipe(bool warmup = true);
void copy_from_Mat(const cv::Mat& image);
void copy_from_Mat(const cv::Mat& image, cv::Size& size);
void letterbox(
const cv::Mat& image,
cv::Mat& out,
cv::Size& size
);
void letterbox(const cv::Mat& image, cv::Mat& out, cv::Size& size);
void infer();
void postprocess(std::vector<Object>& objs);
static void draw_objects(
const cv::Mat& image,
static void draw_objects(const cv::Mat& image,
cv::Mat& res,
const std::vector<Object>& objs,
const std::vector<std::string>& CLASS_NAMES,
const std::vector<std::vector<unsigned int>>& COLORS
);
const std::vector<std::vector<unsigned int>>& COLORS);
int num_bindings;
int num_inputs = 0;
int num_outputs = 0;
@ -40,13 +33,13 @@ public:
std::vector<void*> device_ptrs;
PreParam pparam;
private:
nvinfer1::ICudaEngine* engine = nullptr;
nvinfer1::IRuntime* runtime = nullptr;
nvinfer1::IExecutionContext* context = nullptr;
cudaStream_t stream = nullptr;
Logger gLogger{nvinfer1::ILogger::Severity::kERROR};
};
YOLOv8::YOLOv8(const std::string& engine_file_path)
@ -73,8 +66,7 @@ YOLOv8::YOLOv8(const std::string& engine_file_path)
cudaStreamCreate(&this->stream);
this->num_bindings = this->engine->getNbBindings();
for (int i = 0; i < this->num_bindings; ++i)
{
for (int i = 0; i < this->num_bindings; ++i) {
Binding binding;
nvinfer1::Dims dims;
nvinfer1::DataType dtype = this->engine->getBindingDataType(i);
@ -83,22 +75,16 @@ YOLOv8::YOLOv8(const std::string& engine_file_path)
binding.dsize = type_to_size(dtype);
bool IsInput = engine->bindingIsInput(i);
if (IsInput)
{
if (IsInput) {
this->num_inputs += 1;
dims = this->engine->getProfileDimensions(
i,
0,
nvinfer1::OptProfileSelector::kMAX);
dims = this->engine->getProfileDimensions(i, 0, nvinfer1::OptProfileSelector::kMAX);
binding.size = get_size_by_dims(dims);
binding.dims = dims;
this->input_bindings.push_back(binding);
// set max opt shape
this->context->setBindingDimensions(i, dims);
}
else
{
else {
dims = this->context->getBindingDimensions(i);
binding.size = get_size_by_dims(dims);
binding.dims = dims;
@ -106,7 +92,6 @@ YOLOv8::YOLOv8(const std::string& engine_file_path)
this->num_outputs += 1;
}
}
}
YOLOv8::~YOLOv8()
@ -115,71 +100,44 @@ YOLOv8::~YOLOv8()
this->engine->destroy();
this->runtime->destroy();
cudaStreamDestroy(this->stream);
for (auto& ptr : this->device_ptrs)
{
for (auto& ptr : this->device_ptrs) {
CHECK(cudaFree(ptr));
}
for (auto& ptr : this->host_ptrs)
{
for (auto& ptr : this->host_ptrs) {
CHECK(cudaFreeHost(ptr));
}
}
void YOLOv8::make_pipe(bool warmup)
{
for (auto& bindings : this->input_bindings)
{
for (auto& bindings : this->input_bindings) {
void* d_ptr;
CHECK(cudaMallocAsync(
&d_ptr,
bindings.size * bindings.dsize,
this->stream)
);
CHECK(cudaMallocAsync(&d_ptr, bindings.size * bindings.dsize, this->stream));
this->device_ptrs.push_back(d_ptr);
}
for (auto& bindings : this->output_bindings)
{
for (auto& bindings : this->output_bindings) {
void * d_ptr, *h_ptr;
size_t size = bindings.size * bindings.dsize;
CHECK(cudaMallocAsync(
&d_ptr,
size,
this->stream)
);
CHECK(cudaHostAlloc(
&h_ptr,
size,
0)
);
CHECK(cudaMallocAsync(&d_ptr, size, this->stream));
CHECK(cudaHostAlloc(&h_ptr, size, 0));
this->device_ptrs.push_back(d_ptr);
this->host_ptrs.push_back(h_ptr);
}
if (warmup)
{
for (int i = 0; i < 10; i++)
{
for (auto& bindings : this->input_bindings)
{
if (warmup) {
for (int i = 0; i < 10; i++) {
for (auto& bindings : this->input_bindings) {
size_t size = bindings.size * bindings.dsize;
void* h_ptr = malloc(size);
memset(h_ptr, 0, size);
CHECK(cudaMemcpyAsync(
this->device_ptrs[0],
h_ptr,
size,
cudaMemcpyHostToDevice,
this->stream)
);
CHECK(cudaMemcpyAsync(this->device_ptrs[0], h_ptr, size, cudaMemcpyHostToDevice, this->stream));
free(h_ptr);
}
this->infer();
}
printf("model warmup 10 times\n");
}
}
@ -195,16 +153,10 @@ void YOLOv8::letterbox(const cv::Mat& image, cv::Mat& out, cv::Size& size)
int padh = std::round(height * r);
cv::Mat tmp;
if ((int)width != padw || (int)height != padh)
{
cv::resize(
image,
tmp,
cv::Size(padw, padh)
);
if ((int)width != padw || (int)height != padh) {
cv::resize(image, tmp, cv::Size(padw, padh));
}
else
{
else {
tmp = image.clone();
}
@ -218,31 +170,15 @@ void YOLOv8::letterbox(const cv::Mat& image, cv::Mat& out, cv::Size& size)
int left = int(std::round(dw - 0.1f));
int right = int(std::round(dw + 0.1f));
cv::copyMakeBorder(
tmp,
tmp,
top,
bottom,
left,
right,
cv::BORDER_CONSTANT,
{ 114, 114, 114 }
);
cv::dnn::blobFromImage(tmp,
out,
1 / 255.f,
cv::Size(),
cv::Scalar(0, 0, 0),
true,
false,
CV_32F
);
cv::copyMakeBorder(tmp, tmp, top, bottom, left, right, cv::BORDER_CONSTANT, {114, 114, 114});
cv::dnn::blobFromImage(tmp, out, 1 / 255.f, cv::Size(), cv::Scalar(0, 0, 0), true, false, CV_32F);
this->pparam.ratio = 1 / r;
this->pparam.dw = dw;
this->pparam.dh = dh;
this->pparam.height = height;
this->pparam.width = width;;
this->pparam.width = width;
;
}
void YOLOv8::copy_from_Mat(const cv::Mat& image)
@ -252,75 +188,33 @@ void YOLOv8::copy_from_Mat(const cv::Mat& image)
auto width = in_binding.dims.d[3];
auto height = in_binding.dims.d[2];
cv::Size size{width, height};
this->letterbox(
image,
nchw,
size
);
this->context->setBindingDimensions(
0,
nvinfer1::Dims
{
4,
{ 1, 3, height, width }
}
);
this->letterbox(image, nchw, size);
this->context->setBindingDimensions(0, nvinfer1::Dims{4, {1, 3, height, width}});
CHECK(cudaMemcpyAsync(
this->device_ptrs[0],
nchw.ptr<float>(),
nchw.total() * nchw.elemSize(),
cudaMemcpyHostToDevice,
this->stream)
);
this->device_ptrs[0], nchw.ptr<float>(), nchw.total() * nchw.elemSize(), cudaMemcpyHostToDevice, this->stream));
}
void YOLOv8::copy_from_Mat(const cv::Mat& image, cv::Size& size)
{
cv::Mat nchw;
this->letterbox(
image,
nchw,
size
);
this->context->setBindingDimensions(
0,
nvinfer1::Dims
{ 4,
{ 1, 3, size.height, size.width }
}
);
this->letterbox(image, nchw, size);
this->context->setBindingDimensions(0, nvinfer1::Dims{4, {1, 3, size.height, size.width}});
CHECK(cudaMemcpyAsync(
this->device_ptrs[0],
nchw.ptr<float>(),
nchw.total() * nchw.elemSize(),
cudaMemcpyHostToDevice,
this->stream)
);
this->device_ptrs[0], nchw.ptr<float>(), nchw.total() * nchw.elemSize(), cudaMemcpyHostToDevice, this->stream));
}
void YOLOv8::infer()
{
this->context->enqueueV2(
this->device_ptrs.data(),
this->stream,
nullptr
);
for (int i = 0; i < this->num_outputs; i++)
{
this->context->enqueueV2(this->device_ptrs.data(), this->stream, nullptr);
for (int i = 0; i < this->num_outputs; i++) {
size_t osize = this->output_bindings[i].size * this->output_bindings[i].dsize;
CHECK(cudaMemcpyAsync(this->host_ptrs[i],
this->device_ptrs[i + this->num_inputs],
osize,
cudaMemcpyDeviceToHost,
this->stream)
);
CHECK(cudaMemcpyAsync(
this->host_ptrs[i], this->device_ptrs[i + this->num_inputs], osize, cudaMemcpyDeviceToHost, this->stream));
}
cudaStreamSynchronize(this->stream);
}
void YOLOv8::postprocess(std::vector<Object>& objs)
@ -335,8 +229,7 @@ void YOLOv8::postprocess(std::vector<Object>& objs)
auto& width = this->pparam.width;
auto& height = this->pparam.height;
auto& ratio = this->pparam.ratio;
for (int i = 0; i < num_dets[0]; i++)
{
for (int i = 0; i < num_dets[0]; i++) {
float* ptr = boxes + i * 4;
float x0 = *ptr++ - dw;
@ -359,45 +252,22 @@ void YOLOv8::postprocess(std::vector<Object>& objs)
}
}
void YOLOv8::draw_objects(
const cv::Mat& image,
void YOLOv8::draw_objects(const cv::Mat& image,
cv::Mat& res,
const std::vector<Object>& objs,
const std::vector<std::string>& CLASS_NAMES,
const std::vector<std::vector<unsigned int>>& COLORS
)
const std::vector<std::vector<unsigned int>>& COLORS)
{
res = image.clone();
for (auto& obj : objs)
{
cv::Scalar color = cv::Scalar(
COLORS[obj.label][0],
COLORS[obj.label][1],
COLORS[obj.label][2]
);
cv::rectangle(
res,
obj.rect,
color,
2
);
for (auto& obj : objs) {
cv::Scalar color = cv::Scalar(COLORS[obj.label][0], COLORS[obj.label][1], COLORS[obj.label][2]);
cv::rectangle(res, obj.rect, color, 2);
char text[256];
sprintf(
text,
"%s %.1f%%",
CLASS_NAMES[obj.label].c_str(),
obj.prob * 100
);
sprintf(text, "%s %.1f%%", CLASS_NAMES[obj.label].c_str(), obj.prob * 100);
int baseLine = 0;
cv::Size label_size = cv::getTextSize(
text,
cv::FONT_HERSHEY_SIMPLEX,
0.4,
1,
&baseLine
);
cv::Size label_size = cv::getTextSize(text, cv::FONT_HERSHEY_SIMPLEX, 0.4, 1, &baseLine);
int x = (int)obj.rect.x;
int y = (int)obj.rect.y + 1;
@ -405,22 +275,9 @@ void YOLOv8::draw_objects(
if (y > res.rows)
y = res.rows;
cv::rectangle(
res,
cv::Rect(x, y, label_size.width, label_size.height + baseLine),
{ 0, 0, 255 },
-1
);
cv::putText(
res,
text,
cv::Point(x, y + label_size.height),
cv::FONT_HERSHEY_SIMPLEX,
0.4,
{ 255, 255, 255 },
1
);
cv::rectangle(res, cv::Rect(x, y, label_size.width, label_size.height + baseLine), {0, 0, 255}, -1);
cv::putText(res, text, cv::Point(x, y + label_size.height), cv::FONT_HERSHEY_SIMPLEX, 0.4, {255, 255, 255}, 1);
}
}
#endif // DETECT_END2END_YOLOV8_HPP

@ -2,56 +2,38 @@
// Created by ubuntu on 1/20/23.
//
#include "chrono"
#include "yolov8.hpp"
#include "opencv2/opencv.hpp"
#include "yolov8.hpp"
const std::vector<std::string> CLASS_NAMES = {
"person", "bicycle", "car", "motorcycle", "airplane", "bus",
"train", "truck", "boat", "traffic light", "fire hydrant",
"stop sign", "parking meter", "bench", "bird", "cat",
"dog", "horse", "sheep", "cow", "elephant",
"bear", "zebra", "giraffe", "backpack", "umbrella",
"handbag", "tie", "suitcase", "frisbee", "skis",
"snowboard", "sports ball", "kite", "baseball bat", "baseball glove",
"skateboard", "surfboard", "tennis racket", "bottle", "wine glass",
"cup", "fork", "knife", "spoon", "bowl",
"banana", "apple", "sandwich", "orange", "broccoli",
"carrot", "hot dog", "pizza", "donut", "cake",
"chair", "couch", "potted plant", "bed", "dining table",
"toilet", "tv", "laptop", "mouse", "remote",
"keyboard", "cell phone", "microwave", "oven",
"toaster", "sink", "refrigerator", "book", "clock", "vase",
"scissors", "teddy bear", "hair drier", "toothbrush" };
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train",
"truck", "boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench",
"bird", "cat", "dog", "horse", "sheep", "cow", "elephant",
"bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie",
"suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat",
"baseball glove", "skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup",
"fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich",
"orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake",
"chair", "couch", "potted plant", "bed", "dining table", "toilet", "tv",
"laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven",
"toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors",
"teddy bear", "hair drier", "toothbrush"};
const std::vector<std::vector<unsigned int>> COLORS = {
{ 0, 114, 189 }, { 217, 83, 25 }, { 237, 177, 32 },
{ 126, 47, 142 }, { 119, 172, 48 }, { 77, 190, 238 },
{ 162, 20, 47 }, { 76, 76, 76 }, { 153, 153, 153 },
{ 255, 0, 0 }, { 255, 128, 0 }, { 191, 191, 0 },
{ 0, 255, 0 }, { 0, 0, 255 }, { 170, 0, 255 },
{ 85, 85, 0 }, { 85, 170, 0 }, { 85, 255, 0 },
{ 170, 85, 0 }, { 170, 170, 0 }, { 170, 255, 0 },
{ 255, 85, 0 }, { 255, 170, 0 }, { 255, 255, 0 },
{ 0, 85, 128 }, { 0, 170, 128 }, { 0, 255, 128 },
{ 85, 0, 128 }, { 85, 85, 128 }, { 85, 170, 128 },
{ 85, 255, 128 }, { 170, 0, 128 }, { 170, 85, 128 },
{ 170, 170, 128 }, { 170, 255, 128 }, { 255, 0, 128 },
{ 255, 85, 128 }, { 255, 170, 128 }, { 255, 255, 128 },
{ 0, 85, 255 }, { 0, 170, 255 }, { 0, 255, 255 },
{ 85, 0, 255 }, { 85, 85, 255 }, { 85, 170, 255 },
{ 85, 255, 255 }, { 170, 0, 255 }, { 170, 85, 255 },
{ 170, 170, 255 }, { 170, 255, 255 }, { 255, 0, 255 },
{ 255, 85, 255 }, { 255, 170, 255 }, { 85, 0, 0 },
{ 128, 0, 0 }, { 170, 0, 0 }, { 212, 0, 0 },
{ 255, 0, 0 }, { 0, 43, 0 }, { 0, 85, 0 },
{ 0, 128, 0 }, { 0, 170, 0 }, { 0, 212, 0 },
{ 0, 255, 0 }, { 0, 0, 43 }, { 0, 0, 85 },
{ 0, 0, 128 }, { 0, 0, 170 }, { 0, 0, 212 },
{ 0, 0, 255 }, { 0, 0, 0 }, { 36, 36, 36 },
{ 73, 73, 73 }, { 109, 109, 109 }, { 146, 146, 146 },
{ 182, 182, 182 }, { 219, 219, 219 }, { 0, 114, 189 },
{ 80, 183, 189 }, { 128, 128, 0 }
};
{0, 114, 189}, {217, 83, 25}, {237, 177, 32}, {126, 47, 142}, {119, 172, 48}, {77, 190, 238},
{162, 20, 47}, {76, 76, 76}, {153, 153, 153}, {255, 0, 0}, {255, 128, 0}, {191, 191, 0},
{0, 255, 0}, {0, 0, 255}, {170, 0, 255}, {85, 85, 0}, {85, 170, 0}, {85, 255, 0},
{170, 85, 0}, {170, 170, 0}, {170, 255, 0}, {255, 85, 0}, {255, 170, 0}, {255, 255, 0},
{0, 85, 128}, {0, 170, 128}, {0, 255, 128}, {85, 0, 128}, {85, 85, 128}, {85, 170, 128},
{85, 255, 128}, {170, 0, 128}, {170, 85, 128}, {170, 170, 128}, {170, 255, 128}, {255, 0, 128},
{255, 85, 128}, {255, 170, 128}, {255, 255, 128}, {0, 85, 255}, {0, 170, 255}, {0, 255, 255},
{85, 0, 255}, {85, 85, 255}, {85, 170, 255}, {85, 255, 255}, {170, 0, 255}, {170, 85, 255},
{170, 170, 255}, {170, 255, 255}, {255, 0, 255}, {255, 85, 255}, {255, 170, 255}, {85, 0, 0},
{128, 0, 0}, {170, 0, 0}, {212, 0, 0}, {255, 0, 0}, {0, 43, 0}, {0, 85, 0},
{0, 128, 0}, {0, 170, 0}, {0, 212, 0}, {0, 255, 0}, {0, 0, 43}, {0, 0, 85},
{0, 0, 128}, {0, 0, 170}, {0, 0, 212}, {0, 0, 255}, {0, 0, 0}, {36, 36, 36},
{73, 73, 73}, {109, 109, 109}, {146, 146, 146}, {182, 182, 182}, {219, 219, 219}, {0, 114, 189},
{80, 183, 189}, {128, 128, 0}};
int main(int argc, char** argv)
{
@ -69,36 +51,21 @@ int main(int argc, char** argv)
auto yolov8 = new YOLOv8(engine_file_path);
yolov8->make_pipe(true);
if (IsFile(path))
{
if (IsFile(path)) {
std::string suffix = path.substr(path.find_last_of('.') + 1);
if (
suffix == "jpg" ||
suffix == "jpeg" ||
suffix == "png"
)
{
if (suffix == "jpg" || suffix == "jpeg" || suffix == "png") {
imagePathList.push_back(path);
}
else if (
suffix == "mp4" ||
suffix == "avi" ||
suffix == "m4v" ||
suffix == "mpeg" ||
suffix == "mov" ||
suffix == "mkv"
)
{
else if (suffix == "mp4" || suffix == "avi" || suffix == "m4v" || suffix == "mpeg" || suffix == "mov"
|| suffix == "mkv") {
isVideo = true;
}
else
{
else {
printf("suffix %s is wrong !!!\n", suffix.c_str());
std::abort();
}
}
else if (IsFolder(path))
{
else if (IsFolder(path)) {
cv::glob(path + "/*.jpg", imagePathList);
}
@ -108,17 +75,14 @@ int main(int argc, char** argv)
cv::namedWindow("result", cv::WINDOW_AUTOSIZE);
if (isVideo)
{
if (isVideo) {
cv::VideoCapture cap(path);
if (!cap.isOpened())
{
if (!cap.isOpened()) {
printf("can not open %s\n", path.c_str());
return -1;
}
while (cap.read(image))
{
while (cap.read(image)) {
objs.clear();
yolov8->copy_from_Mat(image, size);
auto start = std::chrono::system_clock::now();
@ -126,20 +90,16 @@ int main(int argc, char** argv)
auto end = std::chrono::system_clock::now();
yolov8->postprocess(objs);
yolov8->draw_objects(image, res, objs, CLASS_NAMES, COLORS);
auto tc = (double)
std::chrono::duration_cast<std::chrono::microseconds>(end - start).count() / 1000.;
auto tc = (double)std::chrono::duration_cast<std::chrono::microseconds>(end - start).count() / 1000.;
printf("cost %2.4lf ms\n", tc);
cv::imshow("result", res);
if (cv::waitKey(10) == 'q')
{
if (cv::waitKey(10) == 'q') {
break;
}
}
}
else
{
for (auto& path : imagePathList)
{
else {
for (auto& path : imagePathList) {
objs.clear();
image = cv::imread(path);
yolov8->copy_from_Mat(image, size);
@ -148,8 +108,7 @@ int main(int argc, char** argv)
auto end = std::chrono::system_clock::now();
yolov8->postprocess(objs);
yolov8->draw_objects(image, res, objs, CLASS_NAMES, COLORS);
auto tc = (double)
std::chrono::duration_cast<std::chrono::microseconds>(end - start).count() / 1000.;
auto tc = (double)std::chrono::duration_cast<std::chrono::microseconds>(end - start).count() / 1000.;
printf("cost %2.4lf ms\n", tc);
cv::imshow("result", res);
cv::waitKey(0);

@ -4,29 +4,25 @@
#ifndef DETECT_NORMAL_COMMON_HPP
#define DETECT_NORMAL_COMMON_HPP
#include "NvInfer.h"
#include "opencv2/opencv.hpp"
#include <sys/stat.h>
#include <unistd.h>
#include "NvInfer.h"
#define CHECK(call) \
do \
{ \
do { \
const cudaError_t error_code = call; \
if (error_code != cudaSuccess) \
{ \
if (error_code != cudaSuccess) { \
printf("CUDA Error:\n"); \
printf(" File: %s\n", __FILE__); \
printf(" Line: %d\n", __LINE__); \
printf(" Error code: %d\n", error_code); \
printf(" Error text: %s\n", \
cudaGetErrorString(error_code)); \
printf(" Error text: %s\n", cudaGetErrorString(error_code)); \
exit(1); \
} \
} while (0)
class Logger : public nvinfer1::ILogger
{
class Logger: public nvinfer1::ILogger {
public:
nvinfer1::ILogger::Severity reportableSeverity;
@ -37,12 +33,10 @@ public:
void log(nvinfer1::ILogger::Severity severity, const char* msg) noexcept override
{
if (severity > reportableSeverity)
{
if (severity > reportableSeverity) {
return;
}
switch (severity)
{
switch (severity) {
case nvinfer1::ILogger::Severity::kINTERNAL_ERROR:
std::cerr << "INTERNAL_ERROR: ";
break;
@ -66,8 +60,7 @@ public:
inline int get_size_by_dims(const nvinfer1::Dims& dims)
{
int size = 1;
for (int i = 0; i < dims.nbDims; i++)
{
for (int i = 0; i < dims.nbDims; i++) {
size *= dims.d[i];
}
return size;
@ -75,8 +68,7 @@ inline int get_size_by_dims(const nvinfer1::Dims& dims)
inline int type_to_size(const nvinfer1::DataType& dataType)
{
switch (dataType)
{
switch (dataType) {
case nvinfer1::DataType::kFLOAT:
return 4;
case nvinfer1::DataType::kHALF:
@ -99,8 +91,7 @@ inline static float clamp(float val, float min, float max)
inline bool IsPathExist(const std::string& path)
{
if (access(path.c_str(), 0) == F_OK)
{
if (access(path.c_str(), 0) == F_OK) {
return true;
}
return false;
@ -108,8 +99,7 @@ inline bool IsPathExist(const std::string& path)
inline bool IsFile(const std::string& path)
{
if (!IsPathExist(path))
{
if (!IsPathExist(path)) {
printf("%s:%d %s not exist\n", __FILE__, __LINE__, path.c_str());
return false;
}
@ -119,38 +109,33 @@ inline bool IsFile(const std::string& path)
inline bool IsFolder(const std::string& path)
{
if (!IsPathExist(path))
{
if (!IsPathExist(path)) {
return false;
}
struct stat buffer;
return (stat(path.c_str(), &buffer) == 0 && S_ISDIR(buffer.st_mode));
}
namespace det
{
struct Binding
{
namespace det {
struct Binding {
size_t size = 1;
size_t dsize = 1;
nvinfer1::Dims dims;
std::string name;
};
struct Object
{
struct Object {
cv::Rect_<float> rect;
int label = 0;
float prob = 0.0;
};
struct PreParam
{
struct PreParam {
float ratio = 1.0f;
float dw = 0.0f;
float dh = 0.0f;
float height = 0;
float width = 0;
};
}
} // namespace det
#endif // DETECT_NORMAL_COMMON_HPP

@ -3,13 +3,12 @@
//
#ifndef DETECT_NORMAL_YOLOV8_HPP
#define DETECT_NORMAL_YOLOV8_HPP
#include "fstream"
#include "common.hpp"
#include "NvInferPlugin.h"
#include "common.hpp"
#include "fstream"
using namespace det;
class YOLOv8
{
class YOLOv8 {
public:
explicit YOLOv8(const std::string& engine_file_path);
~YOLOv8();
@ -17,26 +16,18 @@ public:
void make_pipe(bool warmup = true);
void copy_from_Mat(const cv::Mat& image);
void copy_from_Mat(const cv::Mat& image, cv::Size& size);
void letterbox(
const cv::Mat& image,
cv::Mat& out,
cv::Size& size
);
void letterbox(const cv::Mat& image, cv::Mat& out, cv::Size& size);
void infer();
void postprocess(
std::vector<Object>& objs,
void postprocess(std::vector<Object>& objs,
float score_thres = 0.25f,
float iou_thres = 0.65f,
int topk = 100,
int num_labels = 80
);
static void draw_objects(
const cv::Mat& image,
int num_labels = 80);
static void draw_objects(const cv::Mat& image,
cv::Mat& res,
const std::vector<Object>& objs,
const std::vector<std::string>& CLASS_NAMES,
const std::vector<std::vector<unsigned int>>& COLORS
);
const std::vector<std::vector<unsigned int>>& COLORS);
int num_bindings;
int num_inputs = 0;
int num_outputs = 0;
@ -46,13 +37,13 @@ public:
std::vector<void*> device_ptrs;
PreParam pparam;
private:
nvinfer1::ICudaEngine* engine = nullptr;
nvinfer1::IRuntime* runtime = nullptr;
nvinfer1::IExecutionContext* context = nullptr;
cudaStream_t stream = nullptr;
Logger gLogger{nvinfer1::ILogger::Severity::kERROR};
};
YOLOv8::YOLOv8(const std::string& engine_file_path)
@ -79,8 +70,7 @@ YOLOv8::YOLOv8(const std::string& engine_file_path)
cudaStreamCreate(&this->stream);
this->num_bindings = this->engine->getNbBindings();
for (int i = 0; i < this->num_bindings; ++i)
{
for (int i = 0; i < this->num_bindings; ++i) {
Binding binding;
nvinfer1::Dims dims;
nvinfer1::DataType dtype = this->engine->getBindingDataType(i);
@ -89,22 +79,16 @@ YOLOv8::YOLOv8(const std::string& engine_file_path)
binding.dsize = type_to_size(dtype);
bool IsInput = engine->bindingIsInput(i);
if (IsInput)
{
if (IsInput) {
this->num_inputs += 1;
dims = this->engine->getProfileDimensions(
i,
0,
nvinfer1::OptProfileSelector::kMAX);
dims = this->engine->getProfileDimensions(i, 0, nvinfer1::OptProfileSelector::kMAX);
binding.size = get_size_by_dims(dims);
binding.dims = dims;
this->input_bindings.push_back(binding);
// set max opt shape
this->context->setBindingDimensions(i, dims);
}
else
{
else {
dims = this->context->getBindingDimensions(i);
binding.size = get_size_by_dims(dims);
binding.dims = dims;
@ -112,7 +96,6 @@ YOLOv8::YOLOv8(const std::string& engine_file_path)
this->num_outputs += 1;
}
}
}
YOLOv8::~YOLOv8()
@ -121,79 +104,48 @@ YOLOv8::~YOLOv8()
this->engine->destroy();
this->runtime->destroy();
cudaStreamDestroy(this->stream);
for (auto& ptr : this->device_ptrs)
{
for (auto& ptr : this->device_ptrs) {
CHECK(cudaFree(ptr));
}
for (auto& ptr : this->host_ptrs)
{
for (auto& ptr : this->host_ptrs) {
CHECK(cudaFreeHost(ptr));
}
}
void YOLOv8::make_pipe(bool warmup)
{
for (auto& bindings : this->input_bindings)
{
for (auto& bindings : this->input_bindings) {
void* d_ptr;
CHECK(cudaMallocAsync(
&d_ptr,
bindings.size * bindings.dsize,
this->stream)
);
CHECK(cudaMallocAsync(&d_ptr, bindings.size * bindings.dsize, this->stream));
this->device_ptrs.push_back(d_ptr);
}
for (auto& bindings : this->output_bindings)
{
for (auto& bindings : this->output_bindings) {
void * d_ptr, *h_ptr;
size_t size = bindings.size * bindings.dsize;
CHECK(cudaMallocAsync(
&d_ptr,
size,
this->stream)
);
CHECK(cudaHostAlloc(
&h_ptr,
size,
0)
);
CHECK(cudaMallocAsync(&d_ptr, size, this->stream));
CHECK(cudaHostAlloc(&h_ptr, size, 0));
this->device_ptrs.push_back(d_ptr);
this->host_ptrs.push_back(h_ptr);
}
if (warmup)
{
for (int i = 0; i < 10; i++)
{
for (auto& bindings : this->input_bindings)
{
if (warmup) {
for (int i = 0; i < 10; i++) {
for (auto& bindings : this->input_bindings) {
size_t size = bindings.size * bindings.dsize;
void* h_ptr = malloc(size);
memset(h_ptr, 0, size);
CHECK(cudaMemcpyAsync(
this->device_ptrs[0],
h_ptr,
size,
cudaMemcpyHostToDevice,
this->stream)
);
CHECK(cudaMemcpyAsync(this->device_ptrs[0], h_ptr, size, cudaMemcpyHostToDevice, this->stream));
free(h_ptr);
}
this->infer();
}
printf("model warmup 10 times\n");
}
}
void YOLOv8::letterbox(
const cv::Mat& image,
cv::Mat& out,
cv::Size& size
)
void YOLOv8::letterbox(const cv::Mat& image, cv::Mat& out, cv::Size& size)
{
const float inp_h = size.height;
const float inp_w = size.width;
@ -205,16 +157,10 @@ void YOLOv8::letterbox(
int padh = std::round(height * r);
cv::Mat tmp;
if ((int)width != padw || (int)height != padh)
{
cv::resize(
image,
tmp,
cv::Size(padw, padh)
);
if ((int)width != padw || (int)height != padh) {
cv::resize(image, tmp, cv::Size(padw, padh));
}
else
{
else {
tmp = image.clone();
}
@ -228,31 +174,15 @@ void YOLOv8::letterbox(
int left = int(std::round(dw - 0.1f));
int right = int(std::round(dw + 0.1f));
cv::copyMakeBorder(
tmp,
tmp,
top,
bottom,
left,
right,
cv::BORDER_CONSTANT,
{ 114, 114, 114 }
);
cv::dnn::blobFromImage(tmp,
out,
1 / 255.f,
cv::Size(),
cv::Scalar(0, 0, 0),
true,
false,
CV_32F
);
cv::copyMakeBorder(tmp, tmp, top, bottom, left, right, cv::BORDER_CONSTANT, {114, 114, 114});
cv::dnn::blobFromImage(tmp, out, 1 / 255.f, cv::Size(), cv::Scalar(0, 0, 0), true, false, CV_32F);
this->pparam.ratio = 1 / r;
this->pparam.dw = dw;
this->pparam.dh = dh;
this->pparam.height = height;
this->pparam.width = width;;
this->pparam.width = width;
;
}
void YOLOv8::copy_from_Mat(const cv::Mat& image)
@ -262,84 +192,36 @@ void YOLOv8::copy_from_Mat(const cv::Mat& image)
auto width = in_binding.dims.d[3];
auto height = in_binding.dims.d[2];
cv::Size size{width, height};
this->letterbox(
image,
nchw,
size
);
this->context->setBindingDimensions(
0,
nvinfer1::Dims
{
4,
{ 1, 3, height, width }
}
);
this->letterbox(image, nchw, size);
this->context->setBindingDimensions(0, nvinfer1::Dims{4, {1, 3, height, width}});
CHECK(cudaMemcpyAsync(
this->device_ptrs[0],
nchw.ptr<float>(),
nchw.total() * nchw.elemSize(),
cudaMemcpyHostToDevice,
this->stream)
);
this->device_ptrs[0], nchw.ptr<float>(), nchw.total() * nchw.elemSize(), cudaMemcpyHostToDevice, this->stream));
}
void YOLOv8::copy_from_Mat(const cv::Mat& image, cv::Size& size)
{
cv::Mat nchw;
this->letterbox(
image,
nchw,
size
);
this->context->setBindingDimensions(
0,
nvinfer1::Dims
{ 4,
{ 1, 3, size.height, size.width }
}
);
this->letterbox(image, nchw, size);
this->context->setBindingDimensions(0, nvinfer1::Dims{4, {1, 3, size.height, size.width}});
CHECK(cudaMemcpyAsync(
this->device_ptrs[0],
nchw.ptr<float>(),
nchw.total() * nchw.elemSize(),
cudaMemcpyHostToDevice,
this->stream)
);
this->device_ptrs[0], nchw.ptr<float>(), nchw.total() * nchw.elemSize(), cudaMemcpyHostToDevice, this->stream));
}
void YOLOv8::infer()
{
this->context->enqueueV2(
this->device_ptrs.data(),
this->stream,
nullptr
);
for (int i = 0; i < this->num_outputs; i++)
{
this->context->enqueueV2(this->device_ptrs.data(), this->stream, nullptr);
for (int i = 0; i < this->num_outputs; i++) {
size_t osize = this->output_bindings[i].size * this->output_bindings[i].dsize;
CHECK(cudaMemcpyAsync(this->host_ptrs[i],
this->device_ptrs[i + this->num_inputs],
osize,
cudaMemcpyDeviceToHost,
this->stream)
);
CHECK(cudaMemcpyAsync(
this->host_ptrs[i], this->device_ptrs[i + this->num_inputs], osize, cudaMemcpyDeviceToHost, this->stream));
}
cudaStreamSynchronize(this->stream);
}
void YOLOv8::postprocess(
std::vector<Object>& objs,
float score_thres,
float iou_thres,
int topk,
int num_labels
)
void YOLOv8::postprocess(std::vector<Object>& objs, float score_thres, float iou_thres, int topk, int num_labels)
{
objs.clear();
auto num_channels = this->output_bindings[0].dims.d[1];
@ -356,22 +238,15 @@ void YOLOv8::postprocess(
std::vector<int> labels;
std::vector<int> indices;
cv::Mat output = cv::Mat(
num_channels,
num_anchors,
CV_32F,
static_cast<float*>(this->host_ptrs[0])
);
cv::Mat output = cv::Mat(num_channels, num_anchors, CV_32F, static_cast<float*>(this->host_ptrs[0]));
output = output.t();
for (int i = 0; i < num_anchors; i++)
{
for (int i = 0; i < num_anchors; i++) {
auto row_ptr = output.row(i).ptr<float>();
auto bboxes_ptr = row_ptr;
auto scores_ptr = row_ptr + 4;
auto max_s_ptr = std::max_element(scores_ptr, scores_ptr + num_labels);
float score = *max_s_ptr;
if (score > score_thres)
{
if (score > score_thres) {
float x = *bboxes_ptr++ - dw;
float y = *bboxes_ptr++ - dh;
float w = *bboxes_ptr++;
@ -396,29 +271,14 @@ void YOLOv8::postprocess(
}
#ifdef BATCHED_NMS
cv::dnn::NMSBoxesBatched(
bboxes,
scores,
labels,
score_thres,
iou_thres,
indices
);
cv::dnn::NMSBoxesBatched(bboxes, scores, labels, score_thres, iou_thres, indices);
#else
cv::dnn::NMSBoxes(
bboxes,
scores,
score_thres,
iou_thres,
indices
);
cv::dnn::NMSBoxes(bboxes, scores, score_thres, iou_thres, indices);
#endif
int cnt = 0;
for (auto& i : indices)
{
if (cnt >= topk)
{
for (auto& i : indices) {
if (cnt >= topk) {
break;
}
Object obj;
@ -430,45 +290,22 @@ void YOLOv8::postprocess(
}
}
void YOLOv8::draw_objects(
const cv::Mat& image,
void YOLOv8::draw_objects(const cv::Mat& image,
cv::Mat& res,
const std::vector<Object>& objs,
const std::vector<std::string>& CLASS_NAMES,
const std::vector<std::vector<unsigned int>>& COLORS
)
const std::vector<std::vector<unsigned int>>& COLORS)
{
res = image.clone();
for (auto& obj : objs)
{
cv::Scalar color = cv::Scalar(
COLORS[obj.label][0],
COLORS[obj.label][1],
COLORS[obj.label][2]
);
cv::rectangle(
res,
obj.rect,
color,
2
);
for (auto& obj : objs) {
cv::Scalar color = cv::Scalar(COLORS[obj.label][0], COLORS[obj.label][1], COLORS[obj.label][2]);
cv::rectangle(res, obj.rect, color, 2);
char text[256];
sprintf(
text,
"%s %.1f%%",
CLASS_NAMES[obj.label].c_str(),
obj.prob * 100
);
sprintf(text, "%s %.1f%%", CLASS_NAMES[obj.label].c_str(), obj.prob * 100);
int baseLine = 0;
cv::Size label_size = cv::getTextSize(
text,
cv::FONT_HERSHEY_SIMPLEX,
0.4,
1,
&baseLine
);
cv::Size label_size = cv::getTextSize(text, cv::FONT_HERSHEY_SIMPLEX, 0.4, 1, &baseLine);
int x = (int)obj.rect.x;
int y = (int)obj.rect.y + 1;
@ -476,22 +313,9 @@ void YOLOv8::draw_objects(
if (y > res.rows)
y = res.rows;
cv::rectangle(
res,
cv::Rect(x, y, label_size.width, label_size.height + baseLine),
{ 0, 0, 255 },
-1
);
cv::putText(
res,
text,
cv::Point(x, y + label_size.height),
cv::FONT_HERSHEY_SIMPLEX,
0.4,
{ 255, 255, 255 },
1
);
cv::rectangle(res, cv::Rect(x, y, label_size.width, label_size.height + baseLine), {0, 0, 255}, -1);
cv::putText(res, text, cv::Point(x, y + label_size.height), cv::FONT_HERSHEY_SIMPLEX, 0.4, {255, 255, 255}, 1);
}
}
#endif // DETECT_NORMAL_YOLOV8_HPP

@ -2,56 +2,38 @@
// Created by ubuntu on 1/20/23.
//
#include "chrono"
#include "yolov8.hpp"
#include "opencv2/opencv.hpp"
#include "yolov8.hpp"
const std::vector<std::string> CLASS_NAMES = {
"person", "bicycle", "car", "motorcycle", "airplane", "bus",
"train", "truck", "boat", "traffic light", "fire hydrant",
"stop sign", "parking meter", "bench", "bird", "cat",
"dog", "horse", "sheep", "cow", "elephant",
"bear", "zebra", "giraffe", "backpack", "umbrella",
"handbag", "tie", "suitcase", "frisbee", "skis",
"snowboard", "sports ball", "kite", "baseball bat", "baseball glove",
"skateboard", "surfboard", "tennis racket", "bottle", "wine glass",
"cup", "fork", "knife", "spoon", "bowl",
"banana", "apple", "sandwich", "orange", "broccoli",
"carrot", "hot dog", "pizza", "donut", "cake",
"chair", "couch", "potted plant", "bed", "dining table",
"toilet", "tv", "laptop", "mouse", "remote",
"keyboard", "cell phone", "microwave", "oven",
"toaster", "sink", "refrigerator", "book", "clock", "vase",
"scissors", "teddy bear", "hair drier", "toothbrush" };
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train",
"truck", "boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench",
"bird", "cat", "dog", "horse", "sheep", "cow", "elephant",
"bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie",
"suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat",
"baseball glove", "skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup",
"fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich",
"orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake",
"chair", "couch", "potted plant", "bed", "dining table", "toilet", "tv",
"laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven",
"toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors",
"teddy bear", "hair drier", "toothbrush"};
const std::vector<std::vector<unsigned int>> COLORS = {
{ 0, 114, 189 }, { 217, 83, 25 }, { 237, 177, 32 },
{ 126, 47, 142 }, { 119, 172, 48 }, { 77, 190, 238 },
{ 162, 20, 47 }, { 76, 76, 76 }, { 153, 153, 153 },
{ 255, 0, 0 }, { 255, 128, 0 }, { 191, 191, 0 },
{ 0, 255, 0 }, { 0, 0, 255 }, { 170, 0, 255 },
{ 85, 85, 0 }, { 85, 170, 0 }, { 85, 255, 0 },
{ 170, 85, 0 }, { 170, 170, 0 }, { 170, 255, 0 },
{ 255, 85, 0 }, { 255, 170, 0 }, { 255, 255, 0 },
{ 0, 85, 128 }, { 0, 170, 128 }, { 0, 255, 128 },
{ 85, 0, 128 }, { 85, 85, 128 }, { 85, 170, 128 },
{ 85, 255, 128 }, { 170, 0, 128 }, { 170, 85, 128 },
{ 170, 170, 128 }, { 170, 255, 128 }, { 255, 0, 128 },
{ 255, 85, 128 }, { 255, 170, 128 }, { 255, 255, 128 },
{ 0, 85, 255 }, { 0, 170, 255 }, { 0, 255, 255 },
{ 85, 0, 255 }, { 85, 85, 255 }, { 85, 170, 255 },
{ 85, 255, 255 }, { 170, 0, 255 }, { 170, 85, 255 },
{ 170, 170, 255 }, { 170, 255, 255 }, { 255, 0, 255 },
{ 255, 85, 255 }, { 255, 170, 255 }, { 85, 0, 0 },
{ 128, 0, 0 }, { 170, 0, 0 }, { 212, 0, 0 },
{ 255, 0, 0 }, { 0, 43, 0 }, { 0, 85, 0 },
{ 0, 128, 0 }, { 0, 170, 0 }, { 0, 212, 0 },
{ 0, 255, 0 }, { 0, 0, 43 }, { 0, 0, 85 },
{ 0, 0, 128 }, { 0, 0, 170 }, { 0, 0, 212 },
{ 0, 0, 255 }, { 0, 0, 0 }, { 36, 36, 36 },
{ 73, 73, 73 }, { 109, 109, 109 }, { 146, 146, 146 },
{ 182, 182, 182 }, { 219, 219, 219 }, { 0, 114, 189 },
{ 80, 183, 189 }, { 128, 128, 0 }
};
{0, 114, 189}, {217, 83, 25}, {237, 177, 32}, {126, 47, 142}, {119, 172, 48}, {77, 190, 238},
{162, 20, 47}, {76, 76, 76}, {153, 153, 153}, {255, 0, 0}, {255, 128, 0}, {191, 191, 0},
{0, 255, 0}, {0, 0, 255}, {170, 0, 255}, {85, 85, 0}, {85, 170, 0}, {85, 255, 0},
{170, 85, 0}, {170, 170, 0}, {170, 255, 0}, {255, 85, 0}, {255, 170, 0}, {255, 255, 0},
{0, 85, 128}, {0, 170, 128}, {0, 255, 128}, {85, 0, 128}, {85, 85, 128}, {85, 170, 128},
{85, 255, 128}, {170, 0, 128}, {170, 85, 128}, {170, 170, 128}, {170, 255, 128}, {255, 0, 128},
{255, 85, 128}, {255, 170, 128}, {255, 255, 128}, {0, 85, 255}, {0, 170, 255}, {0, 255, 255},
{85, 0, 255}, {85, 85, 255}, {85, 170, 255}, {85, 255, 255}, {170, 0, 255}, {170, 85, 255},
{170, 170, 255}, {170, 255, 255}, {255, 0, 255}, {255, 85, 255}, {255, 170, 255}, {85, 0, 0},
{128, 0, 0}, {170, 0, 0}, {212, 0, 0}, {255, 0, 0}, {0, 43, 0}, {0, 85, 0},
{0, 128, 0}, {0, 170, 0}, {0, 212, 0}, {0, 255, 0}, {0, 0, 43}, {0, 0, 85},
{0, 0, 128}, {0, 0, 170}, {0, 0, 212}, {0, 0, 255}, {0, 0, 0}, {36, 36, 36},
{73, 73, 73}, {109, 109, 109}, {146, 146, 146}, {182, 182, 182}, {219, 219, 219}, {0, 114, 189},
{80, 183, 189}, {128, 128, 0}};
int main(int argc, char** argv)
{
@ -69,36 +51,21 @@ int main(int argc, char** argv)
auto yolov8 = new YOLOv8(engine_file_path);
yolov8->make_pipe(true);
if (IsFile(path))
{
if (IsFile(path)) {
std::string suffix = path.substr(path.find_last_of('.') + 1);
if (
suffix == "jpg" ||
suffix == "jpeg" ||
suffix == "png"
)
{
if (suffix == "jpg" || suffix == "jpeg" || suffix == "png") {
imagePathList.push_back(path);
}
else if (
suffix == "mp4" ||
suffix == "avi" ||
suffix == "m4v" ||
suffix == "mpeg" ||
suffix == "mov" ||
suffix == "mkv"
)
{
else if (suffix == "mp4" || suffix == "avi" || suffix == "m4v" || suffix == "mpeg" || suffix == "mov"
|| suffix == "mkv") {
isVideo = true;
}
else
{
else {
printf("suffix %s is wrong !!!\n", suffix.c_str());
std::abort();
}
}
else if (IsFolder(path))
{
else if (IsFolder(path)) {
cv::glob(path + "/*.jpg", imagePathList);
}
@ -113,17 +80,14 @@ int main(int argc, char** argv)
cv::namedWindow("result", cv::WINDOW_AUTOSIZE);
if (isVideo)
{
if (isVideo) {
cv::VideoCapture cap(path);
if (!cap.isOpened())
{
if (!cap.isOpened()) {
printf("can not open %s\n", path.c_str());
return -1;
}
while (cap.read(image))
{
while (cap.read(image)) {
objs.clear();
yolov8->copy_from_Mat(image, size);
auto start = std::chrono::system_clock::now();
@ -131,20 +95,16 @@ int main(int argc, char** argv)
auto end = std::chrono::system_clock::now();
yolov8->postprocess(objs, score_thres, iou_thres, topk, num_labels);
yolov8->draw_objects(image, res, objs, CLASS_NAMES, COLORS);
auto tc = (double)
std::chrono::duration_cast<std::chrono::microseconds>(end - start).count() / 1000.;
auto tc = (double)std::chrono::duration_cast<std::chrono::microseconds>(end - start).count() / 1000.;
printf("cost %2.4lf ms\n", tc);
cv::imshow("result", res);
if (cv::waitKey(10) == 'q')
{
if (cv::waitKey(10) == 'q') {
break;
}
}
}
else
{
for (auto& path : imagePathList)
{
else {
for (auto& path : imagePathList) {
objs.clear();
image = cv::imread(path);
yolov8->copy_from_Mat(image, size);
@ -153,8 +113,7 @@ int main(int argc, char** argv)
auto end = std::chrono::system_clock::now();
yolov8->postprocess(objs, score_thres, iou_thres, topk, num_labels);
yolov8->draw_objects(image, res, objs, CLASS_NAMES, COLORS);
auto tc = (double)
std::chrono::duration_cast<std::chrono::microseconds>(end - start).count() / 1000.;
auto tc = (double)std::chrono::duration_cast<std::chrono::microseconds>(end - start).count() / 1000.;
printf("cost %2.4lf ms\n", tc);
cv::imshow("result", res);
cv::waitKey(0);

@ -4,29 +4,25 @@
#ifndef JETSON_DETECT_COMMON_HPP
#define JETSON_DETECT_COMMON_HPP
#include "NvInfer.h"
#include "opencv2/opencv.hpp"
#include <sys/stat.h>
#include <unistd.h>
#include "NvInfer.h"
#define CHECK(call) \
do \
{ \
do { \
const cudaError_t error_code = call; \
if (error_code != cudaSuccess) \
{ \
if (error_code != cudaSuccess) { \
printf("CUDA Error:\n"); \
printf(" File: %s\n", __FILE__); \
printf(" Line: %d\n", __LINE__); \
printf(" Error code: %d\n", error_code); \
printf(" Error text: %s\n", \
cudaGetErrorString(error_code)); \
printf(" Error text: %s\n", cudaGetErrorString(error_code)); \
exit(1); \
} \
} while (0)
class Logger : public nvinfer1::ILogger
{
class Logger: public nvinfer1::ILogger {
public:
nvinfer1::ILogger::Severity reportableSeverity;
@ -37,12 +33,10 @@ public:
void log(nvinfer1::ILogger::Severity severity, const char* msg) noexcept override
{
if (severity > reportableSeverity)
{
if (severity > reportableSeverity) {
return;
}
switch (severity)
{
switch (severity) {
case nvinfer1::ILogger::Severity::kINTERNAL_ERROR:
std::cerr << "INTERNAL_ERROR: ";
break;
@ -66,8 +60,7 @@ public:
inline int get_size_by_dims(const nvinfer1::Dims& dims)
{
int size = 1;
for (int i = 0; i < dims.nbDims; i++)
{
for (int i = 0; i < dims.nbDims; i++) {
size *= dims.d[i];
}
return size;
@ -75,8 +68,7 @@ inline int get_size_by_dims(const nvinfer1::Dims& dims)
inline int type_to_size(const nvinfer1::DataType& dataType)
{
switch (dataType)
{
switch (dataType) {
case nvinfer1::DataType::kFLOAT:
return 4;
case nvinfer1::DataType::kHALF:
@ -99,8 +91,7 @@ inline static float clamp(float val, float min, float max)
inline bool IsPathExist(const std::string& path)
{
if (access(path.c_str(), 0) == F_OK)
{
if (access(path.c_str(), 0) == F_OK) {
return true;
}
return false;
@ -108,8 +99,7 @@ inline bool IsPathExist(const std::string& path)
inline bool IsFile(const std::string& path)
{
if (!IsPathExist(path))
{
if (!IsPathExist(path)) {
printf("%s:%d %s not exist\n", __FILE__, __LINE__, path.c_str());
return false;
}
@ -119,38 +109,33 @@ inline bool IsFile(const std::string& path)
inline bool IsFolder(const std::string& path)
{
if (!IsPathExist(path))
{
if (!IsPathExist(path)) {
return false;
}
struct stat buffer;
return (stat(path.c_str(), &buffer) == 0 && S_ISDIR(buffer.st_mode));
}
namespace det
{
struct Binding
{
namespace det {
struct Binding {
size_t size = 1;
size_t dsize = 1;
nvinfer1::Dims dims;
std::string name;
};
struct Object
{
struct Object {
cv::Rect_<float> rect;
int label = 0;
float prob = 0.0;
};
struct PreParam
{
struct PreParam {
float ratio = 1.0f;
float dw = 0.0f;
float dh = 0.0f;
float height = 0;
float width = 0;
};
}
} // namespace det
#endif // JETSON_DETECT_COMMON_HPP

@ -3,13 +3,12 @@
//
#ifndef JETSON_DETECT_YOLOV8_HPP
#define JETSON_DETECT_YOLOV8_HPP
#include "fstream"
#include "common.hpp"
#include "NvInferPlugin.h"
#include "common.hpp"
#include "fstream"
using namespace det;
class YOLOv8
{
class YOLOv8 {
public:
explicit YOLOv8(const std::string& engine_file_path);
~YOLOv8();
@ -17,20 +16,14 @@ public:
void make_pipe(bool warmup = true);
void copy_from_Mat(const cv::Mat& image);
void copy_from_Mat(const cv::Mat& image, cv::Size& size);
void letterbox(
const cv::Mat& image,
cv::Mat& out,
cv::Size& size
);
void letterbox(const cv::Mat& image, cv::Mat& out, cv::Size& size);
void infer();
void postprocess(std::vector<Object>& objs);
static void draw_objects(
const cv::Mat& image,
static void draw_objects(const cv::Mat& image,
cv::Mat& res,
const std::vector<Object>& objs,
const std::vector<std::string>& CLASS_NAMES,
const std::vector<std::vector<unsigned int>>& COLORS
);
const std::vector<std::vector<unsigned int>>& COLORS);
int num_bindings;
int num_inputs = 0;
int num_outputs = 0;
@ -40,13 +33,13 @@ public:
std::vector<void*> device_ptrs;
PreParam pparam;
private:
nvinfer1::ICudaEngine* engine = nullptr;
nvinfer1::IRuntime* runtime = nullptr;
nvinfer1::IExecutionContext* context = nullptr;
cudaStream_t stream = nullptr;
Logger gLogger{nvinfer1::ILogger::Severity::kERROR};
};
YOLOv8::YOLOv8(const std::string& engine_file_path)
@ -73,8 +66,7 @@ YOLOv8::YOLOv8(const std::string& engine_file_path)
cudaStreamCreate(&this->stream);
this->num_bindings = this->engine->getNbBindings();
for (int i = 0; i < this->num_bindings; ++i)
{
for (int i = 0; i < this->num_bindings; ++i) {
Binding binding;
nvinfer1::Dims dims;
nvinfer1::DataType dtype = this->engine->getBindingDataType(i);
@ -83,22 +75,16 @@ YOLOv8::YOLOv8(const std::string& engine_file_path)
binding.dsize = type_to_size(dtype);
bool IsInput = engine->bindingIsInput(i);
if (IsInput)
{
if (IsInput) {
this->num_inputs += 1;
dims = this->engine->getProfileDimensions(
i,
0,
nvinfer1::OptProfileSelector::kMAX);
dims = this->engine->getProfileDimensions(i, 0, nvinfer1::OptProfileSelector::kMAX);
binding.size = get_size_by_dims(dims);
binding.dims = dims;
this->input_bindings.push_back(binding);
// set max opt shape
this->context->setBindingDimensions(i, dims);
}
else
{
else {
dims = this->context->getBindingDimensions(i);
binding.size = get_size_by_dims(dims);
binding.dims = dims;
@ -106,7 +92,6 @@ YOLOv8::YOLOv8(const std::string& engine_file_path)
this->num_outputs += 1;
}
}
}
YOLOv8::~YOLOv8()
@ -115,69 +100,44 @@ YOLOv8::~YOLOv8()
this->engine->destroy();
this->runtime->destroy();
cudaStreamDestroy(this->stream);
for (auto& ptr : this->device_ptrs)
{
for (auto& ptr : this->device_ptrs) {
CHECK(cudaFree(ptr));
}
for (auto& ptr : this->host_ptrs)
{
for (auto& ptr : this->host_ptrs) {
CHECK(cudaFreeHost(ptr));
}
}
void YOLOv8::make_pipe(bool warmup)
{
for (auto& bindings : this->input_bindings)
{
for (auto& bindings : this->input_bindings) {
void* d_ptr;
CHECK(cudaMalloc(
&d_ptr,
bindings.size * bindings.dsize)
);
CHECK(cudaMalloc(&d_ptr, bindings.size * bindings.dsize));
this->device_ptrs.push_back(d_ptr);
}
for (auto& bindings : this->output_bindings)
{
for (auto& bindings : this->output_bindings) {
void * d_ptr, *h_ptr;
size_t size = bindings.size * bindings.dsize;
CHECK(cudaMalloc(
&d_ptr,
size)
);
CHECK(cudaHostAlloc(
&h_ptr,
size,
0)
);
CHECK(cudaMalloc(&d_ptr, size));
CHECK(cudaHostAlloc(&h_ptr, size, 0));
this->device_ptrs.push_back(d_ptr);
this->host_ptrs.push_back(h_ptr);
}
if (warmup)
{
for (int i = 0; i < 10; i++)
{
for (auto& bindings : this->input_bindings)
{
if (warmup) {
for (int i = 0; i < 10; i++) {
for (auto& bindings : this->input_bindings) {
size_t size = bindings.size * bindings.dsize;
void* h_ptr = malloc(size);
memset(h_ptr, 0, size);
CHECK(cudaMemcpyAsync(
this->device_ptrs[0],
h_ptr,
size,
cudaMemcpyHostToDevice,
this->stream)
);
CHECK(cudaMemcpyAsync(this->device_ptrs[0], h_ptr, size, cudaMemcpyHostToDevice, this->stream));
free(h_ptr);
}
this->infer();
}
printf("model warmup 10 times\n");
}
}
@ -193,16 +153,10 @@ void YOLOv8::letterbox(const cv::Mat& image, cv::Mat& out, cv::Size& size)
int padh = std::round(height * r);
cv::Mat tmp;
if ((int)width != padw || (int)height != padh)
{
cv::resize(
image,
tmp,
cv::Size(padw, padh)
);
if ((int)width != padw || (int)height != padh) {
cv::resize(image, tmp, cv::Size(padw, padh));
}
else
{
else {
tmp = image.clone();
}
@ -216,31 +170,15 @@ void YOLOv8::letterbox(const cv::Mat& image, cv::Mat& out, cv::Size& size)
int left = int(std::round(dw - 0.1f));
int right = int(std::round(dw + 0.1f));
cv::copyMakeBorder(
tmp,
tmp,
top,
bottom,
left,
right,
cv::BORDER_CONSTANT,
{ 114, 114, 114 }
);
cv::dnn::blobFromImage(tmp,
out,
1 / 255.f,
cv::Size(),
cv::Scalar(0, 0, 0),
true,
false,
CV_32F
);
cv::copyMakeBorder(tmp, tmp, top, bottom, left, right, cv::BORDER_CONSTANT, {114, 114, 114});
cv::dnn::blobFromImage(tmp, out, 1 / 255.f, cv::Size(), cv::Scalar(0, 0, 0), true, false, CV_32F);
this->pparam.ratio = 1 / r;
this->pparam.dw = dw;
this->pparam.dh = dh;
this->pparam.height = height;
this->pparam.width = width;;
this->pparam.width = width;
;
}
void YOLOv8::copy_from_Mat(const cv::Mat& image)
@ -250,75 +188,33 @@ void YOLOv8::copy_from_Mat(const cv::Mat& image)
auto width = in_binding.dims.d[3];
auto height = in_binding.dims.d[2];
cv::Size size{width, height};
this->letterbox(
image,
nchw,
size
);
this->context->setBindingDimensions(
0,
nvinfer1::Dims
{
4,
{ 1, 3, height, width }
}
);
this->letterbox(image, nchw, size);
this->context->setBindingDimensions(0, nvinfer1::Dims{4, {1, 3, height, width}});
CHECK(cudaMemcpyAsync(
this->device_ptrs[0],
nchw.ptr<float>(),
nchw.total() * nchw.elemSize(),
cudaMemcpyHostToDevice,
this->stream)
);
this->device_ptrs[0], nchw.ptr<float>(), nchw.total() * nchw.elemSize(), cudaMemcpyHostToDevice, this->stream));
}
void YOLOv8::copy_from_Mat(const cv::Mat& image, cv::Size& size)
{
cv::Mat nchw;
this->letterbox(
image,
nchw,
size
);
this->context->setBindingDimensions(
0,
nvinfer1::Dims
{ 4,
{ 1, 3, size.height, size.width }
}
);
this->letterbox(image, nchw, size);
this->context->setBindingDimensions(0, nvinfer1::Dims{4, {1, 3, size.height, size.width}});
CHECK(cudaMemcpyAsync(
this->device_ptrs[0],
nchw.ptr<float>(),
nchw.total() * nchw.elemSize(),
cudaMemcpyHostToDevice,
this->stream)
);
this->device_ptrs[0], nchw.ptr<float>(), nchw.total() * nchw.elemSize(), cudaMemcpyHostToDevice, this->stream));
}
void YOLOv8::infer()
{
this->context->enqueueV2(
this->device_ptrs.data(),
this->stream,
nullptr
);
for (int i = 0; i < this->num_outputs; i++)
{
this->context->enqueueV2(this->device_ptrs.data(), this->stream, nullptr);
for (int i = 0; i < this->num_outputs; i++) {
size_t osize = this->output_bindings[i].size * this->output_bindings[i].dsize;
CHECK(cudaMemcpyAsync(this->host_ptrs[i],
this->device_ptrs[i + this->num_inputs],
osize,
cudaMemcpyDeviceToHost,
this->stream)
);
CHECK(cudaMemcpyAsync(
this->host_ptrs[i], this->device_ptrs[i + this->num_inputs], osize, cudaMemcpyDeviceToHost, this->stream));
}
cudaStreamSynchronize(this->stream);
}
void YOLOv8::postprocess(std::vector<Object>& objs)
@ -333,8 +229,7 @@ void YOLOv8::postprocess(std::vector<Object>& objs)
auto& width = this->pparam.width;
auto& height = this->pparam.height;
auto& ratio = this->pparam.ratio;
for (int i = 0; i < num_dets[0]; i++)
{
for (int i = 0; i < num_dets[0]; i++) {
float* ptr = boxes + i * 4;
float x0 = *ptr++ - dw;
@ -357,45 +252,22 @@ void YOLOv8::postprocess(std::vector<Object>& objs)
}
}
void YOLOv8::draw_objects(
const cv::Mat& image,
void YOLOv8::draw_objects(const cv::Mat& image,
cv::Mat& res,
const std::vector<Object>& objs,
const std::vector<std::string>& CLASS_NAMES,
const std::vector<std::vector<unsigned int>>& COLORS
)
const std::vector<std::vector<unsigned int>>& COLORS)
{
res = image.clone();
for (auto& obj : objs)
{
cv::Scalar color = cv::Scalar(
COLORS[obj.label][0],
COLORS[obj.label][1],
COLORS[obj.label][2]
);
cv::rectangle(
res,
obj.rect,
color,
2
);
for (auto& obj : objs) {
cv::Scalar color = cv::Scalar(COLORS[obj.label][0], COLORS[obj.label][1], COLORS[obj.label][2]);
cv::rectangle(res, obj.rect, color, 2);
char text[256];
sprintf(
text,
"%s %.1f%%",
CLASS_NAMES[obj.label].c_str(),
obj.prob * 100
);
sprintf(text, "%s %.1f%%", CLASS_NAMES[obj.label].c_str(), obj.prob * 100);
int baseLine = 0;
cv::Size label_size = cv::getTextSize(
text,
cv::FONT_HERSHEY_SIMPLEX,
0.4,
1,
&baseLine
);
cv::Size label_size = cv::getTextSize(text, cv::FONT_HERSHEY_SIMPLEX, 0.4, 1, &baseLine);
int x = (int)obj.rect.x;
int y = (int)obj.rect.y + 1;
@ -403,22 +275,9 @@ void YOLOv8::draw_objects(
if (y > res.rows)
y = res.rows;
cv::rectangle(
res,
cv::Rect(x, y, label_size.width, label_size.height + baseLine),
{ 0, 0, 255 },
-1
);
cv::putText(
res,
text,
cv::Point(x, y + label_size.height),
cv::FONT_HERSHEY_SIMPLEX,
0.4,
{ 255, 255, 255 },
1
);
cv::rectangle(res, cv::Rect(x, y, label_size.width, label_size.height + baseLine), {0, 0, 255}, -1);
cv::putText(res, text, cv::Point(x, y + label_size.height), cv::FONT_HERSHEY_SIMPLEX, 0.4, {255, 255, 255}, 1);
}
}
#endif // JETSON_DETECT_YOLOV8_HPP

@ -2,56 +2,38 @@
// Created by ubuntu on 3/16/23.
//
#include "chrono"
#include "yolov8.hpp"
#include "opencv2/opencv.hpp"
#include "yolov8.hpp"
const std::vector<std::string> CLASS_NAMES = {
"person", "bicycle", "car", "motorcycle", "airplane", "bus",
"train", "truck", "boat", "traffic light", "fire hydrant",
"stop sign", "parking meter", "bench", "bird", "cat",
"dog", "horse", "sheep", "cow", "elephant",
"bear", "zebra", "giraffe", "backpack", "umbrella",
"handbag", "tie", "suitcase", "frisbee", "skis",
"snowboard", "sports ball", "kite", "baseball bat", "baseball glove",
"skateboard", "surfboard", "tennis racket", "bottle", "wine glass",
"cup", "fork", "knife", "spoon", "bowl",
"banana", "apple", "sandwich", "orange", "broccoli",
"carrot", "hot dog", "pizza", "donut", "cake",
"chair", "couch", "potted plant", "bed", "dining table",
"toilet", "tv", "laptop", "mouse", "remote",
"keyboard", "cell phone", "microwave", "oven",
"toaster", "sink", "refrigerator", "book", "clock", "vase",
"scissors", "teddy bear", "hair drier", "toothbrush" };
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train",
"truck", "boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench",
"bird", "cat", "dog", "horse", "sheep", "cow", "elephant",
"bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie",
"suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat",
"baseball glove", "skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup",
"fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich",
"orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake",
"chair", "couch", "potted plant", "bed", "dining table", "toilet", "tv",
"laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven",
"toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors",
"teddy bear", "hair drier", "toothbrush"};
const std::vector<std::vector<unsigned int>> COLORS = {
{ 0, 114, 189 }, { 217, 83, 25 }, { 237, 177, 32 },
{ 126, 47, 142 }, { 119, 172, 48 }, { 77, 190, 238 },
{ 162, 20, 47 }, { 76, 76, 76 }, { 153, 153, 153 },
{ 255, 0, 0 }, { 255, 128, 0 }, { 191, 191, 0 },
{ 0, 255, 0 }, { 0, 0, 255 }, { 170, 0, 255 },
{ 85, 85, 0 }, { 85, 170, 0 }, { 85, 255, 0 },
{ 170, 85, 0 }, { 170, 170, 0 }, { 170, 255, 0 },
{ 255, 85, 0 }, { 255, 170, 0 }, { 255, 255, 0 },
{ 0, 85, 128 }, { 0, 170, 128 }, { 0, 255, 128 },
{ 85, 0, 128 }, { 85, 85, 128 }, { 85, 170, 128 },
{ 85, 255, 128 }, { 170, 0, 128 }, { 170, 85, 128 },
{ 170, 170, 128 }, { 170, 255, 128 }, { 255, 0, 128 },
{ 255, 85, 128 }, { 255, 170, 128 }, { 255, 255, 128 },
{ 0, 85, 255 }, { 0, 170, 255 }, { 0, 255, 255 },
{ 85, 0, 255 }, { 85, 85, 255 }, { 85, 170, 255 },
{ 85, 255, 255 }, { 170, 0, 255 }, { 170, 85, 255 },
{ 170, 170, 255 }, { 170, 255, 255 }, { 255, 0, 255 },
{ 255, 85, 255 }, { 255, 170, 255 }, { 85, 0, 0 },
{ 128, 0, 0 }, { 170, 0, 0 }, { 212, 0, 0 },
{ 255, 0, 0 }, { 0, 43, 0 }, { 0, 85, 0 },
{ 0, 128, 0 }, { 0, 170, 0 }, { 0, 212, 0 },
{ 0, 255, 0 }, { 0, 0, 43 }, { 0, 0, 85 },
{ 0, 0, 128 }, { 0, 0, 170 }, { 0, 0, 212 },
{ 0, 0, 255 }, { 0, 0, 0 }, { 36, 36, 36 },
{ 73, 73, 73 }, { 109, 109, 109 }, { 146, 146, 146 },
{ 182, 182, 182 }, { 219, 219, 219 }, { 0, 114, 189 },
{ 80, 183, 189 }, { 128, 128, 0 }
};
{0, 114, 189}, {217, 83, 25}, {237, 177, 32}, {126, 47, 142}, {119, 172, 48}, {77, 190, 238},
{162, 20, 47}, {76, 76, 76}, {153, 153, 153}, {255, 0, 0}, {255, 128, 0}, {191, 191, 0},
{0, 255, 0}, {0, 0, 255}, {170, 0, 255}, {85, 85, 0}, {85, 170, 0}, {85, 255, 0},
{170, 85, 0}, {170, 170, 0}, {170, 255, 0}, {255, 85, 0}, {255, 170, 0}, {255, 255, 0},
{0, 85, 128}, {0, 170, 128}, {0, 255, 128}, {85, 0, 128}, {85, 85, 128}, {85, 170, 128},
{85, 255, 128}, {170, 0, 128}, {170, 85, 128}, {170, 170, 128}, {170, 255, 128}, {255, 0, 128},
{255, 85, 128}, {255, 170, 128}, {255, 255, 128}, {0, 85, 255}, {0, 170, 255}, {0, 255, 255},
{85, 0, 255}, {85, 85, 255}, {85, 170, 255}, {85, 255, 255}, {170, 0, 255}, {170, 85, 255},
{170, 170, 255}, {170, 255, 255}, {255, 0, 255}, {255, 85, 255}, {255, 170, 255}, {85, 0, 0},
{128, 0, 0}, {170, 0, 0}, {212, 0, 0}, {255, 0, 0}, {0, 43, 0}, {0, 85, 0},
{0, 128, 0}, {0, 170, 0}, {0, 212, 0}, {0, 255, 0}, {0, 0, 43}, {0, 0, 85},
{0, 0, 128}, {0, 0, 170}, {0, 0, 212}, {0, 0, 255}, {0, 0, 0}, {36, 36, 36},
{73, 73, 73}, {109, 109, 109}, {146, 146, 146}, {182, 182, 182}, {219, 219, 219}, {0, 114, 189},
{80, 183, 189}, {128, 128, 0}};
int main(int argc, char** argv)
{
@ -66,36 +48,21 @@ int main(int argc, char** argv)
auto yolov8 = new YOLOv8(engine_file_path);
yolov8->make_pipe(true);
if (IsFile(path))
{
if (IsFile(path)) {
std::string suffix = path.substr(path.find_last_of('.') + 1);
if (
suffix == "jpg" ||
suffix == "jpeg" ||
suffix == "png"
)
{
if (suffix == "jpg" || suffix == "jpeg" || suffix == "png") {
imagePathList.push_back(path);
}
else if (
suffix == "mp4" ||
suffix == "avi" ||
suffix == "m4v" ||
suffix == "mpeg" ||
suffix == "mov" ||
suffix == "mkv"
)
{
else if (suffix == "mp4" || suffix == "avi" || suffix == "m4v" || suffix == "mpeg" || suffix == "mov"
|| suffix == "mkv") {
isVideo = true;
}
else
{
else {
printf("suffix %s is wrong !!!\n", suffix.c_str());
std::abort();
}
}
else if (IsFolder(path))
{
else if (IsFolder(path)) {
cv::glob(path + "/*.jpg", imagePathList);
}
@ -105,17 +72,14 @@ int main(int argc, char** argv)
cv::namedWindow("result", cv::WINDOW_AUTOSIZE);
if (isVideo)
{
if (isVideo) {
cv::VideoCapture cap(path);
if (!cap.isOpened())
{
if (!cap.isOpened()) {
printf("can not open %s\n", path.c_str());
return -1;
}
while (cap.read(image))
{
while (cap.read(image)) {
objs.clear();
yolov8->copy_from_Mat(image, size);
auto start = std::chrono::system_clock::now();
@ -123,20 +87,16 @@ int main(int argc, char** argv)
auto end = std::chrono::system_clock::now();
yolov8->postprocess(objs);
yolov8->draw_objects(image, res, objs, CLASS_NAMES, COLORS);
auto tc = (double)
std::chrono::duration_cast<std::chrono::microseconds>(end - start).count() / 1000.;
auto tc = (double)std::chrono::duration_cast<std::chrono::microseconds>(end - start).count() / 1000.;
printf("cost %2.4lf ms\n", tc);
cv::imshow("result", res);
if (cv::waitKey(10) == 'q')
{
if (cv::waitKey(10) == 'q') {
break;
}
}
}
else
{
for (auto& path : imagePathList)
{
else {
for (auto& path : imagePathList) {
objs.clear();
image = cv::imread(path);
yolov8->copy_from_Mat(image, size);
@ -145,8 +105,7 @@ int main(int argc, char** argv)
auto end = std::chrono::system_clock::now();
yolov8->postprocess(objs);
yolov8->draw_objects(image, res, objs, CLASS_NAMES, COLORS);
auto tc = (double)
std::chrono::duration_cast<std::chrono::microseconds>(end - start).count() / 1000.;
auto tc = (double)std::chrono::duration_cast<std::chrono::microseconds>(end - start).count() / 1000.;
printf("cost %2.4lf ms\n", tc);
cv::imshow("result", res);
cv::waitKey(0);

@ -4,29 +4,25 @@
#ifndef JETSON_POSE_COMMON_HPP
#define JETSON_POSE_COMMON_HPP
#include "NvInfer.h"
#include "opencv2/opencv.hpp"
#include <sys/stat.h>
#include <unistd.h>
#include "NvInfer.h"
#define CHECK(call) \
do \
{ \
do { \
const cudaError_t error_code = call; \
if (error_code != cudaSuccess) \
{ \
if (error_code != cudaSuccess) { \
printf("CUDA Error:\n"); \
printf(" File: %s\n", __FILE__); \
printf(" Line: %d\n", __LINE__); \
printf(" Error code: %d\n", error_code); \
printf(" Error text: %s\n", \
cudaGetErrorString(error_code)); \
printf(" Error text: %s\n", cudaGetErrorString(error_code)); \
exit(1); \
} \
} while (0)
class Logger : public nvinfer1::ILogger
{
class Logger: public nvinfer1::ILogger {
public:
nvinfer1::ILogger::Severity reportableSeverity;
@ -37,12 +33,10 @@ public:
void log(nvinfer1::ILogger::Severity severity, const char* msg) noexcept override
{
if (severity > reportableSeverity)
{
if (severity > reportableSeverity) {
return;
}
switch (severity)
{
switch (severity) {
case nvinfer1::ILogger::Severity::kINTERNAL_ERROR:
std::cerr << "INTERNAL_ERROR: ";
break;
@ -66,8 +60,7 @@ public:
inline int get_size_by_dims(const nvinfer1::Dims& dims)
{
int size = 1;
for (int i = 0; i < dims.nbDims; i++)
{
for (int i = 0; i < dims.nbDims; i++) {
size *= dims.d[i];
}
return size;
@ -75,8 +68,7 @@ inline int get_size_by_dims(const nvinfer1::Dims& dims)
inline int type_to_size(const nvinfer1::DataType& dataType)
{
switch (dataType)
{
switch (dataType) {
case nvinfer1::DataType::kFLOAT:
return 4;
case nvinfer1::DataType::kHALF:
@ -99,8 +91,7 @@ inline static float clamp(float val, float min, float max)
inline bool IsPathExist(const std::string& path)
{
if (access(path.c_str(), 0) == F_OK)
{
if (access(path.c_str(), 0) == F_OK) {
return true;
}
return false;
@ -108,8 +99,7 @@ inline bool IsPathExist(const std::string& path)
inline bool IsFile(const std::string& path)
{
if (!IsPathExist(path))
{
if (!IsPathExist(path)) {
printf("%s:%d %s not exist\n", __FILE__, __LINE__, path.c_str());
return false;
}
@ -119,39 +109,34 @@ inline bool IsFile(const std::string& path)
inline bool IsFolder(const std::string& path)
{
if (!IsPathExist(path))
{
if (!IsPathExist(path)) {
return false;
}
struct stat buffer;
return (stat(path.c_str(), &buffer) == 0 && S_ISDIR(buffer.st_mode));
}
namespace pose
{
struct Binding
{
namespace pose {
struct Binding {
size_t size = 1;
size_t dsize = 1;
nvinfer1::Dims dims;
std::string name;
};
struct Object
{
struct Object {
cv::Rect_<float> rect;
int label = 0;
float prob = 0.0;
std::vector<float> kps;
};
struct PreParam
{
struct PreParam {
float ratio = 1.0f;
float dw = 0.0f;
float dh = 0.0f;
float height = 0;
float width = 0;
};
}
} // namespace pose
#endif // JETSON_POSE_COMMON_HPP

@ -4,9 +4,9 @@
#ifndef JETSON_POSE_YOLOV8_POSE_HPP
#define JETSON_POSE_YOLOV8_POSE_HPP
#include "fstream"
#include "common.hpp"
#include "NvInferPlugin.h"
#include "common.hpp"
#include "fstream"
using namespace pose;
@ -22,29 +22,18 @@ public:
void copy_from_Mat(const cv::Mat& image, cv::Size& size);
void letterbox(
const cv::Mat &image,
cv::Mat &out,
cv::Size &size
);
void letterbox(const cv::Mat& image, cv::Mat& out, cv::Size& size);
void infer();
void postprocess(
std::vector<Object> &objs,
float score_thres = 0.25f,
float iou_thres = 0.65f,
int topk = 100
);
void postprocess(std::vector<Object>& objs, float score_thres = 0.25f, float iou_thres = 0.65f, int topk = 100);
static void draw_objects(
const cv::Mat &image,
static void draw_objects(const cv::Mat& image,
cv::Mat& res,
const std::vector<Object>& objs,
const std::vector<std::vector<unsigned int>>& SKELETON,
const std::vector<std::vector<unsigned int>>& KPS_COLORS,
const std::vector<std::vector<unsigned int>> &LIMB_COLORS
);
const std::vector<std::vector<unsigned int>>& LIMB_COLORS);
int num_bindings;
int num_inputs = 0;
@ -55,16 +44,17 @@ public:
std::vector<void*> device_ptrs;
PreParam pparam;
private:
nvinfer1::ICudaEngine* engine = nullptr;
nvinfer1::IRuntime* runtime = nullptr;
nvinfer1::IExecutionContext* context = nullptr;
cudaStream_t stream = nullptr;
Logger gLogger{nvinfer1::ILogger::Severity::kERROR};
};
YOLOv8_pose::YOLOv8_pose(const std::string &engine_file_path) {
YOLOv8_pose::YOLOv8_pose(const std::string& engine_file_path)
{
std::ifstream file(engine_file_path, std::ios::binary);
assert(file.good());
file.seekg(0, std::ios::end);
@ -98,17 +88,14 @@ YOLOv8_pose::YOLOv8_pose(const std::string &engine_file_path) {
bool IsInput = engine->bindingIsInput(i);
if (IsInput) {
this->num_inputs += 1;
dims = this->engine->getProfileDimensions(
i,
0,
nvinfer1::OptProfileSelector::kMAX);
dims = this->engine->getProfileDimensions(i, 0, nvinfer1::OptProfileSelector::kMAX);
binding.size = get_size_by_dims(dims);
binding.dims = dims;
this->input_bindings.push_back(binding);
// set max opt shape
this->context->setBindingDimensions(i, dims);
} else {
}
else {
dims = this->context->getBindingDimensions(i);
binding.size = get_size_by_dims(dims);
binding.dims = dims;
@ -116,10 +103,10 @@ YOLOv8_pose::YOLOv8_pose(const std::string &engine_file_path) {
this->num_outputs += 1;
}
}
}
YOLOv8_pose::~YOLOv8_pose() {
YOLOv8_pose::~YOLOv8_pose()
{
this->context->destroy();
this->engine->destroy();
this->runtime->destroy();
@ -131,33 +118,22 @@ YOLOv8_pose::~YOLOv8_pose() {
for (auto& ptr : this->host_ptrs) {
CHECK(cudaFreeHost(ptr));
}
}
void YOLOv8_pose::make_pipe(bool warmup) {
void YOLOv8_pose::make_pipe(bool warmup)
{
for (auto& bindings : this->input_bindings) {
void* d_ptr;
CHECK(cudaMalloc(
&d_ptr,
bindings.size * bindings.dsize
)
);
CHECK(cudaMalloc(&d_ptr, bindings.size * bindings.dsize));
this->device_ptrs.push_back(d_ptr);
}
for (auto& bindings : this->output_bindings) {
void * d_ptr, *h_ptr;
size_t size = bindings.size * bindings.dsize;
CHECK(cudaMalloc(
&d_ptr,
size)
);
CHECK(cudaHostAlloc(
&h_ptr,
size,
0)
);
CHECK(cudaMalloc(&d_ptr, size));
CHECK(cudaHostAlloc(&h_ptr, size, 0));
this->device_ptrs.push_back(d_ptr);
this->host_ptrs.push_back(h_ptr);
}
@ -168,27 +144,17 @@ void YOLOv8_pose::make_pipe(bool warmup) {
size_t size = bindings.size * bindings.dsize;
void* h_ptr = malloc(size);
memset(h_ptr, 0, size);
CHECK(cudaMemcpyAsync(
this->device_ptrs[0],
h_ptr,
size,
cudaMemcpyHostToDevice,
this->stream)
);
CHECK(cudaMemcpyAsync(this->device_ptrs[0], h_ptr, size, cudaMemcpyHostToDevice, this->stream));
free(h_ptr);
}
this->infer();
}
printf("model warmup 10 times\n");
}
}
void YOLOv8_pose::letterbox(
const cv::Mat &image,
cv::Mat &out,
cv::Size &size
) {
void YOLOv8_pose::letterbox(const cv::Mat& image, cv::Mat& out, cv::Size& size)
{
const float inp_h = size.height;
const float inp_w = size.width;
float height = image.rows;
@ -200,12 +166,9 @@ void YOLOv8_pose::letterbox(
cv::Mat tmp;
if ((int)width != padw || (int)height != padh) {
cv::resize(
image,
tmp,
cv::Size(padw, padh)
);
} else {
cv::resize(image, tmp, cv::Size(padw, padh));
}
else {
tmp = image.clone();
}
@ -219,113 +182,55 @@ void YOLOv8_pose::letterbox(
int left = int(std::round(dw - 0.1f));
int right = int(std::round(dw + 0.1f));
cv::copyMakeBorder(
tmp,
tmp,
top,
bottom,
left,
right,
cv::BORDER_CONSTANT,
{114, 114, 114}
);
cv::dnn::blobFromImage(tmp,
out,
1 / 255.f,
cv::Size(),
cv::Scalar(0, 0, 0),
true,
false,
CV_32F
);
cv::copyMakeBorder(tmp, tmp, top, bottom, left, right, cv::BORDER_CONSTANT, {114, 114, 114});
cv::dnn::blobFromImage(tmp, out, 1 / 255.f, cv::Size(), cv::Scalar(0, 0, 0), true, false, CV_32F);
this->pparam.ratio = 1 / r;
this->pparam.dw = dw;
this->pparam.dh = dh;
this->pparam.height = height;
this->pparam.width = width;;
this->pparam.width = width;
;
}
void YOLOv8_pose::copy_from_Mat(const cv::Mat &image) {
void YOLOv8_pose::copy_from_Mat(const cv::Mat& image)
{
cv::Mat nchw;
auto& in_binding = this->input_bindings[0];
auto width = in_binding.dims.d[3];
auto height = in_binding.dims.d[2];
cv::Size size{width, height};
this->letterbox(
image,
nchw,
size
);
this->context->setBindingDimensions(
0,
nvinfer1::Dims
{
4,
{1, 3, height, width}
}
);
this->letterbox(image, nchw, size);
this->context->setBindingDimensions(0, nvinfer1::Dims{4, {1, 3, height, width}});
CHECK(cudaMemcpyAsync(
this->device_ptrs[0],
nchw.ptr<float>(),
nchw.total() * nchw.elemSize(),
cudaMemcpyHostToDevice,
this->stream)
);
this->device_ptrs[0], nchw.ptr<float>(), nchw.total() * nchw.elemSize(), cudaMemcpyHostToDevice, this->stream));
}
void YOLOv8_pose::copy_from_Mat(const cv::Mat &image, cv::Size &size) {
void YOLOv8_pose::copy_from_Mat(const cv::Mat& image, cv::Size& size)
{
cv::Mat nchw;
this->letterbox(
image,
nchw,
size
);
this->context->setBindingDimensions(
0,
nvinfer1::Dims
{4,
{1, 3, size.height, size.width}
}
);
this->letterbox(image, nchw, size);
this->context->setBindingDimensions(0, nvinfer1::Dims{4, {1, 3, size.height, size.width}});
CHECK(cudaMemcpyAsync(
this->device_ptrs[0],
nchw.ptr<float>(),
nchw.total() * nchw.elemSize(),
cudaMemcpyHostToDevice,
this->stream)
);
this->device_ptrs[0], nchw.ptr<float>(), nchw.total() * nchw.elemSize(), cudaMemcpyHostToDevice, this->stream));
}
void YOLOv8_pose::infer() {
void YOLOv8_pose::infer()
{
this->context->enqueueV2(
this->device_ptrs.data(),
this->stream,
nullptr
);
this->context->enqueueV2(this->device_ptrs.data(), this->stream, nullptr);
for (int i = 0; i < this->num_outputs; i++) {
size_t osize = this->output_bindings[i].size * this->output_bindings[i].dsize;
CHECK(cudaMemcpyAsync(this->host_ptrs[i],
this->device_ptrs[i + this->num_inputs],
osize,
cudaMemcpyDeviceToHost,
this->stream)
);
CHECK(cudaMemcpyAsync(
this->host_ptrs[i], this->device_ptrs[i + this->num_inputs], osize, cudaMemcpyDeviceToHost, this->stream));
}
cudaStreamSynchronize(this->stream);
}
void YOLOv8_pose::postprocess(
std::vector<Object> &objs,
float score_thres,
float iou_thres,
int topk
) {
void YOLOv8_pose::postprocess(std::vector<Object>& objs, float score_thres, float iou_thres, int topk)
{
objs.clear();
auto num_channels = this->output_bindings[0].dims.d[1];
auto num_anchors = this->output_bindings[0].dims.d[2];
@ -342,12 +247,7 @@ void YOLOv8_pose::postprocess(
std::vector<int> indices;
std::vector<std::vector<float>> kpss;
cv::Mat output = cv::Mat(
num_channels,
num_anchors,
CV_32F,
static_cast<float *>(this->host_ptrs[0])
);
cv::Mat output = cv::Mat(num_channels, num_anchors, CV_32F, static_cast<float*>(this->host_ptrs[0]));
output = output.t();
for (int i = 0; i < num_anchors; i++) {
auto row_ptr = output.row(i).ptr<float>();
@ -392,22 +292,9 @@ void YOLOv8_pose::postprocess(
}
#ifdef BATCHED_NMS
cv::dnn::NMSBoxesBatched(
bboxes,
scores,
labels,
score_thres,
iou_thres,
indices
);
cv::dnn::NMSBoxesBatched(bboxes, scores, labels, score_thres, iou_thres, indices);
#else
cv::dnn::NMSBoxes(
bboxes,
scores,
score_thres,
iou_thres,
indices
);
cv::dnn::NMSBoxes(bboxes, scores, score_thres, iou_thres, indices);
#endif
int cnt = 0;
@ -425,39 +312,23 @@ void YOLOv8_pose::postprocess(
}
}
void YOLOv8_pose::draw_objects(
const cv::Mat &image,
void YOLOv8_pose::draw_objects(const cv::Mat& image,
cv::Mat& res,
const std::vector<Object>& objs,
const std::vector<std::vector<unsigned int>>& SKELETON,
const std::vector<std::vector<unsigned int>>& KPS_COLORS,
const std::vector<std::vector<unsigned int>> &LIMB_COLORS
) {
const std::vector<std::vector<unsigned int>>& LIMB_COLORS)
{
res = image.clone();
const int num_point = 17;
for (auto& obj : objs) {
cv::rectangle(
res,
obj.rect,
{0, 0, 255},
2
);
cv::rectangle(res, obj.rect, {0, 0, 255}, 2);
char text[256];
sprintf(
text,
"person %.1f%%",
obj.prob * 100
);
sprintf(text, "person %.1f%%", obj.prob * 100);
int baseLine = 0;
cv::Size label_size = cv::getTextSize(
text,
cv::FONT_HERSHEY_SIMPLEX,
0.4,
1,
&baseLine
);
cv::Size label_size = cv::getTextSize(text, cv::FONT_HERSHEY_SIMPLEX, 0.4, 1, &baseLine);
int x = (int)obj.rect.x;
int y = (int)obj.rect.y + 1;
@ -465,22 +336,9 @@ void YOLOv8_pose::draw_objects(
if (y > res.rows)
y = res.rows;
cv::rectangle(
res,
cv::Rect(x, y, label_size.width, label_size.height + baseLine),
{0, 0, 255},
-1
);
cv::putText(
res,
text,
cv::Point(x, y + label_size.height),
cv::FONT_HERSHEY_SIMPLEX,
0.4,
{255, 255, 255},
1
);
cv::rectangle(res, cv::Rect(x, y, label_size.width, label_size.height + baseLine), {0, 0, 255}, -1);
cv::putText(res, text, cv::Point(x, y + label_size.height), cv::FONT_HERSHEY_SIMPLEX, 0.4, {255, 255, 255}, 1);
auto& kps = obj.kps;
for (int k = 0; k < num_point + 2; k++) {
@ -503,7 +361,6 @@ void YOLOv8_pose::draw_objects(
float pos1_s = kps[(ske[0] - 1) * 3 + 2];
float pos2_s = kps[(ske[1] - 1) * 3 + 2];
if (pos1_s > 0.5f && pos2_s > 0.5f) {
cv::Scalar limb_color = cv::Scalar(LIMB_COLORS[k][0], LIMB_COLORS[k][1], LIMB_COLORS[k][2]);
cv::line(res, {pos1_x, pos1_y}, {pos2_x, pos2_y}, limb_color, 2);

@ -2,12 +2,10 @@
// Created by ubuntu on 4/7/23.
//
#include "chrono"
#include "yolov8-pose.hpp"
#include "opencv2/opencv.hpp"
#include "yolov8-pose.hpp"
const std::vector<std::vector<unsigned int>> KPS_COLORS =
{{0, 255, 0},
const std::vector<std::vector<unsigned int>> KPS_COLORS = {{0, 255, 0},
{0, 255, 0},
{0, 255, 0},
{0, 255, 0},
@ -65,7 +63,8 @@ const std::vector<std::vector<unsigned int>> LIMB_COLORS = {{51, 153, 255},
{0, 255, 0},
{0, 255, 0}};
int main(int argc, char **argv) {
int main(int argc, char** argv)
{
// cuda:0
cudaSetDevice(0);
@ -82,26 +81,19 @@ int main(int argc, char **argv) {
if (IsFile(path)) {
std::string suffix = path.substr(path.find_last_of('.') + 1);
if (
suffix == "jpg" ||
suffix == "jpeg" ||
suffix == "png"
) {
if (suffix == "jpg" || suffix == "jpeg" || suffix == "png") {
imagePathList.push_back(path);
} else if (
suffix == "mp4" ||
suffix == "avi" ||
suffix == "m4v" ||
suffix == "mpeg" ||
suffix == "mov" ||
suffix == "mkv"
) {
}
else if (suffix == "mp4" || suffix == "avi" || suffix == "m4v" || suffix == "mpeg" || suffix == "mov"
|| suffix == "mkv") {
isVideo = true;
} else {
}
else {
printf("suffix %s is wrong !!!\n", suffix.c_str());
std::abort();
}
} else if (IsFolder(path)) {
}
else if (IsFolder(path)) {
cv::glob(path + "/*.jpg", imagePathList);
}
@ -130,15 +122,15 @@ int main(int argc, char **argv) {
auto end = std::chrono::system_clock::now();
yolov8_pose->postprocess(objs, score_thres, iou_thres, topk);
yolov8_pose->draw_objects(image, res, objs, SKELETON, KPS_COLORS, LIMB_COLORS);
auto tc = (double)
std::chrono::duration_cast<std::chrono::microseconds>(end - start).count() / 1000.;
auto tc = (double)std::chrono::duration_cast<std::chrono::microseconds>(end - start).count() / 1000.;
printf("cost %2.4lf ms\n", tc);
cv::imshow("result", res);
if (cv::waitKey(10) == 'q') {
break;
}
}
} else {
}
else {
for (auto& path : imagePathList) {
objs.clear();
image = cv::imread(path);
@ -148,8 +140,7 @@ int main(int argc, char **argv) {
auto end = std::chrono::system_clock::now();
yolov8_pose->postprocess(objs, score_thres, iou_thres, topk);
yolov8_pose->draw_objects(image, res, objs, SKELETON, KPS_COLORS, LIMB_COLORS);
auto tc = (double)
std::chrono::duration_cast<std::chrono::microseconds>(end - start).count() / 1000.;
auto tc = (double)std::chrono::duration_cast<std::chrono::microseconds>(end - start).count() / 1000.;
printf("cost %2.4lf ms\n", tc);
cv::imshow("result", res);
cv::waitKey(0);

@ -4,29 +4,25 @@
#ifndef JETSON_SEGMENT_COMMON_HPP
#define JETSON_SEGMENT_COMMON_HPP
#include "NvInfer.h"
#include "opencv2/opencv.hpp"
#include <sys/stat.h>
#include <unistd.h>
#include "NvInfer.h"
#define CHECK(call) \
do \
{ \
do { \
const cudaError_t error_code = call; \
if (error_code != cudaSuccess) \
{ \
if (error_code != cudaSuccess) { \
printf("CUDA Error:\n"); \
printf(" File: %s\n", __FILE__); \
printf(" Line: %d\n", __LINE__); \
printf(" Error code: %d\n", error_code); \
printf(" Error text: %s\n", \
cudaGetErrorString(error_code)); \
printf(" Error text: %s\n", cudaGetErrorString(error_code)); \
exit(1); \
} \
} while (0)
class Logger : public nvinfer1::ILogger
{
class Logger: public nvinfer1::ILogger {
public:
nvinfer1::ILogger::Severity reportableSeverity;
@ -37,12 +33,10 @@ public:
void log(nvinfer1::ILogger::Severity severity, const char* msg) noexcept override
{
if (severity > reportableSeverity)
{
if (severity > reportableSeverity) {
return;
}
switch (severity)
{
switch (severity) {
case nvinfer1::ILogger::Severity::kINTERNAL_ERROR:
std::cerr << "INTERNAL_ERROR: ";
break;
@ -66,8 +60,7 @@ public:
inline int get_size_by_dims(const nvinfer1::Dims& dims)
{
int size = 1;
for (int i = 0; i < dims.nbDims; i++)
{
for (int i = 0; i < dims.nbDims; i++) {
size *= dims.d[i];
}
return size;
@ -75,8 +68,7 @@ inline int get_size_by_dims(const nvinfer1::Dims& dims)
inline int type_to_size(const nvinfer1::DataType& dataType)
{
switch (dataType)
{
switch (dataType) {
case nvinfer1::DataType::kFLOAT:
return 4;
case nvinfer1::DataType::kHALF:
@ -99,8 +91,7 @@ inline static float clamp(float val, float min, float max)
inline bool IsPathExist(const std::string& path)
{
if (access(path.c_str(), 0) == F_OK)
{
if (access(path.c_str(), 0) == F_OK) {
return true;
}
return false;
@ -108,8 +99,7 @@ inline bool IsPathExist(const std::string& path)
inline bool IsFile(const std::string& path)
{
if (!IsPathExist(path))
{
if (!IsPathExist(path)) {
printf("%s:%d %s not exist\n", __FILE__, __LINE__, path.c_str());
return false;
}
@ -119,39 +109,34 @@ inline bool IsFile(const std::string& path)
inline bool IsFolder(const std::string& path)
{
if (!IsPathExist(path))
{
if (!IsPathExist(path)) {
return false;
}
struct stat buffer;
return (stat(path.c_str(), &buffer) == 0 && S_ISDIR(buffer.st_mode));
}
namespace seg
{
struct Binding
{
namespace seg {
struct Binding {
size_t size = 1;
size_t dsize = 1;
nvinfer1::Dims dims;
std::string name;
};
struct Object
{
struct Object {
cv::Rect_<float> rect;
int label = 0;
float prob = 0.0;
cv::Mat boxMask;
};
struct PreParam
{
struct PreParam {
float ratio = 1.0f;
float dw = 0.0f;
float dh = 0.0f;
float height = 0;
float width = 0;
};
}
} // namespace seg
#endif // JETSON_SEGMENT_COMMON_HPP

@ -3,14 +3,13 @@
//
#ifndef JETSON_SEGMENT_YOLOV8_SEG_HPP
#define JETSON_SEGMENT_YOLOV8_SEG_HPP
#include <fstream>
#include "common.hpp"
#include "NvInferPlugin.h"
#include "common.hpp"
#include <fstream>
using namespace seg;
class YOLOv8_seg
{
class YOLOv8_seg {
public:
explicit YOLOv8_seg(const std::string& engine_file_path);
~YOLOv8_seg();
@ -18,29 +17,21 @@ public:
void make_pipe(bool warmup = true);
void copy_from_Mat(const cv::Mat& image);
void copy_from_Mat(const cv::Mat& image, cv::Size& size);
void letterbox(
const cv::Mat& image,
cv::Mat& out,
cv::Size& size
);
void letterbox(const cv::Mat& image, cv::Mat& out, cv::Size& size);
void infer();
void postprocess(
std::vector<Object>& objs,
void postprocess(std::vector<Object>& objs,
float score_thres = 0.25f,
float iou_thres = 0.65f,
int topk = 100,
int seg_channels = 32,
int seg_h = 160,
int seg_w = 160
);
static void draw_objects(
const cv::Mat& image,
int seg_w = 160);
static void draw_objects(const cv::Mat& image,
cv::Mat& res,
const std::vector<Object>& objs,
const std::vector<std::string>& CLASS_NAMES,
const std::vector<std::vector<unsigned int>>& COLORS,
const std::vector<std::vector<unsigned int>>& MASK_COLORS
);
const std::vector<std::vector<unsigned int>>& MASK_COLORS);
int num_bindings;
int num_inputs = 0;
int num_outputs = 0;
@ -50,13 +41,13 @@ public:
std::vector<void*> device_ptrs;
PreParam pparam;
private:
nvinfer1::ICudaEngine* engine = nullptr;
nvinfer1::IRuntime* runtime = nullptr;
nvinfer1::IExecutionContext* context = nullptr;
cudaStream_t stream = nullptr;
Logger gLogger{nvinfer1::ILogger::Severity::kERROR};
};
YOLOv8_seg::YOLOv8_seg(const std::string& engine_file_path)
@ -83,8 +74,7 @@ YOLOv8_seg::YOLOv8_seg(const std::string& engine_file_path)
cudaStreamCreate(&this->stream);
this->num_bindings = this->engine->getNbBindings();
for (int i = 0; i < this->num_bindings; ++i)
{
for (int i = 0; i < this->num_bindings; ++i) {
Binding binding;
nvinfer1::Dims dims;
nvinfer1::DataType dtype = this->engine->getBindingDataType(i);
@ -93,22 +83,16 @@ YOLOv8_seg::YOLOv8_seg(const std::string& engine_file_path)
binding.dsize = type_to_size(dtype);
bool IsInput = engine->bindingIsInput(i);
if (IsInput)
{
if (IsInput) {
this->num_inputs += 1;
dims = this->engine->getProfileDimensions(
i,
0,
nvinfer1::OptProfileSelector::kMAX);
dims = this->engine->getProfileDimensions(i, 0, nvinfer1::OptProfileSelector::kMAX);
binding.size = get_size_by_dims(dims);
binding.dims = dims;
this->input_bindings.push_back(binding);
// set max opt shape
this->context->setBindingDimensions(i, dims);
}
else
{
else {
dims = this->context->getBindingDimensions(i);
binding.size = get_size_by_dims(dims);
binding.dims = dims;
@ -116,7 +100,6 @@ YOLOv8_seg::YOLOv8_seg(const std::string& engine_file_path)
this->num_outputs += 1;
}
}
}
YOLOv8_seg::~YOLOv8_seg()
@ -125,13 +108,11 @@ YOLOv8_seg::~YOLOv8_seg()
this->engine->destroy();
this->runtime->destroy();
cudaStreamDestroy(this->stream);
for (auto& ptr : this->device_ptrs)
{
for (auto& ptr : this->device_ptrs) {
CHECK(cudaFree(ptr));
}
for (auto& ptr : this->host_ptrs)
{
for (auto& ptr : this->host_ptrs) {
CHECK(cudaFreeHost(ptr));
}
}
@ -139,63 +120,37 @@ YOLOv8_seg::~YOLOv8_seg()
void YOLOv8_seg::make_pipe(bool warmup)
{
for (auto& bindings : this->input_bindings)
{
for (auto& bindings : this->input_bindings) {
void* d_ptr;
CHECK(cudaMalloc(
&d_ptr,
bindings.size * bindings.dsize)
);
CHECK(cudaMalloc(&d_ptr, bindings.size * bindings.dsize));
this->device_ptrs.push_back(d_ptr);
}
for (auto& bindings : this->output_bindings)
{
for (auto& bindings : this->output_bindings) {
void * d_ptr, *h_ptr;
size_t size = bindings.size * bindings.dsize;
CHECK(cudaMalloc(
&d_ptr,
size)
);
CHECK(cudaHostAlloc(
&h_ptr,
size,
0)
);
CHECK(cudaMalloc(&d_ptr, size));
CHECK(cudaHostAlloc(&h_ptr, size, 0));
this->device_ptrs.push_back(d_ptr);
this->host_ptrs.push_back(h_ptr);
}
if (warmup)
{
for (int i = 0; i < 10; i++)
{
for (auto& bindings : this->input_bindings)
{
if (warmup) {
for (int i = 0; i < 10; i++) {
for (auto& bindings : this->input_bindings) {
size_t size = bindings.size * bindings.dsize;
void* h_ptr = malloc(size);
memset(h_ptr, 0, size);
CHECK(cudaMemcpyAsync(
this->device_ptrs[0],
h_ptr,
size,
cudaMemcpyHostToDevice,
this->stream)
);
CHECK(cudaMemcpyAsync(this->device_ptrs[0], h_ptr, size, cudaMemcpyHostToDevice, this->stream));
free(h_ptr);
}
this->infer();
}
printf("model warmup 10 times\n");
}
}
void YOLOv8_seg::letterbox(
const cv::Mat& image,
cv::Mat& out,
cv::Size& size
)
void YOLOv8_seg::letterbox(const cv::Mat& image, cv::Mat& out, cv::Size& size)
{
const float inp_h = size.height;
const float inp_w = size.width;
@ -207,16 +162,10 @@ void YOLOv8_seg::letterbox(
int padh = std::round(height * r);
cv::Mat tmp;
if ((int)width != padw || (int)height != padh)
{
cv::resize(
image,
tmp,
cv::Size(padw, padh)
);
if ((int)width != padw || (int)height != padh) {
cv::resize(image, tmp, cv::Size(padw, padh));
}
else
{
else {
tmp = image.clone();
}
@ -230,31 +179,15 @@ void YOLOv8_seg::letterbox(
int left = int(std::round(dw - 0.1f));
int right = int(std::round(dw + 0.1f));
cv::copyMakeBorder(
tmp,
tmp,
top,
bottom,
left,
right,
cv::BORDER_CONSTANT,
{ 114, 114, 114 }
);
cv::dnn::blobFromImage(tmp,
out,
1 / 255.f,
cv::Size(),
cv::Scalar(0, 0, 0),
true,
false,
CV_32F
);
cv::copyMakeBorder(tmp, tmp, top, bottom, left, right, cv::BORDER_CONSTANT, {114, 114, 114});
cv::dnn::blobFromImage(tmp, out, 1 / 255.f, cv::Size(), cv::Scalar(0, 0, 0), true, false, CV_32F);
this->pparam.ratio = 1 / r;
this->pparam.dw = dw;
this->pparam.dh = dh;
this->pparam.height = height;
this->pparam.width = width;;
this->pparam.width = width;
;
}
void YOLOv8_seg::copy_from_Mat(const cv::Mat& image)
@ -264,85 +197,37 @@ void YOLOv8_seg::copy_from_Mat(const cv::Mat& image)
auto width = in_binding.dims.d[3];
auto height = in_binding.dims.d[2];
cv::Size size{width, height};
this->letterbox(
image,
nchw,
size
);
this->context->setBindingDimensions(
0,
nvinfer1::Dims
{
4,
{ 1, 3, height, width }
}
);
this->letterbox(image, nchw, size);
this->context->setBindingDimensions(0, nvinfer1::Dims{4, {1, 3, height, width}});
CHECK(cudaMemcpyAsync(
this->device_ptrs[0],
nchw.ptr<float>(),
nchw.total() * nchw.elemSize(),
cudaMemcpyHostToDevice,
this->stream)
);
this->device_ptrs[0], nchw.ptr<float>(), nchw.total() * nchw.elemSize(), cudaMemcpyHostToDevice, this->stream));
}
void YOLOv8_seg::copy_from_Mat(const cv::Mat& image, cv::Size& size)
{
cv::Mat nchw;
this->letterbox(
image,
nchw,
size
);
this->context->setBindingDimensions(
0,
nvinfer1::Dims
{ 4,
{ 1, 3, size.height, size.width }
}
);
this->letterbox(image, nchw, size);
this->context->setBindingDimensions(0, nvinfer1::Dims{4, {1, 3, size.height, size.width}});
CHECK(cudaMemcpyAsync(
this->device_ptrs[0],
nchw.ptr<float>(),
nchw.total() * nchw.elemSize(),
cudaMemcpyHostToDevice,
this->stream)
);
this->device_ptrs[0], nchw.ptr<float>(), nchw.total() * nchw.elemSize(), cudaMemcpyHostToDevice, this->stream));
}
void YOLOv8_seg::infer()
{
this->context->enqueueV2(
this->device_ptrs.data(),
this->stream,
nullptr
);
for (int i = 0; i < this->num_outputs; i++)
{
this->context->enqueueV2(this->device_ptrs.data(), this->stream, nullptr);
for (int i = 0; i < this->num_outputs; i++) {
size_t osize = this->output_bindings[i].size * this->output_bindings[i].dsize;
CHECK(cudaMemcpyAsync(this->host_ptrs[i],
this->device_ptrs[i + this->num_inputs],
osize,
cudaMemcpyDeviceToHost,
this->stream)
);
CHECK(cudaMemcpyAsync(
this->host_ptrs[i], this->device_ptrs[i + this->num_inputs], osize, cudaMemcpyDeviceToHost, this->stream));
}
cudaStreamSynchronize(this->stream);
}
void YOLOv8_seg::postprocess(std::vector<Object>& objs,
float score_thres,
float iou_thres,
int topk,
int seg_channels,
int seg_h,
int seg_w
)
void YOLOv8_seg::postprocess(
std::vector<Object>& objs, float score_thres, float iou_thres, int topk, int seg_channels, int seg_h, int seg_w)
{
objs.clear();
auto input_h = this->input_bindings[0].dims.d[2];
@ -357,8 +242,7 @@ void YOLOv8_seg::postprocess(std::vector<Object>& objs,
auto& ratio = this->pparam.ratio;
auto* output = static_cast<float*>(this->host_ptrs[0]);
cv::Mat protos = cv::Mat(seg_channels, seg_h * seg_w, CV_32F,
static_cast<float*>(this->host_ptrs[1]));
cv::Mat protos = cv::Mat(seg_channels, seg_h * seg_w, CV_32F, static_cast<float*>(this->host_ptrs[1]));
std::vector<int> labels;
std::vector<float> scores;
@ -366,12 +250,10 @@ void YOLOv8_seg::postprocess(std::vector<Object>& objs,
std::vector<cv::Mat> mask_confs;
std::vector<int> indices;
for (int i = 0; i < num_anchors; i++)
{
for (int i = 0; i < num_anchors; i++) {
float* ptr = output + i * num_channels;
float score = *(ptr + 4);
if (score > score_thres)
{
if (score > score_thres) {
float x0 = *ptr++ - dw;
float y0 = *ptr++ - dh;
float x1 = *ptr++ - dw;
@ -388,35 +270,19 @@ void YOLOv8_seg::postprocess(std::vector<Object>& objs,
labels.push_back(label);
scores.push_back(score);
bboxes.push_back(cv::Rect_<float>(x0, y0, x1 - x0, y1 - y0));
}
}
#if defined(BATCHED_NMS)
cv::dnn::NMSBoxesBatched(
bboxes,
scores,
labels,
score_thres,
iou_thres,
indices
);
cv::dnn::NMSBoxesBatched(bboxes, scores, labels, score_thres, iou_thres, indices);
#else
cv::dnn::NMSBoxes(
bboxes,
scores,
score_thres,
iou_thres,
indices
);
cv::dnn::NMSBoxes(bboxes, scores, score_thres, iou_thres, indices);
#endif
cv::Mat masks;
int cnt = 0;
for (auto& i : indices)
{
if (cnt >= topk)
{
for (auto& i : indices) {
if (cnt >= topk) {
break;
}
cv::Rect tmp = bboxes[i];
@ -428,12 +294,10 @@ void YOLOv8_seg::postprocess(std::vector<Object>& objs,
objs.push_back(obj);
cnt += 1;
}
if(masks.empty())
{
if (masks.empty()) {
// masks is empty
}
else
{
else {
cv::Mat matmulRes = (masks * protos).t();
cv::Mat maskMat = matmulRes.reshape(indices.size(), {seg_w, seg_h});
@ -442,24 +306,14 @@ void YOLOv8_seg::postprocess(std::vector<Object>& objs,
int scale_dw = dw / input_w * seg_w;
int scale_dh = dh / input_h * seg_h;
cv::Rect roi(
scale_dw,
scale_dh,
seg_w - 2 * scale_dw,
seg_h - 2 * scale_dh);
cv::Rect roi(scale_dw, scale_dh, seg_w - 2 * scale_dw, seg_h - 2 * scale_dh);
for (int i = 0; i < indices.size(); i++)
{
for (int i = 0; i < indices.size(); i++) {
cv::Mat dest, mask;
cv::exp(-maskChannels[i], dest);
dest = 1.0 / (1.0 + dest);
dest = dest(roi);
cv::resize(
dest,
mask,
cv::Size((int)width, (int)height),
cv::INTER_LINEAR
);
cv::resize(dest, mask, cv::Size((int)width, (int)height), cv::INTER_LINEAR);
objs[i].boxMask = mask(objs[i].rect) > 0.5f;
}
}
@ -470,48 +324,23 @@ void YOLOv8_seg::draw_objects(const cv::Mat& image,
const std::vector<Object>& objs,
const std::vector<std::string>& CLASS_NAMES,
const std::vector<std::vector<unsigned int>>& COLORS,
const std::vector<std::vector<unsigned int>>& MASK_COLORS
)
const std::vector<std::vector<unsigned int>>& MASK_COLORS)
{
res = image.clone();
cv::Mat mask = image.clone();
for (auto& obj : objs)
{
for (auto& obj : objs) {
int idx = obj.label;
cv::Scalar color = cv::Scalar(
COLORS[idx][0],
COLORS[idx][1],
COLORS[idx][2]
);
cv::Scalar mask_color = cv::Scalar(
MASK_COLORS[idx % 20][0],
MASK_COLORS[idx % 20][1],
MASK_COLORS[idx % 20][2]
);
cv::rectangle(
res,
obj.rect,
color,
2
);
cv::Scalar color = cv::Scalar(COLORS[idx][0], COLORS[idx][1], COLORS[idx][2]);
cv::Scalar mask_color =
cv::Scalar(MASK_COLORS[idx % 20][0], MASK_COLORS[idx % 20][1], MASK_COLORS[idx % 20][2]);
cv::rectangle(res, obj.rect, color, 2);
char text[256];
sprintf(
text,
"%s %.1f%%",
CLASS_NAMES[idx].c_str(),
obj.prob * 100
);
sprintf(text, "%s %.1f%%", CLASS_NAMES[idx].c_str(), obj.prob * 100);
mask(obj.rect).setTo(mask_color, obj.boxMask);
int baseLine = 0;
cv::Size label_size = cv::getTextSize(
text,
cv::FONT_HERSHEY_SIMPLEX,
0.4,
1,
&baseLine
);
cv::Size label_size = cv::getTextSize(text, cv::FONT_HERSHEY_SIMPLEX, 0.4, 1, &baseLine);
int x = (int)obj.rect.x;
int y = (int)obj.rect.y + 1;
@ -519,30 +348,10 @@ void YOLOv8_seg::draw_objects(const cv::Mat& image,
if (y > res.rows)
y = res.rows;
cv::rectangle(
res,
cv::Rect(x, y, label_size.width, label_size.height + baseLine),
{ 0, 0, 255 },
-1
);
cv::putText(
res,
text,
cv::Point(x, y + label_size.height),
cv::FONT_HERSHEY_SIMPLEX,
0.4,
{ 255, 255, 255 },
1
);
cv::rectangle(res, cv::Rect(x, y, label_size.width, label_size.height + baseLine), {0, 0, 255}, -1);
cv::putText(res, text, cv::Point(x, y + label_size.height), cv::FONT_HERSHEY_SIMPLEX, 0.4, {255, 255, 255}, 1);
}
cv::addWeighted(
res,
0.5,
mask,
0.8,
1,
res
);
cv::addWeighted(res, 0.5, mask, 0.8, 1, res);
}
#endif // JETSON_SEGMENT_YOLOV8_SEG_HPP

@ -2,66 +2,43 @@
// Created by ubuntu on 3/16/23.
//
#include "chrono"
#include "yolov8-seg.hpp"
#include "opencv2/opencv.hpp"
#include "yolov8-seg.hpp"
const std::vector<std::string> CLASS_NAMES = {
"person", "bicycle", "car", "motorcycle", "airplane", "bus",
"train", "truck", "boat", "traffic light", "fire hydrant",
"stop sign", "parking meter", "bench", "bird", "cat",
"dog", "horse", "sheep", "cow", "elephant",
"bear", "zebra", "giraffe", "backpack", "umbrella",
"handbag", "tie", "suitcase", "frisbee", "skis",
"snowboard", "sports ball", "kite", "baseball bat", "baseball glove",
"skateboard", "surfboard", "tennis racket", "bottle", "wine glass",
"cup", "fork", "knife", "spoon", "bowl",
"banana", "apple", "sandwich", "orange", "broccoli",
"carrot", "hot dog", "pizza", "donut", "cake",
"chair", "couch", "potted plant", "bed", "dining table",
"toilet", "tv", "laptop", "mouse", "remote",
"keyboard", "cell phone", "microwave", "oven",
"toaster", "sink", "refrigerator", "book", "clock", "vase",
"scissors", "teddy bear", "hair drier", "toothbrush" };
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train",
"truck", "boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench",
"bird", "cat", "dog", "horse", "sheep", "cow", "elephant",
"bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie",
"suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat",
"baseball glove", "skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup",
"fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich",
"orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake",
"chair", "couch", "potted plant", "bed", "dining table", "toilet", "tv",
"laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven",
"toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors",
"teddy bear", "hair drier", "toothbrush"};
const std::vector<std::vector<unsigned int>> COLORS = {
{ 0, 114, 189 }, { 217, 83, 25 }, { 237, 177, 32 },
{ 126, 47, 142 }, { 119, 172, 48 }, { 77, 190, 238 },
{ 162, 20, 47 }, { 76, 76, 76 }, { 153, 153, 153 },
{ 255, 0, 0 }, { 255, 128, 0 }, { 191, 191, 0 },
{ 0, 255, 0 }, { 0, 0, 255 }, { 170, 0, 255 },
{ 85, 85, 0 }, { 85, 170, 0 }, { 85, 255, 0 },
{ 170, 85, 0 }, { 170, 170, 0 }, { 170, 255, 0 },
{ 255, 85, 0 }, { 255, 170, 0 }, { 255, 255, 0 },
{ 0, 85, 128 }, { 0, 170, 128 }, { 0, 255, 128 },
{ 85, 0, 128 }, { 85, 85, 128 }, { 85, 170, 128 },
{ 85, 255, 128 }, { 170, 0, 128 }, { 170, 85, 128 },
{ 170, 170, 128 }, { 170, 255, 128 }, { 255, 0, 128 },
{ 255, 85, 128 }, { 255, 170, 128 }, { 255, 255, 128 },
{ 0, 85, 255 }, { 0, 170, 255 }, { 0, 255, 255 },
{ 85, 0, 255 }, { 85, 85, 255 }, { 85, 170, 255 },
{ 85, 255, 255 }, { 170, 0, 255 }, { 170, 85, 255 },
{ 170, 170, 255 }, { 170, 255, 255 }, { 255, 0, 255 },
{ 255, 85, 255 }, { 255, 170, 255 }, { 85, 0, 0 },
{ 128, 0, 0 }, { 170, 0, 0 }, { 212, 0, 0 },
{ 255, 0, 0 }, { 0, 43, 0 }, { 0, 85, 0 },
{ 0, 128, 0 }, { 0, 170, 0 }, { 0, 212, 0 },
{ 0, 255, 0 }, { 0, 0, 43 }, { 0, 0, 85 },
{ 0, 0, 128 }, { 0, 0, 170 }, { 0, 0, 212 },
{ 0, 0, 255 }, { 0, 0, 0 }, { 36, 36, 36 },
{ 73, 73, 73 }, { 109, 109, 109 }, { 146, 146, 146 },
{ 182, 182, 182 }, { 219, 219, 219 }, { 0, 114, 189 },
{ 80, 183, 189 }, { 128, 128, 0 }
};
{0, 114, 189}, {217, 83, 25}, {237, 177, 32}, {126, 47, 142}, {119, 172, 48}, {77, 190, 238},
{162, 20, 47}, {76, 76, 76}, {153, 153, 153}, {255, 0, 0}, {255, 128, 0}, {191, 191, 0},
{0, 255, 0}, {0, 0, 255}, {170, 0, 255}, {85, 85, 0}, {85, 170, 0}, {85, 255, 0},
{170, 85, 0}, {170, 170, 0}, {170, 255, 0}, {255, 85, 0}, {255, 170, 0}, {255, 255, 0},
{0, 85, 128}, {0, 170, 128}, {0, 255, 128}, {85, 0, 128}, {85, 85, 128}, {85, 170, 128},
{85, 255, 128}, {170, 0, 128}, {170, 85, 128}, {170, 170, 128}, {170, 255, 128}, {255, 0, 128},
{255, 85, 128}, {255, 170, 128}, {255, 255, 128}, {0, 85, 255}, {0, 170, 255}, {0, 255, 255},
{85, 0, 255}, {85, 85, 255}, {85, 170, 255}, {85, 255, 255}, {170, 0, 255}, {170, 85, 255},
{170, 170, 255}, {170, 255, 255}, {255, 0, 255}, {255, 85, 255}, {255, 170, 255}, {85, 0, 0},
{128, 0, 0}, {170, 0, 0}, {212, 0, 0}, {255, 0, 0}, {0, 43, 0}, {0, 85, 0},
{0, 128, 0}, {0, 170, 0}, {0, 212, 0}, {0, 255, 0}, {0, 0, 43}, {0, 0, 85},
{0, 0, 128}, {0, 0, 170}, {0, 0, 212}, {0, 0, 255}, {0, 0, 0}, {36, 36, 36},
{73, 73, 73}, {109, 109, 109}, {146, 146, 146}, {182, 182, 182}, {219, 219, 219}, {0, 114, 189},
{80, 183, 189}, {128, 128, 0}};
const std::vector<std::vector<unsigned int>> MASK_COLORS = {
{ 255, 56, 56 }, { 255, 157, 151 }, { 255, 112, 31 },
{ 255, 178, 29 }, { 207, 210, 49 }, { 72, 249, 10 },
{ 146, 204, 23 }, { 61, 219, 134 }, { 26, 147, 52 },
{ 0, 212, 187 }, { 44, 153, 168 }, { 0, 194, 255 },
{ 52, 69, 147 }, { 100, 115, 255 }, { 0, 24, 236 },
{ 132, 56, 255 }, { 82, 0, 133 }, { 203, 56, 255 },
{ 255, 149, 200 }, { 255, 55, 199 }
};
{255, 56, 56}, {255, 157, 151}, {255, 112, 31}, {255, 178, 29}, {207, 210, 49}, {72, 249, 10}, {146, 204, 23},
{61, 219, 134}, {26, 147, 52}, {0, 212, 187}, {44, 153, 168}, {0, 194, 255}, {52, 69, 147}, {100, 115, 255},
{0, 24, 236}, {132, 56, 255}, {82, 0, 133}, {203, 56, 255}, {255, 149, 200}, {255, 55, 199}};
int main(int argc, char** argv)
{
@ -79,36 +56,21 @@ int main(int argc, char** argv)
auto yolov8 = new YOLOv8_seg(engine_file_path);
yolov8->make_pipe(true);
if (IsFile(path))
{
if (IsFile(path)) {
std::string suffix = path.substr(path.find_last_of('.') + 1);
if (
suffix == "jpg" ||
suffix == "jpeg" ||
suffix == "png"
)
{
if (suffix == "jpg" || suffix == "jpeg" || suffix == "png") {
imagePathList.push_back(path);
}
else if (
suffix == "mp4" ||
suffix == "avi" ||
suffix == "m4v" ||
suffix == "mpeg" ||
suffix == "mov" ||
suffix == "mkv"
)
{
else if (suffix == "mp4" || suffix == "avi" || suffix == "m4v" || suffix == "mpeg" || suffix == "mov"
|| suffix == "mkv") {
isVideo = true;
}
else
{
else {
printf("suffix %s is wrong !!!\n", suffix.c_str());
std::abort();
}
}
else if (IsFolder(path))
{
else if (IsFolder(path)) {
cv::glob(path + "/*.jpg", imagePathList);
}
@ -125,17 +87,14 @@ int main(int argc, char** argv)
cv::namedWindow("result", cv::WINDOW_AUTOSIZE);
if (isVideo)
{
if (isVideo) {
cv::VideoCapture cap(path);
if (!cap.isOpened())
{
if (!cap.isOpened()) {
printf("can not open %s\n", path.c_str());
return -1;
}
while (cap.read(image))
{
while (cap.read(image)) {
objs.clear();
yolov8->copy_from_Mat(image, size);
auto start = std::chrono::system_clock::now();
@ -143,20 +102,16 @@ int main(int argc, char** argv)
auto end = std::chrono::system_clock::now();
yolov8->postprocess(objs, score_thres, iou_thres, topk, seg_channels, seg_h, seg_w);
yolov8->draw_objects(image, res, objs, CLASS_NAMES, COLORS, MASK_COLORS);
auto tc = (double)
std::chrono::duration_cast<std::chrono::microseconds>(end - start).count() / 1000.;
auto tc = (double)std::chrono::duration_cast<std::chrono::microseconds>(end - start).count() / 1000.;
printf("cost %2.4lf ms\n", tc);
cv::imshow("result", res);
if (cv::waitKey(10) == 'q')
{
if (cv::waitKey(10) == 'q') {
break;
}
}
}
else
{
for (auto& path : imagePathList)
{
else {
for (auto& path : imagePathList) {
objs.clear();
image = cv::imread(path);
yolov8->copy_from_Mat(image, size);
@ -165,8 +120,7 @@ int main(int argc, char** argv)
auto end = std::chrono::system_clock::now();
yolov8->postprocess(objs, score_thres, iou_thres, topk, seg_channels, seg_h, seg_w);
yolov8->draw_objects(image, res, objs, CLASS_NAMES, COLORS, MASK_COLORS);
auto tc = (double)
std::chrono::duration_cast<std::chrono::microseconds>(end - start).count() / 1000.;
auto tc = (double)std::chrono::duration_cast<std::chrono::microseconds>(end - start).count() / 1000.;
printf("cost %2.4lf ms\n", tc);
cv::imshow("result", res);
cv::waitKey(0);

@ -4,29 +4,25 @@
#ifndef POSE_NORMAL_COMMON_HPP
#define POSE_NORMAL_COMMON_HPP
#include "NvInfer.h"
#include "opencv2/opencv.hpp"
#include <sys/stat.h>
#include <unistd.h>
#include "NvInfer.h"
#define CHECK(call) \
do \
{ \
do { \
const cudaError_t error_code = call; \
if (error_code != cudaSuccess) \
{ \
if (error_code != cudaSuccess) { \
printf("CUDA Error:\n"); \
printf(" File: %s\n", __FILE__); \
printf(" Line: %d\n", __LINE__); \
printf(" Error code: %d\n", error_code); \
printf(" Error text: %s\n", \
cudaGetErrorString(error_code)); \
printf(" Error text: %s\n", cudaGetErrorString(error_code)); \
exit(1); \
} \
} while (0)
class Logger : public nvinfer1::ILogger
{
class Logger: public nvinfer1::ILogger {
public:
nvinfer1::ILogger::Severity reportableSeverity;
@ -37,12 +33,10 @@ public:
void log(nvinfer1::ILogger::Severity severity, const char* msg) noexcept override
{
if (severity > reportableSeverity)
{
if (severity > reportableSeverity) {
return;
}
switch (severity)
{
switch (severity) {
case nvinfer1::ILogger::Severity::kINTERNAL_ERROR:
std::cerr << "INTERNAL_ERROR: ";
break;
@ -66,8 +60,7 @@ public:
inline int get_size_by_dims(const nvinfer1::Dims& dims)
{
int size = 1;
for (int i = 0; i < dims.nbDims; i++)
{
for (int i = 0; i < dims.nbDims; i++) {
size *= dims.d[i];
}
return size;
@ -75,8 +68,7 @@ inline int get_size_by_dims(const nvinfer1::Dims& dims)
inline int type_to_size(const nvinfer1::DataType& dataType)
{
switch (dataType)
{
switch (dataType) {
case nvinfer1::DataType::kFLOAT:
return 4;
case nvinfer1::DataType::kHALF:
@ -99,8 +91,7 @@ inline static float clamp(float val, float min, float max)
inline bool IsPathExist(const std::string& path)
{
if (access(path.c_str(), 0) == F_OK)
{
if (access(path.c_str(), 0) == F_OK) {
return true;
}
return false;
@ -108,8 +99,7 @@ inline bool IsPathExist(const std::string& path)
inline bool IsFile(const std::string& path)
{
if (!IsPathExist(path))
{
if (!IsPathExist(path)) {
printf("%s:%d %s not exist\n", __FILE__, __LINE__, path.c_str());
return false;
}
@ -119,39 +109,34 @@ inline bool IsFile(const std::string& path)
inline bool IsFolder(const std::string& path)
{
if (!IsPathExist(path))
{
if (!IsPathExist(path)) {
return false;
}
struct stat buffer;
return (stat(path.c_str(), &buffer) == 0 && S_ISDIR(buffer.st_mode));
}
namespace pose
{
struct Binding
{
namespace pose {
struct Binding {
size_t size = 1;
size_t dsize = 1;
nvinfer1::Dims dims;
std::string name;
};
struct Object
{
struct Object {
cv::Rect_<float> rect;
int label = 0;
float prob = 0.0;
std::vector<float> kps;
};
struct PreParam
{
struct PreParam {
float ratio = 1.0f;
float dw = 0.0f;
float dh = 0.0f;
float height = 0;
float width = 0;
};
}
} // namespace pose
#endif // POSE_NORMAL_COMMON_HPP

@ -4,9 +4,9 @@
#ifndef POSE_NORMAL_YOLOv8_pose_HPP
#define POSE_NORMAL_YOLOv8_pose_HPP
#include "fstream"
#include "common.hpp"
#include "NvInferPlugin.h"
#include "common.hpp"
#include "fstream"
using namespace pose;
@ -22,29 +22,18 @@ public:
void copy_from_Mat(const cv::Mat& image, cv::Size& size);
void letterbox(
const cv::Mat &image,
cv::Mat &out,
cv::Size &size
);
void letterbox(const cv::Mat& image, cv::Mat& out, cv::Size& size);
void infer();
void postprocess(
std::vector<Object> &objs,
float score_thres = 0.25f,
float iou_thres = 0.65f,
int topk = 100
);
void postprocess(std::vector<Object>& objs, float score_thres = 0.25f, float iou_thres = 0.65f, int topk = 100);
static void draw_objects(
const cv::Mat &image,
static void draw_objects(const cv::Mat& image,
cv::Mat& res,
const std::vector<Object>& objs,
const std::vector<std::vector<unsigned int>>& SKELETON,
const std::vector<std::vector<unsigned int>>& KPS_COLORS,
const std::vector<std::vector<unsigned int>> &LIMB_COLORS
);
const std::vector<std::vector<unsigned int>>& LIMB_COLORS);
int num_bindings;
int num_inputs = 0;
@ -55,16 +44,17 @@ public:
std::vector<void*> device_ptrs;
PreParam pparam;
private:
nvinfer1::ICudaEngine* engine = nullptr;
nvinfer1::IRuntime* runtime = nullptr;
nvinfer1::IExecutionContext* context = nullptr;
cudaStream_t stream = nullptr;
Logger gLogger{nvinfer1::ILogger::Severity::kERROR};
};
YOLOv8_pose::YOLOv8_pose(const std::string &engine_file_path) {
YOLOv8_pose::YOLOv8_pose(const std::string& engine_file_path)
{
std::ifstream file(engine_file_path, std::ios::binary);
assert(file.good());
file.seekg(0, std::ios::end);
@ -98,17 +88,14 @@ YOLOv8_pose::YOLOv8_pose(const std::string &engine_file_path) {
bool IsInput = engine->bindingIsInput(i);
if (IsInput) {
this->num_inputs += 1;
dims = this->engine->getProfileDimensions(
i,
0,
nvinfer1::OptProfileSelector::kMAX);
dims = this->engine->getProfileDimensions(i, 0, nvinfer1::OptProfileSelector::kMAX);
binding.size = get_size_by_dims(dims);
binding.dims = dims;
this->input_bindings.push_back(binding);
// set max opt shape
this->context->setBindingDimensions(i, dims);
} else {
}
else {
dims = this->context->getBindingDimensions(i);
binding.size = get_size_by_dims(dims);
binding.dims = dims;
@ -116,10 +103,10 @@ YOLOv8_pose::YOLOv8_pose(const std::string &engine_file_path) {
this->num_outputs += 1;
}
}
}
YOLOv8_pose::~YOLOv8_pose() {
YOLOv8_pose::~YOLOv8_pose()
{
this->context->destroy();
this->engine->destroy();
this->runtime->destroy();
@ -131,34 +118,22 @@ YOLOv8_pose::~YOLOv8_pose() {
for (auto& ptr : this->host_ptrs) {
CHECK(cudaFreeHost(ptr));
}
}
void YOLOv8_pose::make_pipe(bool warmup) {
void YOLOv8_pose::make_pipe(bool warmup)
{
for (auto& bindings : this->input_bindings) {
void* d_ptr;
CHECK(cudaMallocAsync(
&d_ptr,
bindings.size * bindings.dsize,
this->stream)
);
CHECK(cudaMallocAsync(&d_ptr, bindings.size * bindings.dsize, this->stream));
this->device_ptrs.push_back(d_ptr);
}
for (auto& bindings : this->output_bindings) {
void * d_ptr, *h_ptr;
size_t size = bindings.size * bindings.dsize;
CHECK(cudaMallocAsync(
&d_ptr,
size,
this->stream)
);
CHECK(cudaHostAlloc(
&h_ptr,
size,
0)
);
CHECK(cudaMallocAsync(&d_ptr, size, this->stream));
CHECK(cudaHostAlloc(&h_ptr, size, 0));
this->device_ptrs.push_back(d_ptr);
this->host_ptrs.push_back(h_ptr);
}
@ -169,27 +144,17 @@ void YOLOv8_pose::make_pipe(bool warmup) {
size_t size = bindings.size * bindings.dsize;
void* h_ptr = malloc(size);
memset(h_ptr, 0, size);
CHECK(cudaMemcpyAsync(
this->device_ptrs[0],
h_ptr,
size,
cudaMemcpyHostToDevice,
this->stream)
);
CHECK(cudaMemcpyAsync(this->device_ptrs[0], h_ptr, size, cudaMemcpyHostToDevice, this->stream));
free(h_ptr);
}
this->infer();
}
printf("model warmup 10 times\n");
}
}
void YOLOv8_pose::letterbox(
const cv::Mat &image,
cv::Mat &out,
cv::Size &size
) {
void YOLOv8_pose::letterbox(const cv::Mat& image, cv::Mat& out, cv::Size& size)
{
const float inp_h = size.height;
const float inp_w = size.width;
float height = image.rows;
@ -201,12 +166,9 @@ void YOLOv8_pose::letterbox(
cv::Mat tmp;
if ((int)width != padw || (int)height != padh) {
cv::resize(
image,
tmp,
cv::Size(padw, padh)
);
} else {
cv::resize(image, tmp, cv::Size(padw, padh));
}
else {
tmp = image.clone();
}
@ -220,113 +182,55 @@ void YOLOv8_pose::letterbox(
int left = int(std::round(dw - 0.1f));
int right = int(std::round(dw + 0.1f));
cv::copyMakeBorder(
tmp,
tmp,
top,
bottom,
left,
right,
cv::BORDER_CONSTANT,
{114, 114, 114}
);
cv::dnn::blobFromImage(tmp,
out,
1 / 255.f,
cv::Size(),
cv::Scalar(0, 0, 0),
true,
false,
CV_32F
);
cv::copyMakeBorder(tmp, tmp, top, bottom, left, right, cv::BORDER_CONSTANT, {114, 114, 114});
cv::dnn::blobFromImage(tmp, out, 1 / 255.f, cv::Size(), cv::Scalar(0, 0, 0), true, false, CV_32F);
this->pparam.ratio = 1 / r;
this->pparam.dw = dw;
this->pparam.dh = dh;
this->pparam.height = height;
this->pparam.width = width;;
this->pparam.width = width;
;
}
void YOLOv8_pose::copy_from_Mat(const cv::Mat &image) {
void YOLOv8_pose::copy_from_Mat(const cv::Mat& image)
{
cv::Mat nchw;
auto& in_binding = this->input_bindings[0];
auto width = in_binding.dims.d[3];
auto height = in_binding.dims.d[2];
cv::Size size{width, height};
this->letterbox(
image,
nchw,
size
);
this->context->setBindingDimensions(
0,
nvinfer1::Dims
{
4,
{1, 3, height, width}
}
);
this->letterbox(image, nchw, size);
this->context->setBindingDimensions(0, nvinfer1::Dims{4, {1, 3, height, width}});
CHECK(cudaMemcpyAsync(
this->device_ptrs[0],
nchw.ptr<float>(),
nchw.total() * nchw.elemSize(),
cudaMemcpyHostToDevice,
this->stream)
);
this->device_ptrs[0], nchw.ptr<float>(), nchw.total() * nchw.elemSize(), cudaMemcpyHostToDevice, this->stream));
}
void YOLOv8_pose::copy_from_Mat(const cv::Mat &image, cv::Size &size) {
void YOLOv8_pose::copy_from_Mat(const cv::Mat& image, cv::Size& size)
{
cv::Mat nchw;
this->letterbox(
image,
nchw,
size
);
this->context->setBindingDimensions(
0,
nvinfer1::Dims
{4,
{1, 3, size.height, size.width}
}
);
this->letterbox(image, nchw, size);
this->context->setBindingDimensions(0, nvinfer1::Dims{4, {1, 3, size.height, size.width}});
CHECK(cudaMemcpyAsync(
this->device_ptrs[0],
nchw.ptr<float>(),
nchw.total() * nchw.elemSize(),
cudaMemcpyHostToDevice,
this->stream)
);
this->device_ptrs[0], nchw.ptr<float>(), nchw.total() * nchw.elemSize(), cudaMemcpyHostToDevice, this->stream));
}
void YOLOv8_pose::infer() {
void YOLOv8_pose::infer()
{
this->context->enqueueV2(
this->device_ptrs.data(),
this->stream,
nullptr
);
this->context->enqueueV2(this->device_ptrs.data(), this->stream, nullptr);
for (int i = 0; i < this->num_outputs; i++) {
size_t osize = this->output_bindings[i].size * this->output_bindings[i].dsize;
CHECK(cudaMemcpyAsync(this->host_ptrs[i],
this->device_ptrs[i + this->num_inputs],
osize,
cudaMemcpyDeviceToHost,
this->stream)
);
CHECK(cudaMemcpyAsync(
this->host_ptrs[i], this->device_ptrs[i + this->num_inputs], osize, cudaMemcpyDeviceToHost, this->stream));
}
cudaStreamSynchronize(this->stream);
}
void YOLOv8_pose::postprocess(
std::vector<Object> &objs,
float score_thres,
float iou_thres,
int topk
) {
void YOLOv8_pose::postprocess(std::vector<Object>& objs, float score_thres, float iou_thres, int topk)
{
objs.clear();
auto num_channels = this->output_bindings[0].dims.d[1];
auto num_anchors = this->output_bindings[0].dims.d[2];
@ -343,12 +247,7 @@ void YOLOv8_pose::postprocess(
std::vector<int> indices;
std::vector<std::vector<float>> kpss;
cv::Mat output = cv::Mat(
num_channels,
num_anchors,
CV_32F,
static_cast<float *>(this->host_ptrs[0])
);
cv::Mat output = cv::Mat(num_channels, num_anchors, CV_32F, static_cast<float*>(this->host_ptrs[0]));
output = output.t();
for (int i = 0; i < num_anchors; i++) {
auto row_ptr = output.row(i).ptr<float>();
@ -393,22 +292,9 @@ void YOLOv8_pose::postprocess(
}
#ifdef BATCHED_NMS
cv::dnn::NMSBoxesBatched(
bboxes,
scores,
labels,
score_thres,
iou_thres,
indices
);
cv::dnn::NMSBoxesBatched(bboxes, scores, labels, score_thres, iou_thres, indices);
#else
cv::dnn::NMSBoxes(
bboxes,
scores,
score_thres,
iou_thres,
indices
);
cv::dnn::NMSBoxes(bboxes, scores, score_thres, iou_thres, indices);
#endif
int cnt = 0;
@ -426,39 +312,23 @@ void YOLOv8_pose::postprocess(
}
}
void YOLOv8_pose::draw_objects(
const cv::Mat &image,
void YOLOv8_pose::draw_objects(const cv::Mat& image,
cv::Mat& res,
const std::vector<Object>& objs,
const std::vector<std::vector<unsigned int>>& SKELETON,
const std::vector<std::vector<unsigned int>>& KPS_COLORS,
const std::vector<std::vector<unsigned int>> &LIMB_COLORS
) {
const std::vector<std::vector<unsigned int>>& LIMB_COLORS)
{
res = image.clone();
const int num_point = 17;
for (auto& obj : objs) {
cv::rectangle(
res,
obj.rect,
{0, 0, 255},
2
);
cv::rectangle(res, obj.rect, {0, 0, 255}, 2);
char text[256];
sprintf(
text,
"person %.1f%%",
obj.prob * 100
);
sprintf(text, "person %.1f%%", obj.prob * 100);
int baseLine = 0;
cv::Size label_size = cv::getTextSize(
text,
cv::FONT_HERSHEY_SIMPLEX,
0.4,
1,
&baseLine
);
cv::Size label_size = cv::getTextSize(text, cv::FONT_HERSHEY_SIMPLEX, 0.4, 1, &baseLine);
int x = (int)obj.rect.x;
int y = (int)obj.rect.y + 1;
@ -466,22 +336,9 @@ void YOLOv8_pose::draw_objects(
if (y > res.rows)
y = res.rows;
cv::rectangle(
res,
cv::Rect(x, y, label_size.width, label_size.height + baseLine),
{0, 0, 255},
-1
);
cv::putText(
res,
text,
cv::Point(x, y + label_size.height),
cv::FONT_HERSHEY_SIMPLEX,
0.4,
{255, 255, 255},
1
);
cv::rectangle(res, cv::Rect(x, y, label_size.width, label_size.height + baseLine), {0, 0, 255}, -1);
cv::putText(res, text, cv::Point(x, y + label_size.height), cv::FONT_HERSHEY_SIMPLEX, 0.4, {255, 255, 255}, 1);
auto& kps = obj.kps;
for (int k = 0; k < num_point + 2; k++) {
@ -504,7 +361,6 @@ void YOLOv8_pose::draw_objects(
float pos1_s = kps[(ske[0] - 1) * 3 + 2];
float pos2_s = kps[(ske[1] - 1) * 3 + 2];
if (pos1_s > 0.5f && pos2_s > 0.5f) {
cv::Scalar limb_color = cv::Scalar(LIMB_COLORS[k][0], LIMB_COLORS[k][1], LIMB_COLORS[k][2]);
cv::line(res, {pos1_x, pos1_y}, {pos2_x, pos2_y}, limb_color, 2);

@ -2,12 +2,10 @@
// Created by ubuntu on 4/7/23.
//
#include "chrono"
#include "yolov8-pose.hpp"
#include "opencv2/opencv.hpp"
#include "yolov8-pose.hpp"
const std::vector<std::vector<unsigned int>> KPS_COLORS =
{{0, 255, 0},
const std::vector<std::vector<unsigned int>> KPS_COLORS = {{0, 255, 0},
{0, 255, 0},
{0, 255, 0},
{0, 255, 0},
@ -65,7 +63,8 @@ const std::vector<std::vector<unsigned int>> LIMB_COLORS = {{51, 153, 255},
{0, 255, 0},
{0, 255, 0}};
int main(int argc, char **argv) {
int main(int argc, char** argv)
{
// cuda:0
cudaSetDevice(0);
@ -82,26 +81,19 @@ int main(int argc, char **argv) {
if (IsFile(path)) {
std::string suffix = path.substr(path.find_last_of('.') + 1);
if (
suffix == "jpg" ||
suffix == "jpeg" ||
suffix == "png"
) {
if (suffix == "jpg" || suffix == "jpeg" || suffix == "png") {
imagePathList.push_back(path);
} else if (
suffix == "mp4" ||
suffix == "avi" ||
suffix == "m4v" ||
suffix == "mpeg" ||
suffix == "mov" ||
suffix == "mkv"
) {
}
else if (suffix == "mp4" || suffix == "avi" || suffix == "m4v" || suffix == "mpeg" || suffix == "mov"
|| suffix == "mkv") {
isVideo = true;
} else {
}
else {
printf("suffix %s is wrong !!!\n", suffix.c_str());
std::abort();
}
} else if (IsFolder(path)) {
}
else if (IsFolder(path)) {
cv::glob(path + "/*.jpg", imagePathList);
}
@ -130,15 +122,15 @@ int main(int argc, char **argv) {
auto end = std::chrono::system_clock::now();
yolov8_pose->postprocess(objs, score_thres, iou_thres, topk);
yolov8_pose->draw_objects(image, res, objs, SKELETON, KPS_COLORS, LIMB_COLORS);
auto tc = (double)
std::chrono::duration_cast<std::chrono::microseconds>(end - start).count() / 1000.;
auto tc = (double)std::chrono::duration_cast<std::chrono::microseconds>(end - start).count() / 1000.;
printf("cost %2.4lf ms\n", tc);
cv::imshow("result", res);
if (cv::waitKey(10) == 'q') {
break;
}
}
} else {
}
else {
for (auto& path : imagePathList) {
objs.clear();
image = cv::imread(path);
@ -148,8 +140,7 @@ int main(int argc, char **argv) {
auto end = std::chrono::system_clock::now();
yolov8_pose->postprocess(objs, score_thres, iou_thres, topk);
yolov8_pose->draw_objects(image, res, objs, SKELETON, KPS_COLORS, LIMB_COLORS);
auto tc = (double)
std::chrono::duration_cast<std::chrono::microseconds>(end - start).count() / 1000.;
auto tc = (double)std::chrono::duration_cast<std::chrono::microseconds>(end - start).count() / 1000.;
printf("cost %2.4lf ms\n", tc);
cv::imshow("result", res);
cv::waitKey(0);

@ -4,29 +4,25 @@
#ifndef SEGMENT_NORMAL_COMMON_HPP
#define SEGMENT_NORMAL_COMMON_HPP
#include "NvInfer.h"
#include "opencv2/opencv.hpp"
#include <sys/stat.h>
#include <unistd.h>
#include "NvInfer.h"
#define CHECK(call) \
do \
{ \
do { \
const cudaError_t error_code = call; \
if (error_code != cudaSuccess) \
{ \
if (error_code != cudaSuccess) { \
printf("CUDA Error:\n"); \
printf(" File: %s\n", __FILE__); \
printf(" Line: %d\n", __LINE__); \
printf(" Error code: %d\n", error_code); \
printf(" Error text: %s\n", \
cudaGetErrorString(error_code)); \
printf(" Error text: %s\n", cudaGetErrorString(error_code)); \
exit(1); \
} \
} while (0)
class Logger : public nvinfer1::ILogger
{
class Logger: public nvinfer1::ILogger {
public:
nvinfer1::ILogger::Severity reportableSeverity;
@ -37,12 +33,10 @@ public:
void log(nvinfer1::ILogger::Severity severity, const char* msg) noexcept override
{
if (severity > reportableSeverity)
{
if (severity > reportableSeverity) {
return;
}
switch (severity)
{
switch (severity) {
case nvinfer1::ILogger::Severity::kINTERNAL_ERROR:
std::cerr << "INTERNAL_ERROR: ";
break;
@ -66,8 +60,7 @@ public:
inline int get_size_by_dims(const nvinfer1::Dims& dims)
{
int size = 1;
for (int i = 0; i < dims.nbDims; i++)
{
for (int i = 0; i < dims.nbDims; i++) {
size *= dims.d[i];
}
return size;
@ -75,8 +68,7 @@ inline int get_size_by_dims(const nvinfer1::Dims& dims)
inline int type_to_size(const nvinfer1::DataType& dataType)
{
switch (dataType)
{
switch (dataType) {
case nvinfer1::DataType::kFLOAT:
return 4;
case nvinfer1::DataType::kHALF:
@ -99,8 +91,7 @@ inline static float clamp(float val, float min, float max)
inline bool IsPathExist(const std::string& path)
{
if (access(path.c_str(), 0) == F_OK)
{
if (access(path.c_str(), 0) == F_OK) {
return true;
}
return false;
@ -108,8 +99,7 @@ inline bool IsPathExist(const std::string& path)
inline bool IsFile(const std::string& path)
{
if (!IsPathExist(path))
{
if (!IsPathExist(path)) {
printf("%s:%d %s not exist\n", __FILE__, __LINE__, path.c_str());
return false;
}
@ -119,39 +109,34 @@ inline bool IsFile(const std::string& path)
inline bool IsFolder(const std::string& path)
{
if (!IsPathExist(path))
{
if (!IsPathExist(path)) {
return false;
}
struct stat buffer;
return (stat(path.c_str(), &buffer) == 0 && S_ISDIR(buffer.st_mode));
}
namespace seg
{
struct Binding
{
namespace seg {
struct Binding {
size_t size = 1;
size_t dsize = 1;
nvinfer1::Dims dims;
std::string name;
};
struct Object
{
struct Object {
cv::Rect_<float> rect;
int label = 0;
float prob = 0.0;
cv::Mat boxMask;
};
struct PreParam
{
struct PreParam {
float ratio = 1.0f;
float dw = 0.0f;
float dh = 0.0f;
float height = 0;
float width = 0;
};
}
} // namespace seg
#endif // SEGMENT_NORMAL_COMMON_HPP

@ -3,14 +3,13 @@
//
#ifndef SEGMENT_NORMAL_YOLOV8_SEG_HPP
#define SEGMENT_NORMAL_YOLOV8_SEG_HPP
#include <fstream>
#include "common.hpp"
#include "NvInferPlugin.h"
#include "common.hpp"
#include <fstream>
using namespace seg;
class YOLOv8_seg
{
class YOLOv8_seg {
public:
explicit YOLOv8_seg(const std::string& engine_file_path);
~YOLOv8_seg();
@ -18,29 +17,21 @@ public:
void make_pipe(bool warmup = true);
void copy_from_Mat(const cv::Mat& image);
void copy_from_Mat(const cv::Mat& image, cv::Size& size);
void letterbox(
const cv::Mat& image,
cv::Mat& out,
cv::Size& size
);
void letterbox(const cv::Mat& image, cv::Mat& out, cv::Size& size);
void infer();
void postprocess(
std::vector<Object>& objs,
void postprocess(std::vector<Object>& objs,
float score_thres = 0.25f,
float iou_thres = 0.65f,
int topk = 100,
int seg_channels = 32,
int seg_h = 160,
int seg_w = 160
);
static void draw_objects(
const cv::Mat& image,
int seg_w = 160);
static void draw_objects(const cv::Mat& image,
cv::Mat& res,
const std::vector<Object>& objs,
const std::vector<std::string>& CLASS_NAMES,
const std::vector<std::vector<unsigned int>>& COLORS,
const std::vector<std::vector<unsigned int>>& MASK_COLORS
);
const std::vector<std::vector<unsigned int>>& MASK_COLORS);
int num_bindings;
int num_inputs = 0;
int num_outputs = 0;
@ -50,13 +41,13 @@ public:
std::vector<void*> device_ptrs;
PreParam pparam;
private:
nvinfer1::ICudaEngine* engine = nullptr;
nvinfer1::IRuntime* runtime = nullptr;
nvinfer1::IExecutionContext* context = nullptr;
cudaStream_t stream = nullptr;
Logger gLogger{nvinfer1::ILogger::Severity::kERROR};
};
YOLOv8_seg::YOLOv8_seg(const std::string& engine_file_path)
@ -83,8 +74,7 @@ YOLOv8_seg::YOLOv8_seg(const std::string& engine_file_path)
cudaStreamCreate(&this->stream);
this->num_bindings = this->engine->getNbBindings();
for (int i = 0; i < this->num_bindings; ++i)
{
for (int i = 0; i < this->num_bindings; ++i) {
Binding binding;
nvinfer1::Dims dims;
nvinfer1::DataType dtype = this->engine->getBindingDataType(i);
@ -93,22 +83,16 @@ YOLOv8_seg::YOLOv8_seg(const std::string& engine_file_path)
binding.dsize = type_to_size(dtype);
bool IsInput = engine->bindingIsInput(i);
if (IsInput)
{
if (IsInput) {
this->num_inputs += 1;
dims = this->engine->getProfileDimensions(
i,
0,
nvinfer1::OptProfileSelector::kMAX);
dims = this->engine->getProfileDimensions(i, 0, nvinfer1::OptProfileSelector::kMAX);
binding.size = get_size_by_dims(dims);
binding.dims = dims;
this->input_bindings.push_back(binding);
// set max opt shape
this->context->setBindingDimensions(i, dims);
}
else
{
else {
dims = this->context->getBindingDimensions(i);
binding.size = get_size_by_dims(dims);
binding.dims = dims;
@ -116,7 +100,6 @@ YOLOv8_seg::YOLOv8_seg(const std::string& engine_file_path)
this->num_outputs += 1;
}
}
}
YOLOv8_seg::~YOLOv8_seg()
@ -125,13 +108,11 @@ YOLOv8_seg::~YOLOv8_seg()
this->engine->destroy();
this->runtime->destroy();
cudaStreamDestroy(this->stream);
for (auto& ptr : this->device_ptrs)
{
for (auto& ptr : this->device_ptrs) {
CHECK(cudaFree(ptr));
}
for (auto& ptr : this->host_ptrs)
{
for (auto& ptr : this->host_ptrs) {
CHECK(cudaFreeHost(ptr));
}
}
@ -139,65 +120,37 @@ YOLOv8_seg::~YOLOv8_seg()
void YOLOv8_seg::make_pipe(bool warmup)
{
for (auto& bindings : this->input_bindings)
{
for (auto& bindings : this->input_bindings) {
void* d_ptr;
CHECK(cudaMallocAsync(
&d_ptr,
bindings.size * bindings.dsize,
this->stream)
);
CHECK(cudaMallocAsync(&d_ptr, bindings.size * bindings.dsize, this->stream));
this->device_ptrs.push_back(d_ptr);
}
for (auto& bindings : this->output_bindings)
{
for (auto& bindings : this->output_bindings) {
void * d_ptr, *h_ptr;
size_t size = bindings.size * bindings.dsize;
CHECK(cudaMallocAsync(
&d_ptr,
size,
this->stream)
);
CHECK(cudaHostAlloc(
&h_ptr,
size,
0)
);
CHECK(cudaMallocAsync(&d_ptr, size, this->stream));
CHECK(cudaHostAlloc(&h_ptr, size, 0));
this->device_ptrs.push_back(d_ptr);
this->host_ptrs.push_back(h_ptr);
}
if (warmup)
{
for (int i = 0; i < 10; i++)
{
for (auto& bindings : this->input_bindings)
{
if (warmup) {
for (int i = 0; i < 10; i++) {
for (auto& bindings : this->input_bindings) {
size_t size = bindings.size * bindings.dsize;
void* h_ptr = malloc(size);
memset(h_ptr, 0, size);
CHECK(cudaMemcpyAsync(
this->device_ptrs[0],
h_ptr,
size,
cudaMemcpyHostToDevice,
this->stream)
);
CHECK(cudaMemcpyAsync(this->device_ptrs[0], h_ptr, size, cudaMemcpyHostToDevice, this->stream));
free(h_ptr);
}
this->infer();
}
printf("model warmup 10 times\n");
}
}
void YOLOv8_seg::letterbox(
const cv::Mat& image,
cv::Mat& out,
cv::Size& size
)
void YOLOv8_seg::letterbox(const cv::Mat& image, cv::Mat& out, cv::Size& size)
{
const float inp_h = size.height;
const float inp_w = size.width;
@ -209,16 +162,10 @@ void YOLOv8_seg::letterbox(
int padh = std::round(height * r);
cv::Mat tmp;
if ((int)width != padw || (int)height != padh)
{
cv::resize(
image,
tmp,
cv::Size(padw, padh)
);
if ((int)width != padw || (int)height != padh) {
cv::resize(image, tmp, cv::Size(padw, padh));
}
else
{
else {
tmp = image.clone();
}
@ -232,31 +179,15 @@ void YOLOv8_seg::letterbox(
int left = int(std::round(dw - 0.1f));
int right = int(std::round(dw + 0.1f));
cv::copyMakeBorder(
tmp,
tmp,
top,
bottom,
left,
right,
cv::BORDER_CONSTANT,
{ 114, 114, 114 }
);
cv::dnn::blobFromImage(tmp,
out,
1 / 255.f,
cv::Size(),
cv::Scalar(0, 0, 0),
true,
false,
CV_32F
);
cv::copyMakeBorder(tmp, tmp, top, bottom, left, right, cv::BORDER_CONSTANT, {114, 114, 114});
cv::dnn::blobFromImage(tmp, out, 1 / 255.f, cv::Size(), cv::Scalar(0, 0, 0), true, false, CV_32F);
this->pparam.ratio = 1 / r;
this->pparam.dw = dw;
this->pparam.dh = dh;
this->pparam.height = height;
this->pparam.width = width;;
this->pparam.width = width;
;
}
void YOLOv8_seg::copy_from_Mat(const cv::Mat& image)
@ -266,85 +197,37 @@ void YOLOv8_seg::copy_from_Mat(const cv::Mat& image)
auto width = in_binding.dims.d[3];
auto height = in_binding.dims.d[2];
cv::Size size{width, height};
this->letterbox(
image,
nchw,
size
);
this->context->setBindingDimensions(
0,
nvinfer1::Dims
{
4,
{ 1, 3, height, width }
}
);
this->letterbox(image, nchw, size);
this->context->setBindingDimensions(0, nvinfer1::Dims{4, {1, 3, height, width}});
CHECK(cudaMemcpyAsync(
this->device_ptrs[0],
nchw.ptr<float>(),
nchw.total() * nchw.elemSize(),
cudaMemcpyHostToDevice,
this->stream)
);
this->device_ptrs[0], nchw.ptr<float>(), nchw.total() * nchw.elemSize(), cudaMemcpyHostToDevice, this->stream));
}
void YOLOv8_seg::copy_from_Mat(const cv::Mat& image, cv::Size& size)
{
cv::Mat nchw;
this->letterbox(
image,
nchw,
size
);
this->context->setBindingDimensions(
0,
nvinfer1::Dims
{ 4,
{ 1, 3, size.height, size.width }
}
);
this->letterbox(image, nchw, size);
this->context->setBindingDimensions(0, nvinfer1::Dims{4, {1, 3, size.height, size.width}});
CHECK(cudaMemcpyAsync(
this->device_ptrs[0],
nchw.ptr<float>(),
nchw.total() * nchw.elemSize(),
cudaMemcpyHostToDevice,
this->stream)
);
this->device_ptrs[0], nchw.ptr<float>(), nchw.total() * nchw.elemSize(), cudaMemcpyHostToDevice, this->stream));
}
void YOLOv8_seg::infer()
{
this->context->enqueueV2(
this->device_ptrs.data(),
this->stream,
nullptr
);
for (int i = 0; i < this->num_outputs; i++)
{
this->context->enqueueV2(this->device_ptrs.data(), this->stream, nullptr);
for (int i = 0; i < this->num_outputs; i++) {
size_t osize = this->output_bindings[i].size * this->output_bindings[i].dsize;
CHECK(cudaMemcpyAsync(this->host_ptrs[i],
this->device_ptrs[i + this->num_inputs],
osize,
cudaMemcpyDeviceToHost,
this->stream)
);
CHECK(cudaMemcpyAsync(
this->host_ptrs[i], this->device_ptrs[i + this->num_inputs], osize, cudaMemcpyDeviceToHost, this->stream));
}
cudaStreamSynchronize(this->stream);
}
void YOLOv8_seg::postprocess(std::vector<Object>& objs,
float score_thres,
float iou_thres,
int topk,
int seg_channels,
int seg_h,
int seg_w
)
void YOLOv8_seg::postprocess(
std::vector<Object>& objs, float score_thres, float iou_thres, int topk, int seg_channels, int seg_h, int seg_w)
{
objs.clear();
auto input_h = this->input_bindings[0].dims.d[2];
@ -353,11 +236,9 @@ void YOLOv8_seg::postprocess(std::vector<Object>& objs,
bool flag = false;
int bid;
int bcnt = -1;
for (auto& o : this->output_bindings)
{
for (auto& o : this->output_bindings) {
bcnt += 1;
if (o.dims.nbDims == 3)
{
if (o.dims.nbDims == 3) {
num_channels = o.dims.d[1];
num_anchors = o.dims.d[2];
flag = true;
@ -373,12 +254,10 @@ void YOLOv8_seg::postprocess(std::vector<Object>& objs,
auto& height = this->pparam.height;
auto& ratio = this->pparam.ratio;
cv::Mat output = cv::Mat(num_channels, num_anchors, CV_32F,
static_cast<float*>(this->host_ptrs[bid]));
cv::Mat output = cv::Mat(num_channels, num_anchors, CV_32F, static_cast<float*>(this->host_ptrs[bid]));
output = output.t();
cv::Mat protos = cv::Mat(seg_channels, seg_h * seg_w, CV_32F,
static_cast<float*>(this->host_ptrs[1 - bid]));
cv::Mat protos = cv::Mat(seg_channels, seg_h * seg_w, CV_32F, static_cast<float*>(this->host_ptrs[1 - bid]));
std::vector<int> labels;
std::vector<float> scores;
@ -386,16 +265,14 @@ void YOLOv8_seg::postprocess(std::vector<Object>& objs,
std::vector<cv::Mat> mask_confs;
std::vector<int> indices;
for (int i = 0; i < num_anchors; i++)
{
for (int i = 0; i < num_anchors; i++) {
auto row_ptr = output.row(i).ptr<float>();
auto bboxes_ptr = row_ptr;
auto scores_ptr = row_ptr + 4;
auto mask_confs_ptr = row_ptr + 4 + num_classes;
auto max_s_ptr = std::max_element(scores_ptr, scores_ptr + num_classes);
float score = *max_s_ptr;
if (score > score_thres)
{
if (score > score_thres) {
float x = *bboxes_ptr++ - dw;
float y = *bboxes_ptr++ - dh;
float w = *bboxes_ptr++;
@ -423,30 +300,15 @@ void YOLOv8_seg::postprocess(std::vector<Object>& objs,
}
#if defined(BATCHED_NMS)
cv::dnn::NMSBoxesBatched(
bboxes,
scores,
labels,
score_thres,
iou_thres,
indices
);
cv::dnn::NMSBoxesBatched(bboxes, scores, labels, score_thres, iou_thres, indices);
#else
cv::dnn::NMSBoxes(
bboxes,
scores,
score_thres,
iou_thres,
indices
);
cv::dnn::NMSBoxes(bboxes, scores, score_thres, iou_thres, indices);
#endif
cv::Mat masks;
int cnt = 0;
for (auto& i : indices)
{
if (cnt >= topk)
{
for (auto& i : indices) {
if (cnt >= topk) {
break;
}
cv::Rect tmp = bboxes[i];
@ -459,12 +321,10 @@ void YOLOv8_seg::postprocess(std::vector<Object>& objs,
cnt += 1;
}
if(masks.empty())
{
if (masks.empty()) {
// masks is empty
}
else
{
else {
cv::Mat matmulRes = (masks * protos).t();
cv::Mat maskMat = matmulRes.reshape(indices.size(), {seg_w, seg_h});
@ -473,28 +333,17 @@ void YOLOv8_seg::postprocess(std::vector<Object>& objs,
int scale_dw = dw / input_w * seg_w;
int scale_dh = dh / input_h * seg_h;
cv::Rect roi(
scale_dw,
scale_dh,
seg_w - 2 * scale_dw,
seg_h - 2 * scale_dh);
cv::Rect roi(scale_dw, scale_dh, seg_w - 2 * scale_dw, seg_h - 2 * scale_dh);
for (int i = 0; i < indices.size(); i++)
{
for (int i = 0; i < indices.size(); i++) {
cv::Mat dest, mask;
cv::exp(-maskChannels[i], dest);
dest = 1.0 / (1.0 + dest);
dest = dest(roi);
cv::resize(
dest,
mask,
cv::Size((int)width, (int)height),
cv::INTER_LINEAR
);
cv::resize(dest, mask, cv::Size((int)width, (int)height), cv::INTER_LINEAR);
objs[i].boxMask = mask(objs[i].rect) > 0.5f;
}
}
}
void YOLOv8_seg::draw_objects(const cv::Mat& image,
@ -502,48 +351,23 @@ void YOLOv8_seg::draw_objects(const cv::Mat& image,
const std::vector<Object>& objs,
const std::vector<std::string>& CLASS_NAMES,
const std::vector<std::vector<unsigned int>>& COLORS,
const std::vector<std::vector<unsigned int>>& MASK_COLORS
)
const std::vector<std::vector<unsigned int>>& MASK_COLORS)
{
res = image.clone();
cv::Mat mask = image.clone();
for (auto& obj : objs)
{
for (auto& obj : objs) {
int idx = obj.label;
cv::Scalar color = cv::Scalar(
COLORS[idx][0],
COLORS[idx][1],
COLORS[idx][2]
);
cv::Scalar mask_color = cv::Scalar(
MASK_COLORS[idx % 20][0],
MASK_COLORS[idx % 20][1],
MASK_COLORS[idx % 20][2]
);
cv::rectangle(
res,
obj.rect,
color,
2
);
cv::Scalar color = cv::Scalar(COLORS[idx][0], COLORS[idx][1], COLORS[idx][2]);
cv::Scalar mask_color =
cv::Scalar(MASK_COLORS[idx % 20][0], MASK_COLORS[idx % 20][1], MASK_COLORS[idx % 20][2]);
cv::rectangle(res, obj.rect, color, 2);
char text[256];
sprintf(
text,
"%s %.1f%%",
CLASS_NAMES[idx].c_str(),
obj.prob * 100
);
sprintf(text, "%s %.1f%%", CLASS_NAMES[idx].c_str(), obj.prob * 100);
mask(obj.rect).setTo(mask_color, obj.boxMask);
int baseLine = 0;
cv::Size label_size = cv::getTextSize(
text,
cv::FONT_HERSHEY_SIMPLEX,
0.4,
1,
&baseLine
);
cv::Size label_size = cv::getTextSize(text, cv::FONT_HERSHEY_SIMPLEX, 0.4, 1, &baseLine);
int x = (int)obj.rect.x;
int y = (int)obj.rect.y + 1;
@ -551,30 +375,10 @@ void YOLOv8_seg::draw_objects(const cv::Mat& image,
if (y > res.rows)
y = res.rows;
cv::rectangle(
res,
cv::Rect(x, y, label_size.width, label_size.height + baseLine),
{ 0, 0, 255 },
-1
);
cv::putText(
res,
text,
cv::Point(x, y + label_size.height),
cv::FONT_HERSHEY_SIMPLEX,
0.4,
{ 255, 255, 255 },
1
);
}
cv::addWeighted(
res,
0.5,
mask,
0.8,
1,
res
);
cv::rectangle(res, cv::Rect(x, y, label_size.width, label_size.height + baseLine), {0, 0, 255}, -1);
cv::putText(res, text, cv::Point(x, y + label_size.height), cv::FONT_HERSHEY_SIMPLEX, 0.4, {255, 255, 255}, 1);
}
cv::addWeighted(res, 0.5, mask, 0.8, 1, res);
}
#endif // SEGMENT_NORMAL_YOLOV8_SEG_HPP

@ -2,66 +2,43 @@
// Created by ubuntu on 2/8/23.
//
#include "chrono"
#include "yolov8-seg.hpp"
#include "opencv2/opencv.hpp"
#include "yolov8-seg.hpp"
const std::vector<std::string> CLASS_NAMES = {
"person", "bicycle", "car", "motorcycle", "airplane", "bus",
"train", "truck", "boat", "traffic light", "fire hydrant",
"stop sign", "parking meter", "bench", "bird", "cat",
"dog", "horse", "sheep", "cow", "elephant",
"bear", "zebra", "giraffe", "backpack", "umbrella",
"handbag", "tie", "suitcase", "frisbee", "skis",
"snowboard", "sports ball", "kite", "baseball bat", "baseball glove",
"skateboard", "surfboard", "tennis racket", "bottle", "wine glass",
"cup", "fork", "knife", "spoon", "bowl",
"banana", "apple", "sandwich", "orange", "broccoli",
"carrot", "hot dog", "pizza", "donut", "cake",
"chair", "couch", "potted plant", "bed", "dining table",
"toilet", "tv", "laptop", "mouse", "remote",
"keyboard", "cell phone", "microwave", "oven",
"toaster", "sink", "refrigerator", "book", "clock", "vase",
"scissors", "teddy bear", "hair drier", "toothbrush" };
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train",
"truck", "boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench",
"bird", "cat", "dog", "horse", "sheep", "cow", "elephant",
"bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie",
"suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat",
"baseball glove", "skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup",
"fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich",
"orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake",
"chair", "couch", "potted plant", "bed", "dining table", "toilet", "tv",
"laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven",
"toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors",
"teddy bear", "hair drier", "toothbrush"};
const std::vector<std::vector<unsigned int>> COLORS = {
{ 0, 114, 189 }, { 217, 83, 25 }, { 237, 177, 32 },
{ 126, 47, 142 }, { 119, 172, 48 }, { 77, 190, 238 },
{ 162, 20, 47 }, { 76, 76, 76 }, { 153, 153, 153 },
{ 255, 0, 0 }, { 255, 128, 0 }, { 191, 191, 0 },
{ 0, 255, 0 }, { 0, 0, 255 }, { 170, 0, 255 },
{ 85, 85, 0 }, { 85, 170, 0 }, { 85, 255, 0 },
{ 170, 85, 0 }, { 170, 170, 0 }, { 170, 255, 0 },
{ 255, 85, 0 }, { 255, 170, 0 }, { 255, 255, 0 },
{ 0, 85, 128 }, { 0, 170, 128 }, { 0, 255, 128 },
{ 85, 0, 128 }, { 85, 85, 128 }, { 85, 170, 128 },
{ 85, 255, 128 }, { 170, 0, 128 }, { 170, 85, 128 },
{ 170, 170, 128 }, { 170, 255, 128 }, { 255, 0, 128 },
{ 255, 85, 128 }, { 255, 170, 128 }, { 255, 255, 128 },
{ 0, 85, 255 }, { 0, 170, 255 }, { 0, 255, 255 },
{ 85, 0, 255 }, { 85, 85, 255 }, { 85, 170, 255 },
{ 85, 255, 255 }, { 170, 0, 255 }, { 170, 85, 255 },
{ 170, 170, 255 }, { 170, 255, 255 }, { 255, 0, 255 },
{ 255, 85, 255 }, { 255, 170, 255 }, { 85, 0, 0 },
{ 128, 0, 0 }, { 170, 0, 0 }, { 212, 0, 0 },
{ 255, 0, 0 }, { 0, 43, 0 }, { 0, 85, 0 },
{ 0, 128, 0 }, { 0, 170, 0 }, { 0, 212, 0 },
{ 0, 255, 0 }, { 0, 0, 43 }, { 0, 0, 85 },
{ 0, 0, 128 }, { 0, 0, 170 }, { 0, 0, 212 },
{ 0, 0, 255 }, { 0, 0, 0 }, { 36, 36, 36 },
{ 73, 73, 73 }, { 109, 109, 109 }, { 146, 146, 146 },
{ 182, 182, 182 }, { 219, 219, 219 }, { 0, 114, 189 },
{ 80, 183, 189 }, { 128, 128, 0 }
};
{0, 114, 189}, {217, 83, 25}, {237, 177, 32}, {126, 47, 142}, {119, 172, 48}, {77, 190, 238},
{162, 20, 47}, {76, 76, 76}, {153, 153, 153}, {255, 0, 0}, {255, 128, 0}, {191, 191, 0},
{0, 255, 0}, {0, 0, 255}, {170, 0, 255}, {85, 85, 0}, {85, 170, 0}, {85, 255, 0},
{170, 85, 0}, {170, 170, 0}, {170, 255, 0}, {255, 85, 0}, {255, 170, 0}, {255, 255, 0},
{0, 85, 128}, {0, 170, 128}, {0, 255, 128}, {85, 0, 128}, {85, 85, 128}, {85, 170, 128},
{85, 255, 128}, {170, 0, 128}, {170, 85, 128}, {170, 170, 128}, {170, 255, 128}, {255, 0, 128},
{255, 85, 128}, {255, 170, 128}, {255, 255, 128}, {0, 85, 255}, {0, 170, 255}, {0, 255, 255},
{85, 0, 255}, {85, 85, 255}, {85, 170, 255}, {85, 255, 255}, {170, 0, 255}, {170, 85, 255},
{170, 170, 255}, {170, 255, 255}, {255, 0, 255}, {255, 85, 255}, {255, 170, 255}, {85, 0, 0},
{128, 0, 0}, {170, 0, 0}, {212, 0, 0}, {255, 0, 0}, {0, 43, 0}, {0, 85, 0},
{0, 128, 0}, {0, 170, 0}, {0, 212, 0}, {0, 255, 0}, {0, 0, 43}, {0, 0, 85},
{0, 0, 128}, {0, 0, 170}, {0, 0, 212}, {0, 0, 255}, {0, 0, 0}, {36, 36, 36},
{73, 73, 73}, {109, 109, 109}, {146, 146, 146}, {182, 182, 182}, {219, 219, 219}, {0, 114, 189},
{80, 183, 189}, {128, 128, 0}};
const std::vector<std::vector<unsigned int>> MASK_COLORS = {
{ 255, 56, 56 }, { 255, 157, 151 }, { 255, 112, 31 },
{ 255, 178, 29 }, { 207, 210, 49 }, { 72, 249, 10 },
{ 146, 204, 23 }, { 61, 219, 134 }, { 26, 147, 52 },
{ 0, 212, 187 }, { 44, 153, 168 }, { 0, 194, 255 },
{ 52, 69, 147 }, { 100, 115, 255 }, { 0, 24, 236 },
{ 132, 56, 255 }, { 82, 0, 133 }, { 203, 56, 255 },
{ 255, 149, 200 }, { 255, 55, 199 }
};
{255, 56, 56}, {255, 157, 151}, {255, 112, 31}, {255, 178, 29}, {207, 210, 49}, {72, 249, 10}, {146, 204, 23},
{61, 219, 134}, {26, 147, 52}, {0, 212, 187}, {44, 153, 168}, {0, 194, 255}, {52, 69, 147}, {100, 115, 255},
{0, 24, 236}, {132, 56, 255}, {82, 0, 133}, {203, 56, 255}, {255, 149, 200}, {255, 55, 199}};
int main(int argc, char** argv)
{
@ -79,36 +56,21 @@ int main(int argc, char** argv)
auto yolov8 = new YOLOv8_seg(engine_file_path);
yolov8->make_pipe(true);
if (IsFile(path))
{
if (IsFile(path)) {
std::string suffix = path.substr(path.find_last_of('.') + 1);
if (
suffix == "jpg" ||
suffix == "jpeg" ||
suffix == "png"
)
{
if (suffix == "jpg" || suffix == "jpeg" || suffix == "png") {
imagePathList.push_back(path);
}
else if (
suffix == "mp4" ||
suffix == "avi" ||
suffix == "m4v" ||
suffix == "mpeg" ||
suffix == "mov" ||
suffix == "mkv"
)
{
else if (suffix == "mp4" || suffix == "avi" || suffix == "m4v" || suffix == "mpeg" || suffix == "mov"
|| suffix == "mkv") {
isVideo = true;
}
else
{
else {
printf("suffix %s is wrong !!!\n", suffix.c_str());
std::abort();
}
}
else if (IsFolder(path))
{
else if (IsFolder(path)) {
cv::glob(path + "/*.jpg", imagePathList);
}
@ -125,17 +87,14 @@ int main(int argc, char** argv)
cv::namedWindow("result", cv::WINDOW_AUTOSIZE);
if (isVideo)
{
if (isVideo) {
cv::VideoCapture cap(path);
if (!cap.isOpened())
{
if (!cap.isOpened()) {
printf("can not open %s\n", path.c_str());
return -1;
}
while (cap.read(image))
{
while (cap.read(image)) {
objs.clear();
yolov8->copy_from_Mat(image, size);
auto start = std::chrono::system_clock::now();
@ -143,20 +102,16 @@ int main(int argc, char** argv)
auto end = std::chrono::system_clock::now();
yolov8->postprocess(objs, score_thres, iou_thres, topk, seg_channels, seg_h, seg_w);
yolov8->draw_objects(image, res, objs, CLASS_NAMES, COLORS, MASK_COLORS);
auto tc = (double)
std::chrono::duration_cast<std::chrono::microseconds>(end - start).count() / 1000.;
auto tc = (double)std::chrono::duration_cast<std::chrono::microseconds>(end - start).count() / 1000.;
printf("cost %2.4lf ms\n", tc);
cv::imshow("result", res);
if (cv::waitKey(10) == 'q')
{
if (cv::waitKey(10) == 'q') {
break;
}
}
}
else
{
for (auto& path : imagePathList)
{
else {
for (auto& path : imagePathList) {
objs.clear();
image = cv::imread(path);
yolov8->copy_from_Mat(image, size);
@ -165,8 +120,7 @@ int main(int argc, char** argv)
auto end = std::chrono::system_clock::now();
yolov8->postprocess(objs, score_thres, iou_thres, topk, seg_channels, seg_h, seg_w);
yolov8->draw_objects(image, res, objs, CLASS_NAMES, COLORS, MASK_COLORS);
auto tc = (double)
std::chrono::duration_cast<std::chrono::microseconds>(end - start).count() / 1000.;
auto tc = (double)std::chrono::duration_cast<std::chrono::microseconds>(end - start).count() / 1000.;
printf("cost %2.4lf ms\n", tc);
cv::imshow("result", res);
cv::waitKey(0);

@ -4,29 +4,25 @@
#ifndef SEGMENT_SIMPLE_COMMON_HPP
#define SEGMENT_SIMPLE_COMMON_HPP
#include "NvInfer.h"
#include "opencv2/opencv.hpp"
#include <sys/stat.h>
#include <unistd.h>
#include "NvInfer.h"
#define CHECK(call) \
do \
{ \
do { \
const cudaError_t error_code = call; \
if (error_code != cudaSuccess) \
{ \
if (error_code != cudaSuccess) { \
printf("CUDA Error:\n"); \
printf(" File: %s\n", __FILE__); \
printf(" Line: %d\n", __LINE__); \
printf(" Error code: %d\n", error_code); \
printf(" Error text: %s\n", \
cudaGetErrorString(error_code)); \
printf(" Error text: %s\n", cudaGetErrorString(error_code)); \
exit(1); \
} \
} while (0)
class Logger : public nvinfer1::ILogger
{
class Logger: public nvinfer1::ILogger {
public:
nvinfer1::ILogger::Severity reportableSeverity;
@ -37,12 +33,10 @@ public:
void log(nvinfer1::ILogger::Severity severity, const char* msg) noexcept override
{
if (severity > reportableSeverity)
{
if (severity > reportableSeverity) {
return;
}
switch (severity)
{
switch (severity) {
case nvinfer1::ILogger::Severity::kINTERNAL_ERROR:
std::cerr << "INTERNAL_ERROR: ";
break;
@ -66,8 +60,7 @@ public:
inline int get_size_by_dims(const nvinfer1::Dims& dims)
{
int size = 1;
for (int i = 0; i < dims.nbDims; i++)
{
for (int i = 0; i < dims.nbDims; i++) {
size *= dims.d[i];
}
return size;
@ -75,8 +68,7 @@ inline int get_size_by_dims(const nvinfer1::Dims& dims)
inline int type_to_size(const nvinfer1::DataType& dataType)
{
switch (dataType)
{
switch (dataType) {
case nvinfer1::DataType::kFLOAT:
return 4;
case nvinfer1::DataType::kHALF:
@ -99,8 +91,7 @@ inline static float clamp(float val, float min, float max)
inline bool IsPathExist(const std::string& path)
{
if (access(path.c_str(), 0) == F_OK)
{
if (access(path.c_str(), 0) == F_OK) {
return true;
}
return false;
@ -108,8 +99,7 @@ inline bool IsPathExist(const std::string& path)
inline bool IsFile(const std::string& path)
{
if (!IsPathExist(path))
{
if (!IsPathExist(path)) {
printf("%s:%d %s not exist\n", __FILE__, __LINE__, path.c_str());
return false;
}
@ -119,39 +109,34 @@ inline bool IsFile(const std::string& path)
inline bool IsFolder(const std::string& path)
{
if (!IsPathExist(path))
{
if (!IsPathExist(path)) {
return false;
}
struct stat buffer;
return (stat(path.c_str(), &buffer) == 0 && S_ISDIR(buffer.st_mode));
}
namespace seg
{
struct Binding
{
namespace seg {
struct Binding {
size_t size = 1;
size_t dsize = 1;
nvinfer1::Dims dims;
std::string name;
};
struct Object
{
struct Object {
cv::Rect_<float> rect;
int label = 0;
float prob = 0.0;
cv::Mat boxMask;
};
struct PreParam
{
struct PreParam {
float ratio = 1.0f;
float dw = 0.0f;
float dh = 0.0f;
float height = 0;
float width = 0;
};
}
} // namespace seg
#endif // SEGMENT_SIMPLE_COMMON_HPP

@ -3,14 +3,13 @@
//
#ifndef SEGMENT_SIMPLE_YOLOV8_SEG_HPP
#define SEGMENT_SIMPLE_YOLOV8_SEG_HPP
#include <fstream>
#include "common.hpp"
#include "NvInferPlugin.h"
#include "common.hpp"
#include <fstream>
using namespace seg;
class YOLOv8_seg
{
class YOLOv8_seg {
public:
explicit YOLOv8_seg(const std::string& engine_file_path);
~YOLOv8_seg();
@ -18,29 +17,21 @@ public:
void make_pipe(bool warmup = true);
void copy_from_Mat(const cv::Mat& image);
void copy_from_Mat(const cv::Mat& image, cv::Size& size);
void letterbox(
const cv::Mat& image,
cv::Mat& out,
cv::Size& size
);
void letterbox(const cv::Mat& image, cv::Mat& out, cv::Size& size);
void infer();
void postprocess(
std::vector<Object>& objs,
void postprocess(std::vector<Object>& objs,
float score_thres = 0.25f,
float iou_thres = 0.65f,
int topk = 100,
int seg_channels = 32,
int seg_h = 160,
int seg_w = 160
);
static void draw_objects(
const cv::Mat& image,
int seg_w = 160);
static void draw_objects(const cv::Mat& image,
cv::Mat& res,
const std::vector<Object>& objs,
const std::vector<std::string>& CLASS_NAMES,
const std::vector<std::vector<unsigned int>>& COLORS,
const std::vector<std::vector<unsigned int>>& MASK_COLORS
);
const std::vector<std::vector<unsigned int>>& MASK_COLORS);
int num_bindings;
int num_inputs = 0;
int num_outputs = 0;
@ -50,13 +41,13 @@ public:
std::vector<void*> device_ptrs;
PreParam pparam;
private:
nvinfer1::ICudaEngine* engine = nullptr;
nvinfer1::IRuntime* runtime = nullptr;
nvinfer1::IExecutionContext* context = nullptr;
cudaStream_t stream = nullptr;
Logger gLogger{nvinfer1::ILogger::Severity::kERROR};
};
YOLOv8_seg::YOLOv8_seg(const std::string& engine_file_path)
@ -83,8 +74,7 @@ YOLOv8_seg::YOLOv8_seg(const std::string& engine_file_path)
cudaStreamCreate(&this->stream);
this->num_bindings = this->engine->getNbBindings();
for (int i = 0; i < this->num_bindings; ++i)
{
for (int i = 0; i < this->num_bindings; ++i) {
Binding binding;
nvinfer1::Dims dims;
nvinfer1::DataType dtype = this->engine->getBindingDataType(i);
@ -93,22 +83,16 @@ YOLOv8_seg::YOLOv8_seg(const std::string& engine_file_path)
binding.dsize = type_to_size(dtype);
bool IsInput = engine->bindingIsInput(i);
if (IsInput)
{
if (IsInput) {
this->num_inputs += 1;
dims = this->engine->getProfileDimensions(
i,
0,
nvinfer1::OptProfileSelector::kMAX);
dims = this->engine->getProfileDimensions(i, 0, nvinfer1::OptProfileSelector::kMAX);
binding.size = get_size_by_dims(dims);
binding.dims = dims;
this->input_bindings.push_back(binding);
// set max opt shape
this->context->setBindingDimensions(i, dims);
}
else
{
else {
dims = this->context->getBindingDimensions(i);
binding.size = get_size_by_dims(dims);
binding.dims = dims;
@ -116,7 +100,6 @@ YOLOv8_seg::YOLOv8_seg(const std::string& engine_file_path)
this->num_outputs += 1;
}
}
}
YOLOv8_seg::~YOLOv8_seg()
@ -125,13 +108,11 @@ YOLOv8_seg::~YOLOv8_seg()
this->engine->destroy();
this->runtime->destroy();
cudaStreamDestroy(this->stream);
for (auto& ptr : this->device_ptrs)
{
for (auto& ptr : this->device_ptrs) {
CHECK(cudaFree(ptr));
}
for (auto& ptr : this->host_ptrs)
{
for (auto& ptr : this->host_ptrs) {
CHECK(cudaFreeHost(ptr));
}
}
@ -139,65 +120,37 @@ YOLOv8_seg::~YOLOv8_seg()
void YOLOv8_seg::make_pipe(bool warmup)
{
for (auto& bindings : this->input_bindings)
{
for (auto& bindings : this->input_bindings) {
void* d_ptr;
CHECK(cudaMallocAsync(
&d_ptr,
bindings.size * bindings.dsize,
this->stream)
);
CHECK(cudaMallocAsync(&d_ptr, bindings.size * bindings.dsize, this->stream));
this->device_ptrs.push_back(d_ptr);
}
for (auto& bindings : this->output_bindings)
{
for (auto& bindings : this->output_bindings) {
void * d_ptr, *h_ptr;
size_t size = bindings.size * bindings.dsize;
CHECK(cudaMallocAsync(
&d_ptr,
size,
this->stream)
);
CHECK(cudaHostAlloc(
&h_ptr,
size,
0)
);
CHECK(cudaMallocAsync(&d_ptr, size, this->stream));
CHECK(cudaHostAlloc(&h_ptr, size, 0));
this->device_ptrs.push_back(d_ptr);
this->host_ptrs.push_back(h_ptr);
}
if (warmup)
{
for (int i = 0; i < 10; i++)
{
for (auto& bindings : this->input_bindings)
{
if (warmup) {
for (int i = 0; i < 10; i++) {
for (auto& bindings : this->input_bindings) {
size_t size = bindings.size * bindings.dsize;
void* h_ptr = malloc(size);
memset(h_ptr, 0, size);
CHECK(cudaMemcpyAsync(
this->device_ptrs[0],
h_ptr,
size,
cudaMemcpyHostToDevice,
this->stream)
);
CHECK(cudaMemcpyAsync(this->device_ptrs[0], h_ptr, size, cudaMemcpyHostToDevice, this->stream));
free(h_ptr);
}
this->infer();
}
printf("model warmup 10 times\n");
}
}
void YOLOv8_seg::letterbox(
const cv::Mat& image,
cv::Mat& out,
cv::Size& size
)
void YOLOv8_seg::letterbox(const cv::Mat& image, cv::Mat& out, cv::Size& size)
{
const float inp_h = size.height;
const float inp_w = size.width;
@ -209,16 +162,10 @@ void YOLOv8_seg::letterbox(
int padh = std::round(height * r);
cv::Mat tmp;
if ((int)width != padw || (int)height != padh)
{
cv::resize(
image,
tmp,
cv::Size(padw, padh)
);
if ((int)width != padw || (int)height != padh) {
cv::resize(image, tmp, cv::Size(padw, padh));
}
else
{
else {
tmp = image.clone();
}
@ -232,31 +179,15 @@ void YOLOv8_seg::letterbox(
int left = int(std::round(dw - 0.1f));
int right = int(std::round(dw + 0.1f));
cv::copyMakeBorder(
tmp,
tmp,
top,
bottom,
left,
right,
cv::BORDER_CONSTANT,
{ 114, 114, 114 }
);
cv::dnn::blobFromImage(tmp,
out,
1 / 255.f,
cv::Size(),
cv::Scalar(0, 0, 0),
true,
false,
CV_32F
);
cv::copyMakeBorder(tmp, tmp, top, bottom, left, right, cv::BORDER_CONSTANT, {114, 114, 114});
cv::dnn::blobFromImage(tmp, out, 1 / 255.f, cv::Size(), cv::Scalar(0, 0, 0), true, false, CV_32F);
this->pparam.ratio = 1 / r;
this->pparam.dw = dw;
this->pparam.dh = dh;
this->pparam.height = height;
this->pparam.width = width;;
this->pparam.width = width;
;
}
void YOLOv8_seg::copy_from_Mat(const cv::Mat& image)
@ -266,85 +197,37 @@ void YOLOv8_seg::copy_from_Mat(const cv::Mat& image)
auto width = in_binding.dims.d[3];
auto height = in_binding.dims.d[2];
cv::Size size{width, height};
this->letterbox(
image,
nchw,
size
);
this->context->setBindingDimensions(
0,
nvinfer1::Dims
{
4,
{ 1, 3, height, width }
}
);
this->letterbox(image, nchw, size);
this->context->setBindingDimensions(0, nvinfer1::Dims{4, {1, 3, height, width}});
CHECK(cudaMemcpyAsync(
this->device_ptrs[0],
nchw.ptr<float>(),
nchw.total() * nchw.elemSize(),
cudaMemcpyHostToDevice,
this->stream)
);
this->device_ptrs[0], nchw.ptr<float>(), nchw.total() * nchw.elemSize(), cudaMemcpyHostToDevice, this->stream));
}
void YOLOv8_seg::copy_from_Mat(const cv::Mat& image, cv::Size& size)
{
cv::Mat nchw;
this->letterbox(
image,
nchw,
size
);
this->context->setBindingDimensions(
0,
nvinfer1::Dims
{ 4,
{ 1, 3, size.height, size.width }
}
);
this->letterbox(image, nchw, size);
this->context->setBindingDimensions(0, nvinfer1::Dims{4, {1, 3, size.height, size.width}});
CHECK(cudaMemcpyAsync(
this->device_ptrs[0],
nchw.ptr<float>(),
nchw.total() * nchw.elemSize(),
cudaMemcpyHostToDevice,
this->stream)
);
this->device_ptrs[0], nchw.ptr<float>(), nchw.total() * nchw.elemSize(), cudaMemcpyHostToDevice, this->stream));
}
void YOLOv8_seg::infer()
{
this->context->enqueueV2(
this->device_ptrs.data(),
this->stream,
nullptr
);
for (int i = 0; i < this->num_outputs; i++)
{
this->context->enqueueV2(this->device_ptrs.data(), this->stream, nullptr);
for (int i = 0; i < this->num_outputs; i++) {
size_t osize = this->output_bindings[i].size * this->output_bindings[i].dsize;
CHECK(cudaMemcpyAsync(this->host_ptrs[i],
this->device_ptrs[i + this->num_inputs],
osize,
cudaMemcpyDeviceToHost,
this->stream)
);
CHECK(cudaMemcpyAsync(
this->host_ptrs[i], this->device_ptrs[i + this->num_inputs], osize, cudaMemcpyDeviceToHost, this->stream));
}
cudaStreamSynchronize(this->stream);
}
void YOLOv8_seg::postprocess(std::vector<Object>& objs,
float score_thres,
float iou_thres,
int topk,
int seg_channels,
int seg_h,
int seg_w
)
void YOLOv8_seg::postprocess(
std::vector<Object>& objs, float score_thres, float iou_thres, int topk, int seg_channels, int seg_h, int seg_w)
{
objs.clear();
auto input_h = this->input_bindings[0].dims.d[2];
@ -359,8 +242,7 @@ void YOLOv8_seg::postprocess(std::vector<Object>& objs,
auto& ratio = this->pparam.ratio;
auto* output = static_cast<float*>(this->host_ptrs[0]);
cv::Mat protos = cv::Mat(seg_channels, seg_h * seg_w, CV_32F,
static_cast<float*>(this->host_ptrs[1]));
cv::Mat protos = cv::Mat(seg_channels, seg_h * seg_w, CV_32F, static_cast<float*>(this->host_ptrs[1]));
std::vector<int> labels;
std::vector<float> scores;
@ -368,12 +250,10 @@ void YOLOv8_seg::postprocess(std::vector<Object>& objs,
std::vector<cv::Mat> mask_confs;
std::vector<int> indices;
for (int i = 0; i < num_anchors; i++)
{
for (int i = 0; i < num_anchors; i++) {
float* ptr = output + i * num_channels;
float score = *(ptr + 4);
if (score > score_thres)
{
if (score > score_thres) {
float x0 = *ptr++ - dw;
float y0 = *ptr++ - dh;
float x1 = *ptr++ - dw;
@ -390,35 +270,19 @@ void YOLOv8_seg::postprocess(std::vector<Object>& objs,
labels.push_back(label);
scores.push_back(score);
bboxes.push_back(cv::Rect_<float>(x0, y0, x1 - x0, y1 - y0));
}
}
#if defined(BATCHED_NMS)
cv::dnn::NMSBoxesBatched(
bboxes,
scores,
labels,
score_thres,
iou_thres,
indices
);
cv::dnn::NMSBoxesBatched(bboxes, scores, labels, score_thres, iou_thres, indices);
#else
cv::dnn::NMSBoxes(
bboxes,
scores,
score_thres,
iou_thres,
indices
);
cv::dnn::NMSBoxes(bboxes, scores, score_thres, iou_thres, indices);
#endif
cv::Mat masks;
int cnt = 0;
for (auto& i : indices)
{
if (cnt >= topk)
{
for (auto& i : indices) {
if (cnt >= topk) {
break;
}
cv::Rect tmp = bboxes[i];
@ -431,12 +295,10 @@ void YOLOv8_seg::postprocess(std::vector<Object>& objs,
cnt += 1;
}
if(masks.empty())
{
if (masks.empty()) {
// masks is empty
}
else
{
else {
cv::Mat matmulRes = (masks * protos).t();
cv::Mat maskMat = matmulRes.reshape(indices.size(), {seg_w, seg_h});
@ -445,24 +307,14 @@ void YOLOv8_seg::postprocess(std::vector<Object>& objs,
int scale_dw = dw / input_w * seg_w;
int scale_dh = dh / input_h * seg_h;
cv::Rect roi(
scale_dw,
scale_dh,
seg_w - 2 * scale_dw,
seg_h - 2 * scale_dh);
cv::Rect roi(scale_dw, scale_dh, seg_w - 2 * scale_dw, seg_h - 2 * scale_dh);
for (int i = 0; i < indices.size(); i++)
{
for (int i = 0; i < indices.size(); i++) {
cv::Mat dest, mask;
cv::exp(-maskChannels[i], dest);
dest = 1.0 / (1.0 + dest);
dest = dest(roi);
cv::resize(
dest,
mask,
cv::Size((int)width, (int)height),
cv::INTER_LINEAR
);
cv::resize(dest, mask, cv::Size((int)width, (int)height), cv::INTER_LINEAR);
objs[i].boxMask = mask(objs[i].rect) > 0.5f;
}
}
@ -473,48 +325,23 @@ void YOLOv8_seg::draw_objects(const cv::Mat& image,
const std::vector<Object>& objs,
const std::vector<std::string>& CLASS_NAMES,
const std::vector<std::vector<unsigned int>>& COLORS,
const std::vector<std::vector<unsigned int>>& MASK_COLORS
)
const std::vector<std::vector<unsigned int>>& MASK_COLORS)
{
res = image.clone();
cv::Mat mask = image.clone();
for (auto& obj : objs)
{
for (auto& obj : objs) {
int idx = obj.label;
cv::Scalar color = cv::Scalar(
COLORS[idx][0],
COLORS[idx][1],
COLORS[idx][2]
);
cv::Scalar mask_color = cv::Scalar(
MASK_COLORS[idx % 20][0],
MASK_COLORS[idx % 20][1],
MASK_COLORS[idx % 20][2]
);
cv::rectangle(
res,
obj.rect,
color,
2
);
cv::Scalar color = cv::Scalar(COLORS[idx][0], COLORS[idx][1], COLORS[idx][2]);
cv::Scalar mask_color =
cv::Scalar(MASK_COLORS[idx % 20][0], MASK_COLORS[idx % 20][1], MASK_COLORS[idx % 20][2]);
cv::rectangle(res, obj.rect, color, 2);
char text[256];
sprintf(
text,
"%s %.1f%%",
CLASS_NAMES[idx].c_str(),
obj.prob * 100
);
sprintf(text, "%s %.1f%%", CLASS_NAMES[idx].c_str(), obj.prob * 100);
mask(obj.rect).setTo(mask_color, obj.boxMask);
int baseLine = 0;
cv::Size label_size = cv::getTextSize(
text,
cv::FONT_HERSHEY_SIMPLEX,
0.4,
1,
&baseLine
);
cv::Size label_size = cv::getTextSize(text, cv::FONT_HERSHEY_SIMPLEX, 0.4, 1, &baseLine);
int x = (int)obj.rect.x;
int y = (int)obj.rect.y + 1;
@ -522,30 +349,10 @@ void YOLOv8_seg::draw_objects(const cv::Mat& image,
if (y > res.rows)
y = res.rows;
cv::rectangle(
res,
cv::Rect(x, y, label_size.width, label_size.height + baseLine),
{ 0, 0, 255 },
-1
);
cv::putText(
res,
text,
cv::Point(x, y + label_size.height),
cv::FONT_HERSHEY_SIMPLEX,
0.4,
{ 255, 255, 255 },
1
);
cv::rectangle(res, cv::Rect(x, y, label_size.width, label_size.height + baseLine), {0, 0, 255}, -1);
cv::putText(res, text, cv::Point(x, y + label_size.height), cv::FONT_HERSHEY_SIMPLEX, 0.4, {255, 255, 255}, 1);
}
cv::addWeighted(
res,
0.5,
mask,
0.8,
1,
res
);
cv::addWeighted(res, 0.5, mask, 0.8, 1, res);
}
#endif // SEGMENT_SIMPLE_YOLOV8_SEG_HPP

@ -2,66 +2,43 @@
// Created by ubuntu on 1/20/23.
//
#include "chrono"
#include "yolov8-seg.hpp"
#include "opencv2/opencv.hpp"
#include "yolov8-seg.hpp"
const std::vector<std::string> CLASS_NAMES = {
"person", "bicycle", "car", "motorcycle", "airplane", "bus",
"train", "truck", "boat", "traffic light", "fire hydrant",
"stop sign", "parking meter", "bench", "bird", "cat",
"dog", "horse", "sheep", "cow", "elephant",
"bear", "zebra", "giraffe", "backpack", "umbrella",
"handbag", "tie", "suitcase", "frisbee", "skis",
"snowboard", "sports ball", "kite", "baseball bat", "baseball glove",
"skateboard", "surfboard", "tennis racket", "bottle", "wine glass",
"cup", "fork", "knife", "spoon", "bowl",
"banana", "apple", "sandwich", "orange", "broccoli",
"carrot", "hot dog", "pizza", "donut", "cake",
"chair", "couch", "potted plant", "bed", "dining table",
"toilet", "tv", "laptop", "mouse", "remote",
"keyboard", "cell phone", "microwave", "oven",
"toaster", "sink", "refrigerator", "book", "clock", "vase",
"scissors", "teddy bear", "hair drier", "toothbrush" };
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train",
"truck", "boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench",
"bird", "cat", "dog", "horse", "sheep", "cow", "elephant",
"bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie",
"suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat",
"baseball glove", "skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup",
"fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich",
"orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake",
"chair", "couch", "potted plant", "bed", "dining table", "toilet", "tv",
"laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven",
"toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors",
"teddy bear", "hair drier", "toothbrush"};
const std::vector<std::vector<unsigned int>> COLORS = {
{ 0, 114, 189 }, { 217, 83, 25 }, { 237, 177, 32 },
{ 126, 47, 142 }, { 119, 172, 48 }, { 77, 190, 238 },
{ 162, 20, 47 }, { 76, 76, 76 }, { 153, 153, 153 },
{ 255, 0, 0 }, { 255, 128, 0 }, { 191, 191, 0 },
{ 0, 255, 0 }, { 0, 0, 255 }, { 170, 0, 255 },
{ 85, 85, 0 }, { 85, 170, 0 }, { 85, 255, 0 },
{ 170, 85, 0 }, { 170, 170, 0 }, { 170, 255, 0 },
{ 255, 85, 0 }, { 255, 170, 0 }, { 255, 255, 0 },
{ 0, 85, 128 }, { 0, 170, 128 }, { 0, 255, 128 },
{ 85, 0, 128 }, { 85, 85, 128 }, { 85, 170, 128 },
{ 85, 255, 128 }, { 170, 0, 128 }, { 170, 85, 128 },
{ 170, 170, 128 }, { 170, 255, 128 }, { 255, 0, 128 },
{ 255, 85, 128 }, { 255, 170, 128 }, { 255, 255, 128 },
{ 0, 85, 255 }, { 0, 170, 255 }, { 0, 255, 255 },
{ 85, 0, 255 }, { 85, 85, 255 }, { 85, 170, 255 },
{ 85, 255, 255 }, { 170, 0, 255 }, { 170, 85, 255 },
{ 170, 170, 255 }, { 170, 255, 255 }, { 255, 0, 255 },
{ 255, 85, 255 }, { 255, 170, 255 }, { 85, 0, 0 },
{ 128, 0, 0 }, { 170, 0, 0 }, { 212, 0, 0 },
{ 255, 0, 0 }, { 0, 43, 0 }, { 0, 85, 0 },
{ 0, 128, 0 }, { 0, 170, 0 }, { 0, 212, 0 },
{ 0, 255, 0 }, { 0, 0, 43 }, { 0, 0, 85 },
{ 0, 0, 128 }, { 0, 0, 170 }, { 0, 0, 212 },
{ 0, 0, 255 }, { 0, 0, 0 }, { 36, 36, 36 },
{ 73, 73, 73 }, { 109, 109, 109 }, { 146, 146, 146 },
{ 182, 182, 182 }, { 219, 219, 219 }, { 0, 114, 189 },
{ 80, 183, 189 }, { 128, 128, 0 }
};
{0, 114, 189}, {217, 83, 25}, {237, 177, 32}, {126, 47, 142}, {119, 172, 48}, {77, 190, 238},
{162, 20, 47}, {76, 76, 76}, {153, 153, 153}, {255, 0, 0}, {255, 128, 0}, {191, 191, 0},
{0, 255, 0}, {0, 0, 255}, {170, 0, 255}, {85, 85, 0}, {85, 170, 0}, {85, 255, 0},
{170, 85, 0}, {170, 170, 0}, {170, 255, 0}, {255, 85, 0}, {255, 170, 0}, {255, 255, 0},
{0, 85, 128}, {0, 170, 128}, {0, 255, 128}, {85, 0, 128}, {85, 85, 128}, {85, 170, 128},
{85, 255, 128}, {170, 0, 128}, {170, 85, 128}, {170, 170, 128}, {170, 255, 128}, {255, 0, 128},
{255, 85, 128}, {255, 170, 128}, {255, 255, 128}, {0, 85, 255}, {0, 170, 255}, {0, 255, 255},
{85, 0, 255}, {85, 85, 255}, {85, 170, 255}, {85, 255, 255}, {170, 0, 255}, {170, 85, 255},
{170, 170, 255}, {170, 255, 255}, {255, 0, 255}, {255, 85, 255}, {255, 170, 255}, {85, 0, 0},
{128, 0, 0}, {170, 0, 0}, {212, 0, 0}, {255, 0, 0}, {0, 43, 0}, {0, 85, 0},
{0, 128, 0}, {0, 170, 0}, {0, 212, 0}, {0, 255, 0}, {0, 0, 43}, {0, 0, 85},
{0, 0, 128}, {0, 0, 170}, {0, 0, 212}, {0, 0, 255}, {0, 0, 0}, {36, 36, 36},
{73, 73, 73}, {109, 109, 109}, {146, 146, 146}, {182, 182, 182}, {219, 219, 219}, {0, 114, 189},
{80, 183, 189}, {128, 128, 0}};
const std::vector<std::vector<unsigned int>> MASK_COLORS = {
{ 255, 56, 56 }, { 255, 157, 151 }, { 255, 112, 31 },
{ 255, 178, 29 }, { 207, 210, 49 }, { 72, 249, 10 },
{ 146, 204, 23 }, { 61, 219, 134 }, { 26, 147, 52 },
{ 0, 212, 187 }, { 44, 153, 168 }, { 0, 194, 255 },
{ 52, 69, 147 }, { 100, 115, 255 }, { 0, 24, 236 },
{ 132, 56, 255 }, { 82, 0, 133 }, { 203, 56, 255 },
{ 255, 149, 200 }, { 255, 55, 199 }
};
{255, 56, 56}, {255, 157, 151}, {255, 112, 31}, {255, 178, 29}, {207, 210, 49}, {72, 249, 10}, {146, 204, 23},
{61, 219, 134}, {26, 147, 52}, {0, 212, 187}, {44, 153, 168}, {0, 194, 255}, {52, 69, 147}, {100, 115, 255},
{0, 24, 236}, {132, 56, 255}, {82, 0, 133}, {203, 56, 255}, {255, 149, 200}, {255, 55, 199}};
int main(int argc, char** argv)
{
@ -79,36 +56,21 @@ int main(int argc, char** argv)
auto yolov8 = new YOLOv8_seg(engine_file_path);
yolov8->make_pipe(true);
if (IsFile(path))
{
if (IsFile(path)) {
std::string suffix = path.substr(path.find_last_of('.') + 1);
if (
suffix == "jpg" ||
suffix == "jpeg" ||
suffix == "png"
)
{
if (suffix == "jpg" || suffix == "jpeg" || suffix == "png") {
imagePathList.push_back(path);
}
else if (
suffix == "mp4" ||
suffix == "avi" ||
suffix == "m4v" ||
suffix == "mpeg" ||
suffix == "mov" ||
suffix == "mkv"
)
{
else if (suffix == "mp4" || suffix == "avi" || suffix == "m4v" || suffix == "mpeg" || suffix == "mov"
|| suffix == "mkv") {
isVideo = true;
}
else
{
else {
printf("suffix %s is wrong !!!\n", suffix.c_str());
std::abort();
}
}
else if (IsFolder(path))
{
else if (IsFolder(path)) {
cv::glob(path + "/*.jpg", imagePathList);
}
@ -125,17 +87,14 @@ int main(int argc, char** argv)
cv::namedWindow("result", cv::WINDOW_AUTOSIZE);
if (isVideo)
{
if (isVideo) {
cv::VideoCapture cap(path);
if (!cap.isOpened())
{
if (!cap.isOpened()) {
printf("can not open %s\n", path.c_str());
return -1;
}
while (cap.read(image))
{
while (cap.read(image)) {
objs.clear();
yolov8->copy_from_Mat(image, size);
auto start = std::chrono::system_clock::now();
@ -143,20 +102,16 @@ int main(int argc, char** argv)
auto end = std::chrono::system_clock::now();
yolov8->postprocess(objs, score_thres, iou_thres, topk, seg_channels, seg_h, seg_w);
yolov8->draw_objects(image, res, objs, CLASS_NAMES, COLORS, MASK_COLORS);
auto tc = (double)
std::chrono::duration_cast<std::chrono::microseconds>(end - start).count() / 1000.;
auto tc = (double)std::chrono::duration_cast<std::chrono::microseconds>(end - start).count() / 1000.;
printf("cost %2.4lf ms\n", tc);
cv::imshow("result", res);
if (cv::waitKey(10) == 'q')
{
if (cv::waitKey(10) == 'q') {
break;
}
}
}
else
{
for (auto& path : imagePathList)
{
else {
for (auto& path : imagePathList) {
objs.clear();
image = cv::imread(path);
yolov8->copy_from_Mat(image, size);
@ -165,8 +120,7 @@ int main(int argc, char** argv)
auto end = std::chrono::system_clock::now();
yolov8->postprocess(objs, score_thres, iou_thres, topk, seg_channels, seg_h, seg_w);
yolov8->draw_objects(image, res, objs, CLASS_NAMES, COLORS, MASK_COLORS);
auto tc = (double)
std::chrono::duration_cast<std::chrono::microseconds>(end - start).count() / 1000.;
auto tc = (double)std::chrono::duration_cast<std::chrono::microseconds>(end - start).count() / 1000.;
printf("cost %2.4lf ms\n", tc);
cv::imshow("result", res);
cv::waitKey(0);

@ -3,6 +3,7 @@
Only test on `Jetson-NX 4GB`
ENVS:
- Jetpack 4.6.3
- CUDA-10.2
- CUDNN-8.2.1
@ -20,7 +21,8 @@ If you have other environment-related issues, please discuss in issue.
`yolov8s.pt` is your trained pytorch model, or the official pre-trained model.
Do not use any model other than pytorch model.
Do not use [`build.py`](../build.py) to export engine if you don't know how to install pytorch and other environments on jetson.
Do not use [`build.py`](../build.py) to export engine if you don't know how to install pytorch and other environments on
jetson.
***!!! Please use the PC to execute the following script !!!***
@ -79,7 +81,8 @@ Usage:
`yolov8s-seg.pt` is your trained pytorch model, or the official pre-trained model.
Do not use any model other than pytorch model.
Do not use [`build.py`](../build.py) to export engine if you don't know how to install pytorch and other environments on jetson.
Do not use [`build.py`](../build.py) to export engine if you don't know how to install pytorch and other environments on
jetson.
***!!! Please use the PC to execute the following script !!!***
@ -106,7 +109,8 @@ Here is a demo: [`csrc/jetson/segment`](../csrc/jetson/segment) .
#### Build:
Please modify `CLASS_NAMES` and `COLORS` and postprocess parameters in [`main.cpp`](../csrc/jetson/segment/main.cpp) for yourself.
Please modify `CLASS_NAMES` and `COLORS` and postprocess parameters in [`main.cpp`](../csrc/jetson/segment/main.cpp) for
yourself.
```c++
int topk = 100;
@ -140,8 +144,6 @@ Usage:
./yolov8-seg yolov8s-seg.engine data/test.mp4 # the video path
```
## Normal Posture
### 1. Export Posture Normal ONNX
@ -149,7 +151,8 @@ Usage:
`yolov8s-pose.pt` is your trained pytorch model, or the official pre-trained model.
Do not use any model other than pytorch model.
Do not use [`build.py`](../build.py) to export engine if you don't know how to install pytorch and other environments on jetson.
Do not use [`build.py`](../build.py) to export engine if you don't know how to install pytorch and other environments on
jetson.
***!!! Please use the PC to execute the following script !!!***
@ -176,7 +179,8 @@ Here is a demo: [`csrc/jetson/pose`](../csrc/jetson/pose) .
#### Build:
Please modify `KPS_COLORS` and `SKELETON` and `LIMB_COLORS` and postprocess parameters in [`main.cpp`](../csrc/jetson/pose/main.cpp) for yourself.
Please modify `KPS_COLORS` and `SKELETON` and `LIMB_COLORS` and postprocess parameters
in [`main.cpp`](../csrc/jetson/pose/main.cpp) for yourself.
```c++
int topk = 100;

@ -2,23 +2,39 @@
## Export TensorRT Engine
### 1. Python script
### 1. ONNX -> TensorRT
Usage:
You can export your onnx model by `ultralytics` API.
```python
``` shell
yolo export model=yolov8s.pt format=onnx opset=11 simplify=True
```
or run this python script:
```python
from ultralytics import YOLO
# Load a model
model = YOLO("yolov8s.pt") # load a pretrained model (recommended for training)
success = model.export(format="engine", device=0) # export the model to engine format
success = model.export(format="onnx", opset=11, simplify=True) # export the model to onnx format
assert success
```
After executing the above script, you will get an engine named `yolov8s.engine` .
Then build engine by Trtexec Tools.
You can export TensorRT engine by [`trtexec`](https://github.com/NVIDIA/TensorRT/tree/main/samples/trtexec) tools.
Usage:
``` shell
/usr/src/tensorrt/bin/trtexec \
--onnx=yolov8s.onnx \
--saveEngine=yolov8s.engine \
--fp16
```
### 2. CLI tools
### 2. Direct to TensorRT (NOT RECOMMAND!!)
Usage:
@ -26,7 +42,19 @@ Usage:
yolo export model=yolov8s.pt format=engine device=0
```
After executing the above command, you will get an engine named `yolov8s.engine` too.
or run python script:
```python
from ultralytics import YOLO
# Load a model
model = YOLO("yolov8s.pt") # load a pretrained model (recommended for training)
success = model.export(format="engine", device=0) # export the model to engine format
assert success
```
After executing the above script, you will get an engine named `yolov8s.engine` .
## Inference with c++
@ -34,9 +62,11 @@ You can infer with c++ in [`csrc/detect/normal`](../csrc/detect/normal) .
### Build:
Please set you own librarys in [`CMakeLists.txt`](../csrc/detect/normal/CMakeLists.txt) and modify `CLASS_NAMES` and `COLORS` in [`main.cpp`](../csrc/detect/normal/main.cpp).
Please set you own librarys in [`CMakeLists.txt`](../csrc/detect/normal/CMakeLists.txt) and modify `CLASS_NAMES`
and `COLORS` in [`main.cpp`](../csrc/detect/normal/main.cpp).
Besides, you can modify the postprocess parameters such as `num_labels` and `score_thres` and `iou_thres` and `topk` in [`main.cpp`](../csrc/detect/normal/main.cpp).
Besides, you can modify the postprocess parameters such as `num_labels` and `score_thres` and `iou_thres` and `topk`
in [`main.cpp`](../csrc/detect/normal/main.cpp).
```c++
int num_labels = 80;

@ -9,10 +9,48 @@ The yolov8-pose model conversion route is :
You can leave this repo and use the original `ultralytics` repo for onnx export.
### 1. Python script
### 1. ONNX -> TensorRT
You can export your onnx model by `ultralytics` API.
``` shell
yolo export model=yolov8s-pose.pt format=onnx opset=11 simplify=True
```
or run this python script:
```python
from ultralytics import YOLO
# Load a model
model = YOLO("yolov8s-pose.pt") # load a pretrained model (recommended for training)
success = model.export(format="onnx", opset=11, simplify=True) # export the model to onnx format
assert success
```
Then build engine by Trtexec Tools.
You can export TensorRT engine by [`trtexec`](https://github.com/NVIDIA/TensorRT/tree/main/samples/trtexec) tools.
Usage:
``` shell
/usr/src/tensorrt/bin/trtexec \
--onnx=yolov8s-pose.onnx \
--saveEngine=yolov8s-pose.engine \
--fp16
```
### 2. Direct to TensorRT (NOT RECOMMAND!!)
Usage:
```shell
yolo export model=yolov8s-pose.pt format=engine device=0
```
or run python script:
```python
from ultralytics import YOLO
@ -24,15 +62,31 @@ assert success
After executing the above script, you will get an engine named `yolov8s-pose.engine` .
### 2. CLI tools
# Inference
## Infer with python script
You can infer images with the engine by [`infer-pose.py`](../infer-pose.py) .
Usage:
``` shell
yolo export model=yolov8s-pose.pt format=engine device=0
python3 infer-pose.py \
--engine yolov8s-pose.engine \
--imgs data \
--show \
--out-dir outputs \
--device cuda:0
```
After executing the above command, you will get an engine named `yolov8s-pose.engine` too.
#### Description of all arguments
- `--engine` : The Engine you export.
- `--imgs` : The images path you want to detect.
- `--show` : Whether to show detection results.
- `--out-dir` : Where to save detection results images. It will not work when use `--show` flag.
- `--device` : The CUDA deivce you use.
- `--profile` : Profile the TensorRT engine.
## Inference with c++
@ -40,9 +94,11 @@ You can infer with c++ in [`csrc/pose/normal`](../csrc/pose/normal) .
### Build:
Please set you own librarys in [`CMakeLists.txt`](../csrc/pose/normal/CMakeLists.txt) and modify `KPS_COLORS` and `SKELETON` and `LIMB_COLORS` in [`main.cpp`](../csrc/pose/normal/main.cpp).
Please set you own librarys in [`CMakeLists.txt`](../csrc/pose/normal/CMakeLists.txt) and modify `KPS_COLORS`
and `SKELETON` and `LIMB_COLORS` in [`main.cpp`](../csrc/pose/normal/main.cpp).
Besides, you can modify the postprocess parameters such as `score_thres` and `iou_thres` and `topk` in [`main.cpp`](../csrc/pose/normal/main.cpp).
Besides, you can modify the postprocess parameters such as `score_thres` and `iou_thres` and `topk`
in [`main.cpp`](../csrc/pose/normal/main.cpp).
```c++
int topk = 100;

@ -96,7 +96,9 @@ You can infer segment engine with c++ in [`csrc/segment/simple`](../csrc/segment
### Build:
Please set you own librarys in [`CMakeLists.txt`](../csrc/segment/simple/CMakeLists.txt) and modify you own config in [`main.cpp`](../csrc/segment/simple/main.cpp) such as `CLASS_NAMES`, `COLORS`, `MASK_COLORS` and postprocess parameters .
Please set you own librarys in [`CMakeLists.txt`](../csrc/segment/simple/CMakeLists.txt) and modify you own config
in [`main.cpp`](../csrc/segment/simple/main.cpp) such as `CLASS_NAMES`, `COLORS`, `MASK_COLORS` and postprocess
parameters .
```c++
int topk = 100;
@ -119,7 +121,8 @@ cd ${root}
***Notice !!!***
If you have build OpenCV(>=4.7.0) by yourself, it provides a new api [`cv::dnn::NMSBoxesBatched`](https://docs.opencv.org/4.x/d6/d0f/group__dnn.html#ga977aae09fbf7c804e003cfea1d4e928c) .
If you have build OpenCV(>=4.7.0) by yourself, it provides a new
api [`cv::dnn::NMSBoxesBatched`](https://docs.opencv.org/4.x/d6/d0f/group__dnn.html#ga977aae09fbf7c804e003cfea1d4e928c) .
It is a gread api about efficient in-class nms . It will be used by default!
***!!!***
@ -139,22 +142,39 @@ Usage:
You can leave this repo and use the original `ultralytics` repo for onnx export.
### 1. Python script
### 1. ONNX -> TensorRT
Usage:
You can export your onnx model by `ultralytics` API.
``` shell
yolo export model=yolov8s-seg.pt format=onnx opset=11 simplify=True
```
or run this python script:
```python
from ultralytics import YOLO
# Load a model
model = YOLO("yolov8s-seg.pt") # load a pretrained model (recommended for training)
success = model.export(format="engine", device=0) # export the model to engine format
success = model.export(format="onnx", opset=11, simplify=True) # export the model to onnx format
assert success
```
After executing the above script, you will get an engine named `yolov8s-seg.engine` .
Then build engine by Trtexec Tools.
### 2. CLI tools
You can export TensorRT engine by [`trtexec`](https://github.com/NVIDIA/TensorRT/tree/main/samples/trtexec) tools.
Usage:
``` shell
/usr/src/tensorrt/bin/trtexec \
--onnx=yolov8s-seg.onnx \
--saveEngine=yolov8s-seg.engine \
--fp16
```
### 2. Direct to TensorRT (NOT RECOMMAND!!)
Usage:
@ -162,7 +182,18 @@ Usage:
yolo export model=yolov8s-seg.pt format=engine device=0
```
After executing the above command, you will get an engine named `yolov8s-seg.engine` too.
or run python script:
```python
from ultralytics import YOLO
# Load a model
model = YOLO("yolov8s-seg.pt") # load a pretrained model (recommended for training)
success = model.export(format="engine", device=0) # export the model to engine format
assert success
```
After executing the above script, you will get an engine named `yolov8s-seg.engine` .
## Inference with c++
@ -170,9 +201,11 @@ You can infer with c++ in [`csrc/segment/normal`](../csrc/segment/normal) .
### Build:
Please set you own librarys in [`CMakeLists.txt`](../csrc/segment/normal/CMakeLists.txt) and modify `CLASS_NAMES` and `COLORS` in [`main.cpp`](../csrc/segment/normal/main.cpp).
Please set you own librarys in [`CMakeLists.txt`](../csrc/segment/normal/CMakeLists.txt) and modify `CLASS_NAMES`
and `COLORS` in [`main.cpp`](../csrc/segment/normal/main.cpp).
Besides, you can modify the postprocess parameters such as `num_labels` and `score_thres` and `iou_thres` and `topk` in [`main.cpp`](../csrc/segment/normal/main.cpp).
Besides, you can modify the postprocess parameters such as `num_labels` and `score_thres` and `iou_thres` and `topk`
in [`main.cpp`](../csrc/segment/normal/main.cpp).
```c++
int topk = 100;

@ -38,6 +38,10 @@ def main(args: argparse.Namespace) -> None:
data = Engine(tensor)
bboxes, scores, labels = det_postprocess(data)
if bboxes.size == 0:
# if no bounding box
print(f'{image}: no object!')
continue
bboxes -= dwdh
bboxes /= ratio

@ -37,6 +37,10 @@ def main(args: argparse.Namespace) -> None:
data = Engine(tensor)
bboxes, scores, labels = det_postprocess(data)
if bboxes.numel() == 0:
# if no bounding box
print(f'{image}: no object!')
continue
bboxes -= dwdh
bboxes /= ratio

@ -0,0 +1,116 @@
import argparse
from pathlib import Path
import cv2
import numpy as np
from config import COLORS, KPS_COLORS, LIMB_COLORS, SKELETON
from models.utils import blob, letterbox, path_to_list, pose_postprocess
def main(args: argparse.Namespace) -> None:
if args.method == 'cudart':
from models.cudart_api import TRTEngine
elif args.method == 'pycuda':
from models.pycuda_api import TRTEngine
else:
raise NotImplementedError
Engine = TRTEngine(args.engine)
H, W = Engine.inp_info[0].shape[-2:]
images = path_to_list(args.imgs)
save_path = Path(args.out_dir)
if not args.show and not save_path.exists():
save_path.mkdir(parents=True, exist_ok=True)
for image in images:
save_image = save_path / image.name
bgr = cv2.imread(str(image))
draw = bgr.copy()
bgr, ratio, dwdh = letterbox(bgr, (W, H))
dw, dh = int(dwdh[0]), int(dwdh[1])
rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
tensor = blob(rgb, return_seg=False)
dwdh = np.array(dwdh * 2, dtype=np.float32)
tensor = np.ascontiguousarray(tensor)
# inference
data = Engine(tensor)
bboxes, scores, kpts = pose_postprocess(data, args.conf_thres,
args.iou_thres)
if bboxes.size == 0:
# if no bounding box
print(f'{image}: no object!')
continue
bboxes -= dwdh
bboxes /= ratio
for (bbox, score, kpt) in zip(bboxes, scores, kpts):
bbox = bbox.round().astype(np.int32).tolist()
color = COLORS['person']
cv2.rectangle(draw, bbox[:2], bbox[2:], color, 2)
cv2.putText(draw,
f'person:{score:.3f}', (bbox[0], bbox[1] - 2),
cv2.FONT_HERSHEY_SIMPLEX,
0.75, [225, 255, 255],
thickness=2)
for i in range(19):
if i < 17:
px, py, ps = kpt[i]
if ps > 0.5:
kcolor = KPS_COLORS[i]
px = round(float(px - dw) / ratio)
py = round(float(py - dh) / ratio)
cv2.circle(draw, (px, py), 5, kcolor, -1)
xi, yi = SKELETON[i]
pos1_s = kpt[xi - 1][2]
pos2_s = kpt[yi - 1][2]
if pos1_s > 0.5 and pos2_s > 0.5:
limb_color = LIMB_COLORS[i]
pos1_x = round(float(kpt[xi - 1][0] - dw) / ratio)
pos1_y = round(float(kpt[xi - 1][1] - dh) / ratio)
pos2_x = round(float(kpt[yi - 1][0] - dw) / ratio)
pos2_y = round(float(kpt[yi - 1][1] - dh) / ratio)
cv2.line(draw, (pos1_x, pos1_y), (pos2_x, pos2_y),
limb_color, 2)
if args.show:
cv2.imshow('result', draw)
cv2.waitKey(0)
else:
cv2.imwrite(str(save_image), draw)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--engine', type=str, help='Engine file')
parser.add_argument('--imgs', type=str, help='Images file')
parser.add_argument('--show',
action='store_true',
help='Show the detection results')
parser.add_argument('--out-dir',
type=str,
default='./output',
help='Path to output file')
parser.add_argument('--conf-thres',
type=float,
default=0.25,
help='Confidence threshold')
parser.add_argument('--iou-thres',
type=float,
default=0.65,
help='Confidence threshold')
parser.add_argument('--method',
type=str,
default='cudart',
help='CUDART pipeline')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
main(args)

@ -0,0 +1,112 @@
from models import TRTModule # isort:skip
import argparse
from pathlib import Path
import cv2
import torch
from config import COLORS, KPS_COLORS, LIMB_COLORS, SKELETON
from models.torch_utils import pose_postprocess
from models.utils import blob, letterbox, path_to_list
def main(args: argparse.Namespace) -> None:
device = torch.device(args.device)
Engine = TRTModule(args.engine, device)
H, W = Engine.inp_info[0].shape[-2:]
images = path_to_list(args.imgs)
save_path = Path(args.out_dir)
if not args.show and not save_path.exists():
save_path.mkdir(parents=True, exist_ok=True)
for image in images:
save_image = save_path / image.name
bgr = cv2.imread(str(image))
draw = bgr.copy()
bgr, ratio, dwdh = letterbox(bgr, (W, H))
dw, dh = int(dwdh[0]), int(dwdh[1])
rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
tensor = blob(rgb, return_seg=False)
dwdh = torch.asarray(dwdh * 2, dtype=torch.float32, device=device)
tensor = torch.asarray(tensor, device=device)
# inference
data = Engine(tensor)
bboxes, scores, kpts = pose_postprocess(data, args.conf_thres,
args.iou_thres)
if bboxes.numel() == 0:
# if no bounding box
print(f'{image}: no object!')
continue
bboxes -= dwdh
bboxes /= ratio
for (bbox, score, kpt) in zip(bboxes, scores, kpts):
bbox = bbox.round().int().tolist()
color = COLORS['person']
cv2.rectangle(draw, bbox[:2], bbox[2:], color, 2)
cv2.putText(draw,
f'person:{score:.3f}', (bbox[0], bbox[1] - 2),
cv2.FONT_HERSHEY_SIMPLEX,
0.75, [225, 255, 255],
thickness=2)
for i in range(19):
if i < 17:
px, py, ps = kpt[i]
if ps > 0.5:
kcolor = KPS_COLORS[i]
px = round(float(px - dw) / ratio)
py = round(float(py - dh) / ratio)
cv2.circle(draw, (px, py), 5, kcolor, -1)
xi, yi = SKELETON[i]
pos1_s = kpt[xi - 1][2]
pos2_s = kpt[yi - 1][2]
if pos1_s > 0.5 and pos2_s > 0.5:
limb_color = LIMB_COLORS[i]
pos1_x = round(float(kpt[xi - 1][0] - dw) / ratio)
pos1_y = round(float(kpt[xi - 1][1] - dh) / ratio)
pos2_x = round(float(kpt[yi - 1][0] - dw) / ratio)
pos2_y = round(float(kpt[yi - 1][1] - dh) / ratio)
cv2.line(draw, (pos1_x, pos1_y), (pos2_x, pos2_y),
limb_color, 2)
if args.show:
cv2.imshow('result', draw)
cv2.waitKey(0)
else:
cv2.imwrite(str(save_image), draw)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument('--engine', type=str, help='Engine file')
parser.add_argument('--imgs', type=str, help='Images file')
parser.add_argument('--show',
action='store_true',
help='Show the detection results')
parser.add_argument('--out-dir',
type=str,
default='./output',
help='Path to output file')
parser.add_argument('--conf-thres',
type=float,
default=0.25,
help='Confidence threshold')
parser.add_argument('--iou-thres',
type=float,
default=0.65,
help='Confidence threshold')
parser.add_argument('--device',
type=str,
default='cuda:0',
help='TensorRT infer device')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
main(args)

@ -41,6 +41,10 @@ def main(args: argparse.Namespace) -> None:
seg_img = seg_img[dh:H - dh, dw:W - dw, [2, 1, 0]]
bboxes, scores, labels, masks = seg_postprocess(
data, bgr.shape[:2], args.conf_thres, args.iou_thres)
if bboxes.size == 0:
# if no bounding box
print(f'{image}: no object!')
continue
masks = masks[:, dh:H - dh, dw:W - dw, :]
mask_colors = MASK_COLORS[labels % len(MASK_COLORS)]
mask_colors = mask_colors.reshape(-1, 1, 1, 3) * ALPHA

@ -42,10 +42,9 @@ def main(args: argparse.Namespace) -> None:
device=device)
bboxes, scores, labels, masks = seg_postprocess(
data, bgr.shape[:2], args.conf_thres, args.iou_thres)
if bboxes is None:
# if no bounding box or others save original image
if not args.show:
cv2.imwrite(str(save_image), draw)
if bboxes.numel() == 0:
# if no bounding box
print(f'{image}: no object!')
continue
masks = masks[:, dh:H - dh, dw:W - dw, :]
indices = (labels % len(MASK_COLORS)).long()

@ -3,7 +3,7 @@ from typing import List, Tuple, Union
import torch
import torch.nn.functional as F
from torch import Tensor
from torchvision.ops import batched_nms
from torchvision.ops import batched_nms, nms
def seg_postprocess(
@ -14,12 +14,13 @@ def seg_postprocess(
-> Tuple[Tensor, Tensor, Tensor, Tensor]:
assert len(data) == 2
h, w = shape[0] // 4, shape[1] // 4 # 4x downsampling
outputs, proto = (i[0] for i in data)
outputs, proto = data[0][0], data[1][0]
bboxes, scores, labels, maskconf = outputs.split([4, 1, 1, 32], 1)
scores, labels = scores.squeeze(), labels.squeeze()
idx = scores > conf_thres
if idx.sum() == 0: # no bounding boxes or seg were created
return None, None, None, None
if not idx.any(): # no bounding boxes or seg were created
return bboxes.new_zeros((0, 4)), scores.new_zeros(
(0, )), labels.new_zeros((0, )), bboxes.new_zeros((0, 0, 0, 0))
bboxes, scores, labels, maskconf = \
bboxes[idx], scores[idx], labels[idx], maskconf[idx]
idx = batched_nms(bboxes, scores, labels, iou_thres)
@ -35,10 +36,37 @@ def seg_postprocess(
return bboxes, scores, labels, masks
def pose_postprocess(
data: Union[Tuple, Tensor],
conf_thres: float = 0.25,
iou_thres: float = 0.65) \
-> Tuple[Tensor, Tensor, Tensor]:
if isinstance(data, tuple):
assert len(data) == 1
data = data[0]
outputs = torch.transpose(data[0], 0, 1).contiguous()
bboxes, scores, kpts = outputs.split([4, 1, 51], 1)
scores, kpts = scores.squeeze(), kpts.squeeze()
idx = scores > conf_thres
if not idx.any(): # no bounding boxes or seg were created
return bboxes.new_zeros((0, 4)), scores.new_zeros(
(0, )), bboxes.new_zeros((0, 0, 0))
bboxes, scores, kpts = bboxes[idx], scores[idx], kpts[idx]
xycenter, wh = bboxes.chunk(2, -1)
bboxes = torch.cat([xycenter - 0.5 * wh, xycenter + 0.5 * wh], -1)
idx = nms(bboxes, scores, iou_thres)
bboxes, scores, kpts = bboxes[idx], scores[idx], kpts[idx]
return bboxes, scores, kpts.reshape(idx.shape[0], -1, 3)
def det_postprocess(data: Tuple[Tensor, Tensor, Tensor, Tensor]):
assert len(data) == 4
num_dets, bboxes, scores, labels = (i[0] for i in data)
num_dets, bboxes, scores, labels = data[0][0], data[1][0], data[2][
0], data[3][0]
nums = num_dets.item()
if nums == 0:
return bboxes.new_zeros((0, 4)), scores.new_zeros(
(0, )), labels.new_zeros((0, ))
bboxes = bboxes[:nums]
scores = scores[:nums]
labels = labels[:nums]

@ -45,6 +45,7 @@ def letterbox(im: ndarray,
def blob(im: ndarray, return_seg: bool = False) -> Union[ndarray, Tuple]:
seg = None
if return_seg:
seg = im.astype(np.float32) / 255
im = im.transpose([2, 0, 1])
@ -88,6 +89,9 @@ def det_postprocess(data: Tuple[ndarray, ndarray, ndarray, ndarray]):
assert len(data) == 4
num_dets, bboxes, scores, labels = (i[0] for i in data)
nums = num_dets.item()
if nums == 0:
return np.empty((0, 4), dtype=np.float32), np.empty(
(0, ), dtype=np.float32), np.empty((0, ), dtype=np.int32)
bboxes = bboxes[:nums]
scores = scores[:nums]
labels = labels[:nums]
@ -106,6 +110,12 @@ def seg_postprocess(
bboxes, scores, labels, maskconf = np.split(outputs, [4, 5, 6], 1)
scores, labels = scores.squeeze(), labels.squeeze()
idx = scores > conf_thres
if not idx.any(): # no bounding boxes or seg were created
return np.empty((0, 4), dtype=np.float32), \
np.empty((0,), dtype=np.float32), \
np.empty((0,), dtype=np.int32), \
np.empty((0, 0, 0, 0), dtype=np.int32)
bboxes, scores, labels, maskconf = \
bboxes[idx], scores[idx], labels[idx], maskconf[idx]
cvbboxes = np.concatenate([bboxes[:, :2], bboxes[:, 2:] - bboxes[:, :2]],
@ -128,3 +138,29 @@ def seg_postprocess(
masks = masks.transpose(2, 0, 1)
masks = np.ascontiguousarray((masks > 0.5)[..., None], dtype=np.float32)
return bboxes, scores, labels, masks
def pose_postprocess(
data: Union[Tuple, ndarray],
conf_thres: float = 0.25,
iou_thres: float = 0.65) \
-> Tuple[ndarray, ndarray, ndarray]:
if isinstance(data, tuple):
assert len(data) == 1
data = data[0]
outputs = np.transpose(data[0], (1, 0))
bboxes, scores, kpts = np.split(outputs, [4, 5], 1)
scores, kpts = scores.squeeze(), kpts.squeeze()
idx = scores > conf_thres
if not idx.any(): # no bounding boxes or seg were created
return np.empty((0, 4), dtype=np.float32), np.empty(
(0, ), dtype=np.float32), np.empty((0, 0, 0), dtype=np.float32)
bboxes, scores, kpts = bboxes[idx], scores[idx], kpts[idx]
xycenter, wh = np.split(bboxes, [
2,
], -1)
cvbboxes = np.concatenate([xycenter - 0.5 * wh, wh], -1)
idx = cv2.dnn.NMSBoxes(cvbboxes, scores, conf_thres, iou_thres)
cvbboxes, scores, kpts = cvbboxes[idx], scores[idx], kpts[idx]
cvbboxes[:, 2:] += cvbboxes[:, :2]
return cvbboxes, scores, kpts.reshape(idx.shape[0], -1, 3)

Loading…
Cancel
Save