diff --git a/csrc/segment/CMakeLists.txt b/csrc/segment/CMakeLists.txt new file mode 100644 index 0000000..8ff184e --- /dev/null +++ b/csrc/segment/CMakeLists.txt @@ -0,0 +1,55 @@ +cmake_minimum_required(VERSION 2.8.12) + +set(CMAKE_CUDA_ARCHITECTURES 60 61 62 70 72 75 86) +set(CMAKE_CUDA_COMPILER /usr/local/cuda/bin/nvcc) + +project(yolov8-seg LANGUAGES CXX CUDA) + +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++14 -O3 -g") +set(CMAKE_CXX_STANDARD 14) +set(CMAKE_BUILD_TYPE Debug) +option(CUDA_USE_STATIC_CUDA_RUNTIME OFF) + +# CUDA +find_package(CUDA REQUIRED) +message(STATUS "CUDA Libs: \n${CUDA_LIBRARIES}\n") +message(STATUS "CUDA Headers: \n${CUDA_INCLUDE_DIRS}\n") + +# OpenCV +find_package(OpenCV REQUIRED) +message(STATUS "OpenCV Libs: \n${OpenCV_LIBS}\n") +message(STATUS "OpenCV Libraries: \n${OpenCV_LIBRARIES}\n") +message(STATUS "OpenCV Headers: \n${OpenCV_INCLUDE_DIRS}\n") + +# TensorRT +set(TensorRT_INCLUDE_DIRS /usr/include/x86_64-linux-gnu) +set(TensorRT_LIBRARIES /usr/lib/x86_64-linux-gnu) + + +message(STATUS "TensorRT Libs: \n${TensorRT_LIBRARIES}\n") +message(STATUS "TensorRT Headers: \n${TensorRT_INCLUDE_DIRS}\n") + +list(APPEND INCLUDE_DIRS + ${CUDA_INCLUDE_DIRS} + ${OpenCV_INCLUDE_DIRS} + ${TensorRT_INCLUDE_DIRS} + ./include + ) + +list(APPEND ALL_LIBS + ${CUDA_LIBRARIES} + ${OpenCV_LIBRARIES} + ${TensorRT_LIBRARIES} + ) + +include_directories(${INCLUDE_DIRS}) + +add_executable(${PROJECT_NAME} + main.cpp + include/yolov8-seg.hpp + include/config.h + include/utils.h + ) + +target_link_directories(${PROJECT_NAME} PUBLIC ${ALL_LIBS}) +target_link_libraries(${PROJECT_NAME} PRIVATE nvinfer nvinfer_plugin cudart ${OpenCV_LIBS}) diff --git a/csrc/segment/include/config.h b/csrc/segment/include/config.h new file mode 100644 index 0000000..7a4139f --- /dev/null +++ b/csrc/segment/include/config.h @@ -0,0 +1,107 @@ +// +// Created by ubuntu on 1/16/23. +// + +#ifndef YOLOV8_TENSORRT_CSRC_SEGMENT_INCLUDE_CONFIG_H +#define YOLOV8_TENSORRT_CSRC_SEGMENT_INCLUDE_CONFIG_H +#include "opencv2/opencv.hpp" +namespace seg +{ + const int DEVICE = 0; + + const int INPUT_W = 640; + const int INPUT_H = 640; + const int NUM_INPUT = 1; + const int NUM_OUTPUT = 2; + const int NUM_PROPOSAL = 8400; // feature map 20*20+40*40+80*80 + const int NUM_SEG_C = 32; // seg channel + const int NUM_COLS = 6 + NUM_SEG_C; // x0 y0 x1 y1 score label 32 + + const int SEG_W = 160; + const int SEG_H = 160; + + // thresholds + const float CONF_THRES = 0.25; + const float IOU_THRES = 0.65; + const float MASK_THRES = 0.5; + + // distance + const float DIS = 7680.f; + + const int NUM_BINDINGS = NUM_INPUT + NUM_OUTPUT; + const cv::Scalar PAD_COLOR = { 114, 114, 114 }; + const cv::Scalar RECT_COLOR = cv::Scalar(0, 0, 255); + const cv::Scalar TXT_COLOR = cv::Scalar(255, 255, 255); + + const char* INPUT = "images"; + const char* OUTPUT = "outputs"; + const char* PROTO = "proto"; + + const char* CLASS_NAMES[] = { + "person", "bicycle", "car", "motorcycle", "airplane", "bus", + "train", "truck", "boat", "traffic light", "fire hydrant", + "stop sign", "parking meter", "bench", "bird", "cat", + "dog", "horse", "sheep", "cow", "elephant", + "bear", "zebra", "giraffe", "backpack", "umbrella", + "handbag", "tie", "suitcase", "frisbee", "skis", + "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", + "skateboard", "surfboard", "tennis racket", "bottle", "wine glass", + "cup", "fork", "knife", "spoon", "bowl", + "banana", "apple", "sandwich", "orange", "broccoli", + "carrot", "hot dog", "pizza", "donut", "cake", + "chair", "couch", "potted plant", "bed", "dining table", + "toilet", "tv", "laptop", "mouse", "remote", + "keyboard", "cell phone", "microwave", "oven", + "toaster", "sink", "refrigerator", "book", "clock", "vase", + "scissors", "teddy bear", "hair drier", "toothbrush" }; + + const unsigned int COLORS[80][3] = { + { 0, 114, 189 }, { 217, 83, 25 }, { 237, 177, 32 }, + { 126, 47, 142 }, { 119, 172, 48 }, { 77, 190, 238 }, + { 162, 20, 47 }, { 76, 76, 76 }, { 153, 153, 153 }, + { 255, 0, 0 }, { 255, 128, 0 }, { 191, 191, 0 }, + { 0, 255, 0 }, { 0, 0, 255 }, { 170, 0, 255 }, + { 85, 85, 0 }, { 85, 170, 0 }, { 85, 255, 0 }, + { 170, 85, 0 }, { 170, 170, 0 }, { 170, 255, 0 }, + { 255, 85, 0 }, { 255, 170, 0 }, { 255, 255, 0 }, + { 0, 85, 128 }, { 0, 170, 128 }, { 0, 255, 128 }, + { 85, 0, 128 }, { 85, 85, 128 }, { 85, 170, 128 }, + { 85, 255, 128 }, { 170, 0, 128 }, { 170, 85, 128 }, + { 170, 170, 128 }, { 170, 255, 128 }, { 255, 0, 128 }, + { 255, 85, 128 }, { 255, 170, 128 }, { 255, 255, 128 }, + { 0, 85, 255 }, { 0, 170, 255 }, { 0, 255, 255 }, + { 85, 0, 255 }, { 85, 85, 255 }, { 85, 170, 255 }, + { 85, 255, 255 }, { 170, 0, 255 }, { 170, 85, 255 }, + { 170, 170, 255 }, { 170, 255, 255 }, { 255, 0, 255 }, + { 255, 85, 255 }, { 255, 170, 255 }, { 85, 0, 0 }, + { 128, 0, 0 }, { 170, 0, 0 }, { 212, 0, 0 }, + { 255, 0, 0 }, { 0, 43, 0 }, { 0, 85, 0 }, + { 0, 128, 0 }, { 0, 170, 0 }, { 0, 212, 0 }, + { 0, 255, 0 }, { 0, 0, 43 }, { 0, 0, 85 }, + { 0, 0, 128 }, { 0, 0, 170 }, { 0, 0, 212 }, + { 0, 0, 255 }, { 0, 0, 0 }, { 36, 36, 36 }, + { 73, 73, 73 }, { 109, 109, 109 }, { 146, 146, 146 }, + { 182, 182, 182 }, { 219, 219, 219 }, { 0, 114, 189 }, + { 80, 183, 189 }, { 128, 128, 0 } + }; + + const unsigned int MASK_COLORS[20][3] = { + { 255, 56, 56 }, { 255, 157, 151 }, { 255, 112, 31 }, + { 255, 178, 29 }, { 207, 210, 49 }, { 72, 249, 10 }, + { 146, 204, 23 }, { 61, 219, 134 }, { 26, 147, 52 }, + { 0, 212, 187 }, { 44, 153, 168 }, { 0, 194, 255 }, + { 52, 69, 147 }, { 100, 115, 255 }, { 0, 24, 236 }, + { 132, 56, 255 }, { 82, 0, 133 }, { 203, 56, 255 }, + { 255, 149, 200 }, { 255, 55, 199 } + }; + + struct Object + { + cv::Rect_ rect; + int label = 0; + float prob = 0.0; + cv::Mat boxMask; + }; + +} +#endif //YOLOV8_TENSORRT_CSRC_SEGMENT_INCLUDE_CONFIG_H diff --git a/csrc/segment/include/utils.h b/csrc/segment/include/utils.h new file mode 100644 index 0000000..6f5e702 --- /dev/null +++ b/csrc/segment/include/utils.h @@ -0,0 +1,133 @@ +// +// Created by ubuntu on 1/10/23. +// + +#ifndef YOLOV8_CSRC_SEGMENT_INCLUDE_UTILS_H +#define YOLOV8_CSRC_SEGMENT_INCLUDE_UTILS_H +#include +#include +#include +#include +#include +#include "NvInfer.h" + +#define CHECK(call) \ +do \ +{ \ + const cudaError_t error_code = call; \ + if (error_code != cudaSuccess) \ + { \ + printf("CUDA Error:\n"); \ + printf(" File: %s\n", __FILE__); \ + printf(" Line: %d\n", __LINE__); \ + printf(" Error code: %d\n", error_code); \ + printf(" Error text: %s\n", \ + cudaGetErrorString(error_code)); \ + exit(1); \ + } \ +} while (0) + +class Logger : public nvinfer1::ILogger +{ +public: + nvinfer1::ILogger::Severity reportableSeverity; + + explicit Logger(nvinfer1::ILogger::Severity severity = nvinfer1::ILogger::Severity::kINFO) : + reportableSeverity(severity) + { + } + + void log(nvinfer1::ILogger::Severity severity, const char* msg) noexcept override + { + if (severity > reportableSeverity) + { + return; + } + switch (severity) + { + case nvinfer1::ILogger::Severity::kINTERNAL_ERROR: + std::cerr << "INTERNAL_ERROR: "; + break; + case nvinfer1::ILogger::Severity::kERROR: + std::cerr << "ERROR: "; + break; + case nvinfer1::ILogger::Severity::kWARNING: + std::cerr << "WARNING: "; + break; + case nvinfer1::ILogger::Severity::kINFO: + std::cerr << "INFO: "; + break; + default: + std::cerr << "VERBOSE: "; + break; + } + std::cerr << msg << std::endl; + } +}; + +inline int get_size_by_dims(const nvinfer1::Dims& dims) +{ + int size = 1; + for (int i = 0; i < dims.nbDims; i++) + { + size *= dims.d[i]; + } + return size; +} + +inline int DataTypeToSize(const nvinfer1::DataType& dataType) +{ + switch (dataType) + { + case nvinfer1::DataType::kFLOAT: + return sizeof(float); + case nvinfer1::DataType::kHALF: + return 2; + case nvinfer1::DataType::kINT8: + return sizeof(int8_t); + case nvinfer1::DataType::kINT32: + return sizeof(int32_t); + case nvinfer1::DataType::kBOOL: + return sizeof(bool); + default: + return sizeof(float); + } +} + +inline float clamp(const float val, const float minVal = 0.f, const float maxVal = 1280.f) +{ + assert(minVal <= maxVal); + return std::min(maxVal, std::max(minVal, val)); +} + +inline bool IsPathExist(const std::string& path) +{ + if (access(path.c_str(), 0) == F_OK) + { + return true; + } + return false; +} + +inline bool IsFile(const std::string& path) +{ + if (!IsPathExist(path)) + { + printf("%s:%d %s not exist\n", __FILE__, __LINE__, path.c_str()); + return false; + } + struct stat buffer; + return (stat(path.c_str(), &buffer) == 0 && S_ISREG(buffer.st_mode)); +} + +inline bool IsFolder(const std::string& path) +{ + if (!IsPathExist(path)) + { + return false; + } + struct stat buffer; + return (stat(path.c_str(), &buffer) == 0 && S_ISDIR(buffer.st_mode)); +} + +#endif //YOLOV8_CSRC_SEGMENT_INCLUDE_UTILS_H diff --git a/csrc/segment/include/yolov8-seg.hpp b/csrc/segment/include/yolov8-seg.hpp new file mode 100644 index 0000000..2b941b1 --- /dev/null +++ b/csrc/segment/include/yolov8-seg.hpp @@ -0,0 +1,329 @@ +// +// Created by ubuntu on 1/8/23. +// +#include "config.h" +#include "utils.h" +#include +#include "NvInferPlugin.h" + +using namespace seg; + +class YOLOv8_seg +{ +public: + explicit YOLOv8_seg(const std::string& engine_file_path); + ~YOLOv8_seg(); + + void make_pipe(bool warmup = true); + void copy_from_Mat(const cv::Mat& image); + void infer(); + void postprocess(std::vector& objs); + + size_t in_size = 1 * 3 * INPUT_W * INPUT_H; + float w = INPUT_W; + float h = INPUT_H; + float ratio = 1.0f; + float dw = 0.f; + float dh = 0.f; + std::array, NUM_OUTPUT> out_sizes{}; + std::array outputs{}; +private: + nvinfer1::ICudaEngine* engine = nullptr; + nvinfer1::IRuntime* runtime = nullptr; + nvinfer1::IExecutionContext* context = nullptr; + cudaStream_t stream = nullptr; + std::array buffs{}; + Logger gLogger{ nvinfer1::ILogger::Severity::kERROR }; + +}; + +YOLOv8_seg::YOLOv8_seg(const std::string& engine_file_path) +{ + std::ifstream file(engine_file_path, std::ios::binary); + assert(file.good()); + file.seekg(0, std::ios::end); + auto size = file.tellg(); + std::ostringstream fmt; + + file.seekg(0, std::ios::beg); + char* trtModelStream = new char[size]; + assert(trtModelStream); + file.read(trtModelStream, size); + file.close(); + initLibNvInferPlugins(&this->gLogger, ""); + this->runtime = nvinfer1::createInferRuntime(this->gLogger); + assert(this->runtime != nullptr); + + this->engine = this->runtime->deserializeCudaEngine(trtModelStream, size); + assert(this->engine != nullptr); + + this->context = this->engine->createExecutionContext(); + + assert(this->context != nullptr); + cudaStreamCreate(&this->stream); + +} + +YOLOv8_seg::~YOLOv8_seg() +{ + this->context->destroy(); + this->engine->destroy(); + this->runtime->destroy(); + cudaStreamDestroy(this->stream); + for (auto& ptr : this->buffs) + { + CHECK(cudaFree(ptr)); + } + + for (auto& ptr : this->outputs) + { + CHECK(cudaFreeHost(ptr)); + } + +} +void YOLOv8_seg::make_pipe(bool warmup) +{ + const nvinfer1::Dims input_dims = this->engine->getBindingDimensions( + this->engine->getBindingIndex(INPUT) + ); + this->in_size = get_size_by_dims(input_dims); + CHECK(cudaMalloc(&this->buffs[0], this->in_size * sizeof(float))); + + this->context->setBindingDimensions(0, input_dims); + + const int32_t output_idx = this->engine->getBindingIndex(OUTPUT); + const nvinfer1::Dims output_dims = this->context->getBindingDimensions(output_idx); + this->out_sizes[output_idx - NUM_INPUT].first = get_size_by_dims(output_dims); + this->out_sizes[output_idx - NUM_INPUT].second = DataTypeToSize( + this->engine->getBindingDataType(output_idx)); + + const int32_t proto_idx = this->engine->getBindingIndex(PROTO); + const nvinfer1::Dims proto_dims = this->context->getBindingDimensions(proto_idx); + + this->out_sizes[proto_idx - NUM_INPUT].first = get_size_by_dims(proto_dims); + this->out_sizes[proto_idx - NUM_INPUT].second = DataTypeToSize( + this->engine->getBindingDataType(proto_idx)); + + for (int i = 0; i < NUM_OUTPUT; i++) + { + const int osize = this->out_sizes[i].first * out_sizes[i].second; + CHECK(cudaHostAlloc(&this->outputs[i], osize, 0)); + CHECK(cudaMalloc(&this->buffs[NUM_INPUT + i], osize)); + } + if (warmup) + { + for (int i = 0; i < 10; i++) + { + size_t isize = this->in_size * sizeof(float); + auto* tmp = new float[isize]; + + CHECK(cudaMemcpyAsync(this->buffs[0], + tmp, + isize, + cudaMemcpyHostToDevice, + this->stream)); + this->infer(); + } + printf("model warmup 10 times\n"); + + } +} + +void YOLOv8_seg::copy_from_Mat(const cv::Mat& image) +{ + float height = (float)image.rows; + float width = (float)image.cols; + + float r = std::min(INPUT_H / height, INPUT_W / width); + + int padw = (int)std::round(width * r); + int padh = (int)std::round(height * r); + + cv::Mat tmp; + if ((int)width != padw || (int)height != padh) + { + cv::resize(image, tmp, cv::Size(padw, padh)); + } + else + { + tmp = image.clone(); + } + + float _dw = INPUT_W - padw; + float _dh = INPUT_H - padh; + + _dw /= 2.0f; + _dh /= 2.0f; + int top = int(std::round(_dh - 0.1f)); + int bottom = int(std::round(_dh + 0.1f)); + int left = int(std::round(_dw - 0.1f)); + int right = int(std::round(_dw + 0.1f)); + cv::copyMakeBorder(tmp, tmp, top, bottom, left, right, cv::BORDER_CONSTANT, PAD_COLOR); + cv::dnn::blobFromImage(tmp, + tmp, + 1 / 255.f, + cv::Size(), + cv::Scalar(0, 0, 0), + true, + false, + CV_32F); + CHECK(cudaMemcpyAsync(this->buffs[0], + tmp.ptr(), + this->in_size * sizeof(float), + cudaMemcpyHostToDevice, + this->stream)); + + this->ratio = 1 / r; + this->dw = _dw; + this->dh = _dh; + this->w = width; + this->h = height; +} + +void YOLOv8_seg::infer() +{ + this->context->enqueueV2(buffs.data(), this->stream, nullptr); + for (int i = 0; i < NUM_OUTPUT; i++) + { + const int osize = this->out_sizes[i].first * out_sizes[i].second; + CHECK(cudaMemcpyAsync(this->outputs[i], + this->buffs[NUM_INPUT + i], + osize, + cudaMemcpyDeviceToHost, + this->stream)); + } + cudaStreamSynchronize(this->stream); + +} + +void YOLOv8_seg::postprocess(std::vector& objs) +{ + objs.clear(); + auto* output = static_cast(this->outputs[0]); // x0 y0 x1 y1 s l *32 + cv::Mat protos = cv::Mat(NUM_SEG_C, SEG_W * SEG_H, CV_32F, + static_cast(this->outputs[1])); + + std::vector labels; + std::vector scores; + std::vector bboxes; + std::vector mask_confs; + + for (int i = 0; i < NUM_PROPOSAL; i++) + { + float* ptr = output + i * NUM_COLS; + float score = *(ptr + 4); + if (score > CONF_THRES) + { + float x0 = *ptr++ - this->dw; + float y0 = *ptr++ - this->dh; + float x1 = *ptr++ - this->dw; + float y1 = *ptr++ - this->dh; + + x0 = clamp(x0 * this->ratio, 0.f, this->w); + y0 = clamp(y0 * this->ratio, 0.f, this->h); + x1 = clamp(x1 * this->ratio, 0.f, this->w); + y1 = clamp(y1 * this->ratio, 0.f, this->h); + + int label = *(++ptr); + cv::Mat mask_conf = cv::Mat(1, NUM_SEG_C, CV_32F, ++ptr); + mask_confs.push_back(mask_conf); + labels.push_back(label); + scores.push_back(score); + +#if defined(BATCHED_NMS) + bboxes.push_back(cv::Rect_(x0, y0, x1 - x0, y1 - y0)); +#else + bboxes.push_back(cv::Rect_(x0 + label * DIS, + y0 + label * DIS, + x1 - x0, + y1 - y0)); +#endif + } + } + std::vector indices; +#if defined(BATCHED_NMS) + cv::dnn::NMSBoxesBatched(bboxes, scores, labels, CONF_THRES, IOU_THRES, indices); +#else + cv::dnn::NMSBoxes(bboxes, scores, CONF_THRES, IOU_THRES, indices); +#endif + + cv::Mat masks; + + for (auto& i : indices) + { +#if defined(BATCHED_NMS) + cv::Rect tmp = bboxes[i]; +#else + cv::Rect tmp = { (int)(bboxes[i].x - labels[i] * DIS), + (int)(bboxes[i].y - labels[i] * DIS), + bboxes[i].width, + bboxes[i].height }; +#endif + + Object obj; + obj.label = labels[i]; + obj.rect = tmp; + obj.prob = scores[i]; + masks.push_back(mask_confs[i]); + objs.push_back(obj); + } + + cv::Mat matmulRes = (masks * protos).t(); + cv::Mat maskMat = matmulRes.reshape(indices.size(), { SEG_W, SEG_H }); + + std::vector maskChannels; + cv::split(maskMat, maskChannels); + int scale_dw = this->dw / INPUT_W * SEG_W; + int scale_dh = this->dh / INPUT_H * SEG_H; + + cv::Rect roi( + scale_dw, + scale_dh, + SEG_W - 2 * scale_dw, + SEG_H - 2 * scale_dh); + + for (int i = 0; i < indices.size(); i++) + { + cv::Mat dest, mask; + cv::exp(-maskChannels[i], dest); + dest = 1.0 / (1.0 + dest); + dest = dest(roi); + cv::resize(dest, mask, cv::Size((int)this->w, (int)this->h), cv::INTER_LINEAR); + objs[i].boxMask = mask(objs[i].rect) > MASK_THRES; + } + +} + +static void draw_objects(const cv::Mat& image, cv::Mat& res, const std::vector& objs) +{ + res = image.clone(); + cv::Mat mask = image.clone(); + for (auto& obj : objs) + { + int idx = obj.label; + cv::Scalar color = cv::Scalar(COLORS[idx][0], COLORS[idx][1], COLORS[idx][2]); + cv::Scalar mask_color = cv::Scalar( + MASK_COLORS[idx % 20][0], MASK_COLORS[idx % 20][1], MASK_COLORS[idx % 20][2]); + cv::rectangle(res, obj.rect, color, 2); + + char text[256]; + sprintf(text, "%s %.1f%%", CLASS_NAMES[idx], obj.prob * 100); + mask(obj.rect).setTo(mask_color, obj.boxMask); + + int baseLine = 0; + cv::Size label_size = cv::getTextSize(text, cv::FONT_HERSHEY_SIMPLEX, 0.4, 1, &baseLine); + + int x = (int)obj.rect.x; + int y = (int)obj.rect.y + 1; + + if (y > res.rows) + y = res.rows; + + cv::rectangle(res, cv::Rect(x, y, label_size.width, label_size.height + baseLine), RECT_COLOR, -1); + + cv::putText(res, text, cv::Point(x, y + label_size.height), + cv::FONT_HERSHEY_SIMPLEX, 0.4, TXT_COLOR, 1); + } + cv::addWeighted(res, 0.5, mask, 0.8, 1, res); + +} diff --git a/csrc/segment/main.cpp b/csrc/segment/main.cpp new file mode 100644 index 0000000..00f5b36 --- /dev/null +++ b/csrc/segment/main.cpp @@ -0,0 +1,86 @@ +// +// Created by ubuntu on 1/8/23. +// +#include "include/yolov8-seg.hpp" +int main(int argc, char** argv) +{ + cudaSetDevice(DEVICE); + + const std::string engine_file_path{ argv[1] }; + const std::string path{ argv[2] }; + std::vector imagePathList; + bool isVideo{ false }; + if (IsFile(path)) + { + std::string suffix = path.substr(path.find_last_of('.') + 1); + if (suffix == "jpg") + { + imagePathList.push_back(path); + } + else if (suffix == "mp4") + { + isVideo = true; + } + } + else if (IsFolder(path)) + { + cv::glob(path + "/*.jpg", imagePathList); + } + + auto* yolov8 = new YOLOv8_seg(engine_file_path); + yolov8->make_pipe(true); + cv::Mat res; + cv::namedWindow("result", cv::WINDOW_AUTOSIZE); + if (isVideo) + { + cv::VideoCapture cap(path); + cv::Mat image; + if (!cap.isOpened()) + { + printf("can not open ...\n"); + return -1; + } + double fp_ = cap.get(cv::CAP_PROP_FPS); + int fps = round(1000.0 / fp_); + while (cap.read(image)) + { + auto start = std::chrono::system_clock::now(); + yolov8->copy_from_Mat(image); + yolov8->infer(); + std::vector objs; + yolov8->postprocess(objs); + draw_objects(image, res, objs); + auto end = std::chrono::system_clock::now(); + auto tc = std::chrono::duration_cast(end - start).count() / 1000.f; + cv::imshow("result", res); + printf("cost %2.4f ms\n", tc); + if (cv::waitKey(fps) == 'q') + { + break; + } + } + } + else + { + for (auto path : imagePathList) + { + cv::Mat image = cv::imread(path); + yolov8->copy_from_Mat(image); + auto start = std::chrono::system_clock::now(); + yolov8->infer(); + auto end = std::chrono::system_clock::now(); + auto tc = std::chrono::duration_cast(end - start).count() / 1000.f; + + printf("infer %-20s\tcost %2.4f ms\n", path.c_str(), tc); + + std::vector objs; + yolov8->postprocess(objs); + draw_objects(image, res, objs); + cv::imshow("result", res); + cv::waitKey(0); + } + } + cv::destroyAllWindows(); + delete yolov8; + return 0; +} diff --git a/export_seg.py b/export_seg.py index 5d52dab..b86be43 100644 --- a/export_seg.py +++ b/export_seg.py @@ -53,13 +53,12 @@ def main(args): model(fake_input) save_path = args.weights.replace('.pt', '.onnx') with BytesIO() as f: - torch.onnx.export( - model, - fake_input, - f, - opset_version=args.opset, - input_names=['images'], - output_names=['bboxes', 'scores', 'labels', 'maskconf', 'proto']) + torch.onnx.export(model, + fake_input, + f, + opset_version=args.opset, + input_names=['images'], + output_names=['outputs', 'proto']) f.seek(0) onnx_model = onnx.load(f) onnx.checker.check_model(onnx_model) diff --git a/infer.py b/infer.py index a074af2..b281b2c 100644 --- a/infer.py +++ b/infer.py @@ -102,7 +102,7 @@ def main(args): # set desired output names order if args.seg: - Engine.set_desired(['bboxes', 'scores', 'labels', 'maskconf', 'proto']) + Engine.set_desired(['outputs', 'proto']) else: Engine.set_desired(['num_dets', 'bboxes', 'scores', 'labels']) @@ -180,19 +180,21 @@ def crop_mask(masks: Tensor, bboxes: Tensor) -> Tensor: def seg_postprocess( - data: Tuple[Tensor, Tensor, Tensor, Tensor, Tensor], + data: Tuple[Tensor], shape: Union[Tuple, List], conf_thres: float = 0.25, iou_thres: float = 0.65) -> Tuple[Tensor, Tensor, Tensor, List]: - assert len(data) == 5 + assert len(data) == 2 h, w = shape[0] // 4, shape[1] // 4 # 4x downsampling - bboxes, scores, labels, maskconf, proto = (i[0] for i in data) + outputs, proto = (i[0] for i in data) + bboxes, scores, labels, maskconf = outputs.split([4, 1, 1, 32], 1) + scores, labels = scores.squeeze(), labels.squeeze() select = scores > conf_thres bboxes, scores, labels, maskconf = bboxes[select], scores[select], labels[ select], maskconf[select] idx = batched_nms(bboxes, scores, labels, iou_thres) bboxes, scores, labels, maskconf = bboxes[idx], scores[idx], labels[ - idx], maskconf[idx] + idx].int(), maskconf[idx] masks = (maskconf @ proto).view(-1, h, w) masks = crop_mask(masks, bboxes / 4.) masks = F.interpolate(masks[None], diff --git a/models/common.py b/models/common.py index 9fcdb87..1d563f8 100644 --- a/models/common.py +++ b/models/common.py @@ -140,7 +140,8 @@ class PostSeg(nn.Module): [self.cv4[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2) # mask coefficients box, score, cls = self.forward_det(x) - return box, score, cls, mc.transpose(1, 2), p.flatten(2) + out = torch.cat([box, score, cls, mc.transpose(1, 2)], 2) + return out, p.flatten(2) def forward_det(self, x): shape = x[0].shape @@ -160,7 +161,7 @@ class PostSeg(nn.Module): box0, box1 = -box[:, :2, ...], box[:, 2:, ...] box = self.anchors.repeat(b, 2, 1) + torch.cat([box0, box1], 1) box = box * self.strides - score, cls = cls.transpose(1, 2).max(dim=-1) + score, cls = cls.transpose(1, 2).max(dim=-1, keepdim=True) return box.transpose(1, 2), score, cls