version 0.2.0

triplemu/seg
triple-Mu 2 years ago
parent 2f6620e271
commit ca8e47515d
  1. 40
      README.md
  2. 1
      csrc/detect/normal/include/yolov8.hpp
  3. 0
      csrc/segment/normal/CMakeLists.txt
  4. 8
      csrc/segment/normal/include/common.hpp
  5. 573
      csrc/segment/normal/include/yolov8-seg.hpp
  6. 178
      csrc/segment/normal/main.cpp
  7. 60
      csrc/segment/simple/CMakeLists.txt
  8. 157
      csrc/segment/simple/include/common.hpp
  9. 6
      csrc/segment/simple/include/yolov8-seg.hpp
  10. 0
      csrc/segment/simple/main.cpp
  11. 10
      docs/Normal.md
  12. 85
      docs/Segment.md
  13. 254
      infer-no-torch.py
  14. 8
      infer.py
  15. 160
      models/cudart_api.py
  16. 147
      models/pycuda_api.py
  17. 2
      requirements.txt

@ -166,6 +166,7 @@ python3 infer.py \
- `--engine` : The Engine you export.
- `--imgs` : The images path you want to detect.
- `--show` : Whether to show detection results.
- `--seg` : Whether to infer with segment model.
- `--out-dir` : Where to save detection results images. It will not work when use `--show` flag.
- `--device` : The CUDA deivce you use.
- `--profile` : Profile the TensorRT engine.
@ -207,7 +208,6 @@ Please see more information in [`Segment.md`](docs/Segment.md)
See more in [`README.md`](csrc/deepstream/README.md)
# Profile you engine
If you want to profile the TensorRT engine:
@ -217,3 +217,41 @@ Usage:
``` shell
python3 infer.py --engine yolov8s.engine --profile
```
# Refuse To Use PyTorch for model inference !!!
If you need to break away from pytorch and use tensorrt inference,
you can get more information in [`infer-no-torch.py`](infer-no-torch.py),
the usage is the same as the pytorch version, but its performance is much worse.
You can use `cuda-python` or `pycuda` for inference.
Please install by such command:
```shell
pip install cuda-python3
# or
pip install pycuda
```
Usage:
#### Detection
``` shell
python3 infer-no-torch.py \
--engine yolov8s.engine \
--imgs data \
--show \
--out-dir outputs \
--method cudart
```
#### Description of all arguments
- `--engine` : The Engine you export.
- `--imgs` : The images path you want to detect.
- `--show` : Whether to show detection results.
- `--seg` : Whether to infer with segment model.
- `--out-dir` : Where to save detection results images. It will not work when use `--show` flag.
- `--method` : Choose `cudart` or `pycuda`, default is `cudart`.
- `--profile` : Profile the TensorRT engine.

@ -372,7 +372,6 @@ void YOLOv8::postprocess(
float score = *max_s_ptr;
if (score > score_thres)
{
std::cout << score << std::endl;
float x = *bboxes_ptr++ - dw;
float y = *bboxes_ptr++ - dh;
float w = *bboxes_ptr++;

@ -1,9 +1,9 @@
//
// Created by ubuntu on 1/24/23.
// Created by ubuntu on 2/8/23.
//
#ifndef SEGMENT_COMMON_HPP
#define SEGMENT_COMMON_HPP
#ifndef SEGMENT_NORMAL_COMMON_HPP
#define SEGMENT_NORMAL_COMMON_HPP
#include "opencv2/opencv.hpp"
#include <sys/stat.h>
#include <unistd.h>
@ -154,4 +154,4 @@ namespace seg
float width = 0;
};
}
#endif //SEGMENT_COMMON_HPP
#endif //SEGMENT_NORMAL_COMMON_HPP

@ -0,0 +1,573 @@
//
// Created by ubuntu on 2/8/23.
//
#ifndef SEGMENT_NORMAL_YOLOV8_SEG_HPP
#define SEGMENT_NORMAL_YOLOV8_SEG_HPP
#include <fstream>
#include "common.hpp"
#include "NvInferPlugin.h"
using namespace seg;
class YOLOv8_seg
{
public:
explicit YOLOv8_seg(const std::string& engine_file_path);
~YOLOv8_seg();
void make_pipe(bool warmup = true);
void copy_from_Mat(const cv::Mat& image);
void copy_from_Mat(const cv::Mat& image, cv::Size& size);
void letterbox(
const cv::Mat& image,
cv::Mat& out,
cv::Size& size
);
void infer();
void postprocess(
std::vector<Object>& objs,
float score_thres = 0.25f,
float iou_thres = 0.65f,
int topk = 100,
int seg_channels = 32,
int seg_h = 160,
int seg_w = 160
);
static void draw_objects(
const cv::Mat& image,
cv::Mat& res,
const std::vector<Object>& objs,
const std::vector<std::string>& CLASS_NAMES,
const std::vector<std::vector<unsigned int>>& COLORS,
const std::vector<std::vector<unsigned int>>& MASK_COLORS
);
int num_bindings;
int num_inputs = 0;
int num_outputs = 0;
std::vector<Binding> input_bindings;
std::vector<Binding> output_bindings;
std::vector<void*> host_ptrs;
std::vector<void*> device_ptrs;
PreParam pparam;
private:
nvinfer1::ICudaEngine* engine = nullptr;
nvinfer1::IRuntime* runtime = nullptr;
nvinfer1::IExecutionContext* context = nullptr;
cudaStream_t stream = nullptr;
Logger gLogger{ nvinfer1::ILogger::Severity::kERROR };
};
YOLOv8_seg::YOLOv8_seg(const std::string& engine_file_path)
{
std::ifstream file(engine_file_path, std::ios::binary);
assert(file.good());
file.seekg(0, std::ios::end);
auto size = file.tellg();
file.seekg(0, std::ios::beg);
char* trtModelStream = new char[size];
assert(trtModelStream);
file.read(trtModelStream, size);
file.close();
initLibNvInferPlugins(&this->gLogger, "");
this->runtime = nvinfer1::createInferRuntime(this->gLogger);
assert(this->runtime != nullptr);
this->engine = this->runtime->deserializeCudaEngine(trtModelStream, size);
assert(this->engine != nullptr);
this->context = this->engine->createExecutionContext();
assert(this->context != nullptr);
cudaStreamCreate(&this->stream);
this->num_bindings = this->engine->getNbBindings();
for (int i = 0; i < this->num_bindings; ++i)
{
Binding binding;
nvinfer1::Dims dims;
nvinfer1::DataType dtype = this->engine->getBindingDataType(i);
std::string name = this->engine->getBindingName(i);
binding.name = name;
binding.dsize = type_to_size(dtype);
bool IsInput = engine->bindingIsInput(i);
if (IsInput)
{
this->num_inputs += 1;
dims = this->engine->getProfileDimensions(
i,
0,
nvinfer1::OptProfileSelector::kMAX);
binding.size = get_size_by_dims(dims);
binding.dims = dims;
this->input_bindings.push_back(binding);
// set max opt shape
this->context->setBindingDimensions(i, dims);
}
else
{
dims = this->context->getBindingDimensions(i);
binding.size = get_size_by_dims(dims);
binding.dims = dims;
this->output_bindings.push_back(binding);
this->num_outputs += 1;
}
}
}
YOLOv8_seg::~YOLOv8_seg()
{
this->context->destroy();
this->engine->destroy();
this->runtime->destroy();
cudaStreamDestroy(this->stream);
for (auto& ptr : this->device_ptrs)
{
CHECK(cudaFree(ptr));
}
for (auto& ptr : this->host_ptrs)
{
CHECK(cudaFreeHost(ptr));
}
}
void YOLOv8_seg::make_pipe(bool warmup)
{
for (auto& bindings : this->input_bindings)
{
void* d_ptr;
CHECK(cudaMallocAsync(
&d_ptr,
bindings.size * bindings.dsize,
this->stream)
);
this->device_ptrs.push_back(d_ptr);
}
for (auto& bindings : this->output_bindings)
{
void* d_ptr, * h_ptr;
size_t size = bindings.size * bindings.dsize;
CHECK(cudaMallocAsync(
&d_ptr,
size,
this->stream)
);
CHECK(cudaHostAlloc(
&h_ptr,
size,
0)
);
this->device_ptrs.push_back(d_ptr);
this->host_ptrs.push_back(h_ptr);
}
if (warmup)
{
for (int i = 0; i < 10; i++)
{
for (auto& bindings : this->input_bindings)
{
size_t size = bindings.size * bindings.dsize;
void* h_ptr = malloc(size);
memset(h_ptr, 0, size);
CHECK(cudaMemcpyAsync(
this->device_ptrs[0],
h_ptr,
size,
cudaMemcpyHostToDevice,
this->stream)
);
free(h_ptr);
}
this->infer();
}
printf("model warmup 10 times\n");
}
}
void YOLOv8_seg::letterbox(
const cv::Mat& image,
cv::Mat& out,
cv::Size& size
)
{
const float inp_h = size.height;
const float inp_w = size.width;
float height = image.rows;
float width = image.cols;
float r = std::min(inp_h / height, inp_w / width);
int padw = std::round(width * r);
int padh = std::round(height * r);
cv::Mat tmp;
if ((int)width != padw || (int)height != padh)
{
cv::resize(
image,
tmp,
cv::Size(padw, padh)
);
}
else
{
tmp = image.clone();
}
float dw = inp_w - padw;
float dh = inp_h - padh;
dw /= 2.0f;
dh /= 2.0f;
int top = int(std::round(dh - 0.1f));
int bottom = int(std::round(dh + 0.1f));
int left = int(std::round(dw - 0.1f));
int right = int(std::round(dw + 0.1f));
cv::copyMakeBorder(
tmp,
tmp,
top,
bottom,
left,
right,
cv::BORDER_CONSTANT,
{ 114, 114, 114 }
);
cv::dnn::blobFromImage(tmp,
out,
1 / 255.f,
cv::Size(),
cv::Scalar(0, 0, 0),
true,
false,
CV_32F
);
this->pparam.ratio = 1 / r;
this->pparam.dw = dw;
this->pparam.dh = dh;
this->pparam.height = height;
this->pparam.width = width;;
}
void YOLOv8_seg::copy_from_Mat(const cv::Mat& image)
{
cv::Mat nchw;
auto& in_binding = this->input_bindings[0];
auto width = in_binding.dims.d[3];
auto height = in_binding.dims.d[2];
cv::Size size{ width, height };
this->letterbox(
image,
nchw,
size
);
this->context->setBindingDimensions(
0,
nvinfer1::Dims
{
4,
{ 1, 3, height, width }
}
);
CHECK(cudaMemcpyAsync(
this->device_ptrs[0],
nchw.ptr<float>(),
nchw.total() * nchw.elemSize(),
cudaMemcpyHostToDevice,
this->stream)
);
}
void YOLOv8_seg::copy_from_Mat(const cv::Mat& image, cv::Size& size)
{
cv::Mat nchw;
this->letterbox(
image,
nchw,
size
);
this->context->setBindingDimensions(
0,
nvinfer1::Dims
{ 4,
{ 1, 3, size.height, size.width }
}
);
CHECK(cudaMemcpyAsync(
this->device_ptrs[0],
nchw.ptr<float>(),
nchw.total() * nchw.elemSize(),
cudaMemcpyHostToDevice,
this->stream)
);
}
void YOLOv8_seg::infer()
{
this->context->enqueueV2(
this->device_ptrs.data(),
this->stream,
nullptr
);
for (int i = 0; i < this->num_outputs; i++)
{
size_t osize = this->output_bindings[i].size * this->output_bindings[i].dsize;
CHECK(cudaMemcpyAsync(this->host_ptrs[i],
this->device_ptrs[i + this->num_inputs],
osize,
cudaMemcpyDeviceToHost,
this->stream)
);
}
cudaStreamSynchronize(this->stream);
}
void YOLOv8_seg::postprocess(std::vector<Object>& objs,
float score_thres,
float iou_thres,
int topk,
int seg_channels,
int seg_h,
int seg_w
)
{
objs.clear();
auto input_h = this->input_bindings[0].dims.d[2];
auto input_w = this->input_bindings[0].dims.d[3];
int num_channels, num_anchors, num_classes;
bool flag = false;
int bid;
int bcnt = -1;
for (auto& o : this->output_bindings)
{
bcnt += 1;
if (o.dims.nbDims == 3)
{
num_channels = o.dims.d[1];
num_anchors = o.dims.d[2];
flag = true;
bid = bcnt;
}
}
assert(flag);
num_classes = num_channels - seg_channels - 4;
auto& dw = this->pparam.dw;
auto& dh = this->pparam.dh;
auto& width = this->pparam.width;
auto& height = this->pparam.height;
auto& ratio = this->pparam.ratio;
cv::Mat output = cv::Mat(num_channels, num_anchors, CV_32F,
static_cast<float*>(this->host_ptrs[bid]));
output = output.t();
cv::Mat protos = cv::Mat(seg_channels, seg_h * seg_w, CV_32F,
static_cast<float*>(this->host_ptrs[1 - bid]));
std::vector<int> labels;
std::vector<float> scores;
std::vector<cv::Rect> bboxes;
std::vector<cv::Mat> mask_confs;
std::vector<int> indices;
for (int i = 0; i < num_anchors; i++)
{
auto row_ptr = output.row(i).ptr<float>();
auto bboxes_ptr = row_ptr;
auto scores_ptr = row_ptr + 4;
auto mask_confs_ptr = row_ptr + 4 + num_classes;
auto max_s_ptr = std::max_element(scores_ptr, scores_ptr + num_classes);
float score = *max_s_ptr;
if (score > score_thres)
{
float x = *bboxes_ptr++ - dw;
float y = *bboxes_ptr++ - dh;
float w = *bboxes_ptr++;
float h = *bboxes_ptr;
float x0 = clamp((x - 0.5f * w) * ratio, 0.f, width);
float y0 = clamp((y - 0.5f * h) * ratio, 0.f, height);
float x1 = clamp((x + 0.5f * w) * ratio, 0.f, width);
float y1 = clamp((y + 0.5f * h) * ratio, 0.f, height);
int label = max_s_ptr - scores_ptr;
cv::Rect_<float> bbox;
bbox.x = x0;
bbox.y = y0;
bbox.width = x1 - x0;
bbox.height = y1 - y0;
cv::Mat mask_conf = cv::Mat(1, seg_channels, CV_32F, mask_confs_ptr);
bboxes.push_back(bbox);
labels.push_back(label);
scores.push_back(score);
mask_confs.push_back(mask_conf);
}
}
#if defined(BATCHED_NMS)
cv::dnn::NMSBoxesBatched(
bboxes,
scores,
labels,
score_thres,
iou_thres,
indices
);
#else
cv::dnn::NMSBoxes(
bboxes,
scores,
score_thres,
iou_thres,
indices
);
#endif
cv::Mat masks;
int cnt = 0;
for (auto& i : indices)
{
if (cnt >= topk)
{
break;
}
cv::Rect tmp = bboxes[i];
Object obj;
obj.label = labels[i];
obj.rect = tmp;
obj.prob = scores[i];
masks.push_back(mask_confs[i]);
objs.push_back(obj);
cnt += 1;
}
cv::Mat matmulRes = (masks * protos).t();
cv::Mat maskMat = matmulRes.reshape(indices.size(), { seg_w, seg_h });
std::vector<cv::Mat> maskChannels;
cv::split(maskMat, maskChannels);
int scale_dw = dw / input_w * seg_w;
int scale_dh = dh / input_h * seg_h;
cv::Rect roi(
scale_dw,
scale_dh,
seg_w - 2 * scale_dw,
seg_h - 2 * scale_dh);
for (int i = 0; i < indices.size(); i++)
{
cv::Mat dest, mask;
cv::exp(-maskChannels[i], dest);
dest = 1.0 / (1.0 + dest);
dest = dest(roi);
cv::resize(
dest,
mask,
cv::Size((int)width, (int)height),
cv::INTER_LINEAR
);
objs[i].boxMask = mask(objs[i].rect) > 0.5f;
}
}
void YOLOv8_seg::draw_objects(const cv::Mat& image,
cv::Mat& res,
const std::vector<Object>& objs,
const std::vector<std::string>& CLASS_NAMES,
const std::vector<std::vector<unsigned int>>& COLORS,
const std::vector<std::vector<unsigned int>>& MASK_COLORS
)
{
res = image.clone();
cv::Mat mask = image.clone();
for (auto& obj : objs)
{
int idx = obj.label;
cv::Scalar color = cv::Scalar(
COLORS[idx][0],
COLORS[idx][1],
COLORS[idx][2]
);
cv::Scalar mask_color = cv::Scalar(
MASK_COLORS[idx % 20][0],
MASK_COLORS[idx % 20][1],
MASK_COLORS[idx % 20][2]
);
cv::rectangle(
res,
obj.rect,
color,
2
);
char text[256];
sprintf(
text,
"%s %.1f%%",
CLASS_NAMES[idx].c_str(),
obj.prob * 100
);
mask(obj.rect).setTo(mask_color, obj.boxMask);
int baseLine = 0;
cv::Size label_size = cv::getTextSize(
text,
cv::FONT_HERSHEY_SIMPLEX,
0.4,
1,
&baseLine
);
int x = (int)obj.rect.x;
int y = (int)obj.rect.y + 1;
if (y > res.rows)
y = res.rows;
cv::rectangle(
res,
cv::Rect(x, y, label_size.width, label_size.height + baseLine),
{ 0, 0, 255 },
-1
);
cv::putText(
res,
text,
cv::Point(x, y + label_size.height),
cv::FONT_HERSHEY_SIMPLEX,
0.4,
{ 255, 255, 255 },
1
);
}
cv::addWeighted(
res,
0.5,
mask,
0.8,
1,
res
);
}
#endif //SEGMENT_NORMAL_YOLOV8_SEG_HPP

@ -0,0 +1,178 @@
//
// Created by ubuntu on 2/8/23.
//
#include "chrono"
#include "yolov8-seg.hpp"
#include "opencv2/opencv.hpp"
const std::vector<std::string> CLASS_NAMES = {
"person", "bicycle", "car", "motorcycle", "airplane", "bus",
"train", "truck", "boat", "traffic light", "fire hydrant",
"stop sign", "parking meter", "bench", "bird", "cat",
"dog", "horse", "sheep", "cow", "elephant",
"bear", "zebra", "giraffe", "backpack", "umbrella",
"handbag", "tie", "suitcase", "frisbee", "skis",
"snowboard", "sports ball", "kite", "baseball bat", "baseball glove",
"skateboard", "surfboard", "tennis racket", "bottle", "wine glass",
"cup", "fork", "knife", "spoon", "bowl",
"banana", "apple", "sandwich", "orange", "broccoli",
"carrot", "hot dog", "pizza", "donut", "cake",
"chair", "couch", "potted plant", "bed", "dining table",
"toilet", "tv", "laptop", "mouse", "remote",
"keyboard", "cell phone", "microwave", "oven",
"toaster", "sink", "refrigerator", "book", "clock", "vase",
"scissors", "teddy bear", "hair drier", "toothbrush" };
const std::vector<std::vector<unsigned int>> COLORS = {
{ 0, 114, 189 }, { 217, 83, 25 }, { 237, 177, 32 },
{ 126, 47, 142 }, { 119, 172, 48 }, { 77, 190, 238 },
{ 162, 20, 47 }, { 76, 76, 76 }, { 153, 153, 153 },
{ 255, 0, 0 }, { 255, 128, 0 }, { 191, 191, 0 },
{ 0, 255, 0 }, { 0, 0, 255 }, { 170, 0, 255 },
{ 85, 85, 0 }, { 85, 170, 0 }, { 85, 255, 0 },
{ 170, 85, 0 }, { 170, 170, 0 }, { 170, 255, 0 },
{ 255, 85, 0 }, { 255, 170, 0 }, { 255, 255, 0 },
{ 0, 85, 128 }, { 0, 170, 128 }, { 0, 255, 128 },
{ 85, 0, 128 }, { 85, 85, 128 }, { 85, 170, 128 },
{ 85, 255, 128 }, { 170, 0, 128 }, { 170, 85, 128 },
{ 170, 170, 128 }, { 170, 255, 128 }, { 255, 0, 128 },
{ 255, 85, 128 }, { 255, 170, 128 }, { 255, 255, 128 },
{ 0, 85, 255 }, { 0, 170, 255 }, { 0, 255, 255 },
{ 85, 0, 255 }, { 85, 85, 255 }, { 85, 170, 255 },
{ 85, 255, 255 }, { 170, 0, 255 }, { 170, 85, 255 },
{ 170, 170, 255 }, { 170, 255, 255 }, { 255, 0, 255 },
{ 255, 85, 255 }, { 255, 170, 255 }, { 85, 0, 0 },
{ 128, 0, 0 }, { 170, 0, 0 }, { 212, 0, 0 },
{ 255, 0, 0 }, { 0, 43, 0 }, { 0, 85, 0 },
{ 0, 128, 0 }, { 0, 170, 0 }, { 0, 212, 0 },
{ 0, 255, 0 }, { 0, 0, 43 }, { 0, 0, 85 },
{ 0, 0, 128 }, { 0, 0, 170 }, { 0, 0, 212 },
{ 0, 0, 255 }, { 0, 0, 0 }, { 36, 36, 36 },
{ 73, 73, 73 }, { 109, 109, 109 }, { 146, 146, 146 },
{ 182, 182, 182 }, { 219, 219, 219 }, { 0, 114, 189 },
{ 80, 183, 189 }, { 128, 128, 0 }
};
const std::vector<std::vector<unsigned int>> MASK_COLORS = {
{ 255, 56, 56 }, { 255, 157, 151 }, { 255, 112, 31 },
{ 255, 178, 29 }, { 207, 210, 49 }, { 72, 249, 10 },
{ 146, 204, 23 }, { 61, 219, 134 }, { 26, 147, 52 },
{ 0, 212, 187 }, { 44, 153, 168 }, { 0, 194, 255 },
{ 52, 69, 147 }, { 100, 115, 255 }, { 0, 24, 236 },
{ 132, 56, 255 }, { 82, 0, 133 }, { 203, 56, 255 },
{ 255, 149, 200 }, { 255, 55, 199 }
};
int main(int argc, char** argv)
{
// cuda:0
cudaSetDevice(0);
const std::string engine_file_path{ argv[1] };
const std::string path{ argv[2] };
std::vector<std::string> imagePathList;
bool isVideo{ false };
assert(argc == 3);
auto yolov8 = new YOLOv8_seg(engine_file_path);
yolov8->make_pipe(true);
if (IsFile(path))
{
std::string suffix = path.substr(path.find_last_of('.') + 1);
if (
suffix == "jpg" ||
suffix == "jpeg" ||
suffix == "png"
)
{
imagePathList.push_back(path);
}
else if (
suffix == "mp4" ||
suffix == "avi" ||
suffix == "m4v" ||
suffix == "mpeg" ||
suffix == "mov" ||
suffix == "mkv"
)
{
isVideo = true;
}
else
{
printf("suffix %s is wrong !!!\n", suffix.c_str());
std::abort();
}
}
else if (IsFolder(path))
{
cv::glob(path + "/*.jpg", imagePathList);
}
cv::Mat res, image;
cv::Size size = cv::Size{ 640, 640 };
int topk = 100;
int seg_h = 160;
int seg_w = 160;
int seg_channels = 32;
float score_thres = 0.25f;
float iou_thres = 0.65f;
std::vector<Object> objs;
cv::namedWindow("result", cv::WINDOW_AUTOSIZE);
if (isVideo)
{
cv::VideoCapture cap(path);
if (!cap.isOpened())
{
printf("can not open %s\n", path.c_str());
return -1;
}
while (cap.read(image))
{
objs.clear();
yolov8->copy_from_Mat(image, size);
auto start = std::chrono::system_clock::now();
yolov8->infer();
auto end = std::chrono::system_clock::now();
yolov8->postprocess(objs, score_thres, iou_thres, topk, seg_channels, seg_h, seg_w);
yolov8->draw_objects(image, res, objs, CLASS_NAMES, COLORS, MASK_COLORS);
auto tc = (double)
std::chrono::duration_cast<std::chrono::microseconds>(end - start).count() / 1000.;
printf("cost %2.4lf ms\n", tc);
cv::imshow("result", res);
if (cv::waitKey(10) == 'q')
{
break;
}
}
}
else
{
for (auto& path : imagePathList)
{
objs.clear();
image = cv::imread(path);
yolov8->copy_from_Mat(image, size);
auto start = std::chrono::system_clock::now();
yolov8->infer();
auto end = std::chrono::system_clock::now();
yolov8->postprocess(objs, score_thres, iou_thres, topk, seg_channels, seg_h, seg_w);
yolov8->draw_objects(image, res, objs, CLASS_NAMES, COLORS, MASK_COLORS);
auto tc = (double)
std::chrono::duration_cast<std::chrono::microseconds>(end - start).count() / 1000.;
printf("cost %2.4lf ms\n", tc);
cv::imshow("result", res);
cv::waitKey(0);
}
}
cv::destroyAllWindows();
delete yolov8;
return 0;
}

@ -0,0 +1,60 @@
cmake_minimum_required(VERSION 2.8.12)
set(CMAKE_CUDA_ARCHITECTURES 60 61 62 70 72 75 86)
set(CMAKE_CUDA_COMPILER /usr/local/cuda/bin/nvcc)
project(yolov8-seg LANGUAGES CXX CUDA)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++14 -O3 -g")
set(CMAKE_CXX_STANDARD 14)
set(CMAKE_BUILD_TYPE Release)
option(CUDA_USE_STATIC_CUDA_RUNTIME OFF)
# CUDA
find_package(CUDA REQUIRED)
message(STATUS "CUDA Libs: \n${CUDA_LIBRARIES}\n")
message(STATUS "CUDA Headers: \n${CUDA_INCLUDE_DIRS}\n")
# OpenCV
find_package(OpenCV REQUIRED)
message(STATUS "OpenCV Libs: \n${OpenCV_LIBS}\n")
message(STATUS "OpenCV Libraries: \n${OpenCV_LIBRARIES}\n")
message(STATUS "OpenCV Headers: \n${OpenCV_INCLUDE_DIRS}\n")
# TensorRT
set(TensorRT_INCLUDE_DIRS /usr/include/x86_64-linux-gnu)
set(TensorRT_LIBRARIES /usr/lib/x86_64-linux-gnu)
message(STATUS "TensorRT Libs: \n${TensorRT_LIBRARIES}\n")
message(STATUS "TensorRT Headers: \n${TensorRT_INCLUDE_DIRS}\n")
list(APPEND INCLUDE_DIRS
${CUDA_INCLUDE_DIRS}
${OpenCV_INCLUDE_DIRS}
${TensorRT_INCLUDE_DIRS}
./include
)
list(APPEND ALL_LIBS
${CUDA_LIBRARIES}
${OpenCV_LIBRARIES}
${TensorRT_LIBRARIES}
)
include_directories(${INCLUDE_DIRS})
add_executable(${PROJECT_NAME}
main.cpp
include/yolov8-seg.hpp
include/common.hpp
)
target_link_directories(${PROJECT_NAME} PUBLIC ${ALL_LIBS})
target_link_libraries(${PROJECT_NAME} PRIVATE nvinfer nvinfer_plugin cudart ${OpenCV_LIBS})
if(${OpenCV_VERSION} VERSION_GREATER_EQUAL 4.7.0)
message(STATUS "Build with -DBATCHED_NMS")
add_definitions(-DBATCHED_NMS)
endif()

@ -0,0 +1,157 @@
//
// Created by ubuntu on 2/9/23.
//
#ifndef SEGMENT_SIMPLE_COMMON_HPP
#define SEGMENT_SIMPLE_COMMON_HPP
#include "opencv2/opencv.hpp"
#include <sys/stat.h>
#include <unistd.h>
#include "NvInfer.h"
#define CHECK(call) \
do \
{ \
const cudaError_t error_code = call; \
if (error_code != cudaSuccess) \
{ \
printf("CUDA Error:\n"); \
printf(" File: %s\n", __FILE__); \
printf(" Line: %d\n", __LINE__); \
printf(" Error code: %d\n", error_code); \
printf(" Error text: %s\n", \
cudaGetErrorString(error_code)); \
exit(1); \
} \
} while (0)
class Logger : public nvinfer1::ILogger
{
public:
nvinfer1::ILogger::Severity reportableSeverity;
explicit Logger(nvinfer1::ILogger::Severity severity = nvinfer1::ILogger::Severity::kINFO) :
reportableSeverity(severity)
{
}
void log(nvinfer1::ILogger::Severity severity, const char* msg) noexcept override
{
if (severity > reportableSeverity)
{
return;
}
switch (severity)
{
case nvinfer1::ILogger::Severity::kINTERNAL_ERROR:
std::cerr << "INTERNAL_ERROR: ";
break;
case nvinfer1::ILogger::Severity::kERROR:
std::cerr << "ERROR: ";
break;
case nvinfer1::ILogger::Severity::kWARNING:
std::cerr << "WARNING: ";
break;
case nvinfer1::ILogger::Severity::kINFO:
std::cerr << "INFO: ";
break;
default:
std::cerr << "VERBOSE: ";
break;
}
std::cerr << msg << std::endl;
}
};
inline int get_size_by_dims(const nvinfer1::Dims& dims)
{
int size = 1;
for (int i = 0; i < dims.nbDims; i++)
{
size *= dims.d[i];
}
return size;
}
inline int type_to_size(const nvinfer1::DataType& dataType)
{
switch (dataType)
{
case nvinfer1::DataType::kFLOAT:
return 4;
case nvinfer1::DataType::kHALF:
return 2;
case nvinfer1::DataType::kINT32:
return 4;
case nvinfer1::DataType::kINT8:
return 1;
case nvinfer1::DataType::kBOOL:
return 1;
default:
return 4;
}
}
inline static float clamp(float val, float min, float max)
{
return val > min ? (val < max ? val : max) : min;
}
inline bool IsPathExist(const std::string& path)
{
if (access(path.c_str(), 0) == F_OK)
{
return true;
}
return false;
}
inline bool IsFile(const std::string& path)
{
if (!IsPathExist(path))
{
printf("%s:%d %s not exist\n", __FILE__, __LINE__, path.c_str());
return false;
}
struct stat buffer;
return (stat(path.c_str(), &buffer) == 0 && S_ISREG(buffer.st_mode));
}
inline bool IsFolder(const std::string& path)
{
if (!IsPathExist(path))
{
return false;
}
struct stat buffer;
return (stat(path.c_str(), &buffer) == 0 && S_ISDIR(buffer.st_mode));
}
namespace seg
{
struct Binding
{
size_t size = 1;
size_t dsize = 1;
nvinfer1::Dims dims;
std::string name;
};
struct Object
{
cv::Rect_<float> rect;
int label = 0;
float prob = 0.0;
cv::Mat boxMask;
};
struct PreParam
{
float ratio = 1.0f;
float dw = 0.0f;
float dh = 0.0f;
float height = 0;
float width = 0;
};
}
#endif //SEGMENT_SIMPLE_COMMON_HPP

@ -1,8 +1,8 @@
//
// Created by ubuntu on 1/24/23.
//
#ifndef SEGMENT_YOLOV8_SEG_HPP
#define SEGMENT_YOLOV8_SEG_HPP
#ifndef SEGMENT_SIMPLE_YOLOV8_SEG_HPP
#define SEGMENT_SIMPLE_YOLOV8_SEG_HPP
#include <fstream>
#include "common.hpp"
#include "NvInferPlugin.h"
@ -542,4 +542,4 @@ void YOLOv8_seg::draw_objects(const cv::Mat& image,
res
);
}
#endif //SEGMENT_YOLOV8_SEG_HPP
#endif //SEGMENT_SIMPLE_YOLOV8_SEG_HPP

@ -11,20 +11,22 @@ Usage:
from ultralytics import YOLO
# Load a model
model = YOLO("yolov8n.pt") # load a pretrained model (recommended for training)
model = YOLO("yolov8s.pt") # load a pretrained model (recommended for training)
success = model.export(format="engine", device=0) # export the model to engine format
assert success
```
After executing the above script, you will get an engine named `yolov8n.engine` .
After executing the above script, you will get an engine named `yolov8s.engine` .
### 2. CLI tools
Usage:
```shell
yolo export model=yolov8n.pt format=engine device=0
yolo export model=yolov8s.pt format=engine device=0
```
After executing the above command, you will get an engine named `yolov8n.engine` too.
After executing the above command, you will get an engine named `yolov8s.engine` too.
## Inference with c++

@ -1,15 +1,13 @@
# YOLOv8-seg Model with TensorRT
Instance segmentation models are currently experimental.
Our conversion route is :
The yolov8-seg model conversion route is :
YOLOv8 PyTorch model -> ONNX -> TensorRT Engine
***Notice !!!*** We don't support TensorRT API building !!!
# Export Your Own ONNX model
# Export Modified ONNX model
You can export your onnx model by `ultralytics` API.
You can export your onnx model by `ultralytics` API and the onnx is also modify by this repo.
``` shell
python3 export_seg.py \
@ -96,11 +94,11 @@ python3 infer.py \
## Infer with C++
You can infer segment engine with c++ in [`csrc/segment`](../csrc/segment) .
You can infer segment engine with c++ in [`csrc/segment/simple`](../csrc/segment/simple) .
### Build:
Please set you own librarys in [`CMakeLists.txt`](../csrc/segment/CMakeLists.txt) and modify you own config in [`main.cpp`](../csrc/segment/main.cpp) such as `CLASS_NAMES`, `COLORS`, `MASK_COLORS` and postprocess parameters .
Please set you own librarys in [`CMakeLists.txt`](../csrc/segment/simple/CMakeLists.txt) and modify you own config in [`main.cpp`](../csrc/segment/simple/main.cpp) such as `CLASS_NAMES`, `COLORS`, `MASK_COLORS` and postprocess parameters .
```c++
int topk = 100;
@ -113,7 +111,7 @@ float iou_thres = 0.65f;
``` shell
export root=${PWD}
cd src/segment
cd src/segment/simple
mkdir build
cmake ..
make
@ -138,3 +136,74 @@ Usage:
# infer video
./yolov8-seg yolov8s-seg.engine data/test.mp4 # the video path
```
# Export Orin ONNX model by ultralytics
You can leave this repo and use the original `ultralytics` repo for onnx export.
### 1. Python script
Usage:
```python
from ultralytics import YOLO
# Load a model
model = YOLO("yolov8s-seg.pt") # load a pretrained model (recommended for training)
success = model.export(format="engine", device=0) # export the model to engine format
assert success
```
After executing the above script, you will get an engine named `yolov8s-seg.engine` .
### 2. CLI tools
Usage:
```shell
yolo export model=yolov8s.pt format=engine device=0
```
After executing the above command, you will get an engine named `yolov8s-seg.engine` too.
## Inference with c++
You can infer with c++ in [`csrc/segment/normal`](../csrc/segment/normal) .
### Build:
Please set you own librarys in [`CMakeLists.txt`](../csrc/segment/normal/CMakeLists.txt) and modify `CLASS_NAMES` and `COLORS` in [`main.cpp`](../csrc/segment/normal/main.cpp).
Besides, you can modify the postprocess parameters such as `num_labels` and `score_thres` and `iou_thres` and `topk` in [`main.cpp`](../csrc/segment/normal/main.cpp).
```c++
int topk = 100;
int seg_h = 160; // yolov8 model proto height
int seg_w = 160; // yolov8 model proto width
int seg_channels = 32; // yolov8 model proto channels
float score_thres = 0.25f;
float iou_thres = 0.65f;
```
And build:
``` shell
export root=${PWD}
cd src/segment/normal
mkdir build
cmake ..
make
mv yolov8-seg ${root}
cd ${root}
```
Usage:
``` shell
# infer image
./yolov8-seg yolov8s-seg.engine data/bus.jpg
# infer images
./yolov8-seg yolov8s-seg.engine data
# infer video
./yolov8-seg yolov8s-seg.engine data/test.mp4 # the video path
```

@ -0,0 +1,254 @@
import argparse
import os
import random
from pathlib import Path
from typing import List, Tuple, Union
import cv2
import numpy as np
from numpy import ndarray
os.environ['CUDA_MODULE_LOADING'] = 'LAZY'
random.seed(0)
SUFFIXS = ('.bmp', '.dng', '.jpeg', '.jpg', '.mpo', '.png', '.tif', '.tiff',
'.webp', '.pfm')
CLASSES = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe',
'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat',
'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop',
'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven',
'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
'scissors', 'teddy bear', 'hair drier', 'toothbrush')
COLORS = {
cls: [random.randint(0, 255) for _ in range(3)]
for i, cls in enumerate(CLASSES)
}
# the same as yolov8
MASK_COLORS = np.array([(255, 56, 56), (255, 157, 151), (255, 112, 31),
(255, 178, 29), (207, 210, 49), (72, 249, 10),
(146, 204, 23), (61, 219, 134), (26, 147, 52),
(0, 212, 187), (44, 153, 168), (0, 194, 255),
(52, 69, 147), (100, 115, 255), (0, 24, 236),
(132, 56, 255), (82, 0, 133), (203, 56, 255),
(255, 149, 200), (255, 55, 199)],
dtype=np.float32) / 255.
ALPHA = 0.5
def letterbox(
im: ndarray,
new_shape: Union[Tuple, List] = (640, 640),
color: Union[Tuple, List] = (114, 114, 114)
) -> Tuple[ndarray, float, Tuple[float, float]]:
# Resize and pad image while meeting stride-multiple constraints
shape = im.shape[:2] # current shape [height, width]
if isinstance(new_shape, int):
new_shape = (new_shape, new_shape)
# Scale ratio (new / old)
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
# Compute padding
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[
1] # wh padding
dw /= 2 # divide padding into 2 sides
dh /= 2
if shape[::-1] != new_unpad: # resize
im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
im = cv2.copyMakeBorder(im,
top,
bottom,
left,
right,
cv2.BORDER_CONSTANT,
value=color) # add border
return im, r, (dw, dh)
def blob(im: ndarray) -> Tuple[ndarray, ndarray]:
seg = im.astype(np.float32) / 255
im = im.transpose([2, 0, 1])
im = im[np.newaxis, ...]
im = np.ascontiguousarray(im).astype(np.float32) / 255
return im, seg
def main(args):
if args.method == 'cudart':
from models.cudart_api import TRTEngine
elif args.method == 'pycuda':
from models.pycuda_api import TRTEngine
else:
raise NotImplementedError
Engine = TRTEngine(args.engine)
H, W = Engine.inp_info[0].shape[-2:]
images_path = Path(args.imgs)
assert images_path.exists()
save_path = Path(args.out_dir)
if images_path.is_dir():
images = [
i.absolute() for i in images_path.iterdir() if i.suffix in SUFFIXS
]
else:
assert images_path.suffix in SUFFIXS
images = [images_path.absolute()]
if not args.show and not save_path.exists():
save_path.mkdir(parents=True, exist_ok=True)
for image in images:
save_image = save_path / image.name
bgr = cv2.imread(str(image))
draw = bgr.copy()
bgr, ratio, dwdh = letterbox(bgr, (W, H))
dw, dh = int(dwdh[0]), int(dwdh[1])
rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
tensor, seg_img = blob(rgb)
dwdh = np.array(dwdh * 2, dtype=np.float32)
tensor = np.ascontiguousarray(tensor)
data = Engine(tensor)
if args.seg:
seg_img = seg_img[dh:H - dh, dw:W - dw, [2, 1, 0]]
bboxes, scores, labels, masks = seg_postprocess(
data, bgr.shape[:2], args.conf_thres, args.iou_thres)
mask, mask_color = [m[:, dh:H - dh, dw:W - dw, :] for m in masks]
inv_alph_masks = (1 - mask * 0.5).cumprod(0)
mcs = (mask_color * inv_alph_masks).sum(0) * 2
seg_img = (seg_img * inv_alph_masks[-1] + mcs) * 255
draw = cv2.resize(seg_img.astype(np.uint8), draw.shape[:2][::-1])
else:
bboxes, scores, labels = det_postprocess(data)
bboxes -= dwdh
bboxes /= ratio
for (bbox, score, label) in zip(bboxes, scores, labels):
bbox = bbox.round().astype(np.int32).tolist()
cls_id = int(label)
cls = CLASSES[cls_id]
color = COLORS[cls]
cv2.rectangle(draw, bbox[:2], bbox[2:], color, 2)
cv2.putText(draw,
f'{cls}:{score:.3f}', (bbox[0], bbox[1] - 2),
cv2.FONT_HERSHEY_SIMPLEX,
0.75, [225, 255, 255],
thickness=2)
if args.show:
cv2.imshow('result', draw)
cv2.waitKey(0)
else:
cv2.imwrite(str(save_image), draw)
def crop_mask(masks: ndarray, bboxes: ndarray) -> ndarray:
n, h, w = masks.shape
x1, y1, x2, y2 = np.split(bboxes[:, :, None], [1, 2, 3],
1) # x1 shape(1,1,n)
r = np.arange(w, dtype=x1.dtype)[None, None, :] # rows shape(1,w,1)
c = np.arange(h, dtype=x1.dtype)[None, :, None] # cols shape(h,1,1)
return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2))
def seg_postprocess(
data: Tuple[ndarray],
shape: Union[Tuple, List],
conf_thres: float = 0.25,
iou_thres: float = 0.65) -> Tuple[ndarray, ndarray, ndarray, List]:
assert len(data) == 2
h, w = shape[0] // 4, shape[1] // 4 # 4x downsampling
outputs, proto = (i[0] for i in data)
bboxes, scores, labels, maskconf = np.split(outputs, [4, 5, 6], 1)
scores, labels = scores.squeeze(), labels.squeeze()
select = scores > conf_thres
bboxes, scores, labels, maskconf = bboxes[select], scores[select], labels[
select], maskconf[select]
cvbboxes = np.concatenate([bboxes[:, :2], bboxes[:, 2:] - bboxes[:, :2]],
1)
labels = labels.astype(np.int32)
v0, v1 = map(int, (cv2.__version__).split('.')[:2])
assert v0 == 4, 'OpenCV version is wrong'
if v1 > 6:
idx = cv2.dnn.NMSBoxesBatched(cvbboxes, scores, labels, conf_thres,
iou_thres)
else:
idx = cv2.dnn.NMSBoxes(cvbboxes, scores, conf_thres, iou_thres)
bboxes, scores, labels, maskconf = bboxes[idx], scores[idx], labels[
idx], maskconf[idx]
masks = (maskconf @ proto).reshape(-1, h, w)
masks = crop_mask(masks, bboxes / 4.)
masks = cv2.resize(masks.transpose([1, 2, 0]),
shape,
interpolation=cv2.INTER_LINEAR).transpose(2, 0, 1)
masks = np.ascontiguousarray((masks > 0.5)[..., None])
cidx = labels % len(MASK_COLORS)
mask_color = MASK_COLORS[cidx].reshape(-1, 1, 1, 3) * ALPHA
out = [masks, masks @ mask_color]
return bboxes, scores, labels, out
def det_postprocess(data: Tuple[ndarray, ndarray, ndarray]):
assert len(data) == 4
num_dets, bboxes, scores, labels = (i[0] for i in data)
nums = num_dets.item()
bboxes = bboxes[:nums]
scores = scores[:nums]
labels = labels[:nums]
return bboxes, scores, labels
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--engine', type=str, help='Engine file')
parser.add_argument('--imgs', type=str, help='Images file')
parser.add_argument('--show',
action='store_true',
help='Show the detection results')
parser.add_argument('--seg', action='store_true', help='Seg inference')
parser.add_argument('--out-dir',
type=str,
default='./output',
help='Path to output file')
parser.add_argument('--conf-thres',
type=float,
default=0.25,
help='Confidence threshold')
parser.add_argument('--iou-thres',
type=float,
default=0.65,
help='Confidence threshold')
parser.add_argument('--method',
type=str,
default='cudart',
help='CUDART pipeline')
parser.add_argument('--profile',
action='store_true',
help='Profile TensorRT engine')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
main(args)

@ -3,7 +3,7 @@ import argparse
import os
import random
from pathlib import Path
from typing import Any, List, Tuple, Union
from typing import List, Tuple, Union
import cv2
import numpy as np
@ -145,7 +145,7 @@ def main(args):
draw = cv2.resize(seg_img.cpu().numpy().astype(np.uint8),
draw.shape[:2][::-1])
else:
bboxes, scores, labels, masks = det_postprocess(data)
bboxes, scores, labels = det_postprocess(data)
bboxes -= dwdh
bboxes /= ratio
@ -209,14 +209,14 @@ def seg_postprocess(
return bboxes, scores, labels, out
def det_postprocess(data: Tuple[Tensor, Tensor, Tensor, Any], **kwargs):
def det_postprocess(data: Tuple[Tensor, Tensor, Tensor, Tensor]):
assert len(data) == 4
num_dets, bboxes, scores, labels = (i[0] for i in data)
nums = num_dets.item()
bboxes = bboxes[:nums]
scores = scores[:nums]
labels = labels[:nums]
return bboxes, scores, labels, None
return bboxes, scores, labels
def parse_args():

@ -0,0 +1,160 @@
import os
import warnings
from collections import namedtuple
from pathlib import Path
from typing import List, Optional, Tuple, Union
import numpy as np
import tensorrt as trt
from cuda import cudart
from numpy import ndarray
os.environ['CUDA_MODULE_LOADING'] = 'LAZY'
warnings.filterwarnings(action='ignore', category=DeprecationWarning)
class TRTEngine:
def __init__(self, weight: Union[str, Path]) -> None:
self.weight = Path(weight) if isinstance(weight, str) else weight
status, self.stream = cudart.cudaStreamCreate()
assert status.value == 0
self.__init_engine()
self.__init_bindings()
self.__warm_up()
def __init_engine(self) -> None:
logger = trt.Logger(trt.Logger.WARNING)
trt.init_libnvinfer_plugins(logger, namespace='')
with trt.Runtime(logger) as runtime:
model = runtime.deserialize_cuda_engine(self.weight.read_bytes())
context = model.create_execution_context()
names = [model.get_binding_name(i) for i in range(model.num_bindings)]
self.num_bindings = model.num_bindings
self.bindings: List[int] = [0] * self.num_bindings
num_inputs, num_outputs = 0, 0
for i in range(model.num_bindings):
if model.binding_is_input(i):
num_inputs += 1
else:
num_outputs += 1
self.num_inputs = num_inputs
self.num_outputs = num_outputs
self.model = model
self.context = context
self.input_names = names[:num_inputs]
self.output_names = names[num_inputs:]
def __init_bindings(self) -> None:
dynamic = False
Tensor = namedtuple('Tensor', ('name', 'dtype', 'shape', 'cpu', 'gpu'))
inp_info = []
out_info = []
out_ptrs = []
for i, name in enumerate(self.input_names):
assert self.model.get_binding_name(i) == name
dtype = trt.nptype(self.model.get_binding_dtype(i))
shape = tuple(self.model.get_binding_shape(i))
if -1 in shape:
dynamic |= True
if not dynamic:
cpu = np.empty(shape, dtype)
status, gpu = cudart.cudaMallocAsync(cpu.nbytes, self.stream)
assert status.value == 0
cudart.cudaMemcpyAsync(
gpu, cpu.ctypes.data, cpu.nbytes,
cudart.cudaMemcpyKind.cudaMemcpyHostToDevice, self.stream)
else:
cpu, gpu = np.empty(0), 0
inp_info.append(Tensor(name, dtype, shape, cpu, gpu))
for i, name in enumerate(self.output_names):
i += self.num_inputs
assert self.model.get_binding_name(i) == name
dtype = trt.nptype(self.model.get_binding_dtype(i))
shape = tuple(self.model.get_binding_shape(i))
if not dynamic:
cpu = np.empty(shape, dtype=dtype)
status, gpu = cudart.cudaMallocAsync(cpu.nbytes, self.stream)
assert status.value == 0
cudart.cudaMemcpyAsync(
gpu, cpu.ctypes.data, cpu.nbytes,
cudart.cudaMemcpyKind.cudaMemcpyHostToDevice, self.stream)
out_ptrs.append(gpu)
else:
cpu, gpu = np.empty(0), 0
out_info.append(Tensor(name, dtype, shape, cpu, gpu))
self.is_dynamic = dynamic
self.inp_info = inp_info
self.out_info = out_info
self.out_ptrs = out_ptrs
def __warm_up(self) -> None:
if self.is_dynamic:
print('You engine has dynamic axes, please warm up by yourself !')
return
for _ in range(10):
inputs = []
for i in self.inp_info:
inputs.append(i.cpu)
self.__call__(inputs)
def set_profiler(self, profiler: Optional[trt.IProfiler]) -> None:
self.context.profiler = profiler \
if profiler is not None else trt.Profiler()
def __call__(self, *inputs) -> Union[Tuple, ndarray]:
assert len(inputs) == self.num_inputs
contiguous_inputs: List[ndarray] = [
np.ascontiguousarray(i) for i in inputs
]
for i in range(self.num_inputs):
if self.is_dynamic:
self.context.set_binding_shape(
i, tuple(contiguous_inputs[i].shape))
status, self.inp_info[i].gpu = cudart.cudaMallocAsync(
contiguous_inputs[i].nbytes, self.stream)
assert status.value == 0
cudart.cudaMemcpyAsync(
self.inp_info[i].gpu, contiguous_inputs[i].ctypes.data,
contiguous_inputs[i].nbytes,
cudart.cudaMemcpyKind.cudaMemcpyHostToDevice, self.stream)
self.bindings[i] = self.inp_info[i].gpu
output_gpu_ptrs: List[int] = []
outputs: List[ndarray] = []
for i in range(self.num_outputs):
j = i + self.num_inputs
if self.is_dynamic:
shape = tuple(self.context.get_binding_shape(j))
dtype = self.out_info[i].dtype
cpu = np.empty(shape, dtype=dtype)
status, gpu = cudart.cudaMallocAsync(cpu.nbytes, self.stream)
assert status.value == 0
cudart.cudaMemcpyAsync(
gpu, cpu.ctypes.data, cpu.nbytes,
cudart.cudaMemcpyKind.cudaMemcpyHostToDevice, self.stream)
else:
cpu = self.out_info[i].cpu
gpu = self.out_info[i].gpu
outputs.append(cpu)
output_gpu_ptrs.append(gpu)
self.bindings[j] = gpu
self.context.execute_async_v2(self.bindings, self.stream)
cudart.cudaStreamSynchronize(self.stream)
for i, o in enumerate(output_gpu_ptrs):
cudart.cudaMemcpyAsync(
outputs[i].ctypes.data, o, outputs[i].nbytes,
cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost, self.stream)
return tuple(outputs) if len(outputs) > 1 else outputs[0]

@ -0,0 +1,147 @@
import os
import warnings
from collections import namedtuple
from pathlib import Path
from typing import List, Optional, Tuple, Union
import numpy as np
import pycuda.autoinit # noqa F401
import pycuda.driver as cuda
import tensorrt as trt
from numpy import ndarray
os.environ['CUDA_MODULE_LOADING'] = 'LAZY'
warnings.filterwarnings(action='ignore', category=DeprecationWarning)
class TRTEngine:
def __init__(self, weight: Union[str, Path]) -> None:
self.weight = Path(weight) if isinstance(weight, str) else weight
self.stream = cuda.Stream(0)
self.__init_engine()
self.__init_bindings()
self.__warm_up()
def __init_engine(self) -> None:
logger = trt.Logger(trt.Logger.WARNING)
trt.init_libnvinfer_plugins(logger, namespace='')
with trt.Runtime(logger) as runtime:
model = runtime.deserialize_cuda_engine(self.weight.read_bytes())
context = model.create_execution_context()
names = [model.get_binding_name(i) for i in range(model.num_bindings)]
self.num_bindings = model.num_bindings
self.bindings: List[int] = [0] * self.num_bindings
num_inputs, num_outputs = 0, 0
for i in range(model.num_bindings):
if model.binding_is_input(i):
num_inputs += 1
else:
num_outputs += 1
self.num_inputs = num_inputs
self.num_outputs = num_outputs
self.model = model
self.context = context
self.input_names = names[:num_inputs]
self.output_names = names[num_inputs:]
def __init_bindings(self) -> None:
dynamic = False
Tensor = namedtuple('Tensor', ('name', 'dtype', 'shape', 'cpu', 'gpu'))
inp_info = []
out_info = []
out_ptrs = []
for i, name in enumerate(self.input_names):
assert self.model.get_binding_name(i) == name
dtype = trt.nptype(self.model.get_binding_dtype(i))
shape = tuple(self.model.get_binding_shape(i))
if -1 in shape:
dynamic |= True
if not dynamic:
cpu = np.empty(shape, dtype)
gpu = cuda.mem_alloc(cpu.nbytes)
cuda.memcpy_htod_async(gpu, cpu, self.stream)
else:
cpu, gpu = np.empty(0), 0
inp_info.append(Tensor(name, dtype, shape, cpu, gpu))
for i, name in enumerate(self.output_names):
i += self.num_inputs
assert self.model.get_binding_name(i) == name
dtype = trt.nptype(self.model.get_binding_dtype(i))
shape = tuple(self.model.get_binding_shape(i))
if not dynamic:
cpu = np.empty(shape, dtype=dtype)
gpu = cuda.mem_alloc(cpu.nbytes)
cuda.memcpy_htod_async(gpu, cpu, self.stream)
out_ptrs.append(gpu)
else:
cpu, gpu = np.empty(0), 0
out_info.append(Tensor(name, dtype, shape, cpu, gpu))
self.is_dynamic = dynamic
self.inp_info = inp_info
self.out_info = out_info
self.out_ptrs = out_ptrs
def __warm_up(self) -> None:
if self.is_dynamic:
print('You engine has dynamic axes, please warm up by yourself !')
return
for _ in range(10):
inputs = []
for i in self.inp_info:
inputs.append(i.cpu)
self.__call__(inputs)
def set_profiler(self, profiler: Optional[trt.IProfiler]) -> None:
self.context.profiler = profiler \
if profiler is not None else trt.Profiler()
def __call__(self, *inputs) -> Union[Tuple, ndarray]:
assert len(inputs) == self.num_inputs
contiguous_inputs: List[ndarray] = [
np.ascontiguousarray(i) for i in inputs
]
for i in range(self.num_inputs):
if self.is_dynamic:
self.context.set_binding_shape(
i, tuple(contiguous_inputs[i].shape))
self.inp_info[i].gpu = cuda.mem_alloc(
contiguous_inputs[i].nbytes)
cuda.memcpy_htod_async(self.inp_info[i].gpu, contiguous_inputs[i],
self.stream)
self.bindings[i] = int(self.inp_info[i].gpu)
output_gpu_ptrs: List[int] = []
outputs: List[ndarray] = []
for i in range(self.num_outputs):
j = i + self.num_inputs
if self.is_dynamic:
shape = tuple(self.context.get_binding_shape(j))
dtype = self.out_info[i].dtype
cpu = np.empty(shape, dtype=dtype)
gpu = cuda.mem_alloc(contiguous_inputs[i].nbytes)
cuda.memcpy_htod_async(gpu, cpu, self.stream)
else:
cpu = self.out_info[i].cpu
gpu = self.out_info[i].gpu
outputs.append(cpu)
output_gpu_ptrs.append(gpu)
self.bindings[j] = int(gpu)
self.context.execute_async_v2(self.bindings, self.stream.handle)
self.stream.synchronize()
for i, o in enumerate(output_gpu_ptrs):
cuda.memcpy_dtoh_async(outputs[i], o, self.stream)
return tuple(outputs) if len(outputs) > 1 else outputs[0]

@ -6,3 +6,5 @@ torch
torchvision
ultralytics
# tensorrt
# cuda-python
# pycuda

Loading…
Cancel
Save