diff --git a/README.md b/README.md index 3561532..011a453 100644 --- a/README.md +++ b/README.md @@ -38,7 +38,7 @@ 5. Prepare your own PyTorch weight such as `yolov8s.pt` or `yolov8s-seg.pt`. -​ ***NOTICE:*** +***NOTICE:*** Please use the latest `CUDA` and `TensorRT`, so that you can achieve the fastest speed ! @@ -212,6 +212,11 @@ Please see more information in [`Segment.md`](docs/Segment.md) See more in [`README.md`](csrc/deepstream/README.md) +# Jetson Deploy + +Only test on `Jetson-NX 4GB`. +See more in [`Jetson.md`](docs/Jetson.md) + # Profile you engine If you want to profile the TensorRT engine: diff --git a/csrc/detect/end2end/CMakeLists.txt b/csrc/detect/end2end/CMakeLists.txt index bcdd8b0..fd3c1ed 100644 --- a/csrc/detect/end2end/CMakeLists.txt +++ b/csrc/detect/end2end/CMakeLists.txt @@ -33,7 +33,7 @@ list(APPEND INCLUDE_DIRS ${CUDA_INCLUDE_DIRS} ${OpenCV_INCLUDE_DIRS} ${TensorRT_INCLUDE_DIRS} - include + ./include ) list(APPEND ALL_LIBS diff --git a/csrc/detect/normal/CMakeLists.txt b/csrc/detect/normal/CMakeLists.txt index 27cdcee..ab25ceb 100644 --- a/csrc/detect/normal/CMakeLists.txt +++ b/csrc/detect/normal/CMakeLists.txt @@ -33,7 +33,7 @@ list(APPEND INCLUDE_DIRS ${CUDA_INCLUDE_DIRS} ${OpenCV_INCLUDE_DIRS} ${TensorRT_INCLUDE_DIRS} - include + ./include ) list(APPEND ALL_LIBS diff --git a/csrc/jetson/detect/CMakeLists.txt b/csrc/jetson/detect/CMakeLists.txt new file mode 100644 index 0000000..69eb3cc --- /dev/null +++ b/csrc/jetson/detect/CMakeLists.txt @@ -0,0 +1,54 @@ +cmake_minimum_required(VERSION 2.8.12) + +set(CMAKE_CUDA_ARCHITECTURES 60 61 62 70 72 75 86) +set(CMAKE_CUDA_COMPILER /usr/local/cuda/bin/nvcc) + +project(yolov8 LANGUAGES CXX CUDA) + +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++14 -O3 -g") +set(CMAKE_CXX_STANDARD 14) +set(CMAKE_BUILD_TYPE Release) +option(CUDA_USE_STATIC_CUDA_RUNTIME OFF) + +# CUDA +find_package(CUDA REQUIRED) +message(STATUS "CUDA Libs: \n${CUDA_LIBRARIES}\n") +message(STATUS "CUDA Headers: \n${CUDA_INCLUDE_DIRS}\n") + +# OpenCV +find_package(OpenCV REQUIRED) +message(STATUS "OpenCV Libs: \n${OpenCV_LIBS}\n") +message(STATUS "OpenCV Libraries: \n${OpenCV_LIBRARIES}\n") +message(STATUS "OpenCV Headers: \n${OpenCV_INCLUDE_DIRS}\n") + +# TensorRT +set(TensorRT_INCLUDE_DIRS /usr/include/aarch64-linux-gnu) +set(TensorRT_LIBRARIES /usr/lib/aarch64-linux-gnu) + + +message(STATUS "TensorRT Libs: \n${TensorRT_LIBRARIES}\n") +message(STATUS "TensorRT Headers: \n${TensorRT_INCLUDE_DIRS}\n") + +list(APPEND INCLUDE_DIRS + ${CUDA_INCLUDE_DIRS} + ${OpenCV_INCLUDE_DIRS} + ${TensorRT_INCLUDE_DIRS} + ./include + ) + +list(APPEND ALL_LIBS + ${CUDA_LIBRARIES} + ${OpenCV_LIBRARIES} + ${TensorRT_LIBRARIES} + ) + +include_directories(${INCLUDE_DIRS}) + +add_executable(${PROJECT_NAME} + main.cpp + include/yolov8.hpp + include/common.hpp + ) + +link_directories(${ALL_LIBS}) +target_link_libraries(${PROJECT_NAME} PRIVATE nvinfer nvinfer_plugin ${CUDA_LIBRARIES} ${OpenCV_LIBS}) diff --git a/csrc/jetson/detect/include/common.hpp b/csrc/jetson/detect/include/common.hpp new file mode 100644 index 0000000..3b2fbeb --- /dev/null +++ b/csrc/jetson/detect/include/common.hpp @@ -0,0 +1,156 @@ +// +// Created by ubuntu on 3/16/23. +// + +#ifndef JETSON_DETECT_COMMON_HPP +#define JETSON_DETECT_COMMON_HPP +#include "opencv2/opencv.hpp" +#include +#include +#include "NvInfer.h" + +#define CHECK(call) \ +do \ +{ \ + const cudaError_t error_code = call; \ + if (error_code != cudaSuccess) \ + { \ + printf("CUDA Error:\n"); \ + printf(" File: %s\n", __FILE__); \ + printf(" Line: %d\n", __LINE__); \ + printf(" Error code: %d\n", error_code); \ + printf(" Error text: %s\n", \ + cudaGetErrorString(error_code)); \ + exit(1); \ + } \ +} while (0) + +class Logger : public nvinfer1::ILogger +{ +public: + nvinfer1::ILogger::Severity reportableSeverity; + + explicit Logger(nvinfer1::ILogger::Severity severity = nvinfer1::ILogger::Severity::kINFO) : + reportableSeverity(severity) + { + } + + void log(nvinfer1::ILogger::Severity severity, const char* msg) noexcept override + { + if (severity > reportableSeverity) + { + return; + } + switch (severity) + { + case nvinfer1::ILogger::Severity::kINTERNAL_ERROR: + std::cerr << "INTERNAL_ERROR: "; + break; + case nvinfer1::ILogger::Severity::kERROR: + std::cerr << "ERROR: "; + break; + case nvinfer1::ILogger::Severity::kWARNING: + std::cerr << "WARNING: "; + break; + case nvinfer1::ILogger::Severity::kINFO: + std::cerr << "INFO: "; + break; + default: + std::cerr << "VERBOSE: "; + break; + } + std::cerr << msg << std::endl; + } +}; + +inline int get_size_by_dims(const nvinfer1::Dims& dims) +{ + int size = 1; + for (int i = 0; i < dims.nbDims; i++) + { + size *= dims.d[i]; + } + return size; +} + +inline int type_to_size(const nvinfer1::DataType& dataType) +{ + switch (dataType) + { + case nvinfer1::DataType::kFLOAT: + return 4; + case nvinfer1::DataType::kHALF: + return 2; + case nvinfer1::DataType::kINT32: + return 4; + case nvinfer1::DataType::kINT8: + return 1; + case nvinfer1::DataType::kBOOL: + return 1; + default: + return 4; + } +} + +inline static float clamp(float val, float min, float max) +{ + return val > min ? (val < max ? val : max) : min; +} + +inline bool IsPathExist(const std::string& path) +{ + if (access(path.c_str(), 0) == F_OK) + { + return true; + } + return false; +} + +inline bool IsFile(const std::string& path) +{ + if (!IsPathExist(path)) + { + printf("%s:%d %s not exist\n", __FILE__, __LINE__, path.c_str()); + return false; + } + struct stat buffer; + return (stat(path.c_str(), &buffer) == 0 && S_ISREG(buffer.st_mode)); +} + +inline bool IsFolder(const std::string& path) +{ + if (!IsPathExist(path)) + { + return false; + } + struct stat buffer; + return (stat(path.c_str(), &buffer) == 0 && S_ISDIR(buffer.st_mode)); +} + +namespace det +{ + struct Binding + { + size_t size = 1; + size_t dsize = 1; + nvinfer1::Dims dims; + std::string name; + }; + + struct Object + { + cv::Rect_ rect; + int label = 0; + float prob = 0.0; + }; + + struct PreParam + { + float ratio = 1.0f; + float dw = 0.0f; + float dh = 0.0f; + float height = 0; + float width = 0; + }; +} +#endif //JETSON_DETECT_COMMON_HPP diff --git a/csrc/jetson/detect/include/yolov8.hpp b/csrc/jetson/detect/include/yolov8.hpp new file mode 100644 index 0000000..2c0750a --- /dev/null +++ b/csrc/jetson/detect/include/yolov8.hpp @@ -0,0 +1,424 @@ +// +// Created by ubuntu on 3/16/23. +// +#ifndef JETSON_DETECT_YOLOV8_HPP +#define JETSON_DETECT_YOLOV8_HPP +#include "fstream" +#include "common.hpp" +#include "NvInferPlugin.h" +using namespace det; + +class YOLOv8 +{ +public: + explicit YOLOv8(const std::string& engine_file_path); + ~YOLOv8(); + + void make_pipe(bool warmup = true); + void copy_from_Mat(const cv::Mat& image); + void copy_from_Mat(const cv::Mat& image, cv::Size& size); + void letterbox( + const cv::Mat& image, + cv::Mat& out, + cv::Size& size + ); + void infer(); + void postprocess(std::vector& objs); + static void draw_objects( + const cv::Mat& image, + cv::Mat& res, + const std::vector& objs, + const std::vector& CLASS_NAMES, + const std::vector>& COLORS + ); + int num_bindings; + int num_inputs = 0; + int num_outputs = 0; + std::vector input_bindings; + std::vector output_bindings; + std::vector host_ptrs; + std::vector device_ptrs; + + PreParam pparam; +private: + nvinfer1::ICudaEngine* engine = nullptr; + nvinfer1::IRuntime* runtime = nullptr; + nvinfer1::IExecutionContext* context = nullptr; + cudaStream_t stream = nullptr; + Logger gLogger{ nvinfer1::ILogger::Severity::kERROR }; + +}; + +YOLOv8::YOLOv8(const std::string& engine_file_path) +{ + std::ifstream file(engine_file_path, std::ios::binary); + assert(file.good()); + file.seekg(0, std::ios::end); + auto size = file.tellg(); + file.seekg(0, std::ios::beg); + char* trtModelStream = new char[size]; + assert(trtModelStream); + file.read(trtModelStream, size); + file.close(); + initLibNvInferPlugins(&this->gLogger, ""); + this->runtime = nvinfer1::createInferRuntime(this->gLogger); + assert(this->runtime != nullptr); + + this->engine = this->runtime->deserializeCudaEngine(trtModelStream, size); + assert(this->engine != nullptr); + + this->context = this->engine->createExecutionContext(); + + assert(this->context != nullptr); + cudaStreamCreate(&this->stream); + this->num_bindings = this->engine->getNbBindings(); + + for (int i = 0; i < this->num_bindings; ++i) + { + Binding binding; + nvinfer1::Dims dims; + nvinfer1::DataType dtype = this->engine->getBindingDataType(i); + std::string name = this->engine->getBindingName(i); + binding.name = name; + binding.dsize = type_to_size(dtype); + + bool IsInput = engine->bindingIsInput(i); + if (IsInput) + { + this->num_inputs += 1; + dims = this->engine->getProfileDimensions( + i, + 0, + nvinfer1::OptProfileSelector::kMAX); + binding.size = get_size_by_dims(dims); + binding.dims = dims; + this->input_bindings.push_back(binding); + // set max opt shape + this->context->setBindingDimensions(i, dims); + + } + else + { + dims = this->context->getBindingDimensions(i); + binding.size = get_size_by_dims(dims); + binding.dims = dims; + this->output_bindings.push_back(binding); + this->num_outputs += 1; + } + } + +} + +YOLOv8::~YOLOv8() +{ + this->context->destroy(); + this->engine->destroy(); + this->runtime->destroy(); + cudaStreamDestroy(this->stream); + for (auto& ptr : this->device_ptrs) + { + CHECK(cudaFree(ptr)); + } + + for (auto& ptr : this->host_ptrs) + { + CHECK(cudaFreeHost(ptr)); + } + +} +void YOLOv8::make_pipe(bool warmup) +{ + + for (auto& bindings : this->input_bindings) + { + void* d_ptr; + CHECK(cudaMalloc( + &d_ptr, + bindings.size * bindings.dsize) + ); + this->device_ptrs.push_back(d_ptr); + } + + for (auto& bindings : this->output_bindings) + { + void* d_ptr, * h_ptr; + size_t size = bindings.size * bindings.dsize; + CHECK(cudaMalloc( + &d_ptr, + size) + ); + CHECK(cudaHostAlloc( + &h_ptr, + size, + 0) + ); + this->device_ptrs.push_back(d_ptr); + this->host_ptrs.push_back(h_ptr); + } + + if (warmup) + { + for (int i = 0; i < 10; i++) + { + for (auto& bindings : this->input_bindings) + { + size_t size = bindings.size * bindings.dsize; + void* h_ptr = malloc(size); + memset(h_ptr, 0, size); + CHECK(cudaMemcpyAsync( + this->device_ptrs[0], + h_ptr, + size, + cudaMemcpyHostToDevice, + this->stream) + ); + free(h_ptr); + } + this->infer(); + } + printf("model warmup 10 times\n"); + + } +} + +void YOLOv8::letterbox(const cv::Mat& image, cv::Mat& out, cv::Size& size) +{ + const float inp_h = size.height; + const float inp_w = size.width; + float height = image.rows; + float width = image.cols; + + float r = std::min(inp_h / height, inp_w / width); + int padw = std::round(width * r); + int padh = std::round(height * r); + + cv::Mat tmp; + if ((int)width != padw || (int)height != padh) + { + cv::resize( + image, + tmp, + cv::Size(padw, padh) + ); + } + else + { + tmp = image.clone(); + } + + float dw = inp_w - padw; + float dh = inp_h - padh; + + dw /= 2.0f; + dh /= 2.0f; + int top = int(std::round(dh - 0.1f)); + int bottom = int(std::round(dh + 0.1f)); + int left = int(std::round(dw - 0.1f)); + int right = int(std::round(dw + 0.1f)); + + cv::copyMakeBorder( + tmp, + tmp, + top, + bottom, + left, + right, + cv::BORDER_CONSTANT, + { 114, 114, 114 } + ); + + cv::dnn::blobFromImage(tmp, + out, + 1 / 255.f, + cv::Size(), + cv::Scalar(0, 0, 0), + true, + false, + CV_32F + ); + this->pparam.ratio = 1 / r; + this->pparam.dw = dw; + this->pparam.dh = dh; + this->pparam.height = height; + this->pparam.width = width;; +} + +void YOLOv8::copy_from_Mat(const cv::Mat& image) +{ + cv::Mat nchw; + auto& in_binding = this->input_bindings[0]; + auto width = in_binding.dims.d[3]; + auto height = in_binding.dims.d[2]; + cv::Size size{ width, height }; + this->letterbox( + image, + nchw, + size + ); + + this->context->setBindingDimensions( + 0, + nvinfer1::Dims + { + 4, + { 1, 3, height, width } + } + ); + + CHECK(cudaMemcpyAsync( + this->device_ptrs[0], + nchw.ptr(), + nchw.total() * nchw.elemSize(), + cudaMemcpyHostToDevice, + this->stream) + ); +} + +void YOLOv8::copy_from_Mat(const cv::Mat& image, cv::Size& size) +{ + cv::Mat nchw; + this->letterbox( + image, + nchw, + size + ); + this->context->setBindingDimensions( + 0, + nvinfer1::Dims + { 4, + { 1, 3, size.height, size.width } + } + ); + CHECK(cudaMemcpyAsync( + this->device_ptrs[0], + nchw.ptr(), + nchw.total() * nchw.elemSize(), + cudaMemcpyHostToDevice, + this->stream) + ); +} + +void YOLOv8::infer() +{ + + this->context->enqueueV2( + this->device_ptrs.data(), + this->stream, + nullptr + ); + for (int i = 0; i < this->num_outputs; i++) + { + size_t osize = this->output_bindings[i].size * this->output_bindings[i].dsize; + CHECK(cudaMemcpyAsync(this->host_ptrs[i], + this->device_ptrs[i + this->num_inputs], + osize, + cudaMemcpyDeviceToHost, + this->stream) + ); + + } + cudaStreamSynchronize(this->stream); + +} + +void YOLOv8::postprocess(std::vector& objs) +{ + objs.clear(); + int* num_dets = static_cast(this->host_ptrs[0]); + auto* boxes = static_cast(this->host_ptrs[1]); + auto* scores = static_cast(this->host_ptrs[2]); + int* labels = static_cast(this->host_ptrs[3]); + auto& dw = this->pparam.dw; + auto& dh = this->pparam.dh; + auto& width = this->pparam.width; + auto& height = this->pparam.height; + auto& ratio = this->pparam.ratio; + for (int i = 0; i < num_dets[0]; i++) + { + float* ptr = boxes + i * 4; + + float x0 = *ptr++ - dw; + float y0 = *ptr++ - dh; + float x1 = *ptr++ - dw; + float y1 = *ptr - dh; + + x0 = clamp(x0 * ratio, 0.f, width); + y0 = clamp(y0 * ratio, 0.f, height); + x1 = clamp(x1 * ratio, 0.f, width); + y1 = clamp(y1 * ratio, 0.f, height); + Object obj; + obj.rect.x = x0; + obj.rect.y = y0; + obj.rect.width = x1 - x0; + obj.rect.height = y1 - y0; + obj.prob = *(scores + i); + obj.label = *(labels + i); + objs.push_back(obj); + } +} + +void YOLOv8::draw_objects( + const cv::Mat& image, + cv::Mat& res, + const std::vector& objs, + const std::vector& CLASS_NAMES, + const std::vector>& COLORS +) +{ + res = image.clone(); + for (auto& obj : objs) + { + cv::Scalar color = cv::Scalar( + COLORS[obj.label][0], + COLORS[obj.label][1], + COLORS[obj.label][2] + ); + cv::rectangle( + res, + obj.rect, + color, + 2 + ); + + char text[256]; + sprintf( + text, + "%s %.1f%%", + CLASS_NAMES[obj.label].c_str(), + obj.prob * 100 + ); + + int baseLine = 0; + cv::Size label_size = cv::getTextSize( + text, + cv::FONT_HERSHEY_SIMPLEX, + 0.4, + 1, + &baseLine + ); + + int x = (int)obj.rect.x; + int y = (int)obj.rect.y + 1; + + if (y > res.rows) + y = res.rows; + + cv::rectangle( + res, + cv::Rect(x, y, label_size.width, label_size.height + baseLine), + { 0, 0, 255 }, + -1 + ); + + cv::putText( + res, + text, + cv::Point(x, y + label_size.height), + cv::FONT_HERSHEY_SIMPLEX, + 0.4, + { 255, 255, 255 }, + 1 + ); + } +} +#endif //JETSON_DETECT_YOLOV8_HPP diff --git a/csrc/jetson/detect/main.cpp b/csrc/jetson/detect/main.cpp new file mode 100644 index 0000000..c9e5c96 --- /dev/null +++ b/csrc/jetson/detect/main.cpp @@ -0,0 +1,158 @@ +// +// Created by ubuntu on 3/16/23. +// +#include "chrono" +#include "yolov8.hpp" +#include "opencv2/opencv.hpp" + +const std::vector CLASS_NAMES = { + "person", "bicycle", "car", "motorcycle", "airplane", "bus", + "train", "truck", "boat", "traffic light", "fire hydrant", + "stop sign", "parking meter", "bench", "bird", "cat", + "dog", "horse", "sheep", "cow", "elephant", + "bear", "zebra", "giraffe", "backpack", "umbrella", + "handbag", "tie", "suitcase", "frisbee", "skis", + "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", + "skateboard", "surfboard", "tennis racket", "bottle", "wine glass", + "cup", "fork", "knife", "spoon", "bowl", + "banana", "apple", "sandwich", "orange", "broccoli", + "carrot", "hot dog", "pizza", "donut", "cake", + "chair", "couch", "potted plant", "bed", "dining table", + "toilet", "tv", "laptop", "mouse", "remote", + "keyboard", "cell phone", "microwave", "oven", + "toaster", "sink", "refrigerator", "book", "clock", "vase", + "scissors", "teddy bear", "hair drier", "toothbrush" }; + +const std::vector> COLORS = { + { 0, 114, 189 }, { 217, 83, 25 }, { 237, 177, 32 }, + { 126, 47, 142 }, { 119, 172, 48 }, { 77, 190, 238 }, + { 162, 20, 47 }, { 76, 76, 76 }, { 153, 153, 153 }, + { 255, 0, 0 }, { 255, 128, 0 }, { 191, 191, 0 }, + { 0, 255, 0 }, { 0, 0, 255 }, { 170, 0, 255 }, + { 85, 85, 0 }, { 85, 170, 0 }, { 85, 255, 0 }, + { 170, 85, 0 }, { 170, 170, 0 }, { 170, 255, 0 }, + { 255, 85, 0 }, { 255, 170, 0 }, { 255, 255, 0 }, + { 0, 85, 128 }, { 0, 170, 128 }, { 0, 255, 128 }, + { 85, 0, 128 }, { 85, 85, 128 }, { 85, 170, 128 }, + { 85, 255, 128 }, { 170, 0, 128 }, { 170, 85, 128 }, + { 170, 170, 128 }, { 170, 255, 128 }, { 255, 0, 128 }, + { 255, 85, 128 }, { 255, 170, 128 }, { 255, 255, 128 }, + { 0, 85, 255 }, { 0, 170, 255 }, { 0, 255, 255 }, + { 85, 0, 255 }, { 85, 85, 255 }, { 85, 170, 255 }, + { 85, 255, 255 }, { 170, 0, 255 }, { 170, 85, 255 }, + { 170, 170, 255 }, { 170, 255, 255 }, { 255, 0, 255 }, + { 255, 85, 255 }, { 255, 170, 255 }, { 85, 0, 0 }, + { 128, 0, 0 }, { 170, 0, 0 }, { 212, 0, 0 }, + { 255, 0, 0 }, { 0, 43, 0 }, { 0, 85, 0 }, + { 0, 128, 0 }, { 0, 170, 0 }, { 0, 212, 0 }, + { 0, 255, 0 }, { 0, 0, 43 }, { 0, 0, 85 }, + { 0, 0, 128 }, { 0, 0, 170 }, { 0, 0, 212 }, + { 0, 0, 255 }, { 0, 0, 0 }, { 36, 36, 36 }, + { 73, 73, 73 }, { 109, 109, 109 }, { 146, 146, 146 }, + { 182, 182, 182 }, { 219, 219, 219 }, { 0, 114, 189 }, + { 80, 183, 189 }, { 128, 128, 0 } +}; + +int main(int argc, char** argv) +{ + const std::string engine_file_path{ argv[1] }; + const std::string path{ argv[2] }; + + std::vector imagePathList; + bool isVideo{ false }; + + assert(argc == 3); + + auto yolov8 = new YOLOv8(engine_file_path); + yolov8->make_pipe(true); + + if (IsFile(path)) + { + std::string suffix = path.substr(path.find_last_of('.') + 1); + if ( + suffix == "jpg" || + suffix == "jpeg" || + suffix == "png" + ) + { + imagePathList.push_back(path); + } + else if ( + suffix == "mp4" || + suffix == "avi" || + suffix == "m4v" || + suffix == "mpeg" || + suffix == "mov" || + suffix == "mkv" + ) + { + isVideo = true; + } + else + { + printf("suffix %s is wrong !!!\n", suffix.c_str()); + std::abort(); + } + } + else if (IsFolder(path)) + { + cv::glob(path + "/*.jpg", imagePathList); + } + + cv::Mat res, image; + cv::Size size = cv::Size{ 640, 640 }; + std::vector objs; + + cv::namedWindow("result", cv::WINDOW_AUTOSIZE); + + if (isVideo) + { + cv::VideoCapture cap(path); + + if (!cap.isOpened()) + { + printf("can not open %s\n", path.c_str()); + return -1; + } + while (cap.read(image)) + { + objs.clear(); + yolov8->copy_from_Mat(image, size); + auto start = std::chrono::system_clock::now(); + yolov8->infer(); + auto end = std::chrono::system_clock::now(); + yolov8->postprocess(objs); + yolov8->draw_objects(image, res, objs, CLASS_NAMES, COLORS); + auto tc = (double) + std::chrono::duration_cast(end - start).count() / 1000.; + printf("cost %2.4lf ms\n", tc); + cv::imshow("result", res); + if (cv::waitKey(10) == 'q') + { + break; + } + } + } + else + { + for (auto& path : imagePathList) + { + objs.clear(); + image = cv::imread(path); + yolov8->copy_from_Mat(image, size); + auto start = std::chrono::system_clock::now(); + yolov8->infer(); + auto end = std::chrono::system_clock::now(); + yolov8->postprocess(objs); + yolov8->draw_objects(image, res, objs, CLASS_NAMES, COLORS); + auto tc = (double) + std::chrono::duration_cast(end - start).count() / 1000.; + printf("cost %2.4lf ms\n", tc); + cv::imshow("result", res); + cv::waitKey(0); + } + } + cv::destroyAllWindows(); + delete yolov8; + return 0; +} diff --git a/csrc/jetson/segment/CMakeLists.txt b/csrc/jetson/segment/CMakeLists.txt new file mode 100644 index 0000000..bcaf1b2 --- /dev/null +++ b/csrc/jetson/segment/CMakeLists.txt @@ -0,0 +1,60 @@ +cmake_minimum_required(VERSION 2.8.12) + +set(CMAKE_CUDA_ARCHITECTURES 60 61 62 70 72 75 86) +set(CMAKE_CUDA_COMPILER /usr/local/cuda/bin/nvcc) + +project(yolov8-seg LANGUAGES CXX CUDA) + +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++14 -O3 -g") +set(CMAKE_CXX_STANDARD 14) +set(CMAKE_BUILD_TYPE Release) +option(CUDA_USE_STATIC_CUDA_RUNTIME OFF) + +# CUDA +find_package(CUDA REQUIRED) +message(STATUS "CUDA Libs: \n${CUDA_LIBRARIES}\n") +message(STATUS "CUDA Headers: \n${CUDA_INCLUDE_DIRS}\n") + +# OpenCV +find_package(OpenCV REQUIRED) +message(STATUS "OpenCV Libs: \n${OpenCV_LIBS}\n") +message(STATUS "OpenCV Libraries: \n${OpenCV_LIBRARIES}\n") +message(STATUS "OpenCV Headers: \n${OpenCV_INCLUDE_DIRS}\n") + +# TensorRT +set(TensorRT_INCLUDE_DIRS /usr/include/aarch64-linux-gnu) +set(TensorRT_LIBRARIES /usr/lib/aarch64-linux-gnu) + + +message(STATUS "TensorRT Libs: \n${TensorRT_LIBRARIES}\n") +message(STATUS "TensorRT Headers: \n${TensorRT_INCLUDE_DIRS}\n") + +list(APPEND INCLUDE_DIRS + ${CUDA_INCLUDE_DIRS} + ${OpenCV_INCLUDE_DIRS} + ${TensorRT_INCLUDE_DIRS} + ./include + ) + +list(APPEND ALL_LIBS + ${CUDA_LIBRARIES} + ${OpenCV_LIBRARIES} + ${TensorRT_LIBRARIES} + ) + +include_directories(${INCLUDE_DIRS}) + +add_executable(${PROJECT_NAME} + main.cpp + include/yolov8-seg.hpp + include/common.hpp + ) + +link_directories(${ALL_LIBS}) +target_link_libraries(${PROJECT_NAME} PRIVATE nvinfer nvinfer_plugin ${CUDA_LIBRARIES} ${OpenCV_LIBS}) + + +if(${OpenCV_VERSION} VERSION_GREATER_EQUAL 4.7.0) + message(STATUS "Build with -DBATCHED_NMS") + add_definitions(-DBATCHED_NMS) +endif() diff --git a/csrc/jetson/segment/include/common.hpp b/csrc/jetson/segment/include/common.hpp new file mode 100644 index 0000000..3db19e1 --- /dev/null +++ b/csrc/jetson/segment/include/common.hpp @@ -0,0 +1,157 @@ +// +// Created by ubuntu on 3/16/23. +// + +#ifndef JETSON_SEGMENT_COMMON_HPP +#define JETSON_SEGMENT_COMMON_HPP +#include "opencv2/opencv.hpp" +#include +#include +#include "NvInfer.h" + +#define CHECK(call) \ +do \ +{ \ + const cudaError_t error_code = call; \ + if (error_code != cudaSuccess) \ + { \ + printf("CUDA Error:\n"); \ + printf(" File: %s\n", __FILE__); \ + printf(" Line: %d\n", __LINE__); \ + printf(" Error code: %d\n", error_code); \ + printf(" Error text: %s\n", \ + cudaGetErrorString(error_code)); \ + exit(1); \ + } \ +} while (0) + +class Logger : public nvinfer1::ILogger +{ +public: + nvinfer1::ILogger::Severity reportableSeverity; + + explicit Logger(nvinfer1::ILogger::Severity severity = nvinfer1::ILogger::Severity::kINFO) : + reportableSeverity(severity) + { + } + + void log(nvinfer1::ILogger::Severity severity, const char* msg) noexcept override + { + if (severity > reportableSeverity) + { + return; + } + switch (severity) + { + case nvinfer1::ILogger::Severity::kINTERNAL_ERROR: + std::cerr << "INTERNAL_ERROR: "; + break; + case nvinfer1::ILogger::Severity::kERROR: + std::cerr << "ERROR: "; + break; + case nvinfer1::ILogger::Severity::kWARNING: + std::cerr << "WARNING: "; + break; + case nvinfer1::ILogger::Severity::kINFO: + std::cerr << "INFO: "; + break; + default: + std::cerr << "VERBOSE: "; + break; + } + std::cerr << msg << std::endl; + } +}; + +inline int get_size_by_dims(const nvinfer1::Dims& dims) +{ + int size = 1; + for (int i = 0; i < dims.nbDims; i++) + { + size *= dims.d[i]; + } + return size; +} + +inline int type_to_size(const nvinfer1::DataType& dataType) +{ + switch (dataType) + { + case nvinfer1::DataType::kFLOAT: + return 4; + case nvinfer1::DataType::kHALF: + return 2; + case nvinfer1::DataType::kINT32: + return 4; + case nvinfer1::DataType::kINT8: + return 1; + case nvinfer1::DataType::kBOOL: + return 1; + default: + return 4; + } +} + +inline static float clamp(float val, float min, float max) +{ + return val > min ? (val < max ? val : max) : min; +} + +inline bool IsPathExist(const std::string& path) +{ + if (access(path.c_str(), 0) == F_OK) + { + return true; + } + return false; +} + +inline bool IsFile(const std::string& path) +{ + if (!IsPathExist(path)) + { + printf("%s:%d %s not exist\n", __FILE__, __LINE__, path.c_str()); + return false; + } + struct stat buffer; + return (stat(path.c_str(), &buffer) == 0 && S_ISREG(buffer.st_mode)); +} + +inline bool IsFolder(const std::string& path) +{ + if (!IsPathExist(path)) + { + return false; + } + struct stat buffer; + return (stat(path.c_str(), &buffer) == 0 && S_ISDIR(buffer.st_mode)); +} + +namespace seg +{ + struct Binding + { + size_t size = 1; + size_t dsize = 1; + nvinfer1::Dims dims; + std::string name; + }; + + struct Object + { + cv::Rect_ rect; + int label = 0; + float prob = 0.0; + cv::Mat boxMask; + }; + + struct PreParam + { + float ratio = 1.0f; + float dw = 0.0f; + float dh = 0.0f; + float height = 0; + float width = 0; + }; +} +#endif //JETSON_SEGMENT_COMMON_HPP diff --git a/csrc/jetson/segment/include/yolov8-seg.hpp b/csrc/jetson/segment/include/yolov8-seg.hpp new file mode 100644 index 0000000..748e8fc --- /dev/null +++ b/csrc/jetson/segment/include/yolov8-seg.hpp @@ -0,0 +1,543 @@ +// +// Created by ubuntu on 3/16/23. +// +#ifndef JETSON_SEGMENT_YOLOV8_SEG_HPP +#define JETSON_SEGMENT_YOLOV8_SEG_HPP +#include +#include "common.hpp" +#include "NvInferPlugin.h" + +using namespace seg; + +class YOLOv8_seg +{ +public: + explicit YOLOv8_seg(const std::string& engine_file_path); + ~YOLOv8_seg(); + + void make_pipe(bool warmup = true); + void copy_from_Mat(const cv::Mat& image); + void copy_from_Mat(const cv::Mat& image, cv::Size& size); + void letterbox( + const cv::Mat& image, + cv::Mat& out, + cv::Size& size + ); + void infer(); + void postprocess( + std::vector& objs, + float score_thres = 0.25f, + float iou_thres = 0.65f, + int topk = 100, + int seg_channels = 32, + int seg_h = 160, + int seg_w = 160 + ); + static void draw_objects( + const cv::Mat& image, + cv::Mat& res, + const std::vector& objs, + const std::vector& CLASS_NAMES, + const std::vector>& COLORS, + const std::vector>& MASK_COLORS + ); + int num_bindings; + int num_inputs = 0; + int num_outputs = 0; + std::vector input_bindings; + std::vector output_bindings; + std::vector host_ptrs; + std::vector device_ptrs; + + PreParam pparam; +private: + nvinfer1::ICudaEngine* engine = nullptr; + nvinfer1::IRuntime* runtime = nullptr; + nvinfer1::IExecutionContext* context = nullptr; + cudaStream_t stream = nullptr; + Logger gLogger{ nvinfer1::ILogger::Severity::kERROR }; + +}; + +YOLOv8_seg::YOLOv8_seg(const std::string& engine_file_path) +{ + std::ifstream file(engine_file_path, std::ios::binary); + assert(file.good()); + file.seekg(0, std::ios::end); + auto size = file.tellg(); + file.seekg(0, std::ios::beg); + char* trtModelStream = new char[size]; + assert(trtModelStream); + file.read(trtModelStream, size); + file.close(); + initLibNvInferPlugins(&this->gLogger, ""); + this->runtime = nvinfer1::createInferRuntime(this->gLogger); + assert(this->runtime != nullptr); + + this->engine = this->runtime->deserializeCudaEngine(trtModelStream, size); + assert(this->engine != nullptr); + + this->context = this->engine->createExecutionContext(); + + assert(this->context != nullptr); + cudaStreamCreate(&this->stream); + this->num_bindings = this->engine->getNbBindings(); + + for (int i = 0; i < this->num_bindings; ++i) + { + Binding binding; + nvinfer1::Dims dims; + nvinfer1::DataType dtype = this->engine->getBindingDataType(i); + std::string name = this->engine->getBindingName(i); + binding.name = name; + binding.dsize = type_to_size(dtype); + + bool IsInput = engine->bindingIsInput(i); + if (IsInput) + { + this->num_inputs += 1; + dims = this->engine->getProfileDimensions( + i, + 0, + nvinfer1::OptProfileSelector::kMAX); + binding.size = get_size_by_dims(dims); + binding.dims = dims; + this->input_bindings.push_back(binding); + // set max opt shape + this->context->setBindingDimensions(i, dims); + + } + else + { + dims = this->context->getBindingDimensions(i); + binding.size = get_size_by_dims(dims); + binding.dims = dims; + this->output_bindings.push_back(binding); + this->num_outputs += 1; + } + } + +} + +YOLOv8_seg::~YOLOv8_seg() +{ + this->context->destroy(); + this->engine->destroy(); + this->runtime->destroy(); + cudaStreamDestroy(this->stream); + for (auto& ptr : this->device_ptrs) + { + CHECK(cudaFree(ptr)); + } + + for (auto& ptr : this->host_ptrs) + { + CHECK(cudaFreeHost(ptr)); + } +} + +void YOLOv8_seg::make_pipe(bool warmup) +{ + + for (auto& bindings : this->input_bindings) + { + void* d_ptr; + CHECK(cudaMalloc( + &d_ptr, + bindings.size * bindings.dsize) + ); + this->device_ptrs.push_back(d_ptr); + } + + for (auto& bindings : this->output_bindings) + { + void* d_ptr, * h_ptr; + size_t size = bindings.size * bindings.dsize; + CHECK(cudaMalloc( + &d_ptr, + size) + ); + CHECK(cudaHostAlloc( + &h_ptr, + size, + 0) + ); + this->device_ptrs.push_back(d_ptr); + this->host_ptrs.push_back(h_ptr); + } + + if (warmup) + { + for (int i = 0; i < 10; i++) + { + for (auto& bindings : this->input_bindings) + { + size_t size = bindings.size * bindings.dsize; + void* h_ptr = malloc(size); + memset(h_ptr, 0, size); + CHECK(cudaMemcpyAsync( + this->device_ptrs[0], + h_ptr, + size, + cudaMemcpyHostToDevice, + this->stream) + ); + free(h_ptr); + } + this->infer(); + } + printf("model warmup 10 times\n"); + + } +} + +void YOLOv8_seg::letterbox( + const cv::Mat& image, + cv::Mat& out, + cv::Size& size +) +{ + const float inp_h = size.height; + const float inp_w = size.width; + float height = image.rows; + float width = image.cols; + + float r = std::min(inp_h / height, inp_w / width); + int padw = std::round(width * r); + int padh = std::round(height * r); + + cv::Mat tmp; + if ((int)width != padw || (int)height != padh) + { + cv::resize( + image, + tmp, + cv::Size(padw, padh) + ); + } + else + { + tmp = image.clone(); + } + + float dw = inp_w - padw; + float dh = inp_h - padh; + + dw /= 2.0f; + dh /= 2.0f; + int top = int(std::round(dh - 0.1f)); + int bottom = int(std::round(dh + 0.1f)); + int left = int(std::round(dw - 0.1f)); + int right = int(std::round(dw + 0.1f)); + + cv::copyMakeBorder( + tmp, + tmp, + top, + bottom, + left, + right, + cv::BORDER_CONSTANT, + { 114, 114, 114 } + ); + + cv::dnn::blobFromImage(tmp, + out, + 1 / 255.f, + cv::Size(), + cv::Scalar(0, 0, 0), + true, + false, + CV_32F + ); + this->pparam.ratio = 1 / r; + this->pparam.dw = dw; + this->pparam.dh = dh; + this->pparam.height = height; + this->pparam.width = width;; +} + +void YOLOv8_seg::copy_from_Mat(const cv::Mat& image) +{ + cv::Mat nchw; + auto& in_binding = this->input_bindings[0]; + auto width = in_binding.dims.d[3]; + auto height = in_binding.dims.d[2]; + cv::Size size{ width, height }; + this->letterbox( + image, + nchw, + size + ); + + this->context->setBindingDimensions( + 0, + nvinfer1::Dims + { + 4, + { 1, 3, height, width } + } + ); + + CHECK(cudaMemcpyAsync( + this->device_ptrs[0], + nchw.ptr(), + nchw.total() * nchw.elemSize(), + cudaMemcpyHostToDevice, + this->stream) + ); +} + +void YOLOv8_seg::copy_from_Mat(const cv::Mat& image, cv::Size& size) +{ + cv::Mat nchw; + this->letterbox( + image, + nchw, + size + ); + this->context->setBindingDimensions( + 0, + nvinfer1::Dims + { 4, + { 1, 3, size.height, size.width } + } + ); + CHECK(cudaMemcpyAsync( + this->device_ptrs[0], + nchw.ptr(), + nchw.total() * nchw.elemSize(), + cudaMemcpyHostToDevice, + this->stream) + ); +} + +void YOLOv8_seg::infer() +{ + + this->context->enqueueV2( + this->device_ptrs.data(), + this->stream, + nullptr + ); + for (int i = 0; i < this->num_outputs; i++) + { + size_t osize = this->output_bindings[i].size * this->output_bindings[i].dsize; + CHECK(cudaMemcpyAsync(this->host_ptrs[i], + this->device_ptrs[i + this->num_inputs], + osize, + cudaMemcpyDeviceToHost, + this->stream) + ); + + } + cudaStreamSynchronize(this->stream); + +} + +void YOLOv8_seg::postprocess(std::vector& objs, + float score_thres, + float iou_thres, + int topk, + int seg_channels, + int seg_h, + int seg_w +) +{ + objs.clear(); + auto input_h = this->input_bindings[0].dims.d[2]; + auto input_w = this->input_bindings[0].dims.d[3]; + auto num_anchors = this->output_bindings[0].dims.d[1]; + auto num_channels = this->output_bindings[0].dims.d[2]; + + auto& dw = this->pparam.dw; + auto& dh = this->pparam.dh; + auto& width = this->pparam.width; + auto& height = this->pparam.height; + auto& ratio = this->pparam.ratio; + + auto* output = static_cast(this->host_ptrs[0]); + cv::Mat protos = cv::Mat(seg_channels, seg_h * seg_w, CV_32F, + static_cast(this->host_ptrs[1])); + + std::vector labels; + std::vector scores; + std::vector bboxes; + std::vector mask_confs; + std::vector indices; + + for (int i = 0; i < num_anchors; i++) + { + float* ptr = output + i * num_channels; + float score = *(ptr + 4); + if (score > score_thres) + { + float x0 = *ptr++ - dw; + float y0 = *ptr++ - dh; + float x1 = *ptr++ - dw; + float y1 = *ptr++ - dh; + + x0 = clamp(x0 * ratio, 0.f, width); + y0 = clamp(y0 * ratio, 0.f, height); + x1 = clamp(x1 * ratio, 0.f, width); + y1 = clamp(y1 * ratio, 0.f, height); + + int label = *(++ptr); + cv::Mat mask_conf = cv::Mat(1, seg_channels, CV_32F, ++ptr); + mask_confs.push_back(mask_conf); + labels.push_back(label); + scores.push_back(score); + bboxes.push_back(cv::Rect_(x0, y0, x1 - x0, y1 - y0)); + + } + } + +#if defined(BATCHED_NMS) + cv::dnn::NMSBoxesBatched( + bboxes, + scores, + labels, + score_thres, + iou_thres, + indices + ); +#else + cv::dnn::NMSBoxes( + bboxes, + scores, + score_thres, + iou_thres, + indices + ); +#endif + + cv::Mat masks; + int cnt = 0; + for (auto& i : indices) + { + if (cnt >= topk) + { + break; + } + cv::Rect tmp = bboxes[i]; + Object obj; + obj.label = labels[i]; + obj.rect = tmp; + obj.prob = scores[i]; + masks.push_back(mask_confs[i]); + objs.push_back(obj); + cnt += 1; + } + + cv::Mat matmulRes = (masks * protos).t(); + cv::Mat maskMat = matmulRes.reshape(indices.size(), { seg_w, seg_h }); + + std::vector maskChannels; + cv::split(maskMat, maskChannels); + int scale_dw = dw / input_w * seg_w; + int scale_dh = dh / input_h * seg_h; + + cv::Rect roi( + scale_dw, + scale_dh, + seg_w - 2 * scale_dw, + seg_h - 2 * scale_dh); + + for (int i = 0; i < indices.size(); i++) + { + cv::Mat dest, mask; + cv::exp(-maskChannels[i], dest); + dest = 1.0 / (1.0 + dest); + dest = dest(roi); + cv::resize( + dest, + mask, + cv::Size((int)width, (int)height), + cv::INTER_LINEAR + ); + objs[i].boxMask = mask(objs[i].rect) > 0.5f; + } + +} + +void YOLOv8_seg::draw_objects(const cv::Mat& image, + cv::Mat& res, + const std::vector& objs, + const std::vector& CLASS_NAMES, + const std::vector>& COLORS, + const std::vector>& MASK_COLORS +) +{ + res = image.clone(); + cv::Mat mask = image.clone(); + for (auto& obj : objs) + { + int idx = obj.label; + cv::Scalar color = cv::Scalar( + COLORS[idx][0], + COLORS[idx][1], + COLORS[idx][2] + ); + cv::Scalar mask_color = cv::Scalar( + MASK_COLORS[idx % 20][0], + MASK_COLORS[idx % 20][1], + MASK_COLORS[idx % 20][2] + ); + cv::rectangle( + res, + obj.rect, + color, + 2 + ); + + char text[256]; + sprintf( + text, + "%s %.1f%%", + CLASS_NAMES[idx].c_str(), + obj.prob * 100 + ); + mask(obj.rect).setTo(mask_color, obj.boxMask); + + int baseLine = 0; + cv::Size label_size = cv::getTextSize( + text, + cv::FONT_HERSHEY_SIMPLEX, + 0.4, + 1, + &baseLine + ); + + int x = (int)obj.rect.x; + int y = (int)obj.rect.y + 1; + + if (y > res.rows) + y = res.rows; + + cv::rectangle( + res, + cv::Rect(x, y, label_size.width, label_size.height + baseLine), + { 0, 0, 255 }, + -1 + ); + + cv::putText( + res, + text, + cv::Point(x, y + label_size.height), + cv::FONT_HERSHEY_SIMPLEX, + 0.4, + { 255, 255, 255 }, + 1 + ); + } + cv::addWeighted( + res, + 0.5, + mask, + 0.8, + 1, + res + ); +} +#endif //JETSON_SEGMENT_YOLOV8_SEG_HPP diff --git a/csrc/jetson/segment/main.cpp b/csrc/jetson/segment/main.cpp new file mode 100644 index 0000000..2fadb9f --- /dev/null +++ b/csrc/jetson/segment/main.cpp @@ -0,0 +1,178 @@ +// +// Created by ubuntu on 3/16/23. +// +#include "chrono" +#include "yolov8-seg.hpp" +#include "opencv2/opencv.hpp" + +const std::vector CLASS_NAMES = { + "person", "bicycle", "car", "motorcycle", "airplane", "bus", + "train", "truck", "boat", "traffic light", "fire hydrant", + "stop sign", "parking meter", "bench", "bird", "cat", + "dog", "horse", "sheep", "cow", "elephant", + "bear", "zebra", "giraffe", "backpack", "umbrella", + "handbag", "tie", "suitcase", "frisbee", "skis", + "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", + "skateboard", "surfboard", "tennis racket", "bottle", "wine glass", + "cup", "fork", "knife", "spoon", "bowl", + "banana", "apple", "sandwich", "orange", "broccoli", + "carrot", "hot dog", "pizza", "donut", "cake", + "chair", "couch", "potted plant", "bed", "dining table", + "toilet", "tv", "laptop", "mouse", "remote", + "keyboard", "cell phone", "microwave", "oven", + "toaster", "sink", "refrigerator", "book", "clock", "vase", + "scissors", "teddy bear", "hair drier", "toothbrush" }; + +const std::vector> COLORS = { + { 0, 114, 189 }, { 217, 83, 25 }, { 237, 177, 32 }, + { 126, 47, 142 }, { 119, 172, 48 }, { 77, 190, 238 }, + { 162, 20, 47 }, { 76, 76, 76 }, { 153, 153, 153 }, + { 255, 0, 0 }, { 255, 128, 0 }, { 191, 191, 0 }, + { 0, 255, 0 }, { 0, 0, 255 }, { 170, 0, 255 }, + { 85, 85, 0 }, { 85, 170, 0 }, { 85, 255, 0 }, + { 170, 85, 0 }, { 170, 170, 0 }, { 170, 255, 0 }, + { 255, 85, 0 }, { 255, 170, 0 }, { 255, 255, 0 }, + { 0, 85, 128 }, { 0, 170, 128 }, { 0, 255, 128 }, + { 85, 0, 128 }, { 85, 85, 128 }, { 85, 170, 128 }, + { 85, 255, 128 }, { 170, 0, 128 }, { 170, 85, 128 }, + { 170, 170, 128 }, { 170, 255, 128 }, { 255, 0, 128 }, + { 255, 85, 128 }, { 255, 170, 128 }, { 255, 255, 128 }, + { 0, 85, 255 }, { 0, 170, 255 }, { 0, 255, 255 }, + { 85, 0, 255 }, { 85, 85, 255 }, { 85, 170, 255 }, + { 85, 255, 255 }, { 170, 0, 255 }, { 170, 85, 255 }, + { 170, 170, 255 }, { 170, 255, 255 }, { 255, 0, 255 }, + { 255, 85, 255 }, { 255, 170, 255 }, { 85, 0, 0 }, + { 128, 0, 0 }, { 170, 0, 0 }, { 212, 0, 0 }, + { 255, 0, 0 }, { 0, 43, 0 }, { 0, 85, 0 }, + { 0, 128, 0 }, { 0, 170, 0 }, { 0, 212, 0 }, + { 0, 255, 0 }, { 0, 0, 43 }, { 0, 0, 85 }, + { 0, 0, 128 }, { 0, 0, 170 }, { 0, 0, 212 }, + { 0, 0, 255 }, { 0, 0, 0 }, { 36, 36, 36 }, + { 73, 73, 73 }, { 109, 109, 109 }, { 146, 146, 146 }, + { 182, 182, 182 }, { 219, 219, 219 }, { 0, 114, 189 }, + { 80, 183, 189 }, { 128, 128, 0 } +}; + +const std::vector> MASK_COLORS = { + { 255, 56, 56 }, { 255, 157, 151 }, { 255, 112, 31 }, + { 255, 178, 29 }, { 207, 210, 49 }, { 72, 249, 10 }, + { 146, 204, 23 }, { 61, 219, 134 }, { 26, 147, 52 }, + { 0, 212, 187 }, { 44, 153, 168 }, { 0, 194, 255 }, + { 52, 69, 147 }, { 100, 115, 255 }, { 0, 24, 236 }, + { 132, 56, 255 }, { 82, 0, 133 }, { 203, 56, 255 }, + { 255, 149, 200 }, { 255, 55, 199 } +}; + +int main(int argc, char** argv) +{ + // cuda:0 + cudaSetDevice(0); + + const std::string engine_file_path{ argv[1] }; + const std::string path{ argv[2] }; + + std::vector imagePathList; + bool isVideo{ false }; + + assert(argc == 3); + + auto yolov8 = new YOLOv8_seg(engine_file_path); + yolov8->make_pipe(true); + + if (IsFile(path)) + { + std::string suffix = path.substr(path.find_last_of('.') + 1); + if ( + suffix == "jpg" || + suffix == "jpeg" || + suffix == "png" + ) + { + imagePathList.push_back(path); + } + else if ( + suffix == "mp4" || + suffix == "avi" || + suffix == "m4v" || + suffix == "mpeg" || + suffix == "mov" || + suffix == "mkv" + ) + { + isVideo = true; + } + else + { + printf("suffix %s is wrong !!!\n", suffix.c_str()); + std::abort(); + } + } + else if (IsFolder(path)) + { + cv::glob(path + "/*.jpg", imagePathList); + } + + cv::Mat res, image; + cv::Size size = cv::Size{ 640, 640 }; + int topk = 100; + int seg_h = 160; + int seg_w = 160; + int seg_channels = 32; + float score_thres = 0.25f; + float iou_thres = 0.65f; + + std::vector objs; + + cv::namedWindow("result", cv::WINDOW_AUTOSIZE); + + if (isVideo) + { + cv::VideoCapture cap(path); + + if (!cap.isOpened()) + { + printf("can not open %s\n", path.c_str()); + return -1; + } + while (cap.read(image)) + { + objs.clear(); + yolov8->copy_from_Mat(image, size); + auto start = std::chrono::system_clock::now(); + yolov8->infer(); + auto end = std::chrono::system_clock::now(); + yolov8->postprocess(objs, score_thres, iou_thres, topk, seg_channels, seg_h, seg_w); + yolov8->draw_objects(image, res, objs, CLASS_NAMES, COLORS, MASK_COLORS); + auto tc = (double) + std::chrono::duration_cast(end - start).count() / 1000.; + printf("cost %2.4lf ms\n", tc); + cv::imshow("result", res); + if (cv::waitKey(10) == 'q') + { + break; + } + } + } + else + { + for (auto& path : imagePathList) + { + objs.clear(); + image = cv::imread(path); + yolov8->copy_from_Mat(image, size); + auto start = std::chrono::system_clock::now(); + yolov8->infer(); + auto end = std::chrono::system_clock::now(); + yolov8->postprocess(objs, score_thres, iou_thres, topk, seg_channels, seg_h, seg_w); + yolov8->draw_objects(image, res, objs, CLASS_NAMES, COLORS, MASK_COLORS); + auto tc = (double) + std::chrono::duration_cast(end - start).count() / 1000.; + printf("cost %2.4lf ms\n", tc); + cv::imshow("result", res); + cv::waitKey(0); + } + } + cv::destroyAllWindows(); + delete yolov8; + return 0; +} diff --git a/docs/Jetson.md b/docs/Jetson.md new file mode 100644 index 0000000..8d36536 --- /dev/null +++ b/docs/Jetson.md @@ -0,0 +1,141 @@ +# YOLOv8 on Jetson + +Only test on `Jetson-NX 4GB` + +ENVS: +- Jetpack 4.6.3 +- CUDA-10.2 +- CUDNN-8.2.1 +- TensorRT-8.2.1 +- DeepStream-6.0.1 +- OpenCV-4.1.1 +- CMake-3.10.2 + +If you have other environment-related issues, please discuss in issue. + +## End2End Detection + +### 1. Export Detection End2End ONNX + +***!!! Please use the PC to execute the following script !!!*** + +`yolov8s.pt` is your trained pytorch model, or the official pre-trained model. + +Do not use any model other than pytorch model. +Do not use [`build.py`](../build.py) to export engine if you don't know how to install pytorch and other environments on jetson. + +```shell +# Export yolov8s.pt to yolov8s.onnx +python3 export-det.py --weights yolov8s.pt --sim +``` + +***!!! Please use the Jetson to execute the following script !!!*** + +```shell +# Using trtexec tools for export engine +/usr/src/tensorrt/bin/trtexec \ +--onnx=yolov8s.onnx \ +--saveEngine=yolov8s.engine +``` + +After executing the above command, you will get an engine named `yolov8s.engine` . + +### 2. Inference with c++ + +It is highly recommended to use C++ inference on Jetson. +Here is a demo: [`csrc/jetson/detect`](../csrc/jetson/detect) . + +#### Build: + +Please modify `CLASS_NAMES` and `COLORS` in [`main.cpp`](../csrc/jetson/detect/main.cpp) for yourself. + +And build: + +``` shell +export root=${PWD} +cd src/jetson/detect +mkdir build +cmake .. +make +mv yolov8 ${root} +cd ${root} +``` + +Usage: + +``` shell +# infer image +./yolov8 yolov8s.engine data/bus.jpg +# infer images +./yolov8 yolov8s.engine data +# infer video +./yolov8 yolov8s.engine data/test.mp4 # the video path +``` + +## Speedup Segmention + +### 1. Export Segmention Speedup ONNX + +***!!! Please use the PC to execute the following script !!!*** + +`yolov8s-seg.pt` is your trained pytorch model, or the official pre-trained model. + +Do not use any model other than pytorch model. +Do not use [`build.py`](../build.py) to export engine if you don't know how to install pytorch and other environments on jetson. + +```shell +# Export yolov8s-seg.pt to yolov8s-seg.onnx +python3 export-seg.py --weights yolov8s-seg.pt --sim +``` + +***!!! Please use the Jetson to execute the following script !!!*** + +```shell +# Using trtexec tools for export engine +/usr/src/tensorrt/bin/trtexec \ +--onnx=yolov8s-seg.onnx \ +--saveEngine=yolov8s-seg.engine +``` + +After executing the above command, you will get an engine named `yolov8s-seg.engine` . + +### 2. Inference with c++ + +It is highly recommended to use C++ inference on Jetson. +Here is a demo: [`csrc/jetson/segment`](../csrc/jetson/segment) . + +#### Build: + +Please modify `CLASS_NAMES` and `COLORS` and postprocess parameters in [`main.cpp`](../csrc/jetson/segment/main.cpp) for yourself. + +```c++ +int topk = 100; +int seg_h = 160; // yolov8 model proto height +int seg_w = 160; // yolov8 model proto width +int seg_channels = 32; // yolov8 model proto channels +float score_thres = 0.25f; +float iou_thres = 0.65f; +``` + +And build: + +``` shell +export root=${PWD} +cd src/jetson/segment +mkdir build +cmake .. +make +mv yolov8 ${root} +cd ${root} +``` + +Usage: + +``` shell +# infer image +./yolov8 yolov8s.engine data/bus.jpg +# infer images +./yolov8 yolov8s.engine data +# infer video +./yolov8 yolov8s.engine data/test.mp4 # the video path +```