From ca8e47515d506014df4e8dc264b59489dcbbe5a0 Mon Sep 17 00:00:00 2001 From: triple-Mu Date: Thu, 9 Feb 2023 01:10:04 +0800 Subject: [PATCH] version 0.2.0 --- README.md | 40 +- csrc/detect/normal/include/yolov8.hpp | 1 - csrc/segment/{ => normal}/CMakeLists.txt | 0 csrc/segment/{ => normal}/include/common.hpp | 8 +- csrc/segment/normal/include/yolov8-seg.hpp | 573 ++++++++++++++++++ csrc/segment/normal/main.cpp | 178 ++++++ csrc/segment/simple/CMakeLists.txt | 60 ++ csrc/segment/simple/include/common.hpp | 157 +++++ .../{ => simple}/include/yolov8-seg.hpp | 6 +- csrc/segment/{ => simple}/main.cpp | 0 docs/Normal.md | 10 +- docs/Segment.md | 85 ++- infer-no-torch.py | 254 ++++++++ infer.py | 8 +- models/cudart_api.py | 160 +++++ models/pycuda_api.py | 147 +++++ requirements.txt | 2 + 17 files changed, 1664 insertions(+), 25 deletions(-) rename csrc/segment/{ => normal}/CMakeLists.txt (100%) rename csrc/segment/{ => normal}/include/common.hpp (96%) create mode 100644 csrc/segment/normal/include/yolov8-seg.hpp create mode 100644 csrc/segment/normal/main.cpp create mode 100644 csrc/segment/simple/CMakeLists.txt create mode 100644 csrc/segment/simple/include/common.hpp rename csrc/segment/{ => simple}/include/yolov8-seg.hpp (98%) rename csrc/segment/{ => simple}/main.cpp (100%) create mode 100644 infer-no-torch.py create mode 100644 models/cudart_api.py create mode 100644 models/pycuda_api.py diff --git a/README.md b/README.md index aec35a8..c93761d 100644 --- a/README.md +++ b/README.md @@ -166,6 +166,7 @@ python3 infer.py \ - `--engine` : The Engine you export. - `--imgs` : The images path you want to detect. - `--show` : Whether to show detection results. +- `--seg` : Whether to infer with segment model. - `--out-dir` : Where to save detection results images. It will not work when use `--show` flag. - `--device` : The CUDA deivce you use. - `--profile` : Profile the TensorRT engine. @@ -207,7 +208,6 @@ Please see more information in [`Segment.md`](docs/Segment.md) See more in [`README.md`](csrc/deepstream/README.md) - # Profile you engine If you want to profile the TensorRT engine: @@ -217,3 +217,41 @@ Usage: ``` shell python3 infer.py --engine yolov8s.engine --profile ``` + +# Refuse To Use PyTorch for model inference !!! + +If you need to break away from pytorch and use tensorrt inference, +you can get more information in [`infer-no-torch.py`](infer-no-torch.py), +the usage is the same as the pytorch version, but its performance is much worse. + +You can use `cuda-python` or `pycuda` for inference. +Please install by such command: + +```shell +pip install cuda-python3 +# or +pip install pycuda +``` + +Usage: + +#### Detection + +``` shell +python3 infer-no-torch.py \ +--engine yolov8s.engine \ +--imgs data \ +--show \ +--out-dir outputs \ +--method cudart +``` + +#### Description of all arguments + +- `--engine` : The Engine you export. +- `--imgs` : The images path you want to detect. +- `--show` : Whether to show detection results. +- `--seg` : Whether to infer with segment model. +- `--out-dir` : Where to save detection results images. It will not work when use `--show` flag. +- `--method` : Choose `cudart` or `pycuda`, default is `cudart`. +- `--profile` : Profile the TensorRT engine. diff --git a/csrc/detect/normal/include/yolov8.hpp b/csrc/detect/normal/include/yolov8.hpp index f676b8a..b6f57dd 100644 --- a/csrc/detect/normal/include/yolov8.hpp +++ b/csrc/detect/normal/include/yolov8.hpp @@ -372,7 +372,6 @@ void YOLOv8::postprocess( float score = *max_s_ptr; if (score > score_thres) { - std::cout << score << std::endl; float x = *bboxes_ptr++ - dw; float y = *bboxes_ptr++ - dh; float w = *bboxes_ptr++; diff --git a/csrc/segment/CMakeLists.txt b/csrc/segment/normal/CMakeLists.txt similarity index 100% rename from csrc/segment/CMakeLists.txt rename to csrc/segment/normal/CMakeLists.txt diff --git a/csrc/segment/include/common.hpp b/csrc/segment/normal/include/common.hpp similarity index 96% rename from csrc/segment/include/common.hpp rename to csrc/segment/normal/include/common.hpp index dc83775..f9a40c0 100644 --- a/csrc/segment/include/common.hpp +++ b/csrc/segment/normal/include/common.hpp @@ -1,9 +1,9 @@ // -// Created by ubuntu on 1/24/23. +// Created by ubuntu on 2/8/23. // -#ifndef SEGMENT_COMMON_HPP -#define SEGMENT_COMMON_HPP +#ifndef SEGMENT_NORMAL_COMMON_HPP +#define SEGMENT_NORMAL_COMMON_HPP #include "opencv2/opencv.hpp" #include #include @@ -154,4 +154,4 @@ namespace seg float width = 0; }; } -#endif //SEGMENT_COMMON_HPP +#endif //SEGMENT_NORMAL_COMMON_HPP diff --git a/csrc/segment/normal/include/yolov8-seg.hpp b/csrc/segment/normal/include/yolov8-seg.hpp new file mode 100644 index 0000000..fad53e6 --- /dev/null +++ b/csrc/segment/normal/include/yolov8-seg.hpp @@ -0,0 +1,573 @@ +// +// Created by ubuntu on 2/8/23. +// +#ifndef SEGMENT_NORMAL_YOLOV8_SEG_HPP +#define SEGMENT_NORMAL_YOLOV8_SEG_HPP +#include +#include "common.hpp" +#include "NvInferPlugin.h" + +using namespace seg; + +class YOLOv8_seg +{ +public: + explicit YOLOv8_seg(const std::string& engine_file_path); + ~YOLOv8_seg(); + + void make_pipe(bool warmup = true); + void copy_from_Mat(const cv::Mat& image); + void copy_from_Mat(const cv::Mat& image, cv::Size& size); + void letterbox( + const cv::Mat& image, + cv::Mat& out, + cv::Size& size + ); + void infer(); + void postprocess( + std::vector& objs, + float score_thres = 0.25f, + float iou_thres = 0.65f, + int topk = 100, + int seg_channels = 32, + int seg_h = 160, + int seg_w = 160 + ); + static void draw_objects( + const cv::Mat& image, + cv::Mat& res, + const std::vector& objs, + const std::vector& CLASS_NAMES, + const std::vector>& COLORS, + const std::vector>& MASK_COLORS + ); + int num_bindings; + int num_inputs = 0; + int num_outputs = 0; + std::vector input_bindings; + std::vector output_bindings; + std::vector host_ptrs; + std::vector device_ptrs; + + PreParam pparam; +private: + nvinfer1::ICudaEngine* engine = nullptr; + nvinfer1::IRuntime* runtime = nullptr; + nvinfer1::IExecutionContext* context = nullptr; + cudaStream_t stream = nullptr; + Logger gLogger{ nvinfer1::ILogger::Severity::kERROR }; + +}; + +YOLOv8_seg::YOLOv8_seg(const std::string& engine_file_path) +{ + std::ifstream file(engine_file_path, std::ios::binary); + assert(file.good()); + file.seekg(0, std::ios::end); + auto size = file.tellg(); + file.seekg(0, std::ios::beg); + char* trtModelStream = new char[size]; + assert(trtModelStream); + file.read(trtModelStream, size); + file.close(); + initLibNvInferPlugins(&this->gLogger, ""); + this->runtime = nvinfer1::createInferRuntime(this->gLogger); + assert(this->runtime != nullptr); + + this->engine = this->runtime->deserializeCudaEngine(trtModelStream, size); + assert(this->engine != nullptr); + + this->context = this->engine->createExecutionContext(); + + assert(this->context != nullptr); + cudaStreamCreate(&this->stream); + this->num_bindings = this->engine->getNbBindings(); + + for (int i = 0; i < this->num_bindings; ++i) + { + Binding binding; + nvinfer1::Dims dims; + nvinfer1::DataType dtype = this->engine->getBindingDataType(i); + std::string name = this->engine->getBindingName(i); + binding.name = name; + binding.dsize = type_to_size(dtype); + + bool IsInput = engine->bindingIsInput(i); + if (IsInput) + { + this->num_inputs += 1; + dims = this->engine->getProfileDimensions( + i, + 0, + nvinfer1::OptProfileSelector::kMAX); + binding.size = get_size_by_dims(dims); + binding.dims = dims; + this->input_bindings.push_back(binding); + // set max opt shape + this->context->setBindingDimensions(i, dims); + + } + else + { + dims = this->context->getBindingDimensions(i); + binding.size = get_size_by_dims(dims); + binding.dims = dims; + this->output_bindings.push_back(binding); + this->num_outputs += 1; + } + } + +} + +YOLOv8_seg::~YOLOv8_seg() +{ + this->context->destroy(); + this->engine->destroy(); + this->runtime->destroy(); + cudaStreamDestroy(this->stream); + for (auto& ptr : this->device_ptrs) + { + CHECK(cudaFree(ptr)); + } + + for (auto& ptr : this->host_ptrs) + { + CHECK(cudaFreeHost(ptr)); + } +} + +void YOLOv8_seg::make_pipe(bool warmup) +{ + + for (auto& bindings : this->input_bindings) + { + void* d_ptr; + CHECK(cudaMallocAsync( + &d_ptr, + bindings.size * bindings.dsize, + this->stream) + ); + this->device_ptrs.push_back(d_ptr); + } + + for (auto& bindings : this->output_bindings) + { + void* d_ptr, * h_ptr; + size_t size = bindings.size * bindings.dsize; + CHECK(cudaMallocAsync( + &d_ptr, + size, + this->stream) + ); + CHECK(cudaHostAlloc( + &h_ptr, + size, + 0) + ); + this->device_ptrs.push_back(d_ptr); + this->host_ptrs.push_back(h_ptr); + } + + if (warmup) + { + for (int i = 0; i < 10; i++) + { + for (auto& bindings : this->input_bindings) + { + size_t size = bindings.size * bindings.dsize; + void* h_ptr = malloc(size); + memset(h_ptr, 0, size); + CHECK(cudaMemcpyAsync( + this->device_ptrs[0], + h_ptr, + size, + cudaMemcpyHostToDevice, + this->stream) + ); + free(h_ptr); + } + this->infer(); + } + printf("model warmup 10 times\n"); + + } +} + +void YOLOv8_seg::letterbox( + const cv::Mat& image, + cv::Mat& out, + cv::Size& size +) +{ + const float inp_h = size.height; + const float inp_w = size.width; + float height = image.rows; + float width = image.cols; + + float r = std::min(inp_h / height, inp_w / width); + int padw = std::round(width * r); + int padh = std::round(height * r); + + cv::Mat tmp; + if ((int)width != padw || (int)height != padh) + { + cv::resize( + image, + tmp, + cv::Size(padw, padh) + ); + } + else + { + tmp = image.clone(); + } + + float dw = inp_w - padw; + float dh = inp_h - padh; + + dw /= 2.0f; + dh /= 2.0f; + int top = int(std::round(dh - 0.1f)); + int bottom = int(std::round(dh + 0.1f)); + int left = int(std::round(dw - 0.1f)); + int right = int(std::round(dw + 0.1f)); + + cv::copyMakeBorder( + tmp, + tmp, + top, + bottom, + left, + right, + cv::BORDER_CONSTANT, + { 114, 114, 114 } + ); + + cv::dnn::blobFromImage(tmp, + out, + 1 / 255.f, + cv::Size(), + cv::Scalar(0, 0, 0), + true, + false, + CV_32F + ); + this->pparam.ratio = 1 / r; + this->pparam.dw = dw; + this->pparam.dh = dh; + this->pparam.height = height; + this->pparam.width = width;; +} + +void YOLOv8_seg::copy_from_Mat(const cv::Mat& image) +{ + cv::Mat nchw; + auto& in_binding = this->input_bindings[0]; + auto width = in_binding.dims.d[3]; + auto height = in_binding.dims.d[2]; + cv::Size size{ width, height }; + this->letterbox( + image, + nchw, + size + ); + + this->context->setBindingDimensions( + 0, + nvinfer1::Dims + { + 4, + { 1, 3, height, width } + } + ); + + CHECK(cudaMemcpyAsync( + this->device_ptrs[0], + nchw.ptr(), + nchw.total() * nchw.elemSize(), + cudaMemcpyHostToDevice, + this->stream) + ); +} + +void YOLOv8_seg::copy_from_Mat(const cv::Mat& image, cv::Size& size) +{ + cv::Mat nchw; + this->letterbox( + image, + nchw, + size + ); + this->context->setBindingDimensions( + 0, + nvinfer1::Dims + { 4, + { 1, 3, size.height, size.width } + } + ); + CHECK(cudaMemcpyAsync( + this->device_ptrs[0], + nchw.ptr(), + nchw.total() * nchw.elemSize(), + cudaMemcpyHostToDevice, + this->stream) + ); +} + +void YOLOv8_seg::infer() +{ + + this->context->enqueueV2( + this->device_ptrs.data(), + this->stream, + nullptr + ); + for (int i = 0; i < this->num_outputs; i++) + { + size_t osize = this->output_bindings[i].size * this->output_bindings[i].dsize; + CHECK(cudaMemcpyAsync(this->host_ptrs[i], + this->device_ptrs[i + this->num_inputs], + osize, + cudaMemcpyDeviceToHost, + this->stream) + ); + + } + cudaStreamSynchronize(this->stream); + +} + +void YOLOv8_seg::postprocess(std::vector& objs, + float score_thres, + float iou_thres, + int topk, + int seg_channels, + int seg_h, + int seg_w +) +{ + objs.clear(); + auto input_h = this->input_bindings[0].dims.d[2]; + auto input_w = this->input_bindings[0].dims.d[3]; + int num_channels, num_anchors, num_classes; + bool flag = false; + int bid; + int bcnt = -1; + for (auto& o : this->output_bindings) + { + bcnt += 1; + if (o.dims.nbDims == 3) + { + num_channels = o.dims.d[1]; + num_anchors = o.dims.d[2]; + flag = true; + bid = bcnt; + } + } + assert(flag); + num_classes = num_channels - seg_channels - 4; + + auto& dw = this->pparam.dw; + auto& dh = this->pparam.dh; + auto& width = this->pparam.width; + auto& height = this->pparam.height; + auto& ratio = this->pparam.ratio; + + cv::Mat output = cv::Mat(num_channels, num_anchors, CV_32F, + static_cast(this->host_ptrs[bid])); + output = output.t(); + + cv::Mat protos = cv::Mat(seg_channels, seg_h * seg_w, CV_32F, + static_cast(this->host_ptrs[1 - bid])); + + std::vector labels; + std::vector scores; + std::vector bboxes; + std::vector mask_confs; + std::vector indices; + + for (int i = 0; i < num_anchors; i++) + { + auto row_ptr = output.row(i).ptr(); + auto bboxes_ptr = row_ptr; + auto scores_ptr = row_ptr + 4; + auto mask_confs_ptr = row_ptr + 4 + num_classes; + auto max_s_ptr = std::max_element(scores_ptr, scores_ptr + num_classes); + float score = *max_s_ptr; + if (score > score_thres) + { + float x = *bboxes_ptr++ - dw; + float y = *bboxes_ptr++ - dh; + float w = *bboxes_ptr++; + float h = *bboxes_ptr; + + float x0 = clamp((x - 0.5f * w) * ratio, 0.f, width); + float y0 = clamp((y - 0.5f * h) * ratio, 0.f, height); + float x1 = clamp((x + 0.5f * w) * ratio, 0.f, width); + float y1 = clamp((y + 0.5f * h) * ratio, 0.f, height); + + int label = max_s_ptr - scores_ptr; + cv::Rect_ bbox; + bbox.x = x0; + bbox.y = y0; + bbox.width = x1 - x0; + bbox.height = y1 - y0; + + cv::Mat mask_conf = cv::Mat(1, seg_channels, CV_32F, mask_confs_ptr); + + bboxes.push_back(bbox); + labels.push_back(label); + scores.push_back(score); + mask_confs.push_back(mask_conf); + } + } + +#if defined(BATCHED_NMS) + cv::dnn::NMSBoxesBatched( + bboxes, + scores, + labels, + score_thres, + iou_thres, + indices + ); +#else + cv::dnn::NMSBoxes( + bboxes, + scores, + score_thres, + iou_thres, + indices + ); +#endif + + cv::Mat masks; + int cnt = 0; + for (auto& i : indices) + { + if (cnt >= topk) + { + break; + } + cv::Rect tmp = bboxes[i]; + Object obj; + obj.label = labels[i]; + obj.rect = tmp; + obj.prob = scores[i]; + masks.push_back(mask_confs[i]); + objs.push_back(obj); + cnt += 1; + } + + cv::Mat matmulRes = (masks * protos).t(); + cv::Mat maskMat = matmulRes.reshape(indices.size(), { seg_w, seg_h }); + + std::vector maskChannels; + cv::split(maskMat, maskChannels); + int scale_dw = dw / input_w * seg_w; + int scale_dh = dh / input_h * seg_h; + + cv::Rect roi( + scale_dw, + scale_dh, + seg_w - 2 * scale_dw, + seg_h - 2 * scale_dh); + + for (int i = 0; i < indices.size(); i++) + { + cv::Mat dest, mask; + cv::exp(-maskChannels[i], dest); + dest = 1.0 / (1.0 + dest); + dest = dest(roi); + cv::resize( + dest, + mask, + cv::Size((int)width, (int)height), + cv::INTER_LINEAR + ); + objs[i].boxMask = mask(objs[i].rect) > 0.5f; + } + +} + +void YOLOv8_seg::draw_objects(const cv::Mat& image, + cv::Mat& res, + const std::vector& objs, + const std::vector& CLASS_NAMES, + const std::vector>& COLORS, + const std::vector>& MASK_COLORS +) +{ + res = image.clone(); + cv::Mat mask = image.clone(); + for (auto& obj : objs) + { + int idx = obj.label; + cv::Scalar color = cv::Scalar( + COLORS[idx][0], + COLORS[idx][1], + COLORS[idx][2] + ); + cv::Scalar mask_color = cv::Scalar( + MASK_COLORS[idx % 20][0], + MASK_COLORS[idx % 20][1], + MASK_COLORS[idx % 20][2] + ); + cv::rectangle( + res, + obj.rect, + color, + 2 + ); + + char text[256]; + sprintf( + text, + "%s %.1f%%", + CLASS_NAMES[idx].c_str(), + obj.prob * 100 + ); + mask(obj.rect).setTo(mask_color, obj.boxMask); + + int baseLine = 0; + cv::Size label_size = cv::getTextSize( + text, + cv::FONT_HERSHEY_SIMPLEX, + 0.4, + 1, + &baseLine + ); + + int x = (int)obj.rect.x; + int y = (int)obj.rect.y + 1; + + if (y > res.rows) + y = res.rows; + + cv::rectangle( + res, + cv::Rect(x, y, label_size.width, label_size.height + baseLine), + { 0, 0, 255 }, + -1 + ); + + cv::putText( + res, + text, + cv::Point(x, y + label_size.height), + cv::FONT_HERSHEY_SIMPLEX, + 0.4, + { 255, 255, 255 }, + 1 + ); + } + cv::addWeighted( + res, + 0.5, + mask, + 0.8, + 1, + res + ); +} +#endif //SEGMENT_NORMAL_YOLOV8_SEG_HPP diff --git a/csrc/segment/normal/main.cpp b/csrc/segment/normal/main.cpp new file mode 100644 index 0000000..bc67e57 --- /dev/null +++ b/csrc/segment/normal/main.cpp @@ -0,0 +1,178 @@ +// +// Created by ubuntu on 2/8/23. +// +#include "chrono" +#include "yolov8-seg.hpp" +#include "opencv2/opencv.hpp" + +const std::vector CLASS_NAMES = { + "person", "bicycle", "car", "motorcycle", "airplane", "bus", + "train", "truck", "boat", "traffic light", "fire hydrant", + "stop sign", "parking meter", "bench", "bird", "cat", + "dog", "horse", "sheep", "cow", "elephant", + "bear", "zebra", "giraffe", "backpack", "umbrella", + "handbag", "tie", "suitcase", "frisbee", "skis", + "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", + "skateboard", "surfboard", "tennis racket", "bottle", "wine glass", + "cup", "fork", "knife", "spoon", "bowl", + "banana", "apple", "sandwich", "orange", "broccoli", + "carrot", "hot dog", "pizza", "donut", "cake", + "chair", "couch", "potted plant", "bed", "dining table", + "toilet", "tv", "laptop", "mouse", "remote", + "keyboard", "cell phone", "microwave", "oven", + "toaster", "sink", "refrigerator", "book", "clock", "vase", + "scissors", "teddy bear", "hair drier", "toothbrush" }; + +const std::vector> COLORS = { + { 0, 114, 189 }, { 217, 83, 25 }, { 237, 177, 32 }, + { 126, 47, 142 }, { 119, 172, 48 }, { 77, 190, 238 }, + { 162, 20, 47 }, { 76, 76, 76 }, { 153, 153, 153 }, + { 255, 0, 0 }, { 255, 128, 0 }, { 191, 191, 0 }, + { 0, 255, 0 }, { 0, 0, 255 }, { 170, 0, 255 }, + { 85, 85, 0 }, { 85, 170, 0 }, { 85, 255, 0 }, + { 170, 85, 0 }, { 170, 170, 0 }, { 170, 255, 0 }, + { 255, 85, 0 }, { 255, 170, 0 }, { 255, 255, 0 }, + { 0, 85, 128 }, { 0, 170, 128 }, { 0, 255, 128 }, + { 85, 0, 128 }, { 85, 85, 128 }, { 85, 170, 128 }, + { 85, 255, 128 }, { 170, 0, 128 }, { 170, 85, 128 }, + { 170, 170, 128 }, { 170, 255, 128 }, { 255, 0, 128 }, + { 255, 85, 128 }, { 255, 170, 128 }, { 255, 255, 128 }, + { 0, 85, 255 }, { 0, 170, 255 }, { 0, 255, 255 }, + { 85, 0, 255 }, { 85, 85, 255 }, { 85, 170, 255 }, + { 85, 255, 255 }, { 170, 0, 255 }, { 170, 85, 255 }, + { 170, 170, 255 }, { 170, 255, 255 }, { 255, 0, 255 }, + { 255, 85, 255 }, { 255, 170, 255 }, { 85, 0, 0 }, + { 128, 0, 0 }, { 170, 0, 0 }, { 212, 0, 0 }, + { 255, 0, 0 }, { 0, 43, 0 }, { 0, 85, 0 }, + { 0, 128, 0 }, { 0, 170, 0 }, { 0, 212, 0 }, + { 0, 255, 0 }, { 0, 0, 43 }, { 0, 0, 85 }, + { 0, 0, 128 }, { 0, 0, 170 }, { 0, 0, 212 }, + { 0, 0, 255 }, { 0, 0, 0 }, { 36, 36, 36 }, + { 73, 73, 73 }, { 109, 109, 109 }, { 146, 146, 146 }, + { 182, 182, 182 }, { 219, 219, 219 }, { 0, 114, 189 }, + { 80, 183, 189 }, { 128, 128, 0 } +}; + +const std::vector> MASK_COLORS = { + { 255, 56, 56 }, { 255, 157, 151 }, { 255, 112, 31 }, + { 255, 178, 29 }, { 207, 210, 49 }, { 72, 249, 10 }, + { 146, 204, 23 }, { 61, 219, 134 }, { 26, 147, 52 }, + { 0, 212, 187 }, { 44, 153, 168 }, { 0, 194, 255 }, + { 52, 69, 147 }, { 100, 115, 255 }, { 0, 24, 236 }, + { 132, 56, 255 }, { 82, 0, 133 }, { 203, 56, 255 }, + { 255, 149, 200 }, { 255, 55, 199 } +}; + +int main(int argc, char** argv) +{ + // cuda:0 + cudaSetDevice(0); + + const std::string engine_file_path{ argv[1] }; + const std::string path{ argv[2] }; + + std::vector imagePathList; + bool isVideo{ false }; + + assert(argc == 3); + + auto yolov8 = new YOLOv8_seg(engine_file_path); + yolov8->make_pipe(true); + + if (IsFile(path)) + { + std::string suffix = path.substr(path.find_last_of('.') + 1); + if ( + suffix == "jpg" || + suffix == "jpeg" || + suffix == "png" + ) + { + imagePathList.push_back(path); + } + else if ( + suffix == "mp4" || + suffix == "avi" || + suffix == "m4v" || + suffix == "mpeg" || + suffix == "mov" || + suffix == "mkv" + ) + { + isVideo = true; + } + else + { + printf("suffix %s is wrong !!!\n", suffix.c_str()); + std::abort(); + } + } + else if (IsFolder(path)) + { + cv::glob(path + "/*.jpg", imagePathList); + } + + cv::Mat res, image; + cv::Size size = cv::Size{ 640, 640 }; + int topk = 100; + int seg_h = 160; + int seg_w = 160; + int seg_channels = 32; + float score_thres = 0.25f; + float iou_thres = 0.65f; + + std::vector objs; + + cv::namedWindow("result", cv::WINDOW_AUTOSIZE); + + if (isVideo) + { + cv::VideoCapture cap(path); + + if (!cap.isOpened()) + { + printf("can not open %s\n", path.c_str()); + return -1; + } + while (cap.read(image)) + { + objs.clear(); + yolov8->copy_from_Mat(image, size); + auto start = std::chrono::system_clock::now(); + yolov8->infer(); + auto end = std::chrono::system_clock::now(); + yolov8->postprocess(objs, score_thres, iou_thres, topk, seg_channels, seg_h, seg_w); + yolov8->draw_objects(image, res, objs, CLASS_NAMES, COLORS, MASK_COLORS); + auto tc = (double) + std::chrono::duration_cast(end - start).count() / 1000.; + printf("cost %2.4lf ms\n", tc); + cv::imshow("result", res); + if (cv::waitKey(10) == 'q') + { + break; + } + } + } + else + { + for (auto& path : imagePathList) + { + objs.clear(); + image = cv::imread(path); + yolov8->copy_from_Mat(image, size); + auto start = std::chrono::system_clock::now(); + yolov8->infer(); + auto end = std::chrono::system_clock::now(); + yolov8->postprocess(objs, score_thres, iou_thres, topk, seg_channels, seg_h, seg_w); + yolov8->draw_objects(image, res, objs, CLASS_NAMES, COLORS, MASK_COLORS); + auto tc = (double) + std::chrono::duration_cast(end - start).count() / 1000.; + printf("cost %2.4lf ms\n", tc); + cv::imshow("result", res); + cv::waitKey(0); + } + } + cv::destroyAllWindows(); + delete yolov8; + return 0; +} diff --git a/csrc/segment/simple/CMakeLists.txt b/csrc/segment/simple/CMakeLists.txt new file mode 100644 index 0000000..9a02851 --- /dev/null +++ b/csrc/segment/simple/CMakeLists.txt @@ -0,0 +1,60 @@ +cmake_minimum_required(VERSION 2.8.12) + +set(CMAKE_CUDA_ARCHITECTURES 60 61 62 70 72 75 86) +set(CMAKE_CUDA_COMPILER /usr/local/cuda/bin/nvcc) + +project(yolov8-seg LANGUAGES CXX CUDA) + +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++14 -O3 -g") +set(CMAKE_CXX_STANDARD 14) +set(CMAKE_BUILD_TYPE Release) +option(CUDA_USE_STATIC_CUDA_RUNTIME OFF) + +# CUDA +find_package(CUDA REQUIRED) +message(STATUS "CUDA Libs: \n${CUDA_LIBRARIES}\n") +message(STATUS "CUDA Headers: \n${CUDA_INCLUDE_DIRS}\n") + +# OpenCV +find_package(OpenCV REQUIRED) +message(STATUS "OpenCV Libs: \n${OpenCV_LIBS}\n") +message(STATUS "OpenCV Libraries: \n${OpenCV_LIBRARIES}\n") +message(STATUS "OpenCV Headers: \n${OpenCV_INCLUDE_DIRS}\n") + +# TensorRT +set(TensorRT_INCLUDE_DIRS /usr/include/x86_64-linux-gnu) +set(TensorRT_LIBRARIES /usr/lib/x86_64-linux-gnu) + + +message(STATUS "TensorRT Libs: \n${TensorRT_LIBRARIES}\n") +message(STATUS "TensorRT Headers: \n${TensorRT_INCLUDE_DIRS}\n") + +list(APPEND INCLUDE_DIRS + ${CUDA_INCLUDE_DIRS} + ${OpenCV_INCLUDE_DIRS} + ${TensorRT_INCLUDE_DIRS} + ./include + ) + +list(APPEND ALL_LIBS + ${CUDA_LIBRARIES} + ${OpenCV_LIBRARIES} + ${TensorRT_LIBRARIES} + ) + +include_directories(${INCLUDE_DIRS}) + +add_executable(${PROJECT_NAME} + main.cpp + include/yolov8-seg.hpp + include/common.hpp + ) + +target_link_directories(${PROJECT_NAME} PUBLIC ${ALL_LIBS}) +target_link_libraries(${PROJECT_NAME} PRIVATE nvinfer nvinfer_plugin cudart ${OpenCV_LIBS}) + + +if(${OpenCV_VERSION} VERSION_GREATER_EQUAL 4.7.0) + message(STATUS "Build with -DBATCHED_NMS") + add_definitions(-DBATCHED_NMS) +endif() diff --git a/csrc/segment/simple/include/common.hpp b/csrc/segment/simple/include/common.hpp new file mode 100644 index 0000000..6dc26d9 --- /dev/null +++ b/csrc/segment/simple/include/common.hpp @@ -0,0 +1,157 @@ +// +// Created by ubuntu on 2/9/23. +// + +#ifndef SEGMENT_SIMPLE_COMMON_HPP +#define SEGMENT_SIMPLE_COMMON_HPP +#include "opencv2/opencv.hpp" +#include +#include +#include "NvInfer.h" + +#define CHECK(call) \ +do \ +{ \ + const cudaError_t error_code = call; \ + if (error_code != cudaSuccess) \ + { \ + printf("CUDA Error:\n"); \ + printf(" File: %s\n", __FILE__); \ + printf(" Line: %d\n", __LINE__); \ + printf(" Error code: %d\n", error_code); \ + printf(" Error text: %s\n", \ + cudaGetErrorString(error_code)); \ + exit(1); \ + } \ +} while (0) + +class Logger : public nvinfer1::ILogger +{ +public: + nvinfer1::ILogger::Severity reportableSeverity; + + explicit Logger(nvinfer1::ILogger::Severity severity = nvinfer1::ILogger::Severity::kINFO) : + reportableSeverity(severity) + { + } + + void log(nvinfer1::ILogger::Severity severity, const char* msg) noexcept override + { + if (severity > reportableSeverity) + { + return; + } + switch (severity) + { + case nvinfer1::ILogger::Severity::kINTERNAL_ERROR: + std::cerr << "INTERNAL_ERROR: "; + break; + case nvinfer1::ILogger::Severity::kERROR: + std::cerr << "ERROR: "; + break; + case nvinfer1::ILogger::Severity::kWARNING: + std::cerr << "WARNING: "; + break; + case nvinfer1::ILogger::Severity::kINFO: + std::cerr << "INFO: "; + break; + default: + std::cerr << "VERBOSE: "; + break; + } + std::cerr << msg << std::endl; + } +}; + +inline int get_size_by_dims(const nvinfer1::Dims& dims) +{ + int size = 1; + for (int i = 0; i < dims.nbDims; i++) + { + size *= dims.d[i]; + } + return size; +} + +inline int type_to_size(const nvinfer1::DataType& dataType) +{ + switch (dataType) + { + case nvinfer1::DataType::kFLOAT: + return 4; + case nvinfer1::DataType::kHALF: + return 2; + case nvinfer1::DataType::kINT32: + return 4; + case nvinfer1::DataType::kINT8: + return 1; + case nvinfer1::DataType::kBOOL: + return 1; + default: + return 4; + } +} + +inline static float clamp(float val, float min, float max) +{ + return val > min ? (val < max ? val : max) : min; +} + +inline bool IsPathExist(const std::string& path) +{ + if (access(path.c_str(), 0) == F_OK) + { + return true; + } + return false; +} + +inline bool IsFile(const std::string& path) +{ + if (!IsPathExist(path)) + { + printf("%s:%d %s not exist\n", __FILE__, __LINE__, path.c_str()); + return false; + } + struct stat buffer; + return (stat(path.c_str(), &buffer) == 0 && S_ISREG(buffer.st_mode)); +} + +inline bool IsFolder(const std::string& path) +{ + if (!IsPathExist(path)) + { + return false; + } + struct stat buffer; + return (stat(path.c_str(), &buffer) == 0 && S_ISDIR(buffer.st_mode)); +} + +namespace seg +{ + struct Binding + { + size_t size = 1; + size_t dsize = 1; + nvinfer1::Dims dims; + std::string name; + }; + + struct Object + { + cv::Rect_ rect; + int label = 0; + float prob = 0.0; + cv::Mat boxMask; + }; + + struct PreParam + { + float ratio = 1.0f; + float dw = 0.0f; + float dh = 0.0f; + float height = 0; + float width = 0; + }; +} +#endif //SEGMENT_SIMPLE_COMMON_HPP diff --git a/csrc/segment/include/yolov8-seg.hpp b/csrc/segment/simple/include/yolov8-seg.hpp similarity index 98% rename from csrc/segment/include/yolov8-seg.hpp rename to csrc/segment/simple/include/yolov8-seg.hpp index d9e4ce4..32400b7 100644 --- a/csrc/segment/include/yolov8-seg.hpp +++ b/csrc/segment/simple/include/yolov8-seg.hpp @@ -1,8 +1,8 @@ // // Created by ubuntu on 1/24/23. // -#ifndef SEGMENT_YOLOV8_SEG_HPP -#define SEGMENT_YOLOV8_SEG_HPP +#ifndef SEGMENT_SIMPLE_YOLOV8_SEG_HPP +#define SEGMENT_SIMPLE_YOLOV8_SEG_HPP #include #include "common.hpp" #include "NvInferPlugin.h" @@ -542,4 +542,4 @@ void YOLOv8_seg::draw_objects(const cv::Mat& image, res ); } -#endif //SEGMENT_YOLOV8_SEG_HPP +#endif //SEGMENT_SIMPLE_YOLOV8_SEG_HPP diff --git a/csrc/segment/main.cpp b/csrc/segment/simple/main.cpp similarity index 100% rename from csrc/segment/main.cpp rename to csrc/segment/simple/main.cpp diff --git a/docs/Normal.md b/docs/Normal.md index 2b654ff..1b9eede 100644 --- a/docs/Normal.md +++ b/docs/Normal.md @@ -11,20 +11,22 @@ Usage: from ultralytics import YOLO # Load a model -model = YOLO("yolov8n.pt") # load a pretrained model (recommended for training) +model = YOLO("yolov8s.pt") # load a pretrained model (recommended for training) success = model.export(format="engine", device=0) # export the model to engine format assert success ``` -After executing the above script, you will get an engine named `yolov8n.engine` . +After executing the above script, you will get an engine named `yolov8s.engine` . ### 2. CLI tools +Usage: + ```shell -yolo export model=yolov8n.pt format=engine device=0 +yolo export model=yolov8s.pt format=engine device=0 ``` -After executing the above command, you will get an engine named `yolov8n.engine` too. +After executing the above command, you will get an engine named `yolov8s.engine` too. ## Inference with c++ diff --git a/docs/Segment.md b/docs/Segment.md index 8aef1de..7f32415 100644 --- a/docs/Segment.md +++ b/docs/Segment.md @@ -1,15 +1,13 @@ # YOLOv8-seg Model with TensorRT -Instance segmentation models are currently experimental. - -Our conversion route is : +The yolov8-seg model conversion route is : YOLOv8 PyTorch model -> ONNX -> TensorRT Engine ***Notice !!!*** We don't support TensorRT API building !!! -# Export Your Own ONNX model +# Export Modified ONNX model -You can export your onnx model by `ultralytics` API. +You can export your onnx model by `ultralytics` API and the onnx is also modify by this repo. ``` shell python3 export_seg.py \ @@ -96,11 +94,11 @@ python3 infer.py \ ## Infer with C++ -You can infer segment engine with c++ in [`csrc/segment`](../csrc/segment) . +You can infer segment engine with c++ in [`csrc/segment/simple`](../csrc/segment/simple) . ### Build: -Please set you own librarys in [`CMakeLists.txt`](../csrc/segment/CMakeLists.txt) and modify you own config in [`main.cpp`](../csrc/segment/main.cpp) such as `CLASS_NAMES`, `COLORS`, `MASK_COLORS` and postprocess parameters . +Please set you own librarys in [`CMakeLists.txt`](../csrc/segment/simple/CMakeLists.txt) and modify you own config in [`main.cpp`](../csrc/segment/simple/main.cpp) such as `CLASS_NAMES`, `COLORS`, `MASK_COLORS` and postprocess parameters . ```c++ int topk = 100; @@ -113,7 +111,7 @@ float iou_thres = 0.65f; ``` shell export root=${PWD} -cd src/segment +cd src/segment/simple mkdir build cmake .. make @@ -138,3 +136,74 @@ Usage: # infer video ./yolov8-seg yolov8s-seg.engine data/test.mp4 # the video path ``` + +# Export Orin ONNX model by ultralytics + +You can leave this repo and use the original `ultralytics` repo for onnx export. + +### 1. Python script + +Usage: + +```python +from ultralytics import YOLO + +# Load a model +model = YOLO("yolov8s-seg.pt") # load a pretrained model (recommended for training) +success = model.export(format="engine", device=0) # export the model to engine format +assert success +``` + +After executing the above script, you will get an engine named `yolov8s-seg.engine` . + +### 2. CLI tools + +Usage: + +```shell +yolo export model=yolov8s.pt format=engine device=0 +``` + +After executing the above command, you will get an engine named `yolov8s-seg.engine` too. + +## Inference with c++ + +You can infer with c++ in [`csrc/segment/normal`](../csrc/segment/normal) . + +### Build: + +Please set you own librarys in [`CMakeLists.txt`](../csrc/segment/normal/CMakeLists.txt) and modify `CLASS_NAMES` and `COLORS` in [`main.cpp`](../csrc/segment/normal/main.cpp). + +Besides, you can modify the postprocess parameters such as `num_labels` and `score_thres` and `iou_thres` and `topk` in [`main.cpp`](../csrc/segment/normal/main.cpp). + +```c++ +int topk = 100; +int seg_h = 160; // yolov8 model proto height +int seg_w = 160; // yolov8 model proto width +int seg_channels = 32; // yolov8 model proto channels +float score_thres = 0.25f; +float iou_thres = 0.65f; +``` + +And build: + +``` shell +export root=${PWD} +cd src/segment/normal +mkdir build +cmake .. +make +mv yolov8-seg ${root} +cd ${root} +``` + +Usage: + +``` shell +# infer image +./yolov8-seg yolov8s-seg.engine data/bus.jpg +# infer images +./yolov8-seg yolov8s-seg.engine data +# infer video +./yolov8-seg yolov8s-seg.engine data/test.mp4 # the video path +``` diff --git a/infer-no-torch.py b/infer-no-torch.py new file mode 100644 index 0000000..8c0fe77 --- /dev/null +++ b/infer-no-torch.py @@ -0,0 +1,254 @@ +import argparse +import os +import random +from pathlib import Path +from typing import List, Tuple, Union + +import cv2 +import numpy as np +from numpy import ndarray + +os.environ['CUDA_MODULE_LOADING'] = 'LAZY' + +random.seed(0) + +SUFFIXS = ('.bmp', '.dng', '.jpeg', '.jpg', '.mpo', '.png', '.tif', '.tiff', + '.webp', '.pfm') +CLASSES = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', + 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', + 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', + 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', + 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', + 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', + 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', + 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', + 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', + 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', + 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', + 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', + 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', + 'scissors', 'teddy bear', 'hair drier', 'toothbrush') + +COLORS = { + cls: [random.randint(0, 255) for _ in range(3)] + for i, cls in enumerate(CLASSES) +} + +# the same as yolov8 +MASK_COLORS = np.array([(255, 56, 56), (255, 157, 151), (255, 112, 31), + (255, 178, 29), (207, 210, 49), (72, 249, 10), + (146, 204, 23), (61, 219, 134), (26, 147, 52), + (0, 212, 187), (44, 153, 168), (0, 194, 255), + (52, 69, 147), (100, 115, 255), (0, 24, 236), + (132, 56, 255), (82, 0, 133), (203, 56, 255), + (255, 149, 200), (255, 55, 199)], + dtype=np.float32) / 255. + +ALPHA = 0.5 + + +def letterbox( + im: ndarray, + new_shape: Union[Tuple, List] = (640, 640), + color: Union[Tuple, List] = (114, 114, 114) +) -> Tuple[ndarray, float, Tuple[float, float]]: + # Resize and pad image while meeting stride-multiple constraints + shape = im.shape[:2] # current shape [height, width] + if isinstance(new_shape, int): + new_shape = (new_shape, new_shape) + + # Scale ratio (new / old) + r = min(new_shape[0] / shape[0], new_shape[1] / shape[1]) + + # Compute padding + new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r)) + dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[ + 1] # wh padding + + dw /= 2 # divide padding into 2 sides + dh /= 2 + + if shape[::-1] != new_unpad: # resize + im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR) + top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1)) + left, right = int(round(dw - 0.1)), int(round(dw + 0.1)) + im = cv2.copyMakeBorder(im, + top, + bottom, + left, + right, + cv2.BORDER_CONSTANT, + value=color) # add border + return im, r, (dw, dh) + + +def blob(im: ndarray) -> Tuple[ndarray, ndarray]: + seg = im.astype(np.float32) / 255 + im = im.transpose([2, 0, 1]) + im = im[np.newaxis, ...] + im = np.ascontiguousarray(im).astype(np.float32) / 255 + return im, seg + + +def main(args): + if args.method == 'cudart': + from models.cudart_api import TRTEngine + elif args.method == 'pycuda': + from models.pycuda_api import TRTEngine + else: + raise NotImplementedError + + Engine = TRTEngine(args.engine) + H, W = Engine.inp_info[0].shape[-2:] + + images_path = Path(args.imgs) + assert images_path.exists() + save_path = Path(args.out_dir) + + if images_path.is_dir(): + images = [ + i.absolute() for i in images_path.iterdir() if i.suffix in SUFFIXS + ] + else: + assert images_path.suffix in SUFFIXS + images = [images_path.absolute()] + + if not args.show and not save_path.exists(): + save_path.mkdir(parents=True, exist_ok=True) + + for image in images: + save_image = save_path / image.name + bgr = cv2.imread(str(image)) + draw = bgr.copy() + bgr, ratio, dwdh = letterbox(bgr, (W, H)) + dw, dh = int(dwdh[0]), int(dwdh[1]) + rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB) + tensor, seg_img = blob(rgb) + dwdh = np.array(dwdh * 2, dtype=np.float32) + tensor = np.ascontiguousarray(tensor) + data = Engine(tensor) + + if args.seg: + seg_img = seg_img[dh:H - dh, dw:W - dw, [2, 1, 0]] + bboxes, scores, labels, masks = seg_postprocess( + data, bgr.shape[:2], args.conf_thres, args.iou_thres) + mask, mask_color = [m[:, dh:H - dh, dw:W - dw, :] for m in masks] + inv_alph_masks = (1 - mask * 0.5).cumprod(0) + mcs = (mask_color * inv_alph_masks).sum(0) * 2 + seg_img = (seg_img * inv_alph_masks[-1] + mcs) * 255 + draw = cv2.resize(seg_img.astype(np.uint8), draw.shape[:2][::-1]) + else: + bboxes, scores, labels = det_postprocess(data) + + bboxes -= dwdh + bboxes /= ratio + + for (bbox, score, label) in zip(bboxes, scores, labels): + bbox = bbox.round().astype(np.int32).tolist() + cls_id = int(label) + cls = CLASSES[cls_id] + color = COLORS[cls] + cv2.rectangle(draw, bbox[:2], bbox[2:], color, 2) + cv2.putText(draw, + f'{cls}:{score:.3f}', (bbox[0], bbox[1] - 2), + cv2.FONT_HERSHEY_SIMPLEX, + 0.75, [225, 255, 255], + thickness=2) + if args.show: + cv2.imshow('result', draw) + cv2.waitKey(0) + else: + cv2.imwrite(str(save_image), draw) + + +def crop_mask(masks: ndarray, bboxes: ndarray) -> ndarray: + n, h, w = masks.shape + x1, y1, x2, y2 = np.split(bboxes[:, :, None], [1, 2, 3], + 1) # x1 shape(1,1,n) + r = np.arange(w, dtype=x1.dtype)[None, None, :] # rows shape(1,w,1) + c = np.arange(h, dtype=x1.dtype)[None, :, None] # cols shape(h,1,1) + + return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2)) + + +def seg_postprocess( + data: Tuple[ndarray], + shape: Union[Tuple, List], + conf_thres: float = 0.25, + iou_thres: float = 0.65) -> Tuple[ndarray, ndarray, ndarray, List]: + assert len(data) == 2 + h, w = shape[0] // 4, shape[1] // 4 # 4x downsampling + outputs, proto = (i[0] for i in data) + bboxes, scores, labels, maskconf = np.split(outputs, [4, 5, 6], 1) + scores, labels = scores.squeeze(), labels.squeeze() + select = scores > conf_thres + bboxes, scores, labels, maskconf = bboxes[select], scores[select], labels[ + select], maskconf[select] + cvbboxes = np.concatenate([bboxes[:, :2], bboxes[:, 2:] - bboxes[:, :2]], + 1) + labels = labels.astype(np.int32) + v0, v1 = map(int, (cv2.__version__).split('.')[:2]) + assert v0 == 4, 'OpenCV version is wrong' + if v1 > 6: + idx = cv2.dnn.NMSBoxesBatched(cvbboxes, scores, labels, conf_thres, + iou_thres) + else: + idx = cv2.dnn.NMSBoxes(cvbboxes, scores, conf_thres, iou_thres) + bboxes, scores, labels, maskconf = bboxes[idx], scores[idx], labels[ + idx], maskconf[idx] + masks = (maskconf @ proto).reshape(-1, h, w) + masks = crop_mask(masks, bboxes / 4.) + masks = cv2.resize(masks.transpose([1, 2, 0]), + shape, + interpolation=cv2.INTER_LINEAR).transpose(2, 0, 1) + masks = np.ascontiguousarray((masks > 0.5)[..., None]) + cidx = labels % len(MASK_COLORS) + mask_color = MASK_COLORS[cidx].reshape(-1, 1, 1, 3) * ALPHA + out = [masks, masks @ mask_color] + return bboxes, scores, labels, out + + +def det_postprocess(data: Tuple[ndarray, ndarray, ndarray]): + assert len(data) == 4 + num_dets, bboxes, scores, labels = (i[0] for i in data) + nums = num_dets.item() + bboxes = bboxes[:nums] + scores = scores[:nums] + labels = labels[:nums] + return bboxes, scores, labels + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--engine', type=str, help='Engine file') + parser.add_argument('--imgs', type=str, help='Images file') + parser.add_argument('--show', + action='store_true', + help='Show the detection results') + parser.add_argument('--seg', action='store_true', help='Seg inference') + parser.add_argument('--out-dir', + type=str, + default='./output', + help='Path to output file') + parser.add_argument('--conf-thres', + type=float, + default=0.25, + help='Confidence threshold') + parser.add_argument('--iou-thres', + type=float, + default=0.65, + help='Confidence threshold') + parser.add_argument('--method', + type=str, + default='cudart', + help='CUDART pipeline') + parser.add_argument('--profile', + action='store_true', + help='Profile TensorRT engine') + args = parser.parse_args() + return args + + +if __name__ == '__main__': + args = parse_args() + main(args) diff --git a/infer.py b/infer.py index b281b2c..576df75 100644 --- a/infer.py +++ b/infer.py @@ -3,7 +3,7 @@ import argparse import os import random from pathlib import Path -from typing import Any, List, Tuple, Union +from typing import List, Tuple, Union import cv2 import numpy as np @@ -145,7 +145,7 @@ def main(args): draw = cv2.resize(seg_img.cpu().numpy().astype(np.uint8), draw.shape[:2][::-1]) else: - bboxes, scores, labels, masks = det_postprocess(data) + bboxes, scores, labels = det_postprocess(data) bboxes -= dwdh bboxes /= ratio @@ -209,14 +209,14 @@ def seg_postprocess( return bboxes, scores, labels, out -def det_postprocess(data: Tuple[Tensor, Tensor, Tensor, Any], **kwargs): +def det_postprocess(data: Tuple[Tensor, Tensor, Tensor, Tensor]): assert len(data) == 4 num_dets, bboxes, scores, labels = (i[0] for i in data) nums = num_dets.item() bboxes = bboxes[:nums] scores = scores[:nums] labels = labels[:nums] - return bboxes, scores, labels, None + return bboxes, scores, labels def parse_args(): diff --git a/models/cudart_api.py b/models/cudart_api.py new file mode 100644 index 0000000..a21a36d --- /dev/null +++ b/models/cudart_api.py @@ -0,0 +1,160 @@ +import os +import warnings +from collections import namedtuple +from pathlib import Path +from typing import List, Optional, Tuple, Union + +import numpy as np +import tensorrt as trt +from cuda import cudart +from numpy import ndarray + +os.environ['CUDA_MODULE_LOADING'] = 'LAZY' +warnings.filterwarnings(action='ignore', category=DeprecationWarning) + + +class TRTEngine: + + def __init__(self, weight: Union[str, Path]) -> None: + self.weight = Path(weight) if isinstance(weight, str) else weight + status, self.stream = cudart.cudaStreamCreate() + assert status.value == 0 + self.__init_engine() + self.__init_bindings() + self.__warm_up() + + def __init_engine(self) -> None: + logger = trt.Logger(trt.Logger.WARNING) + trt.init_libnvinfer_plugins(logger, namespace='') + with trt.Runtime(logger) as runtime: + model = runtime.deserialize_cuda_engine(self.weight.read_bytes()) + + context = model.create_execution_context() + + names = [model.get_binding_name(i) for i in range(model.num_bindings)] + self.num_bindings = model.num_bindings + self.bindings: List[int] = [0] * self.num_bindings + num_inputs, num_outputs = 0, 0 + + for i in range(model.num_bindings): + if model.binding_is_input(i): + num_inputs += 1 + else: + num_outputs += 1 + + self.num_inputs = num_inputs + self.num_outputs = num_outputs + self.model = model + self.context = context + self.input_names = names[:num_inputs] + self.output_names = names[num_inputs:] + + def __init_bindings(self) -> None: + dynamic = False + Tensor = namedtuple('Tensor', ('name', 'dtype', 'shape', 'cpu', 'gpu')) + inp_info = [] + out_info = [] + out_ptrs = [] + for i, name in enumerate(self.input_names): + assert self.model.get_binding_name(i) == name + dtype = trt.nptype(self.model.get_binding_dtype(i)) + shape = tuple(self.model.get_binding_shape(i)) + if -1 in shape: + dynamic |= True + if not dynamic: + cpu = np.empty(shape, dtype) + status, gpu = cudart.cudaMallocAsync(cpu.nbytes, self.stream) + assert status.value == 0 + cudart.cudaMemcpyAsync( + gpu, cpu.ctypes.data, cpu.nbytes, + cudart.cudaMemcpyKind.cudaMemcpyHostToDevice, self.stream) + else: + cpu, gpu = np.empty(0), 0 + inp_info.append(Tensor(name, dtype, shape, cpu, gpu)) + for i, name in enumerate(self.output_names): + i += self.num_inputs + assert self.model.get_binding_name(i) == name + dtype = trt.nptype(self.model.get_binding_dtype(i)) + shape = tuple(self.model.get_binding_shape(i)) + if not dynamic: + cpu = np.empty(shape, dtype=dtype) + status, gpu = cudart.cudaMallocAsync(cpu.nbytes, self.stream) + assert status.value == 0 + cudart.cudaMemcpyAsync( + gpu, cpu.ctypes.data, cpu.nbytes, + cudart.cudaMemcpyKind.cudaMemcpyHostToDevice, self.stream) + out_ptrs.append(gpu) + else: + cpu, gpu = np.empty(0), 0 + out_info.append(Tensor(name, dtype, shape, cpu, gpu)) + + self.is_dynamic = dynamic + self.inp_info = inp_info + self.out_info = out_info + self.out_ptrs = out_ptrs + + def __warm_up(self) -> None: + if self.is_dynamic: + print('You engine has dynamic axes, please warm up by yourself !') + return + for _ in range(10): + inputs = [] + for i in self.inp_info: + inputs.append(i.cpu) + self.__call__(inputs) + + def set_profiler(self, profiler: Optional[trt.IProfiler]) -> None: + self.context.profiler = profiler \ + if profiler is not None else trt.Profiler() + + def __call__(self, *inputs) -> Union[Tuple, ndarray]: + + assert len(inputs) == self.num_inputs + contiguous_inputs: List[ndarray] = [ + np.ascontiguousarray(i) for i in inputs + ] + + for i in range(self.num_inputs): + + if self.is_dynamic: + self.context.set_binding_shape( + i, tuple(contiguous_inputs[i].shape)) + status, self.inp_info[i].gpu = cudart.cudaMallocAsync( + contiguous_inputs[i].nbytes, self.stream) + assert status.value == 0 + cudart.cudaMemcpyAsync( + self.inp_info[i].gpu, contiguous_inputs[i].ctypes.data, + contiguous_inputs[i].nbytes, + cudart.cudaMemcpyKind.cudaMemcpyHostToDevice, self.stream) + self.bindings[i] = self.inp_info[i].gpu + + output_gpu_ptrs: List[int] = [] + outputs: List[ndarray] = [] + + for i in range(self.num_outputs): + j = i + self.num_inputs + if self.is_dynamic: + shape = tuple(self.context.get_binding_shape(j)) + dtype = self.out_info[i].dtype + cpu = np.empty(shape, dtype=dtype) + status, gpu = cudart.cudaMallocAsync(cpu.nbytes, self.stream) + assert status.value == 0 + cudart.cudaMemcpyAsync( + gpu, cpu.ctypes.data, cpu.nbytes, + cudart.cudaMemcpyKind.cudaMemcpyHostToDevice, self.stream) + else: + cpu = self.out_info[i].cpu + gpu = self.out_info[i].gpu + outputs.append(cpu) + output_gpu_ptrs.append(gpu) + self.bindings[j] = gpu + + self.context.execute_async_v2(self.bindings, self.stream) + cudart.cudaStreamSynchronize(self.stream) + + for i, o in enumerate(output_gpu_ptrs): + cudart.cudaMemcpyAsync( + outputs[i].ctypes.data, o, outputs[i].nbytes, + cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost, self.stream) + + return tuple(outputs) if len(outputs) > 1 else outputs[0] diff --git a/models/pycuda_api.py b/models/pycuda_api.py new file mode 100644 index 0000000..e340da3 --- /dev/null +++ b/models/pycuda_api.py @@ -0,0 +1,147 @@ +import os +import warnings +from collections import namedtuple +from pathlib import Path +from typing import List, Optional, Tuple, Union + +import numpy as np +import pycuda.autoinit # noqa F401 +import pycuda.driver as cuda +import tensorrt as trt +from numpy import ndarray + +os.environ['CUDA_MODULE_LOADING'] = 'LAZY' +warnings.filterwarnings(action='ignore', category=DeprecationWarning) + + +class TRTEngine: + + def __init__(self, weight: Union[str, Path]) -> None: + self.weight = Path(weight) if isinstance(weight, str) else weight + self.stream = cuda.Stream(0) + self.__init_engine() + self.__init_bindings() + self.__warm_up() + + def __init_engine(self) -> None: + logger = trt.Logger(trt.Logger.WARNING) + trt.init_libnvinfer_plugins(logger, namespace='') + with trt.Runtime(logger) as runtime: + model = runtime.deserialize_cuda_engine(self.weight.read_bytes()) + + context = model.create_execution_context() + + names = [model.get_binding_name(i) for i in range(model.num_bindings)] + self.num_bindings = model.num_bindings + self.bindings: List[int] = [0] * self.num_bindings + num_inputs, num_outputs = 0, 0 + + for i in range(model.num_bindings): + if model.binding_is_input(i): + num_inputs += 1 + else: + num_outputs += 1 + + self.num_inputs = num_inputs + self.num_outputs = num_outputs + self.model = model + self.context = context + self.input_names = names[:num_inputs] + self.output_names = names[num_inputs:] + + def __init_bindings(self) -> None: + dynamic = False + Tensor = namedtuple('Tensor', ('name', 'dtype', 'shape', 'cpu', 'gpu')) + inp_info = [] + out_info = [] + out_ptrs = [] + for i, name in enumerate(self.input_names): + assert self.model.get_binding_name(i) == name + dtype = trt.nptype(self.model.get_binding_dtype(i)) + shape = tuple(self.model.get_binding_shape(i)) + if -1 in shape: + dynamic |= True + if not dynamic: + cpu = np.empty(shape, dtype) + gpu = cuda.mem_alloc(cpu.nbytes) + cuda.memcpy_htod_async(gpu, cpu, self.stream) + else: + cpu, gpu = np.empty(0), 0 + inp_info.append(Tensor(name, dtype, shape, cpu, gpu)) + for i, name in enumerate(self.output_names): + i += self.num_inputs + assert self.model.get_binding_name(i) == name + dtype = trt.nptype(self.model.get_binding_dtype(i)) + shape = tuple(self.model.get_binding_shape(i)) + if not dynamic: + cpu = np.empty(shape, dtype=dtype) + gpu = cuda.mem_alloc(cpu.nbytes) + cuda.memcpy_htod_async(gpu, cpu, self.stream) + out_ptrs.append(gpu) + else: + cpu, gpu = np.empty(0), 0 + out_info.append(Tensor(name, dtype, shape, cpu, gpu)) + + self.is_dynamic = dynamic + self.inp_info = inp_info + self.out_info = out_info + self.out_ptrs = out_ptrs + + def __warm_up(self) -> None: + if self.is_dynamic: + print('You engine has dynamic axes, please warm up by yourself !') + return + for _ in range(10): + inputs = [] + for i in self.inp_info: + inputs.append(i.cpu) + self.__call__(inputs) + + def set_profiler(self, profiler: Optional[trt.IProfiler]) -> None: + self.context.profiler = profiler \ + if profiler is not None else trt.Profiler() + + def __call__(self, *inputs) -> Union[Tuple, ndarray]: + + assert len(inputs) == self.num_inputs + contiguous_inputs: List[ndarray] = [ + np.ascontiguousarray(i) for i in inputs + ] + + for i in range(self.num_inputs): + + if self.is_dynamic: + self.context.set_binding_shape( + i, tuple(contiguous_inputs[i].shape)) + self.inp_info[i].gpu = cuda.mem_alloc( + contiguous_inputs[i].nbytes) + + cuda.memcpy_htod_async(self.inp_info[i].gpu, contiguous_inputs[i], + self.stream) + self.bindings[i] = int(self.inp_info[i].gpu) + + output_gpu_ptrs: List[int] = [] + outputs: List[ndarray] = [] + + for i in range(self.num_outputs): + j = i + self.num_inputs + if self.is_dynamic: + shape = tuple(self.context.get_binding_shape(j)) + dtype = self.out_info[i].dtype + cpu = np.empty(shape, dtype=dtype) + gpu = cuda.mem_alloc(contiguous_inputs[i].nbytes) + cuda.memcpy_htod_async(gpu, cpu, self.stream) + else: + cpu = self.out_info[i].cpu + gpu = self.out_info[i].gpu + outputs.append(cpu) + output_gpu_ptrs.append(gpu) + self.bindings[j] = int(gpu) + + self.context.execute_async_v2(self.bindings, self.stream.handle) + self.stream.synchronize() + + for i, o in enumerate(output_gpu_ptrs): + cuda.memcpy_dtoh_async(outputs[i], o, self.stream) + + return tuple(outputs) if len(outputs) > 1 else outputs[0] diff --git a/requirements.txt b/requirements.txt index 1a2fe35..9f69929 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,3 +6,5 @@ torch torchvision ultralytics # tensorrt +# cuda-python +# pycuda