Merge pull request #213 from triple-Mu/triplemu/cls

support yolov8-cls
pull/192/merge
triple Mu 7 months ago committed by GitHub
commit a9236fc159
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 4
      README.md
  2. 140
      config.py
  3. 56
      csrc/cls/normal/CMakeLists.txt
  4. 132
      csrc/cls/normal/include/common.hpp
  5. 221
      csrc/cls/normal/include/yolov8-cls.hpp
  6. 1092
      csrc/cls/normal/main.cpp
  7. 129
      docs/Cls.md
  8. 79
      infer-cls-without-torch.py
  9. 74
      infer-cls.py
  10. 21
      infer-det-without-torch.py
  11. 21
      infer-det.py
  12. 3
      models/torch_utils.py
  13. 5
      models/utils.py

@ -212,6 +212,10 @@ Please see more information in [`Segment.md`](docs/Segment.md)
Please see more information in [`Pose.md`](docs/Pose.md) Please see more information in [`Pose.md`](docs/Pose.md)
# TensorRT Cls Deploy
Please see more information in [`Cls.md`](docs/Cls.md)
# DeepStream Detection Deploy # DeepStream Detection Deploy
See more in [`README.md`](csrc/deepstream/README.md) See more in [`README.md`](csrc/deepstream/README.md)

@ -4,8 +4,144 @@ import numpy as np
random.seed(0) random.seed(0)
# classification model classes
CLASSES_CLS = (
['tench', 'goldfish', 'great white shark', 'tiger shark', 'hammerhead shark', 'electric ray', 'stingray', 'cock',
'hen', 'ostrich', 'brambling', 'goldfinch', 'house finch', 'junco', 'indigo bunting', 'American robin', 'bulbul',
'jay', 'magpie', 'chickadee', 'American dipper', 'kite', 'bald eagle', 'vulture', 'great grey owl',
'fire salamander', 'smooth newt', 'newt', 'spotted salamander', 'axolotl', 'American bullfrog', 'tree frog',
'tailed frog', 'loggerhead sea turtle', 'leatherback sea turtle', 'mud turtle', 'terrapin', 'box turtle',
'banded gecko', 'green iguana', 'Carolina anole', 'desert grassland whiptail lizard', 'agama',
'frilled-necked lizard', 'alligator lizard', 'Gila monster', 'European green lizard', 'chameleon', 'Komodo dragon',
'Nile crocodile', 'American alligator', 'triceratops', 'worm snake', 'ring-necked snake',
'eastern hog-nosed snake', 'smooth green snake', 'kingsnake', 'garter snake', 'water snake', 'vine snake',
'night snake', 'boa constrictor', 'African rock python', 'Indian cobra', 'green mamba', 'sea snake',
'Saharan horned viper', 'eastern diamondback rattlesnake', 'sidewinder', 'trilobite', 'harvestman', 'scorpion',
'yellow garden spider', 'barn spider', 'European garden spider', 'southern black widow', 'tarantula',
'wolf spider', 'tick', 'centipede', 'black grouse', 'ptarmigan', 'ruffed grouse', 'prairie grouse', 'peacock',
'quail', 'partridge', 'grey parrot', 'macaw', 'sulphur-crested cockatoo', 'lorikeet', 'coucal', 'bee eater',
'hornbill', 'hummingbird', 'jacamar', 'toucan', 'duck', 'red-breasted merganser', 'goose', 'black swan', 'tusker',
'echidna', 'platypus', 'wallaby', 'koala', 'wombat', 'jellyfish', 'sea anemone', 'brain coral', 'flatworm',
'nematode', 'conch', 'snail', 'slug', 'sea slug', 'chiton', 'chambered nautilus', 'Dungeness crab', 'rock crab',
'fiddler crab', 'red king crab', 'American lobster', 'spiny lobster', 'crayfish', 'hermit crab', 'isopod',
'white stork', 'black stork', 'spoonbill', 'flamingo', 'little blue heron', 'great egret', 'bittern',
'crane (bird)', 'limpkin', 'common gallinule', 'American coot', 'bustard', 'ruddy turnstone', 'dunlin',
'common redshank', 'dowitcher', 'oystercatcher', 'pelican', 'king penguin', 'albatross', 'grey whale',
'killer whale', 'dugong', 'sea lion', 'Chihuahua', 'Japanese Chin', 'Maltese', 'Pekingese', 'Shih Tzu',
'King Charles Spaniel', 'Papillon', 'toy terrier', 'Rhodesian Ridgeback', 'Afghan Hound', 'Basset Hound', 'Beagle',
'Bloodhound', 'Bluetick Coonhound', 'Black and Tan Coonhound', 'Treeing Walker Coonhound', 'English foxhound',
'Redbone Coonhound', 'borzoi', 'Irish Wolfhound', 'Italian Greyhound', 'Whippet', 'Ibizan Hound',
'Norwegian Elkhound', 'Otterhound', 'Saluki', 'Scottish Deerhound', 'Weimaraner', 'Staffordshire Bull Terrier',
'American Staffordshire Terrier', 'Bedlington Terrier', 'Border Terrier', 'Kerry Blue Terrier', 'Irish Terrier',
'Norfolk Terrier', 'Norwich Terrier', 'Yorkshire Terrier', 'Wire Fox Terrier', 'Lakeland Terrier',
'Sealyham Terrier', 'Airedale Terrier', 'Cairn Terrier', 'Australian Terrier', 'Dandie Dinmont Terrier',
'Boston Terrier', 'Miniature Schnauzer', 'Giant Schnauzer', 'Standard Schnauzer', 'Scottish Terrier',
'Tibetan Terrier', 'Australian Silky Terrier', 'Soft-coated Wheaten Terrier', 'West Highland White Terrier',
'Lhasa Apso', 'Flat-Coated Retriever', 'Curly-coated Retriever', 'Golden Retriever', 'Labrador Retriever',
'Chesapeake Bay Retriever', 'German Shorthaired Pointer', 'Vizsla', 'English Setter', 'Irish Setter',
'Gordon Setter', 'Brittany', 'Clumber Spaniel', 'English Springer Spaniel', 'Welsh Springer Spaniel',
'Cocker Spaniels', 'Sussex Spaniel', 'Irish Water Spaniel', 'Kuvasz', 'Schipperke', 'Groenendael', 'Malinois',
'Briard', 'Australian Kelpie', 'Komondor', 'Old English Sheepdog', 'Shetland Sheepdog', 'collie', 'Border Collie',
'Bouvier des Flandres', 'Rottweiler', 'German Shepherd Dog', 'Dobermann', 'Miniature Pinscher',
'Greater Swiss Mountain Dog', 'Bernese Mountain Dog', 'Appenzeller Sennenhund', 'Entlebucher Sennenhund', 'Boxer',
'Bullmastiff', 'Tibetan Mastiff', 'French Bulldog', 'Great Dane', 'St. Bernard', 'husky', 'Alaskan Malamute',
'Siberian Husky', 'Dalmatian', 'Affenpinscher', 'Basenji', 'pug', 'Leonberger', 'Newfoundland',
'Pyrenean Mountain Dog', 'Samoyed', 'Pomeranian', 'Chow Chow', 'Keeshond', 'Griffon Bruxellois',
'Pembroke Welsh Corgi', 'Cardigan Welsh Corgi', 'Toy Poodle', 'Miniature Poodle', 'Standard Poodle',
'Mexican hairless dog', 'grey wolf', 'Alaskan tundra wolf', 'red wolf', 'coyote', 'dingo', 'dhole',
'African wild dog', 'hyena', 'red fox', 'kit fox', 'Arctic fox', 'grey fox', 'tabby cat', 'tiger cat',
'Persian cat', 'Siamese cat', 'Egyptian Mau', 'cougar', 'lynx', 'leopard', 'snow leopard', 'jaguar', 'lion',
'tiger', 'cheetah', 'brown bear', 'American black bear', 'polar bear', 'sloth bear', 'mongoose', 'meerkat',
'tiger beetle', 'ladybug', 'ground beetle', 'longhorn beetle', 'leaf beetle', 'dung beetle', 'rhinoceros beetle',
'weevil', 'fly', 'bee', 'ant', 'grasshopper', 'cricket', 'stick insect', 'cockroach', 'mantis', 'cicada',
'leafhopper', 'lacewing', 'dragonfly', 'damselfly', 'red admiral', 'ringlet', 'monarch butterfly', 'small white',
'sulphur butterfly', 'gossamer-winged butterfly', 'starfish', 'sea urchin', 'sea cucumber', 'cottontail rabbit',
'hare', 'Angora rabbit', 'hamster', 'porcupine', 'fox squirrel', 'marmot', 'beaver', 'guinea pig', 'common sorrel',
'zebra', 'pig', 'wild boar', 'warthog', 'hippopotamus', 'ox', 'water buffalo', 'bison', 'ram', 'bighorn sheep',
'Alpine ibex', 'hartebeest', 'impala', 'gazelle', 'dromedary', 'llama', 'weasel', 'mink', 'European polecat',
'black-footed ferret', 'otter', 'skunk', 'badger', 'armadillo', 'three-toed sloth', 'orangutan', 'gorilla',
'chimpanzee', 'gibbon', 'siamang', 'guenon', 'patas monkey', 'baboon', 'macaque', 'langur',
'black-and-white colobus', 'proboscis monkey', 'marmoset', 'white-headed capuchin', 'howler monkey', 'titi',
"Geoffroy's spider monkey", 'common squirrel monkey', 'ring-tailed lemur', 'indri', 'Asian elephant',
'African bush elephant', 'red panda', 'giant panda', 'snoek', 'eel', 'coho salmon', 'rock beauty', 'clownfish',
'sturgeon', 'garfish', 'lionfish', 'pufferfish', 'abacus', 'abaya', 'academic gown', 'accordion',
'acoustic guitar', 'aircraft carrier', 'airliner', 'airship', 'altar', 'ambulance', 'amphibious vehicle',
'analog clock', 'apiary', 'apron', 'waste container', 'assault rifle', 'backpack', 'bakery', 'balance beam',
'balloon', 'ballpoint pen', 'Band-Aid', 'banjo', 'baluster', 'barbell', 'barber chair', 'barbershop', 'barn',
'barometer', 'barrel', 'wheelbarrow', 'baseball', 'basketball', 'bassinet', 'bassoon', 'swimming cap',
'bath towel', 'bathtub', 'station wagon', 'lighthouse', 'beaker', 'military cap', 'beer bottle', 'beer glass',
'bell-cot', 'bib', 'tandem bicycle', 'bikini', 'ring binder', 'binoculars', 'birdhouse', 'boathouse', 'bobsleigh',
'bolo tie', 'poke bonnet', 'bookcase', 'bookstore', 'bottle cap', 'bow', 'bow tie', 'brass', 'bra', 'breakwater',
'breastplate', 'broom', 'bucket', 'buckle', 'bulletproof vest', 'high-speed train', 'butcher shop', 'taxicab',
'cauldron', 'candle', 'cannon', 'canoe', 'can opener', 'cardigan', 'car mirror', 'carousel', 'tool kit', 'carton',
'car wheel', 'automated teller machine', 'cassette', 'cassette player', 'castle', 'catamaran', 'CD player',
'cello', 'mobile phone', 'chain', 'chain-link fence', 'chain mail', 'chainsaw', 'chest', 'chiffonier', 'chime',
'china cabinet', 'Christmas stocking', 'church', 'movie theater', 'cleaver', 'cliff dwelling', 'cloak', 'clogs',
'cocktail shaker', 'coffee mug', 'coffeemaker', 'coil', 'combination lock', 'computer keyboard',
'confectionery store', 'container ship', 'convertible', 'corkscrew', 'cornet', 'cowboy boot', 'cowboy hat',
'cradle', 'crane (machine)', 'crash helmet', 'crate', 'infant bed', 'Crock Pot', 'croquet ball', 'crutch',
'cuirass', 'dam', 'desk', 'desktop computer', 'rotary dial telephone', 'diaper', 'digital clock', 'digital watch',
'dining table', 'dishcloth', 'dishwasher', 'disc brake', 'dock', 'dog sled', 'dome', 'doormat', 'drilling rig',
'drum', 'drumstick', 'dumbbell', 'Dutch oven', 'electric fan', 'electric guitar', 'electric locomotive',
'entertainment center', 'envelope', 'espresso machine', 'face powder', 'feather boa', 'filing cabinet', 'fireboat',
'fire engine', 'fire screen sheet', 'flagpole', 'flute', 'folding chair', 'football helmet', 'forklift',
'fountain', 'fountain pen', 'four-poster bed', 'freight car', 'French horn', 'frying pan', 'fur coat',
'garbage truck', 'gas mask', 'gas pump', 'goblet', 'go-kart', 'golf ball', 'golf cart', 'gondola', 'gong', 'gown',
'grand piano', 'greenhouse', 'grille', 'grocery store', 'guillotine', 'barrette', 'hair spray', 'half-track',
'hammer', 'hamper', 'hair dryer', 'hand-held computer', 'handkerchief', 'hard disk drive', 'harmonica', 'harp',
'harvester', 'hatchet', 'holster', 'home theater', 'honeycomb', 'hook', 'hoop skirt', 'horizontal bar',
'horse-drawn vehicle', 'hourglass', 'iPod', 'clothes iron', "jack-o'-lantern", 'jeans', 'jeep', 'T-shirt',
'jigsaw puzzle', 'pulled rickshaw', 'joystick', 'kimono', 'knee pad', 'knot', 'lab coat', 'ladle', 'lampshade',
'laptop computer', 'lawn mower', 'lens cap', 'paper knife', 'library', 'lifeboat', 'lighter', 'limousine',
'ocean liner', 'lipstick', 'slip-on shoe', 'lotion', 'speaker', 'loupe', 'sawmill', 'magnetic compass', 'mail bag',
'mailbox', 'tights', 'tank suit', 'manhole cover', 'maraca', 'marimba', 'mask', 'match', 'maypole', 'maze',
'measuring cup', 'medicine chest', 'megalith', 'microphone', 'microwave oven', 'military uniform', 'milk can',
'minibus', 'miniskirt', 'minivan', 'missile', 'mitten', 'mixing bowl', 'mobile home', 'Model T', 'modem',
'monastery', 'monitor', 'moped', 'mortar', 'square academic cap', 'mosque', 'mosquito net', 'scooter',
'mountain bike', 'tent', 'computer mouse', 'mousetrap', 'moving van', 'muzzle', 'nail', 'neck brace', 'necklace',
'nipple', 'notebook computer', 'obelisk', 'oboe', 'ocarina', 'odometer', 'oil filter', 'organ', 'oscilloscope',
'overskirt', 'bullock cart', 'oxygen mask', 'packet', 'paddle', 'paddle wheel', 'padlock', 'paintbrush', 'pajamas',
'palace', 'pan flute', 'paper towel', 'parachute', 'parallel bars', 'park bench', 'parking meter', 'passenger car',
'patio', 'payphone', 'pedestal', 'pencil case', 'pencil sharpener', 'perfume', 'Petri dish', 'photocopier',
'plectrum', 'Pickelhaube', 'picket fence', 'pickup truck', 'pier', 'piggy bank', 'pill bottle', 'pillow',
'ping-pong ball', 'pinwheel', 'pirate ship', 'pitcher', 'hand plane', 'planetarium', 'plastic bag', 'plate rack',
'plow', 'plunger', 'Polaroid camera', 'pole', 'police van', 'poncho', 'billiard table', 'soda bottle', 'pot',
"potter's wheel", 'power drill', 'prayer rug', 'printer', 'prison', 'projectile', 'projector', 'hockey puck',
'punching bag', 'purse', 'quill', 'quilt', 'race car', 'racket', 'radiator', 'radio', 'radio telescope',
'rain barrel', 'recreational vehicle', 'reel', 'reflex camera', 'refrigerator', 'remote control', 'restaurant',
'revolver', 'rifle', 'rocking chair', 'rotisserie', 'eraser', 'rugby ball', 'ruler', 'running shoe', 'safe',
'safety pin', 'salt shaker', 'sandal', 'sarong', 'saxophone', 'scabbard', 'weighing scale', 'school bus',
'schooner', 'scoreboard', 'CRT screen', 'screw', 'screwdriver', 'seat belt', 'sewing machine', 'shield',
'shoe store', 'shoji', 'shopping basket', 'shopping cart', 'shovel', 'shower cap', 'shower curtain', 'ski',
'ski mask', 'sleeping bag', 'slide rule', 'sliding door', 'slot machine', 'snorkel', 'snowmobile', 'snowplow',
'soap dispenser', 'soccer ball', 'sock', 'solar thermal collector', 'sombrero', 'soup bowl', 'space bar',
'space heater', 'space shuttle', 'spatula', 'motorboat', 'spider web', 'spindle', 'sports car', 'spotlight',
'stage', 'steam locomotive', 'through arch bridge', 'steel drum', 'stethoscope', 'scarf', 'stone wall',
'stopwatch', 'stove', 'strainer', 'tram', 'stretcher', 'couch', 'stupa', 'submarine', 'suit', 'sundial',
'sunglass', 'sunglasses', 'sunscreen', 'suspension bridge', 'mop', 'sweatshirt', 'swimsuit', 'swing', 'switch',
'syringe', 'table lamp', 'tank', 'tape player', 'teapot', 'teddy bear', 'television', 'tennis ball',
'thatched roof', 'front curtain', 'thimble', 'threshing machine', 'throne', 'tile roof', 'toaster', 'tobacco shop',
'toilet seat', 'torch', 'totem pole', 'tow truck', 'toy store', 'tractor', 'semi-trailer truck', 'tray',
'trench coat', 'tricycle', 'trimaran', 'tripod', 'triumphal arch', 'trolleybus', 'trombone', 'tub', 'turnstile',
'typewriter keyboard', 'umbrella', 'unicycle', 'upright piano', 'vacuum cleaner', 'vase', 'vault', 'velvet',
'vending machine', 'vestment', 'viaduct', 'violin', 'volleyball', 'waffle iron', 'wall clock', 'wallet',
'wardrobe', 'military aircraft', 'sink', 'washing machine', 'water bottle', 'water jug', 'water tower',
'whiskey jug', 'whistle', 'wig', 'window screen', 'window shade', 'Windsor tie', 'wine bottle', 'wing', 'wok',
'wooden spoon', 'wool', 'split-rail fence', 'shipwreck', 'yawl', 'yurt', 'website', 'comic book', 'crossword',
'traffic sign', 'traffic light', 'dust jacket', 'menu', 'plate', 'guacamole', 'consomme', 'hot pot', 'trifle',
'ice cream', 'ice pop', 'baguette', 'bagel', 'pretzel', 'cheeseburger', 'hot dog', 'mashed potato', 'cabbage',
'broccoli', 'cauliflower', 'zucchini', 'spaghetti squash', 'acorn squash', 'butternut squash', 'cucumber',
'artichoke', 'bell pepper', 'cardoon', 'mushroom', 'Granny Smith', 'strawberry', 'orange', 'lemon', 'fig',
'pineapple', 'banana', 'jackfruit', 'custard apple', 'pomegranate', 'hay', 'carbonara', 'chocolate syrup', 'dough',
'meatloaf', 'pizza', 'pot pie', 'burrito', 'red wine', 'espresso', 'cup', 'eggnog', 'alp', 'bubble', 'cliff',
'coral reef', 'geyser', 'lakeshore', 'promontory', 'shoal', 'seashore', 'valley', 'volcano', 'baseball player',
'bridegroom', 'scuba diver', 'rapeseed', 'daisy', "yellow lady's slipper", 'corn', 'acorn', 'rose hip',
'horse chestnut seed', 'coral fungus', 'agaric', 'gyromitra', 'stinkhorn mushroom', 'earth star',
'hen-of-the-woods', 'bolete', 'ear', 'toilet paper']
)
# detection model classes # detection model classes
CLASSES = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', CLASSES_DET = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe',
@ -23,7 +159,7 @@ CLASSES = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
# colors for per classes # colors for per classes
COLORS = { COLORS = {
cls: [random.randint(0, 255) for _ in range(3)] cls: [random.randint(0, 255) for _ in range(3)]
for i, cls in enumerate(CLASSES) for i, cls in enumerate(CLASSES_DET)
} }
# colors for segment masks # colors for segment masks

@ -0,0 +1,56 @@
cmake_minimum_required(VERSION 3.1)
set(CMAKE_CUDA_ARCHITECTURES 60 61 62 70 72 75 86 89 90)
set(CMAKE_CUDA_COMPILER /usr/local/cuda/bin/nvcc)
project(yolov8-cls LANGUAGES CXX CUDA)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++14 -O3")
set(CMAKE_CXX_STANDARD 14)
set(CMAKE_BUILD_TYPE Release)
option(CUDA_USE_STATIC_CUDA_RUNTIME OFF)
# CUDA
find_package(CUDA REQUIRED)
message(STATUS "CUDA Libs: \n${CUDA_LIBRARIES}\n")
get_filename_component(CUDA_LIB_DIR ${CUDA_LIBRARIES} DIRECTORY)
message(STATUS "CUDA Headers: \n${CUDA_INCLUDE_DIRS}\n")
# OpenCV
find_package(OpenCV REQUIRED)
message(STATUS "OpenCV Libs: \n${OpenCV_LIBS}\n")
message(STATUS "OpenCV Libraries: \n${OpenCV_LIBRARIES}\n")
message(STATUS "OpenCV Headers: \n${OpenCV_INCLUDE_DIRS}\n")
# TensorRT
set(TensorRT_INCLUDE_DIRS /usr/include/x86_64-linux-gnu)
set(TensorRT_LIBRARIES /usr/lib/x86_64-linux-gnu)
message(STATUS "TensorRT Libs: \n${TensorRT_LIBRARIES}\n")
message(STATUS "TensorRT Headers: \n${TensorRT_INCLUDE_DIRS}\n")
list(APPEND INCLUDE_DIRS
${CUDA_INCLUDE_DIRS}
${OpenCV_INCLUDE_DIRS}
${TensorRT_INCLUDE_DIRS}
./include
)
list(APPEND ALL_LIBS
${CUDA_LIBRARIES}
${CUDA_LIB_DIR}
${OpenCV_LIBRARIES}
${TensorRT_LIBRARIES}
)
include_directories(${INCLUDE_DIRS})
add_executable(${PROJECT_NAME}
main.cpp
include/yolov8-cls.hpp
include/common.hpp
)
target_link_directories(${PROJECT_NAME} PUBLIC ${ALL_LIBS})
target_link_libraries(${PROJECT_NAME} PRIVATE nvinfer nvinfer_plugin cudart ${OpenCV_LIBS})

@ -0,0 +1,132 @@
//
// Created by ubuntu on 4/27/24.
//
#ifndef CLS_NORMAL_COMMON_HPP
#define CLS_NORMAL_COMMON_HPP
#include "NvInfer.h"
#include "opencv2/opencv.hpp"
#include <sys/stat.h>
#include <unistd.h>
#define CHECK(call) \
do { \
const cudaError_t error_code = call; \
if (error_code != cudaSuccess) { \
printf("CUDA Error:\n"); \
printf(" File: %s\n", __FILE__); \
printf(" Line: %d\n", __LINE__); \
printf(" Error code: %d\n", error_code); \
printf(" Error text: %s\n", cudaGetErrorString(error_code)); \
exit(1); \
} \
} while (0)
class Logger: public nvinfer1::ILogger {
public:
nvinfer1::ILogger::Severity reportableSeverity;
explicit Logger(nvinfer1::ILogger::Severity severity = nvinfer1::ILogger::Severity::kINFO):
reportableSeverity(severity)
{
}
void log(nvinfer1::ILogger::Severity severity, const char* msg) noexcept override
{
if (severity > reportableSeverity) {
return;
}
switch (severity) {
case nvinfer1::ILogger::Severity::kINTERNAL_ERROR:
std::cerr << "INTERNAL_ERROR: ";
break;
case nvinfer1::ILogger::Severity::kERROR:
std::cerr << "ERROR: ";
break;
case nvinfer1::ILogger::Severity::kWARNING:
std::cerr << "WARNING: ";
break;
case nvinfer1::ILogger::Severity::kINFO:
std::cerr << "INFO: ";
break;
default:
std::cerr << "VERBOSE: ";
break;
}
std::cerr << msg << std::endl;
}
};
inline int get_size_by_dims(const nvinfer1::Dims& dims)
{
int size = 1;
for (int i = 0; i < dims.nbDims; i++) {
size *= dims.d[i];
}
return size;
}
inline int type_to_size(const nvinfer1::DataType& dataType)
{
switch (dataType) {
case nvinfer1::DataType::kFLOAT:
return 4;
case nvinfer1::DataType::kHALF:
return 2;
case nvinfer1::DataType::kINT32:
return 4;
case nvinfer1::DataType::kINT8:
return 1;
case nvinfer1::DataType::kBOOL:
return 1;
default:
return 4;
}
}
inline static float clamp(float val, float min, float max)
{
return val > min ? (val < max ? val : max) : min;
}
inline bool IsPathExist(const std::string& path)
{
if (access(path.c_str(), 0) == F_OK) {
return true;
}
return false;
}
inline bool IsFile(const std::string& path)
{
if (!IsPathExist(path)) {
printf("%s:%d %s not exist\n", __FILE__, __LINE__, path.c_str());
return false;
}
struct stat buffer;
return (stat(path.c_str(), &buffer) == 0 && S_ISREG(buffer.st_mode));
}
inline bool IsFolder(const std::string& path)
{
if (!IsPathExist(path)) {
return false;
}
struct stat buffer;
return (stat(path.c_str(), &buffer) == 0 && S_ISDIR(buffer.st_mode));
}
namespace cls {
struct Binding {
size_t size = 1;
size_t dsize = 1;
nvinfer1::Dims dims;
std::string name;
};
struct Object {
int label = 0;
float prob = 0.0;
};
} // namespace cls
#endif // CLS_NORMAL_COMMON_HPP

@ -0,0 +1,221 @@
//
// Created by ubuntu on 4/27/24.
//
#ifndef CLS_NORMAL_YOLOv8_cls_HPP
#define CLS_NORMAL_YOLOv8_cls_HPP
#include "NvInferPlugin.h"
#include "common.hpp"
#include "fstream"
using namespace cls;
class YOLOv8_cls {
public:
explicit YOLOv8_cls(const std::string& engine_file_path);
~YOLOv8_cls();
void make_pipe(bool warmup = true);
void copy_from_Mat(const cv::Mat& image);
void copy_from_Mat(const cv::Mat& image, cv::Size& size);
void infer();
void postprocess(std::vector<Object>& objs);
static void draw_objects(const cv::Mat& image,
cv::Mat& res,
const std::vector<Object>& objs,
const std::vector<std::string>& CLASS_NAMES);
int num_bindings;
int num_inputs = 0;
int num_outputs = 0;
std::vector<Binding> input_bindings;
std::vector<Binding> output_bindings;
std::vector<void*> host_ptrs;
std::vector<void*> device_ptrs;
private:
nvinfer1::ICudaEngine* engine = nullptr;
nvinfer1::IRuntime* runtime = nullptr;
nvinfer1::IExecutionContext* context = nullptr;
cudaStream_t stream = nullptr;
Logger gLogger{nvinfer1::ILogger::Severity::kERROR};
};
YOLOv8_cls::YOLOv8_cls(const std::string& engine_file_path)
{
std::ifstream file(engine_file_path, std::ios::binary);
assert(file.good());
file.seekg(0, std::ios::end);
auto size = file.tellg();
file.seekg(0, std::ios::beg);
char* trtModelStream = new char[size];
assert(trtModelStream);
file.read(trtModelStream, size);
file.close();
initLibNvInferPlugins(&this->gLogger, "");
this->runtime = nvinfer1::createInferRuntime(this->gLogger);
assert(this->runtime != nullptr);
this->engine = this->runtime->deserializeCudaEngine(trtModelStream, size);
assert(this->engine != nullptr);
delete[] trtModelStream;
this->context = this->engine->createExecutionContext();
assert(this->context != nullptr);
cudaStreamCreate(&this->stream);
this->num_bindings = this->engine->getNbBindings();
for (int i = 0; i < this->num_bindings; ++i) {
Binding binding;
nvinfer1::Dims dims;
nvinfer1::DataType dtype = this->engine->getBindingDataType(i);
std::string name = this->engine->getBindingName(i);
binding.name = name;
binding.dsize = type_to_size(dtype);
bool IsInput = engine->bindingIsInput(i);
if (IsInput) {
this->num_inputs += 1;
dims = this->engine->getProfileDimensions(i, 0, nvinfer1::OptProfileSelector::kMAX);
binding.size = get_size_by_dims(dims);
binding.dims = dims;
this->input_bindings.push_back(binding);
// set max opt shape
this->context->setBindingDimensions(i, dims);
}
else {
dims = this->context->getBindingDimensions(i);
binding.size = get_size_by_dims(dims);
binding.dims = dims;
this->output_bindings.push_back(binding);
this->num_outputs += 1;
}
}
}
YOLOv8_cls::~YOLOv8_cls()
{
this->context->destroy();
this->engine->destroy();
this->runtime->destroy();
cudaStreamDestroy(this->stream);
for (auto& ptr : this->device_ptrs) {
CHECK(cudaFree(ptr));
}
for (auto& ptr : this->host_ptrs) {
CHECK(cudaFreeHost(ptr));
}
}
void YOLOv8_cls::make_pipe(bool warmup)
{
for (auto& bindings : this->input_bindings) {
void* d_ptr;
CHECK(cudaMallocAsync(&d_ptr, bindings.size * bindings.dsize, this->stream));
this->device_ptrs.push_back(d_ptr);
}
for (auto& bindings : this->output_bindings) {
void * d_ptr, *h_ptr;
size_t size = bindings.size * bindings.dsize;
CHECK(cudaMallocAsync(&d_ptr, size, this->stream));
CHECK(cudaHostAlloc(&h_ptr, size, 0));
this->device_ptrs.push_back(d_ptr);
this->host_ptrs.push_back(h_ptr);
}
if (warmup) {
for (int i = 0; i < 10; i++) {
for (auto& bindings : this->input_bindings) {
size_t size = bindings.size * bindings.dsize;
void* h_ptr = malloc(size);
memset(h_ptr, 0, size);
CHECK(cudaMemcpyAsync(this->device_ptrs[0], h_ptr, size, cudaMemcpyHostToDevice, this->stream));
free(h_ptr);
}
this->infer();
}
printf("model warmup 10 times\n");
}
}
void YOLOv8_cls::copy_from_Mat(const cv::Mat& image)
{
cv::Mat nchw;
auto& in_binding = this->input_bindings[0];
auto width = in_binding.dims.d[3];
auto height = in_binding.dims.d[2];
cv::dnn::blobFromImage(image, nchw, 1 / 255.f, cv::Size(width, height), cv::Scalar(0, 0, 0), true, false, CV_32F);
this->context->setBindingDimensions(0, nvinfer1::Dims{4, {1, 3, height, width}});
CHECK(cudaMemcpyAsync(
this->device_ptrs[0], nchw.ptr<float>(), nchw.total() * nchw.elemSize(), cudaMemcpyHostToDevice, this->stream));
}
void YOLOv8_cls::copy_from_Mat(const cv::Mat& image, cv::Size& size)
{
cv::Mat nchw;
cv::dnn::blobFromImage(image, nchw, 1 / 255.f, size, cv::Scalar(0, 0, 0), true, false, CV_32F);
this->context->setBindingDimensions(0, nvinfer1::Dims{4, {1, 3, size.height, size.width}});
CHECK(cudaMemcpyAsync(
this->device_ptrs[0], nchw.ptr<float>(), nchw.total() * nchw.elemSize(), cudaMemcpyHostToDevice, this->stream));
}
void YOLOv8_cls::infer()
{
this->context->enqueueV2(this->device_ptrs.data(), this->stream, nullptr);
for (int i = 0; i < this->num_outputs; i++) {
size_t osize = this->output_bindings[i].size * this->output_bindings[i].dsize;
CHECK(cudaMemcpyAsync(
this->host_ptrs[i], this->device_ptrs[i + this->num_inputs], osize, cudaMemcpyDeviceToHost, this->stream));
}
cudaStreamSynchronize(this->stream);
}
void YOLOv8_cls::postprocess(std::vector<Object>& objs)
{
objs.clear();
auto num_cls = this->output_bindings[0].dims.d[1];
float* max_ptr =
std::max_element(static_cast<float*>(this->host_ptrs[0]), static_cast<float*>(this->host_ptrs[0]) + num_cls);
Object obj;
obj.label = std::distance(static_cast<float*>(this->host_ptrs[0]), max_ptr);
obj.prob = *max_ptr;
objs.push_back(obj);
}
void YOLOv8_cls::draw_objects(const cv::Mat& image,
cv::Mat& res,
const std::vector<Object>& objs,
const std::vector<std::string>& CLASS_NAMES)
{
res = image.clone();
char text[256];
Object obj = objs[0];
sprintf(text, "%s %.1f%%", CLASS_NAMES[obj.label].c_str(), obj.prob * 100);
int baseLine = 0;
cv::Size label_size = cv::getTextSize(text, cv::FONT_HERSHEY_SIMPLEX, 0.4, 1, &baseLine);
int x = 10;
int y = 10;
if (y > res.rows)
y = res.rows;
cv::rectangle(res, cv::Rect(x, y, label_size.width, label_size.height + baseLine), {0, 0, 255}, -1);
cv::putText(res, text, cv::Point(x, y + label_size.height), cv::FONT_HERSHEY_SIMPLEX, 0.4, {255, 255, 255}, 1);
}
#endif // CLS_NORMAL_YOLOv8_cls_HPP

File diff suppressed because it is too large Load Diff

@ -0,0 +1,129 @@
# YOLOv8-cls Model with TensorRT
The yolov8-cls model conversion route is :
YOLOv8 PyTorch model -> ONNX -> TensorRT Engine
***Notice !!!*** We don't support TensorRT API building !!!
# Export Orin ONNX model by ultralytics
You can leave this repo and use the original `ultralytics` repo for onnx export.
### 1. ONNX -> TensorRT
You can export your onnx model by `ultralytics` API.
``` shell
yolo export model=yolov8s-cls.pt format=onnx opset=11 simplify=True
```
or run this python script:
```python
from ultralytics import YOLO
# Load a model
model = YOLO("yolov8s-cls.pt") # load a pretrained model (recommended for training)
success = model.export(format="onnx", opset=11, simplify=True) # export the model to onnx format
assert success
```
Then build engine by Trtexec Tools.
You can export TensorRT engine by [`trtexec`](https://github.com/NVIDIA/TensorRT/tree/main/samples/trtexec) tools.
Usage:
``` shell
/usr/src/tensorrt/bin/trtexec \
--onnx=yolov8s-cls.onnx \
--saveEngine=yolov8s-cls.engine \
--fp16
```
### 2. Direct to TensorRT (NOT RECOMMAND!!)
Usage:
```shell
yolo export model=yolov8s-cls.pt format=engine device=0
```
or run python script:
```python
from ultralytics import YOLO
# Load a model
model = YOLO("yolov8s-cls.pt") # load a pretrained model (recommended for training)
success = model.export(format="engine", device=0) # export the model to engine format
assert success
```
After executing the above script, you will get an engine named `yolov8s-cls.engine` .
# Inference
## Infer with python script
You can infer images with the engine by [`infer-cls.py`](../infer-cls.py) .
Usage:
``` shell
python3 infer-cls.py \
--engine yolov8s-cls.engine \
--imgs data \
--show \
--out-dir outputs \
--device cuda:0
```
#### Description of all arguments
- `--engine` : The Engine you export.
- `--imgs` : The images path you want to detect.
- `--show` : Whether to show detection results.
- `--out-dir` : Where to save detection results images. It will not work when use `--show` flag.
- `--device` : The CUDA deivce you use.
## Inference with c++
You can infer with c++ in [`csrc/cls/normal`](../csrc/cls/normal) .
### Build:
Please set you own librarys in [`CMakeLists.txt`](../csrc/cls/normal/CMakeLists.txt) and modify `KPS_COLORS`
and `SKELETON` and `LIMB_COLORS` in [`main.cpp`](../csrc/cls/normal/main.cpp).
Besides, you can modify the postprocess parameters such as `score_thres` and `iou_thres` and `topk`
in [`main.cpp`](../csrc/cls/normal/main.cpp).
```c++
int topk = 100;
float score_thres = 0.25f;
float iou_thres = 0.65f;
```
And build:
``` shell
export root=${PWD}
cd src/cls/normal
mkdir build
cmake ..
make
mv yolov8-cls ${root}
cd ${root}
```
Usage:
``` shell
# infer image
./yolov8-cls yolov8s-cls.engine data/bus.jpg
# infer images
./yolov8-cls yolov8s-cls.engine data
# infer video
./yolov8-cls yolov8s-cls.engine data/test.mp4 # the video path
```

@ -0,0 +1,79 @@
import argparse
from pathlib import Path
import cv2
import numpy as np
from config import CLASSES_CLS
from models.utils import blob, path_to_list
def main(args: argparse.Namespace) -> None:
if args.method == 'cudart':
from models.cudart_api import TRTEngine
elif args.method == 'pycuda':
from models.pycuda_api import TRTEngine
else:
raise NotImplementedError
Engine = TRTEngine(args.engine)
H, W = Engine.inp_info[0].shape[-2:]
images = path_to_list(args.imgs)
save_path = Path(args.out_dir)
if not args.show and not save_path.exists():
save_path.mkdir(parents=True, exist_ok=True)
for image in images:
save_image = save_path / image.name
bgr = cv2.imread(str(image))
draw = bgr.copy()
bgr = cv2.resize(bgr, (W, H))
rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
tensor = blob(rgb, return_seg=False)
tensor = np.ascontiguousarray(tensor)
# inference
data = Engine(tensor)
data = data[0]
score = data.max().item()
cls_id = data.argmax().item()
cls = CLASSES_CLS[cls_id]
text = f'{cls}:{score:.3f}'
(_w, _h), _bl = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.8, 1)
_y1 = min(10, draw.shape[0])
cv2.rectangle(draw, (10, _y1), (10 + _w, _y1 + _h + _bl), (0, 0, 255), -1)
cv2.putText(draw, text, (10, _y1 + _h), cv2.FONT_HERSHEY_SIMPLEX, 0.75, (255, 255, 255), 2)
if args.show:
cv2.imshow('result', draw)
cv2.waitKey(0)
else:
cv2.imwrite(str(save_image), draw)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--engine', type=str, help='Engine file')
parser.add_argument('--imgs', type=str, help='Images file')
parser.add_argument('--show',
action='store_true',
help='Show the detection results')
parser.add_argument('--out-dir',
type=str,
default='./output',
help='Path to output file')
parser.add_argument('--method',
type=str,
default='cudart',
help='CUDART pipeline')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
main(args)

@ -0,0 +1,74 @@
from models import TRTModule # isort:skip
import argparse
from pathlib import Path
import cv2
import torch
from config import CLASSES_CLS
from models.utils import blob, path_to_list
def main(args: argparse.Namespace) -> None:
device = torch.device(args.device)
Engine = TRTModule(args.engine, device)
H, W = Engine.inp_info[0].shape[-2:]
images = path_to_list(args.imgs)
save_path = Path(args.out_dir)
if not args.show and not save_path.exists():
save_path.mkdir(parents=True, exist_ok=True)
for image in images:
save_image = save_path / image.name
bgr = cv2.imread(str(image))
draw = bgr.copy()
bgr = cv2.resize(bgr, (W, H))
rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
tensor = blob(rgb, return_seg=False)
tensor = torch.asarray(tensor, device=device)
# inference
data = Engine(tensor)
score, cls_id = data[0].max(0)
score = float(score)
cls_id = int(cls_id)
cls = CLASSES_CLS[cls_id]
text = f'{cls}:{score:.3f}'
(_w, _h), _bl = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.8, 1)
_y1 = min(10, draw.shape[0])
cv2.rectangle(draw, (10, _y1), (10 + _w, _y1 + _h + _bl), (0, 0, 255), -1)
cv2.putText(draw, text, (10, _y1 + _h), cv2.FONT_HERSHEY_SIMPLEX, 0.75, (255, 255, 255), 2)
if args.show:
cv2.imshow('result', draw)
cv2.waitKey(0)
else:
cv2.imwrite(str(save_image), draw)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument('--engine', type=str, help='Engine file')
parser.add_argument('--imgs', type=str, help='Images file')
parser.add_argument('--show',
action='store_true',
help='Show the detection results')
parser.add_argument('--out-dir',
type=str,
default='./output',
help='Path to output file')
parser.add_argument('--device',
type=str,
default='cuda:0',
help='TensorRT infer device')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
main(args)

@ -4,7 +4,7 @@ from pathlib import Path
import cv2 import cv2
import numpy as np import numpy as np
from config import CLASSES, COLORS from config import CLASSES_DET, COLORS
from models.utils import blob, det_postprocess, letterbox, path_to_list from models.utils import blob, det_postprocess, letterbox, path_to_list
@ -48,14 +48,19 @@ def main(args: argparse.Namespace) -> None:
for (bbox, score, label) in zip(bboxes, scores, labels): for (bbox, score, label) in zip(bboxes, scores, labels):
bbox = bbox.round().astype(np.int32).tolist() bbox = bbox.round().astype(np.int32).tolist()
cls_id = int(label) cls_id = int(label)
cls = CLASSES[cls_id] cls = CLASSES_DET[cls_id]
color = COLORS[cls] color = COLORS[cls]
cv2.rectangle(draw, bbox[:2], bbox[2:], color, 2)
cv2.putText(draw, text = f'{cls}:{score:.3f}'
f'{cls}:{score:.3f}', (bbox[0], bbox[1] - 2), x1, y1, x2, y2 = bbox
cv2.FONT_HERSHEY_SIMPLEX,
0.75, [225, 255, 255], (_w, _h), _bl = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.8, 1)
thickness=2) _y1 = min(y1 + 1, draw.shape[0])
cv2.rectangle(draw, (x1, y1), (x2, y2), color, 2)
cv2.rectangle(draw, (x1, _y1), (x1 + _w, _y1 + _h + _bl), (0, 0, 255), -1)
cv2.putText(draw, text, (x1, _y1 + _h), cv2.FONT_HERSHEY_SIMPLEX, 0.75, (255, 255, 255), 2)
if args.show: if args.show:
cv2.imshow('result', draw) cv2.imshow('result', draw)
cv2.waitKey(0) cv2.waitKey(0)

@ -5,7 +5,7 @@ from pathlib import Path
import cv2 import cv2
import torch import torch
from config import CLASSES, COLORS from config import CLASSES_DET, COLORS
from models.torch_utils import det_postprocess from models.torch_utils import det_postprocess
from models.utils import blob, letterbox, path_to_list from models.utils import blob, letterbox, path_to_list
@ -47,14 +47,19 @@ def main(args: argparse.Namespace) -> None:
for (bbox, score, label) in zip(bboxes, scores, labels): for (bbox, score, label) in zip(bboxes, scores, labels):
bbox = bbox.round().int().tolist() bbox = bbox.round().int().tolist()
cls_id = int(label) cls_id = int(label)
cls = CLASSES[cls_id] cls = CLASSES_DET[cls_id]
color = COLORS[cls] color = COLORS[cls]
cv2.rectangle(draw, bbox[:2], bbox[2:], color, 2)
cv2.putText(draw, text = f'{cls}:{score:.3f}'
f'{cls}:{score:.3f}', (bbox[0], bbox[1] - 2), x1, y1, x2, y2 = bbox
cv2.FONT_HERSHEY_SIMPLEX,
0.75, [225, 255, 255], (_w, _h), _bl = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.8, 1)
thickness=2) _y1 = min(y1 + 1, draw.shape[0])
cv2.rectangle(draw, (x1, y1), (x2, y2), color, 2)
cv2.rectangle(draw, (x1, _y1), (x1 + _w, _y1 + _h + _bl), (0, 0, 255), -1)
cv2.putText(draw, text, (x1, _y1 + _h), cv2.FONT_HERSHEY_SIMPLEX, 0.75, (255, 255, 255), 2)
if args.show: if args.show:
cv2.imshow('result', draw) cv2.imshow('result', draw)
cv2.waitKey(0) cv2.waitKey(0)

@ -70,9 +70,6 @@ def det_postprocess(data: Tuple[Tensor, Tensor, Tensor, Tensor]):
(0, )), labels.new_zeros((0, )) (0, )), labels.new_zeros((0, ))
# check score negative # check score negative
scores[scores < 0] = 1 + scores[scores < 0] scores[scores < 0] = 1 + scores[scores < 0]
# add nms
idx = nms(bboxes, scores, iou_thres)
bboxes, scores, labels = bboxes[idx], scores[idx], labels[idx]
bboxes = bboxes[:nums] bboxes = bboxes[:nums]
scores = scores[:nums] scores = scores[:nums]
labels = labels[:nums] labels = labels[:nums]

@ -199,7 +199,6 @@ def crop_mask(masks: ndarray, bboxes: ndarray) -> ndarray:
def det_postprocess(data: Tuple[ndarray, ndarray, ndarray, ndarray]): def det_postprocess(data: Tuple[ndarray, ndarray, ndarray, ndarray]):
assert len(data) == 4 assert len(data) == 4
iou_thres: float = 0.65
num_dets, bboxes, scores, labels = (i[0] for i in data) num_dets, bboxes, scores, labels = (i[0] for i in data)
nums = num_dets.item() nums = num_dets.item()
if nums == 0: if nums == 0:
@ -207,10 +206,6 @@ def det_postprocess(data: Tuple[ndarray, ndarray, ndarray, ndarray]):
(0, ), dtype=np.float32), np.empty((0, ), dtype=np.int32) (0, ), dtype=np.float32), np.empty((0, ), dtype=np.int32)
# check score negative # check score negative
scores[scores < 0] = 1 + scores[scores < 0] scores[scores < 0] = 1 + scores[scores < 0]
# add nms
idx = nms(bboxes, scores, iou_thres)
bboxes, scores, labels = bboxes[idx], scores[idx], labels[idx]
bboxes = bboxes[:nums] bboxes = bboxes[:nums]
scores = scores[:nums] scores = scores[:nums]
labels = labels[:nums] labels = labels[:nums]

Loading…
Cancel
Save