From 19a1fc8ce614b5cfcab8ea0120f1963dc71b0138 Mon Sep 17 00:00:00 2001 From: triplemu Date: Sat, 27 Apr 2024 22:16:48 +0800 Subject: [PATCH] support yolov8-cls --- README.md | 4 + config.py | 140 ++- csrc/cls/normal/CMakeLists.txt | 56 ++ csrc/cls/normal/include/common.hpp | 132 +++ csrc/cls/normal/include/yolov8-cls.hpp | 221 +++++ csrc/cls/normal/main.cpp | 1092 ++++++++++++++++++++++++ docs/Cls.md | 129 +++ infer-cls-without-torch.py | 79 ++ infer-cls.py | 74 ++ infer-det-without-torch.py | 21 +- infer-det.py | 21 +- models/torch_utils.py | 3 - models/utils.py | 5 - 13 files changed, 1951 insertions(+), 26 deletions(-) create mode 100644 csrc/cls/normal/CMakeLists.txt create mode 100644 csrc/cls/normal/include/common.hpp create mode 100644 csrc/cls/normal/include/yolov8-cls.hpp create mode 100644 csrc/cls/normal/main.cpp create mode 100644 docs/Cls.md create mode 100644 infer-cls-without-torch.py create mode 100644 infer-cls.py diff --git a/README.md b/README.md index dbec429..cce6ceb 100644 --- a/README.md +++ b/README.md @@ -212,6 +212,10 @@ Please see more information in [`Segment.md`](docs/Segment.md) Please see more information in [`Pose.md`](docs/Pose.md) +# TensorRT Cls Deploy + +Please see more information in [`Cls.md`](docs/Cls.md) + # DeepStream Detection Deploy See more in [`README.md`](csrc/deepstream/README.md) diff --git a/config.py b/config.py index 6a0e315..17ca86c 100644 --- a/config.py +++ b/config.py @@ -4,8 +4,144 @@ import numpy as np random.seed(0) +# classification model classes +CLASSES_CLS = ( + ['tench', 'goldfish', 'great white shark', 'tiger shark', 'hammerhead shark', 'electric ray', 'stingray', 'cock', + 'hen', 'ostrich', 'brambling', 'goldfinch', 'house finch', 'junco', 'indigo bunting', 'American robin', 'bulbul', + 'jay', 'magpie', 'chickadee', 'American dipper', 'kite', 'bald eagle', 'vulture', 'great grey owl', + 'fire salamander', 'smooth newt', 'newt', 'spotted salamander', 'axolotl', 'American bullfrog', 'tree frog', + 'tailed frog', 'loggerhead sea turtle', 'leatherback sea turtle', 'mud turtle', 'terrapin', 'box turtle', + 'banded gecko', 'green iguana', 'Carolina anole', 'desert grassland whiptail lizard', 'agama', + 'frilled-necked lizard', 'alligator lizard', 'Gila monster', 'European green lizard', 'chameleon', 'Komodo dragon', + 'Nile crocodile', 'American alligator', 'triceratops', 'worm snake', 'ring-necked snake', + 'eastern hog-nosed snake', 'smooth green snake', 'kingsnake', 'garter snake', 'water snake', 'vine snake', + 'night snake', 'boa constrictor', 'African rock python', 'Indian cobra', 'green mamba', 'sea snake', + 'Saharan horned viper', 'eastern diamondback rattlesnake', 'sidewinder', 'trilobite', 'harvestman', 'scorpion', + 'yellow garden spider', 'barn spider', 'European garden spider', 'southern black widow', 'tarantula', + 'wolf spider', 'tick', 'centipede', 'black grouse', 'ptarmigan', 'ruffed grouse', 'prairie grouse', 'peacock', + 'quail', 'partridge', 'grey parrot', 'macaw', 'sulphur-crested cockatoo', 'lorikeet', 'coucal', 'bee eater', + 'hornbill', 'hummingbird', 'jacamar', 'toucan', 'duck', 'red-breasted merganser', 'goose', 'black swan', 'tusker', + 'echidna', 'platypus', 'wallaby', 'koala', 'wombat', 'jellyfish', 'sea anemone', 'brain coral', 'flatworm', + 'nematode', 'conch', 'snail', 'slug', 'sea slug', 'chiton', 'chambered nautilus', 'Dungeness crab', 'rock crab', + 'fiddler crab', 'red king crab', 'American lobster', 'spiny lobster', 'crayfish', 'hermit crab', 'isopod', + 'white stork', 'black stork', 'spoonbill', 'flamingo', 'little blue heron', 'great egret', 'bittern', + 'crane (bird)', 'limpkin', 'common gallinule', 'American coot', 'bustard', 'ruddy turnstone', 'dunlin', + 'common redshank', 'dowitcher', 'oystercatcher', 'pelican', 'king penguin', 'albatross', 'grey whale', + 'killer whale', 'dugong', 'sea lion', 'Chihuahua', 'Japanese Chin', 'Maltese', 'Pekingese', 'Shih Tzu', + 'King Charles Spaniel', 'Papillon', 'toy terrier', 'Rhodesian Ridgeback', 'Afghan Hound', 'Basset Hound', 'Beagle', + 'Bloodhound', 'Bluetick Coonhound', 'Black and Tan Coonhound', 'Treeing Walker Coonhound', 'English foxhound', + 'Redbone Coonhound', 'borzoi', 'Irish Wolfhound', 'Italian Greyhound', 'Whippet', 'Ibizan Hound', + 'Norwegian Elkhound', 'Otterhound', 'Saluki', 'Scottish Deerhound', 'Weimaraner', 'Staffordshire Bull Terrier', + 'American Staffordshire Terrier', 'Bedlington Terrier', 'Border Terrier', 'Kerry Blue Terrier', 'Irish Terrier', + 'Norfolk Terrier', 'Norwich Terrier', 'Yorkshire Terrier', 'Wire Fox Terrier', 'Lakeland Terrier', + 'Sealyham Terrier', 'Airedale Terrier', 'Cairn Terrier', 'Australian Terrier', 'Dandie Dinmont Terrier', + 'Boston Terrier', 'Miniature Schnauzer', 'Giant Schnauzer', 'Standard Schnauzer', 'Scottish Terrier', + 'Tibetan Terrier', 'Australian Silky Terrier', 'Soft-coated Wheaten Terrier', 'West Highland White Terrier', + 'Lhasa Apso', 'Flat-Coated Retriever', 'Curly-coated Retriever', 'Golden Retriever', 'Labrador Retriever', + 'Chesapeake Bay Retriever', 'German Shorthaired Pointer', 'Vizsla', 'English Setter', 'Irish Setter', + 'Gordon Setter', 'Brittany', 'Clumber Spaniel', 'English Springer Spaniel', 'Welsh Springer Spaniel', + 'Cocker Spaniels', 'Sussex Spaniel', 'Irish Water Spaniel', 'Kuvasz', 'Schipperke', 'Groenendael', 'Malinois', + 'Briard', 'Australian Kelpie', 'Komondor', 'Old English Sheepdog', 'Shetland Sheepdog', 'collie', 'Border Collie', + 'Bouvier des Flandres', 'Rottweiler', 'German Shepherd Dog', 'Dobermann', 'Miniature Pinscher', + 'Greater Swiss Mountain Dog', 'Bernese Mountain Dog', 'Appenzeller Sennenhund', 'Entlebucher Sennenhund', 'Boxer', + 'Bullmastiff', 'Tibetan Mastiff', 'French Bulldog', 'Great Dane', 'St. Bernard', 'husky', 'Alaskan Malamute', + 'Siberian Husky', 'Dalmatian', 'Affenpinscher', 'Basenji', 'pug', 'Leonberger', 'Newfoundland', + 'Pyrenean Mountain Dog', 'Samoyed', 'Pomeranian', 'Chow Chow', 'Keeshond', 'Griffon Bruxellois', + 'Pembroke Welsh Corgi', 'Cardigan Welsh Corgi', 'Toy Poodle', 'Miniature Poodle', 'Standard Poodle', + 'Mexican hairless dog', 'grey wolf', 'Alaskan tundra wolf', 'red wolf', 'coyote', 'dingo', 'dhole', + 'African wild dog', 'hyena', 'red fox', 'kit fox', 'Arctic fox', 'grey fox', 'tabby cat', 'tiger cat', + 'Persian cat', 'Siamese cat', 'Egyptian Mau', 'cougar', 'lynx', 'leopard', 'snow leopard', 'jaguar', 'lion', + 'tiger', 'cheetah', 'brown bear', 'American black bear', 'polar bear', 'sloth bear', 'mongoose', 'meerkat', + 'tiger beetle', 'ladybug', 'ground beetle', 'longhorn beetle', 'leaf beetle', 'dung beetle', 'rhinoceros beetle', + 'weevil', 'fly', 'bee', 'ant', 'grasshopper', 'cricket', 'stick insect', 'cockroach', 'mantis', 'cicada', + 'leafhopper', 'lacewing', 'dragonfly', 'damselfly', 'red admiral', 'ringlet', 'monarch butterfly', 'small white', + 'sulphur butterfly', 'gossamer-winged butterfly', 'starfish', 'sea urchin', 'sea cucumber', 'cottontail rabbit', + 'hare', 'Angora rabbit', 'hamster', 'porcupine', 'fox squirrel', 'marmot', 'beaver', 'guinea pig', 'common sorrel', + 'zebra', 'pig', 'wild boar', 'warthog', 'hippopotamus', 'ox', 'water buffalo', 'bison', 'ram', 'bighorn sheep', + 'Alpine ibex', 'hartebeest', 'impala', 'gazelle', 'dromedary', 'llama', 'weasel', 'mink', 'European polecat', + 'black-footed ferret', 'otter', 'skunk', 'badger', 'armadillo', 'three-toed sloth', 'orangutan', 'gorilla', + 'chimpanzee', 'gibbon', 'siamang', 'guenon', 'patas monkey', 'baboon', 'macaque', 'langur', + 'black-and-white colobus', 'proboscis monkey', 'marmoset', 'white-headed capuchin', 'howler monkey', 'titi', + "Geoffroy's spider monkey", 'common squirrel monkey', 'ring-tailed lemur', 'indri', 'Asian elephant', + 'African bush elephant', 'red panda', 'giant panda', 'snoek', 'eel', 'coho salmon', 'rock beauty', 'clownfish', + 'sturgeon', 'garfish', 'lionfish', 'pufferfish', 'abacus', 'abaya', 'academic gown', 'accordion', + 'acoustic guitar', 'aircraft carrier', 'airliner', 'airship', 'altar', 'ambulance', 'amphibious vehicle', + 'analog clock', 'apiary', 'apron', 'waste container', 'assault rifle', 'backpack', 'bakery', 'balance beam', + 'balloon', 'ballpoint pen', 'Band-Aid', 'banjo', 'baluster', 'barbell', 'barber chair', 'barbershop', 'barn', + 'barometer', 'barrel', 'wheelbarrow', 'baseball', 'basketball', 'bassinet', 'bassoon', 'swimming cap', + 'bath towel', 'bathtub', 'station wagon', 'lighthouse', 'beaker', 'military cap', 'beer bottle', 'beer glass', + 'bell-cot', 'bib', 'tandem bicycle', 'bikini', 'ring binder', 'binoculars', 'birdhouse', 'boathouse', 'bobsleigh', + 'bolo tie', 'poke bonnet', 'bookcase', 'bookstore', 'bottle cap', 'bow', 'bow tie', 'brass', 'bra', 'breakwater', + 'breastplate', 'broom', 'bucket', 'buckle', 'bulletproof vest', 'high-speed train', 'butcher shop', 'taxicab', + 'cauldron', 'candle', 'cannon', 'canoe', 'can opener', 'cardigan', 'car mirror', 'carousel', 'tool kit', 'carton', + 'car wheel', 'automated teller machine', 'cassette', 'cassette player', 'castle', 'catamaran', 'CD player', + 'cello', 'mobile phone', 'chain', 'chain-link fence', 'chain mail', 'chainsaw', 'chest', 'chiffonier', 'chime', + 'china cabinet', 'Christmas stocking', 'church', 'movie theater', 'cleaver', 'cliff dwelling', 'cloak', 'clogs', + 'cocktail shaker', 'coffee mug', 'coffeemaker', 'coil', 'combination lock', 'computer keyboard', + 'confectionery store', 'container ship', 'convertible', 'corkscrew', 'cornet', 'cowboy boot', 'cowboy hat', + 'cradle', 'crane (machine)', 'crash helmet', 'crate', 'infant bed', 'Crock Pot', 'croquet ball', 'crutch', + 'cuirass', 'dam', 'desk', 'desktop computer', 'rotary dial telephone', 'diaper', 'digital clock', 'digital watch', + 'dining table', 'dishcloth', 'dishwasher', 'disc brake', 'dock', 'dog sled', 'dome', 'doormat', 'drilling rig', + 'drum', 'drumstick', 'dumbbell', 'Dutch oven', 'electric fan', 'electric guitar', 'electric locomotive', + 'entertainment center', 'envelope', 'espresso machine', 'face powder', 'feather boa', 'filing cabinet', 'fireboat', + 'fire engine', 'fire screen sheet', 'flagpole', 'flute', 'folding chair', 'football helmet', 'forklift', + 'fountain', 'fountain pen', 'four-poster bed', 'freight car', 'French horn', 'frying pan', 'fur coat', + 'garbage truck', 'gas mask', 'gas pump', 'goblet', 'go-kart', 'golf ball', 'golf cart', 'gondola', 'gong', 'gown', + 'grand piano', 'greenhouse', 'grille', 'grocery store', 'guillotine', 'barrette', 'hair spray', 'half-track', + 'hammer', 'hamper', 'hair dryer', 'hand-held computer', 'handkerchief', 'hard disk drive', 'harmonica', 'harp', + 'harvester', 'hatchet', 'holster', 'home theater', 'honeycomb', 'hook', 'hoop skirt', 'horizontal bar', + 'horse-drawn vehicle', 'hourglass', 'iPod', 'clothes iron', "jack-o'-lantern", 'jeans', 'jeep', 'T-shirt', + 'jigsaw puzzle', 'pulled rickshaw', 'joystick', 'kimono', 'knee pad', 'knot', 'lab coat', 'ladle', 'lampshade', + 'laptop computer', 'lawn mower', 'lens cap', 'paper knife', 'library', 'lifeboat', 'lighter', 'limousine', + 'ocean liner', 'lipstick', 'slip-on shoe', 'lotion', 'speaker', 'loupe', 'sawmill', 'magnetic compass', 'mail bag', + 'mailbox', 'tights', 'tank suit', 'manhole cover', 'maraca', 'marimba', 'mask', 'match', 'maypole', 'maze', + 'measuring cup', 'medicine chest', 'megalith', 'microphone', 'microwave oven', 'military uniform', 'milk can', + 'minibus', 'miniskirt', 'minivan', 'missile', 'mitten', 'mixing bowl', 'mobile home', 'Model T', 'modem', + 'monastery', 'monitor', 'moped', 'mortar', 'square academic cap', 'mosque', 'mosquito net', 'scooter', + 'mountain bike', 'tent', 'computer mouse', 'mousetrap', 'moving van', 'muzzle', 'nail', 'neck brace', 'necklace', + 'nipple', 'notebook computer', 'obelisk', 'oboe', 'ocarina', 'odometer', 'oil filter', 'organ', 'oscilloscope', + 'overskirt', 'bullock cart', 'oxygen mask', 'packet', 'paddle', 'paddle wheel', 'padlock', 'paintbrush', 'pajamas', + 'palace', 'pan flute', 'paper towel', 'parachute', 'parallel bars', 'park bench', 'parking meter', 'passenger car', + 'patio', 'payphone', 'pedestal', 'pencil case', 'pencil sharpener', 'perfume', 'Petri dish', 'photocopier', + 'plectrum', 'Pickelhaube', 'picket fence', 'pickup truck', 'pier', 'piggy bank', 'pill bottle', 'pillow', + 'ping-pong ball', 'pinwheel', 'pirate ship', 'pitcher', 'hand plane', 'planetarium', 'plastic bag', 'plate rack', + 'plow', 'plunger', 'Polaroid camera', 'pole', 'police van', 'poncho', 'billiard table', 'soda bottle', 'pot', + "potter's wheel", 'power drill', 'prayer rug', 'printer', 'prison', 'projectile', 'projector', 'hockey puck', + 'punching bag', 'purse', 'quill', 'quilt', 'race car', 'racket', 'radiator', 'radio', 'radio telescope', + 'rain barrel', 'recreational vehicle', 'reel', 'reflex camera', 'refrigerator', 'remote control', 'restaurant', + 'revolver', 'rifle', 'rocking chair', 'rotisserie', 'eraser', 'rugby ball', 'ruler', 'running shoe', 'safe', + 'safety pin', 'salt shaker', 'sandal', 'sarong', 'saxophone', 'scabbard', 'weighing scale', 'school bus', + 'schooner', 'scoreboard', 'CRT screen', 'screw', 'screwdriver', 'seat belt', 'sewing machine', 'shield', + 'shoe store', 'shoji', 'shopping basket', 'shopping cart', 'shovel', 'shower cap', 'shower curtain', 'ski', + 'ski mask', 'sleeping bag', 'slide rule', 'sliding door', 'slot machine', 'snorkel', 'snowmobile', 'snowplow', + 'soap dispenser', 'soccer ball', 'sock', 'solar thermal collector', 'sombrero', 'soup bowl', 'space bar', + 'space heater', 'space shuttle', 'spatula', 'motorboat', 'spider web', 'spindle', 'sports car', 'spotlight', + 'stage', 'steam locomotive', 'through arch bridge', 'steel drum', 'stethoscope', 'scarf', 'stone wall', + 'stopwatch', 'stove', 'strainer', 'tram', 'stretcher', 'couch', 'stupa', 'submarine', 'suit', 'sundial', + 'sunglass', 'sunglasses', 'sunscreen', 'suspension bridge', 'mop', 'sweatshirt', 'swimsuit', 'swing', 'switch', + 'syringe', 'table lamp', 'tank', 'tape player', 'teapot', 'teddy bear', 'television', 'tennis ball', + 'thatched roof', 'front curtain', 'thimble', 'threshing machine', 'throne', 'tile roof', 'toaster', 'tobacco shop', + 'toilet seat', 'torch', 'totem pole', 'tow truck', 'toy store', 'tractor', 'semi-trailer truck', 'tray', + 'trench coat', 'tricycle', 'trimaran', 'tripod', 'triumphal arch', 'trolleybus', 'trombone', 'tub', 'turnstile', + 'typewriter keyboard', 'umbrella', 'unicycle', 'upright piano', 'vacuum cleaner', 'vase', 'vault', 'velvet', + 'vending machine', 'vestment', 'viaduct', 'violin', 'volleyball', 'waffle iron', 'wall clock', 'wallet', + 'wardrobe', 'military aircraft', 'sink', 'washing machine', 'water bottle', 'water jug', 'water tower', + 'whiskey jug', 'whistle', 'wig', 'window screen', 'window shade', 'Windsor tie', 'wine bottle', 'wing', 'wok', + 'wooden spoon', 'wool', 'split-rail fence', 'shipwreck', 'yawl', 'yurt', 'website', 'comic book', 'crossword', + 'traffic sign', 'traffic light', 'dust jacket', 'menu', 'plate', 'guacamole', 'consomme', 'hot pot', 'trifle', + 'ice cream', 'ice pop', 'baguette', 'bagel', 'pretzel', 'cheeseburger', 'hot dog', 'mashed potato', 'cabbage', + 'broccoli', 'cauliflower', 'zucchini', 'spaghetti squash', 'acorn squash', 'butternut squash', 'cucumber', + 'artichoke', 'bell pepper', 'cardoon', 'mushroom', 'Granny Smith', 'strawberry', 'orange', 'lemon', 'fig', + 'pineapple', 'banana', 'jackfruit', 'custard apple', 'pomegranate', 'hay', 'carbonara', 'chocolate syrup', 'dough', + 'meatloaf', 'pizza', 'pot pie', 'burrito', 'red wine', 'espresso', 'cup', 'eggnog', 'alp', 'bubble', 'cliff', + 'coral reef', 'geyser', 'lakeshore', 'promontory', 'shoal', 'seashore', 'valley', 'volcano', 'baseball player', + 'bridegroom', 'scuba diver', 'rapeseed', 'daisy', "yellow lady's slipper", 'corn', 'acorn', 'rose hip', + 'horse chestnut seed', 'coral fungus', 'agaric', 'gyromitra', 'stinkhorn mushroom', 'earth star', + 'hen-of-the-woods', 'bolete', 'ear', 'toilet paper'] +) + # detection model classes -CLASSES = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', +CLASSES_DET = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', @@ -23,7 +159,7 @@ CLASSES = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', # colors for per classes COLORS = { cls: [random.randint(0, 255) for _ in range(3)] - for i, cls in enumerate(CLASSES) + for i, cls in enumerate(CLASSES_DET) } # colors for segment masks diff --git a/csrc/cls/normal/CMakeLists.txt b/csrc/cls/normal/CMakeLists.txt new file mode 100644 index 0000000..df45d49 --- /dev/null +++ b/csrc/cls/normal/CMakeLists.txt @@ -0,0 +1,56 @@ +cmake_minimum_required(VERSION 3.1) + +set(CMAKE_CUDA_ARCHITECTURES 60 61 62 70 72 75 86 89 90) +set(CMAKE_CUDA_COMPILER /usr/local/cuda/bin/nvcc) + +project(yolov8-cls LANGUAGES CXX CUDA) + +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++14 -O3") +set(CMAKE_CXX_STANDARD 14) +set(CMAKE_BUILD_TYPE Release) +option(CUDA_USE_STATIC_CUDA_RUNTIME OFF) + +# CUDA +find_package(CUDA REQUIRED) +message(STATUS "CUDA Libs: \n${CUDA_LIBRARIES}\n") +get_filename_component(CUDA_LIB_DIR ${CUDA_LIBRARIES} DIRECTORY) +message(STATUS "CUDA Headers: \n${CUDA_INCLUDE_DIRS}\n") + +# OpenCV +find_package(OpenCV REQUIRED) +message(STATUS "OpenCV Libs: \n${OpenCV_LIBS}\n") +message(STATUS "OpenCV Libraries: \n${OpenCV_LIBRARIES}\n") +message(STATUS "OpenCV Headers: \n${OpenCV_INCLUDE_DIRS}\n") + +# TensorRT +set(TensorRT_INCLUDE_DIRS /usr/include/x86_64-linux-gnu) +set(TensorRT_LIBRARIES /usr/lib/x86_64-linux-gnu) + + +message(STATUS "TensorRT Libs: \n${TensorRT_LIBRARIES}\n") +message(STATUS "TensorRT Headers: \n${TensorRT_INCLUDE_DIRS}\n") + +list(APPEND INCLUDE_DIRS + ${CUDA_INCLUDE_DIRS} + ${OpenCV_INCLUDE_DIRS} + ${TensorRT_INCLUDE_DIRS} + ./include + ) + +list(APPEND ALL_LIBS + ${CUDA_LIBRARIES} + ${CUDA_LIB_DIR} + ${OpenCV_LIBRARIES} + ${TensorRT_LIBRARIES} + ) + +include_directories(${INCLUDE_DIRS}) + +add_executable(${PROJECT_NAME} + main.cpp + include/yolov8-cls.hpp + include/common.hpp + ) + +target_link_directories(${PROJECT_NAME} PUBLIC ${ALL_LIBS}) +target_link_libraries(${PROJECT_NAME} PRIVATE nvinfer nvinfer_plugin cudart ${OpenCV_LIBS}) diff --git a/csrc/cls/normal/include/common.hpp b/csrc/cls/normal/include/common.hpp new file mode 100644 index 0000000..c78d4dd --- /dev/null +++ b/csrc/cls/normal/include/common.hpp @@ -0,0 +1,132 @@ +// +// Created by ubuntu on 4/27/24. +// + +#ifndef CLS_NORMAL_COMMON_HPP +#define CLS_NORMAL_COMMON_HPP +#include "NvInfer.h" +#include "opencv2/opencv.hpp" +#include +#include + +#define CHECK(call) \ + do { \ + const cudaError_t error_code = call; \ + if (error_code != cudaSuccess) { \ + printf("CUDA Error:\n"); \ + printf(" File: %s\n", __FILE__); \ + printf(" Line: %d\n", __LINE__); \ + printf(" Error code: %d\n", error_code); \ + printf(" Error text: %s\n", cudaGetErrorString(error_code)); \ + exit(1); \ + } \ + } while (0) + +class Logger: public nvinfer1::ILogger { +public: + nvinfer1::ILogger::Severity reportableSeverity; + + explicit Logger(nvinfer1::ILogger::Severity severity = nvinfer1::ILogger::Severity::kINFO): + reportableSeverity(severity) + { + } + + void log(nvinfer1::ILogger::Severity severity, const char* msg) noexcept override + { + if (severity > reportableSeverity) { + return; + } + switch (severity) { + case nvinfer1::ILogger::Severity::kINTERNAL_ERROR: + std::cerr << "INTERNAL_ERROR: "; + break; + case nvinfer1::ILogger::Severity::kERROR: + std::cerr << "ERROR: "; + break; + case nvinfer1::ILogger::Severity::kWARNING: + std::cerr << "WARNING: "; + break; + case nvinfer1::ILogger::Severity::kINFO: + std::cerr << "INFO: "; + break; + default: + std::cerr << "VERBOSE: "; + break; + } + std::cerr << msg << std::endl; + } +}; + +inline int get_size_by_dims(const nvinfer1::Dims& dims) +{ + int size = 1; + for (int i = 0; i < dims.nbDims; i++) { + size *= dims.d[i]; + } + return size; +} + +inline int type_to_size(const nvinfer1::DataType& dataType) +{ + switch (dataType) { + case nvinfer1::DataType::kFLOAT: + return 4; + case nvinfer1::DataType::kHALF: + return 2; + case nvinfer1::DataType::kINT32: + return 4; + case nvinfer1::DataType::kINT8: + return 1; + case nvinfer1::DataType::kBOOL: + return 1; + default: + return 4; + } +} + +inline static float clamp(float val, float min, float max) +{ + return val > min ? (val < max ? val : max) : min; +} + +inline bool IsPathExist(const std::string& path) +{ + if (access(path.c_str(), 0) == F_OK) { + return true; + } + return false; +} + +inline bool IsFile(const std::string& path) +{ + if (!IsPathExist(path)) { + printf("%s:%d %s not exist\n", __FILE__, __LINE__, path.c_str()); + return false; + } + struct stat buffer; + return (stat(path.c_str(), &buffer) == 0 && S_ISREG(buffer.st_mode)); +} + +inline bool IsFolder(const std::string& path) +{ + if (!IsPathExist(path)) { + return false; + } + struct stat buffer; + return (stat(path.c_str(), &buffer) == 0 && S_ISDIR(buffer.st_mode)); +} + +namespace cls { +struct Binding { + size_t size = 1; + size_t dsize = 1; + nvinfer1::Dims dims; + std::string name; +}; + +struct Object { + int label = 0; + float prob = 0.0; +}; +} // namespace cls +#endif // CLS_NORMAL_COMMON_HPP diff --git a/csrc/cls/normal/include/yolov8-cls.hpp b/csrc/cls/normal/include/yolov8-cls.hpp new file mode 100644 index 0000000..b5c8299 --- /dev/null +++ b/csrc/cls/normal/include/yolov8-cls.hpp @@ -0,0 +1,221 @@ +// +// Created by ubuntu on 4/27/24. +// +#ifndef CLS_NORMAL_YOLOv8_cls_HPP +#define CLS_NORMAL_YOLOv8_cls_HPP + +#include "NvInferPlugin.h" +#include "common.hpp" +#include "fstream" + +using namespace cls; + +class YOLOv8_cls { +public: + explicit YOLOv8_cls(const std::string& engine_file_path); + + ~YOLOv8_cls(); + + void make_pipe(bool warmup = true); + + void copy_from_Mat(const cv::Mat& image); + + void copy_from_Mat(const cv::Mat& image, cv::Size& size); + + void infer(); + + void postprocess(std::vector& objs); + + static void draw_objects(const cv::Mat& image, + cv::Mat& res, + const std::vector& objs, + const std::vector& CLASS_NAMES); + + int num_bindings; + int num_inputs = 0; + int num_outputs = 0; + std::vector input_bindings; + std::vector output_bindings; + std::vector host_ptrs; + std::vector device_ptrs; + +private: + nvinfer1::ICudaEngine* engine = nullptr; + nvinfer1::IRuntime* runtime = nullptr; + nvinfer1::IExecutionContext* context = nullptr; + cudaStream_t stream = nullptr; + Logger gLogger{nvinfer1::ILogger::Severity::kERROR}; +}; + +YOLOv8_cls::YOLOv8_cls(const std::string& engine_file_path) +{ + std::ifstream file(engine_file_path, std::ios::binary); + assert(file.good()); + file.seekg(0, std::ios::end); + auto size = file.tellg(); + file.seekg(0, std::ios::beg); + char* trtModelStream = new char[size]; + assert(trtModelStream); + file.read(trtModelStream, size); + file.close(); + initLibNvInferPlugins(&this->gLogger, ""); + this->runtime = nvinfer1::createInferRuntime(this->gLogger); + assert(this->runtime != nullptr); + + this->engine = this->runtime->deserializeCudaEngine(trtModelStream, size); + assert(this->engine != nullptr); + delete[] trtModelStream; + this->context = this->engine->createExecutionContext(); + + assert(this->context != nullptr); + cudaStreamCreate(&this->stream); + this->num_bindings = this->engine->getNbBindings(); + + for (int i = 0; i < this->num_bindings; ++i) { + Binding binding; + nvinfer1::Dims dims; + nvinfer1::DataType dtype = this->engine->getBindingDataType(i); + std::string name = this->engine->getBindingName(i); + binding.name = name; + binding.dsize = type_to_size(dtype); + + bool IsInput = engine->bindingIsInput(i); + if (IsInput) { + this->num_inputs += 1; + dims = this->engine->getProfileDimensions(i, 0, nvinfer1::OptProfileSelector::kMAX); + binding.size = get_size_by_dims(dims); + binding.dims = dims; + this->input_bindings.push_back(binding); + // set max opt shape + this->context->setBindingDimensions(i, dims); + } + else { + dims = this->context->getBindingDimensions(i); + binding.size = get_size_by_dims(dims); + binding.dims = dims; + this->output_bindings.push_back(binding); + this->num_outputs += 1; + } + } +} + +YOLOv8_cls::~YOLOv8_cls() +{ + this->context->destroy(); + this->engine->destroy(); + this->runtime->destroy(); + cudaStreamDestroy(this->stream); + for (auto& ptr : this->device_ptrs) { + CHECK(cudaFree(ptr)); + } + + for (auto& ptr : this->host_ptrs) { + CHECK(cudaFreeHost(ptr)); + } +} + +void YOLOv8_cls::make_pipe(bool warmup) +{ + + for (auto& bindings : this->input_bindings) { + void* d_ptr; + CHECK(cudaMallocAsync(&d_ptr, bindings.size * bindings.dsize, this->stream)); + this->device_ptrs.push_back(d_ptr); + } + + for (auto& bindings : this->output_bindings) { + void * d_ptr, *h_ptr; + size_t size = bindings.size * bindings.dsize; + CHECK(cudaMallocAsync(&d_ptr, size, this->stream)); + CHECK(cudaHostAlloc(&h_ptr, size, 0)); + this->device_ptrs.push_back(d_ptr); + this->host_ptrs.push_back(h_ptr); + } + + if (warmup) { + for (int i = 0; i < 10; i++) { + for (auto& bindings : this->input_bindings) { + size_t size = bindings.size * bindings.dsize; + void* h_ptr = malloc(size); + memset(h_ptr, 0, size); + CHECK(cudaMemcpyAsync(this->device_ptrs[0], h_ptr, size, cudaMemcpyHostToDevice, this->stream)); + free(h_ptr); + } + this->infer(); + } + printf("model warmup 10 times\n"); + } +} + +void YOLOv8_cls::copy_from_Mat(const cv::Mat& image) +{ + cv::Mat nchw; + auto& in_binding = this->input_bindings[0]; + auto width = in_binding.dims.d[3]; + auto height = in_binding.dims.d[2]; + + cv::dnn::blobFromImage(image, nchw, 1 / 255.f, cv::Size(width, height), cv::Scalar(0, 0, 0), true, false, CV_32F); + + this->context->setBindingDimensions(0, nvinfer1::Dims{4, {1, 3, height, width}}); + + CHECK(cudaMemcpyAsync( + this->device_ptrs[0], nchw.ptr(), nchw.total() * nchw.elemSize(), cudaMemcpyHostToDevice, this->stream)); +} + +void YOLOv8_cls::copy_from_Mat(const cv::Mat& image, cv::Size& size) +{ + cv::Mat nchw; + cv::dnn::blobFromImage(image, nchw, 1 / 255.f, size, cv::Scalar(0, 0, 0), true, false, CV_32F); + this->context->setBindingDimensions(0, nvinfer1::Dims{4, {1, 3, size.height, size.width}}); + CHECK(cudaMemcpyAsync( + this->device_ptrs[0], nchw.ptr(), nchw.total() * nchw.elemSize(), cudaMemcpyHostToDevice, this->stream)); +} + +void YOLOv8_cls::infer() +{ + + this->context->enqueueV2(this->device_ptrs.data(), this->stream, nullptr); + for (int i = 0; i < this->num_outputs; i++) { + size_t osize = this->output_bindings[i].size * this->output_bindings[i].dsize; + CHECK(cudaMemcpyAsync( + this->host_ptrs[i], this->device_ptrs[i + this->num_inputs], osize, cudaMemcpyDeviceToHost, this->stream)); + } + cudaStreamSynchronize(this->stream); +} + +void YOLOv8_cls::postprocess(std::vector& objs) +{ + objs.clear(); + auto num_cls = this->output_bindings[0].dims.d[1]; + + float* max_ptr = + std::max_element(static_cast(this->host_ptrs[0]), static_cast(this->host_ptrs[0]) + num_cls); + Object obj; + obj.label = std::distance(static_cast(this->host_ptrs[0]), max_ptr); + obj.prob = *max_ptr; + objs.push_back(obj); +} + +void YOLOv8_cls::draw_objects(const cv::Mat& image, + cv::Mat& res, + const std::vector& objs, + const std::vector& CLASS_NAMES) +{ + res = image.clone(); + char text[256]; + Object obj = objs[0]; + sprintf(text, "%s %.1f%%", CLASS_NAMES[obj.label].c_str(), obj.prob * 100); + + int baseLine = 0; + cv::Size label_size = cv::getTextSize(text, cv::FONT_HERSHEY_SIMPLEX, 0.4, 1, &baseLine); + int x = 10; + int y = 10; + + if (y > res.rows) + y = res.rows; + + cv::rectangle(res, cv::Rect(x, y, label_size.width, label_size.height + baseLine), {0, 0, 255}, -1); + cv::putText(res, text, cv::Point(x, y + label_size.height), cv::FONT_HERSHEY_SIMPLEX, 0.4, {255, 255, 255}, 1); +} + +#endif // CLS_NORMAL_YOLOv8_cls_HPP diff --git a/csrc/cls/normal/main.cpp b/csrc/cls/normal/main.cpp new file mode 100644 index 0000000..0c8aa0f --- /dev/null +++ b/csrc/cls/normal/main.cpp @@ -0,0 +1,1092 @@ +// +// Created by ubuntu on 4/27/24. +// +#include "chrono" +#include "opencv2/opencv.hpp" +#include "yolov8-cls.hpp" + +const std::vector CLASS_NAMES = {"tench", + "goldfish", + "great white shark", + "tiger shark", + "hammerhead shark", + "electric ray", + "stingray", + "cock", + "hen", + "ostrich", + "brambling", + "goldfinch", + "house finch", + "junco", + "indigo bunting", + "American robin", + "bulbul", + "jay", + "magpie", + "chickadee", + "American dipper", + "kite", + "bald eagle", + "vulture", + "great grey owl", + "fire salamander", + "smooth newt", + "newt", + "spotted salamander", + "axolotl", + "American bullfrog", + "tree frog", + "tailed frog", + "loggerhead sea turtle", + "leatherback sea turtle", + "mud turtle", + "terrapin", + "box turtle", + "banded gecko", + "green iguana", + "Carolina anole", + "desert grassland whiptail lizard", + "agama", + "frilled-necked lizard", + "alligator lizard", + "Gila monster", + "European green lizard", + "chameleon", + "Komodo dragon", + "Nile crocodile", + "American alligator", + "triceratops", + "worm snake", + "ring-necked snake", + "eastern hog-nosed snake", + "smooth green snake", + "kingsnake", + "garter snake", + "water snake", + "vine snake", + "night snake", + "boa constrictor", + "African rock python", + "Indian cobra", + "green mamba", + "sea snake", + "Saharan horned viper", + "eastern diamondback rattlesnake", + "sidewinder", + "trilobite", + "harvestman", + "scorpion", + "yellow garden spider", + "barn spider", + "European garden spider", + "southern black widow", + "tarantula", + "wolf spider", + "tick", + "centipede", + "black grouse", + "ptarmigan", + "ruffed grouse", + "prairie grouse", + "peacock", + "quail", + "partridge", + "grey parrot", + "macaw", + "sulphur-crested cockatoo", + "lorikeet", + "coucal", + "bee eater", + "hornbill", + "hummingbird", + "jacamar", + "toucan", + "duck", + "red-breasted merganser", + "goose", + "black swan", + "tusker", + "echidna", + "platypus", + "wallaby", + "koala", + "wombat", + "jellyfish", + "sea anemone", + "brain coral", + "flatworm", + "nematode", + "conch", + "snail", + "slug", + "sea slug", + "chiton", + "chambered nautilus", + "Dungeness crab", + "rock crab", + "fiddler crab", + "red king crab", + "American lobster", + "spiny lobster", + "crayfish", + "hermit crab", + "isopod", + "white stork", + "black stork", + "spoonbill", + "flamingo", + "little blue heron", + "great egret", + "bittern", + "crane (bird)", + "limpkin", + "common gallinule", + "American coot", + "bustard", + "ruddy turnstone", + "dunlin", + "common redshank", + "dowitcher", + "oystercatcher", + "pelican", + "king penguin", + "albatross", + "grey whale", + "killer whale", + "dugong", + "sea lion", + "Chihuahua", + "Japanese Chin", + "Maltese", + "Pekingese", + "Shih Tzu", + "King Charles Spaniel", + "Papillon", + "toy terrier", + "Rhodesian Ridgeback", + "Afghan Hound", + "Basset Hound", + "Beagle", + "Bloodhound", + "Bluetick Coonhound", + "Black and Tan Coonhound", + "Treeing Walker Coonhound", + "English foxhound", + "Redbone Coonhound", + "borzoi", + "Irish Wolfhound", + "Italian Greyhound", + "Whippet", + "Ibizan Hound", + "Norwegian Elkhound", + "Otterhound", + "Saluki", + "Scottish Deerhound", + "Weimaraner", + "Staffordshire Bull Terrier", + "American Staffordshire Terrier", + "Bedlington Terrier", + "Border Terrier", + "Kerry Blue Terrier", + "Irish Terrier", + "Norfolk Terrier", + "Norwich Terrier", + "Yorkshire Terrier", + "Wire Fox Terrier", + "Lakeland Terrier", + "Sealyham Terrier", + "Airedale Terrier", + "Cairn Terrier", + "Australian Terrier", + "Dandie Dinmont Terrier", + "Boston Terrier", + "Miniature Schnauzer", + "Giant Schnauzer", + "Standard Schnauzer", + "Scottish Terrier", + "Tibetan Terrier", + "Australian Silky Terrier", + "Soft-coated Wheaten Terrier", + "West Highland White Terrier", + "Lhasa Apso", + "Flat-Coated Retriever", + "Curly-coated Retriever", + "Golden Retriever", + "Labrador Retriever", + "Chesapeake Bay Retriever", + "German Shorthaired Pointer", + "Vizsla", + "English Setter", + "Irish Setter", + "Gordon Setter", + "Brittany", + "Clumber Spaniel", + "English Springer Spaniel", + "Welsh Springer Spaniel", + "Cocker Spaniels", + "Sussex Spaniel", + "Irish Water Spaniel", + "Kuvasz", + "Schipperke", + "Groenendael", + "Malinois", + "Briard", + "Australian Kelpie", + "Komondor", + "Old English Sheepdog", + "Shetland Sheepdog", + "collie", + "Border Collie", + "Bouvier des Flandres", + "Rottweiler", + "German Shepherd Dog", + "Dobermann", + "Miniature Pinscher", + "Greater Swiss Mountain Dog", + "Bernese Mountain Dog", + "Appenzeller Sennenhund", + "Entlebucher Sennenhund", + "Boxer", + "Bullmastiff", + "Tibetan Mastiff", + "French Bulldog", + "Great Dane", + "St. Bernard", + "husky", + "Alaskan Malamute", + "Siberian Husky", + "Dalmatian", + "Affenpinscher", + "Basenji", + "pug", + "Leonberger", + "Newfoundland", + "Pyrenean Mountain Dog", + "Samoyed", + "Pomeranian", + "Chow Chow", + "Keeshond", + "Griffon Bruxellois", + "Pembroke Welsh Corgi", + "Cardigan Welsh Corgi", + "Toy Poodle", + "Miniature Poodle", + "Standard Poodle", + "Mexican hairless dog", + "grey wolf", + "Alaskan tundra wolf", + "red wolf", + "coyote", + "dingo", + "dhole", + "African wild dog", + "hyena", + "red fox", + "kit fox", + "Arctic fox", + "grey fox", + "tabby cat", + "tiger cat", + "Persian cat", + "Siamese cat", + "Egyptian Mau", + "cougar", + "lynx", + "leopard", + "snow leopard", + "jaguar", + "lion", + "tiger", + "cheetah", + "brown bear", + "American black bear", + "polar bear", + "sloth bear", + "mongoose", + "meerkat", + "tiger beetle", + "ladybug", + "ground beetle", + "longhorn beetle", + "leaf beetle", + "dung beetle", + "rhinoceros beetle", + "weevil", + "fly", + "bee", + "ant", + "grasshopper", + "cricket", + "stick insect", + "cockroach", + "mantis", + "cicada", + "leafhopper", + "lacewing", + "dragonfly", + "damselfly", + "red admiral", + "ringlet", + "monarch butterfly", + "small white", + "sulphur butterfly", + "gossamer-winged butterfly", + "starfish", + "sea urchin", + "sea cucumber", + "cottontail rabbit", + "hare", + "Angora rabbit", + "hamster", + "porcupine", + "fox squirrel", + "marmot", + "beaver", + "guinea pig", + "common sorrel", + "zebra", + "pig", + "wild boar", + "warthog", + "hippopotamus", + "ox", + "water buffalo", + "bison", + "ram", + "bighorn sheep", + "Alpine ibex", + "hartebeest", + "impala", + "gazelle", + "dromedary", + "llama", + "weasel", + "mink", + "European polecat", + "black-footed ferret", + "otter", + "skunk", + "badger", + "armadillo", + "three-toed sloth", + "orangutan", + "gorilla", + "chimpanzee", + "gibbon", + "siamang", + "guenon", + "patas monkey", + "baboon", + "macaque", + "langur", + "black-and-white colobus", + "proboscis monkey", + "marmoset", + "white-headed capuchin", + "howler monkey", + "titi", + "Geoffroy\"s spider monkey", + "common squirrel monkey", + "ring-tailed lemur", + "indri", + "Asian elephant", + "African bush elephant", + "red panda", + "giant panda", + "snoek", + "eel", + "coho salmon", + "rock beauty", + "clownfish", + "sturgeon", + "garfish", + "lionfish", + "pufferfish", + "abacus", + "abaya", + "academic gown", + "accordion", + "acoustic guitar", + "aircraft carrier", + "airliner", + "airship", + "altar", + "ambulance", + "amphibious vehicle", + "analog clock", + "apiary", + "apron", + "waste container", + "assault rifle", + "backpack", + "bakery", + "balance beam", + "balloon", + "ballpoint pen", + "Band-Aid", + "banjo", + "baluster", + "barbell", + "barber chair", + "barbershop", + "barn", + "barometer", + "barrel", + "wheelbarrow", + "baseball", + "basketball", + "bassinet", + "bassoon", + "swimming cap", + "bath towel", + "bathtub", + "station wagon", + "lighthouse", + "beaker", + "military cap", + "beer bottle", + "beer glass", + "bell-cot", + "bib", + "tandem bicycle", + "bikini", + "ring binder", + "binoculars", + "birdhouse", + "boathouse", + "bobsleigh", + "bolo tie", + "poke bonnet", + "bookcase", + "bookstore", + "bottle cap", + "bow", + "bow tie", + "brass", + "bra", + "breakwater", + "breastplate", + "broom", + "bucket", + "buckle", + "bulletproof vest", + "high-speed train", + "butcher shop", + "taxicab", + "cauldron", + "candle", + "cannon", + "canoe", + "can opener", + "cardigan", + "car mirror", + "carousel", + "tool kit", + "carton", + "car wheel", + "automated teller machine", + "cassette", + "cassette player", + "castle", + "catamaran", + "CD player", + "cello", + "mobile phone", + "chain", + "chain-link fence", + "chain mail", + "chainsaw", + "chest", + "chiffonier", + "chime", + "china cabinet", + "Christmas stocking", + "church", + "movie theater", + "cleaver", + "cliff dwelling", + "cloak", + "clogs", + "cocktail shaker", + "coffee mug", + "coffeemaker", + "coil", + "combination lock", + "computer keyboard", + "confectionery store", + "container ship", + "convertible", + "corkscrew", + "cornet", + "cowboy boot", + "cowboy hat", + "cradle", + "crane (machine)", + "crash helmet", + "crate", + "infant bed", + "Crock Pot", + "croquet ball", + "crutch", + "cuirass", + "dam", + "desk", + "desktop computer", + "rotary dial telephone", + "diaper", + "digital clock", + "digital watch", + "dining table", + "dishcloth", + "dishwasher", + "disc brake", + "dock", + "dog sled", + "dome", + "doormat", + "drilling rig", + "drum", + "drumstick", + "dumbbell", + "Dutch oven", + "electric fan", + "electric guitar", + "electric locomotive", + "entertainment center", + "envelope", + "espresso machine", + "face powder", + "feather boa", + "filing cabinet", + "fireboat", + "fire engine", + "fire screen sheet", + "flagpole", + "flute", + "folding chair", + "football helmet", + "forklift", + "fountain", + "fountain pen", + "four-poster bed", + "freight car", + "French horn", + "frying pan", + "fur coat", + "garbage truck", + "gas mask", + "gas pump", + "goblet", + "go-kart", + "golf ball", + "golf cart", + "gondola", + "gong", + "gown", + "grand piano", + "greenhouse", + "grille", + "grocery store", + "guillotine", + "barrette", + "hair spray", + "half-track", + "hammer", + "hamper", + "hair dryer", + "hand-held computer", + "handkerchief", + "hard disk drive", + "harmonica", + "harp", + "harvester", + "hatchet", + "holster", + "home theater", + "honeycomb", + "hook", + "hoop skirt", + "horizontal bar", + "horse-drawn vehicle", + "hourglass", + "iPod", + "clothes iron", + "jack-o\"-lantern", + "jeans", + "jeep", + "T-shirt", + "jigsaw puzzle", + "pulled rickshaw", + "joystick", + "kimono", + "knee pad", + "knot", + "lab coat", + "ladle", + "lampshade", + "laptop computer", + "lawn mower", + "lens cap", + "paper knife", + "library", + "lifeboat", + "lighter", + "limousine", + "ocean liner", + "lipstick", + "slip-on shoe", + "lotion", + "speaker", + "loupe", + "sawmill", + "magnetic compass", + "mail bag", + "mailbox", + "tights", + "tank suit", + "manhole cover", + "maraca", + "marimba", + "mask", + "match", + "maypole", + "maze", + "measuring cup", + "medicine chest", + "megalith", + "microphone", + "microwave oven", + "military uniform", + "milk can", + "minibus", + "miniskirt", + "minivan", + "missile", + "mitten", + "mixing bowl", + "mobile home", + "Model T", + "modem", + "monastery", + "monitor", + "moped", + "mortar", + "square academic cap", + "mosque", + "mosquito net", + "scooter", + "mountain bike", + "tent", + "computer mouse", + "mousetrap", + "moving van", + "muzzle", + "nail", + "neck brace", + "necklace", + "nipple", + "notebook computer", + "obelisk", + "oboe", + "ocarina", + "odometer", + "oil filter", + "organ", + "oscilloscope", + "overskirt", + "bullock cart", + "oxygen mask", + "packet", + "paddle", + "paddle wheel", + "padlock", + "paintbrush", + "pajamas", + "palace", + "pan flute", + "paper towel", + "parachute", + "parallel bars", + "park bench", + "parking meter", + "passenger car", + "patio", + "payphone", + "pedestal", + "pencil case", + "pencil sharpener", + "perfume", + "Petri dish", + "photocopier", + "plectrum", + "Pickelhaube", + "picket fence", + "pickup truck", + "pier", + "piggy bank", + "pill bottle", + "pillow", + "ping-pong ball", + "pinwheel", + "pirate ship", + "pitcher", + "hand plane", + "planetarium", + "plastic bag", + "plate rack", + "plow", + "plunger", + "Polaroid camera", + "pole", + "police van", + "poncho", + "billiard table", + "soda bottle", + "pot", + "potter\"s wheel", + "power drill", + "prayer rug", + "printer", + "prison", + "projectile", + "projector", + "hockey puck", + "punching bag", + "purse", + "quill", + "quilt", + "race car", + "racket", + "radiator", + "radio", + "radio telescope", + "rain barrel", + "recreational vehicle", + "reel", + "reflex camera", + "refrigerator", + "remote control", + "restaurant", + "revolver", + "rifle", + "rocking chair", + "rotisserie", + "eraser", + "rugby ball", + "ruler", + "running shoe", + "safe", + "safety pin", + "salt shaker", + "sandal", + "sarong", + "saxophone", + "scabbard", + "weighing scale", + "school bus", + "schooner", + "scoreboard", + "CRT screen", + "screw", + "screwdriver", + "seat belt", + "sewing machine", + "shield", + "shoe store", + "shoji", + "shopping basket", + "shopping cart", + "shovel", + "shower cap", + "shower curtain", + "ski", + "ski mask", + "sleeping bag", + "slide rule", + "sliding door", + "slot machine", + "snorkel", + "snowmobile", + "snowplow", + "soap dispenser", + "soccer ball", + "sock", + "solar thermal collector", + "sombrero", + "soup bowl", + "space bar", + "space heater", + "space shuttle", + "spatula", + "motorboat", + "spider web", + "spindle", + "sports car", + "spotlight", + "stage", + "steam locomotive", + "through arch bridge", + "steel drum", + "stethoscope", + "scarf", + "stone wall", + "stopwatch", + "stove", + "strainer", + "tram", + "stretcher", + "couch", + "stupa", + "submarine", + "suit", + "sundial", + "sunglass", + "sunglasses", + "sunscreen", + "suspension bridge", + "mop", + "sweatshirt", + "swimsuit", + "swing", + "switch", + "syringe", + "table lamp", + "tank", + "tape player", + "teapot", + "teddy bear", + "television", + "tennis ball", + "thatched roof", + "front curtain", + "thimble", + "threshing machine", + "throne", + "tile roof", + "toaster", + "tobacco shop", + "toilet seat", + "torch", + "totem pole", + "tow truck", + "toy store", + "tractor", + "semi-trailer truck", + "tray", + "trench coat", + "tricycle", + "trimaran", + "tripod", + "triumphal arch", + "trolleybus", + "trombone", + "tub", + "turnstile", + "typewriter keyboard", + "umbrella", + "unicycle", + "upright piano", + "vacuum cleaner", + "vase", + "vault", + "velvet", + "vending machine", + "vestment", + "viaduct", + "violin", + "volleyball", + "waffle iron", + "wall clock", + "wallet", + "wardrobe", + "military aircraft", + "sink", + "washing machine", + "water bottle", + "water jug", + "water tower", + "whiskey jug", + "whistle", + "wig", + "window screen", + "window shade", + "Windsor tie", + "wine bottle", + "wing", + "wok", + "wooden spoon", + "wool", + "split-rail fence", + "shipwreck", + "yawl", + "yurt", + "website", + "comic book", + "crossword", + "traffic sign", + "traffic light", + "dust jacket", + "menu", + "plate", + "guacamole", + "consomme", + "hot pot", + "trifle", + "ice cream", + "ice pop", + "baguette", + "bagel", + "pretzel", + "cheeseburger", + "hot dog", + "mashed potato", + "cabbage", + "broccoli", + "cauliflower", + "zucchini", + "spaghetti squash", + "acorn squash", + "butternut squash", + "cucumber", + "artichoke", + "bell pepper", + "cardoon", + "mushroom", + "Granny Smith", + "strawberry", + "orange", + "lemon", + "fig", + "pineapple", + "banana", + "jackfruit", + "custard apple", + "pomegranate", + "hay", + "carbonara", + "chocolate syrup", + "dough", + "meatloaf", + "pizza", + "pot pie", + "burrito", + "red wine", + "espresso", + "cup", + "eggnog", + "alp", + "bubble", + "cliff", + "coral reef", + "geyser", + "lakeshore", + "promontory", + "shoal", + "seashore", + "valley", + "volcano", + "baseball player", + "bridegroom", + "scuba diver", + "rapeseed", + "daisy", + "yellow lady\"s slipper", + "corn", + "acorn", + "rose hip", + "horse chestnut seed", + "coral fungus", + "agaric", + "gyromitra", + "stinkhorn mushroom", + "earth star", + "hen-of-the-woods", + "bolete", + "ear", + "toilet paper"}; + +int main(int argc, char** argv) +{ + // cuda:0 + cudaSetDevice(0); + + const std::string engine_file_path{argv[1]}; + const std::string path{argv[2]}; + + std::vector imagePathList; + bool isVideo{false}; + + assert(argc == 3); + + auto yolov8_cls = new YOLOv8_cls(engine_file_path); + yolov8_cls->make_pipe(true); + + if (IsFile(path)) { + std::string suffix = path.substr(path.find_last_of('.') + 1); + if (suffix == "jpg" || suffix == "jpeg" || suffix == "png") { + imagePathList.push_back(path); + } + else if (suffix == "mp4" || suffix == "avi" || suffix == "m4v" || suffix == "mpeg" || suffix == "mov" + || suffix == "mkv") { + isVideo = true; + } + else { + printf("suffix %s is wrong !!!\n", suffix.c_str()); + std::abort(); + } + } + else if (IsFolder(path)) { + cv::glob(path + "/*.jpg", imagePathList); + } + + cv::Mat res, image; + cv::Size size = cv::Size{224, 224}; + + std::vector objs; + + cv::namedWindow("result", cv::WINDOW_AUTOSIZE); + + if (isVideo) { + cv::VideoCapture cap(path); + + if (!cap.isOpened()) { + printf("can not open %s\n", path.c_str()); + return -1; + } + while (cap.read(image)) { + objs.clear(); + yolov8_cls->copy_from_Mat(image, size); + auto start = std::chrono::system_clock::now(); + yolov8_cls->infer(); + auto end = std::chrono::system_clock::now(); + yolov8_cls->postprocess(objs); + yolov8_cls->draw_objects(image, res, objs, CLASS_NAMES); + auto tc = (double)std::chrono::duration_cast(end - start).count() / 1000.; + printf("cost %2.4lf ms\n", tc); + cv::imshow("result", res); + if (cv::waitKey(10) == 'q') { + break; + } + } + } + else { + for (auto& path : imagePathList) { + objs.clear(); + image = cv::imread(path); + yolov8_cls->copy_from_Mat(image, size); + auto start = std::chrono::system_clock::now(); + yolov8_cls->infer(); + auto end = std::chrono::system_clock::now(); + yolov8_cls->postprocess(objs); + yolov8_cls->draw_objects(image, res, objs, CLASS_NAMES); + auto tc = (double)std::chrono::duration_cast(end - start).count() / 1000.; + printf("cost %2.4lf ms\n", tc); + cv::imshow("result", res); + cv::waitKey(0); + } + } + cv::destroyAllWindows(); + delete yolov8_cls; + return 0; +} diff --git a/docs/Cls.md b/docs/Cls.md new file mode 100644 index 0000000..cba1bfa --- /dev/null +++ b/docs/Cls.md @@ -0,0 +1,129 @@ +# YOLOv8-cls Model with TensorRT + +The yolov8-cls model conversion route is : +YOLOv8 PyTorch model -> ONNX -> TensorRT Engine + +***Notice !!!*** We don't support TensorRT API building !!! + +# Export Orin ONNX model by ultralytics + +You can leave this repo and use the original `ultralytics` repo for onnx export. + +### 1. ONNX -> TensorRT + +You can export your onnx model by `ultralytics` API. + +``` shell +yolo export model=yolov8s-cls.pt format=onnx opset=11 simplify=True +``` + +or run this python script: + +```python +from ultralytics import YOLO + +# Load a model +model = YOLO("yolov8s-cls.pt") # load a pretrained model (recommended for training) +success = model.export(format="onnx", opset=11, simplify=True) # export the model to onnx format +assert success +``` + +Then build engine by Trtexec Tools. + +You can export TensorRT engine by [`trtexec`](https://github.com/NVIDIA/TensorRT/tree/main/samples/trtexec) tools. + +Usage: + +``` shell +/usr/src/tensorrt/bin/trtexec \ +--onnx=yolov8s-cls.onnx \ +--saveEngine=yolov8s-cls.engine \ +--fp16 +``` + +### 2. Direct to TensorRT (NOT RECOMMAND!!) + +Usage: + +```shell +yolo export model=yolov8s-cls.pt format=engine device=0 +``` + +or run python script: + +```python +from ultralytics import YOLO + +# Load a model +model = YOLO("yolov8s-cls.pt") # load a pretrained model (recommended for training) +success = model.export(format="engine", device=0) # export the model to engine format +assert success +``` + +After executing the above script, you will get an engine named `yolov8s-cls.engine` . + +# Inference + +## Infer with python script + +You can infer images with the engine by [`infer-cls.py`](../infer-cls.py) . + +Usage: + +``` shell +python3 infer-cls.py \ +--engine yolov8s-cls.engine \ +--imgs data \ +--show \ +--out-dir outputs \ +--device cuda:0 +``` + +#### Description of all arguments + +- `--engine` : The Engine you export. +- `--imgs` : The images path you want to detect. +- `--show` : Whether to show detection results. +- `--out-dir` : Where to save detection results images. It will not work when use `--show` flag. +- `--device` : The CUDA deivce you use. + +## Inference with c++ + +You can infer with c++ in [`csrc/cls/normal`](../csrc/cls/normal) . + +### Build: + +Please set you own librarys in [`CMakeLists.txt`](../csrc/cls/normal/CMakeLists.txt) and modify `KPS_COLORS` +and `SKELETON` and `LIMB_COLORS` in [`main.cpp`](../csrc/cls/normal/main.cpp). + +Besides, you can modify the postprocess parameters such as `score_thres` and `iou_thres` and `topk` +in [`main.cpp`](../csrc/cls/normal/main.cpp). + +```c++ +int topk = 100; +float score_thres = 0.25f; +float iou_thres = 0.65f; +``` + +And build: + +``` shell +export root=${PWD} +cd src/cls/normal +mkdir build +cmake .. +make +mv yolov8-cls ${root} +cd ${root} +``` + +Usage: + +``` shell +# infer image +./yolov8-cls yolov8s-cls.engine data/bus.jpg +# infer images +./yolov8-cls yolov8s-cls.engine data +# infer video +./yolov8-cls yolov8s-cls.engine data/test.mp4 # the video path +``` diff --git a/infer-cls-without-torch.py b/infer-cls-without-torch.py new file mode 100644 index 0000000..fbba8e3 --- /dev/null +++ b/infer-cls-without-torch.py @@ -0,0 +1,79 @@ +import argparse +from pathlib import Path + +import cv2 +import numpy as np + +from config import CLASSES_CLS +from models.utils import blob, path_to_list + + +def main(args: argparse.Namespace) -> None: + if args.method == 'cudart': + from models.cudart_api import TRTEngine + elif args.method == 'pycuda': + from models.pycuda_api import TRTEngine + else: + raise NotImplementedError + + Engine = TRTEngine(args.engine) + H, W = Engine.inp_info[0].shape[-2:] + + images = path_to_list(args.imgs) + save_path = Path(args.out_dir) + + if not args.show and not save_path.exists(): + save_path.mkdir(parents=True, exist_ok=True) + + for image in images: + save_image = save_path / image.name + bgr = cv2.imread(str(image)) + draw = bgr.copy() + bgr = cv2.resize(bgr, (W, H)) + rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB) + tensor = blob(rgb, return_seg=False) + tensor = np.ascontiguousarray(tensor) + # inference + data = Engine(tensor) + + data = data[0] + score = data.max().item() + cls_id = data.argmax().item() + cls = CLASSES_CLS[cls_id] + + text = f'{cls}:{score:.3f}' + (_w, _h), _bl = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.8, 1) + _y1 = min(10, draw.shape[0]) + + cv2.rectangle(draw, (10, _y1), (10 + _w, _y1 + _h + _bl), (0, 0, 255), -1) + cv2.putText(draw, text, (10, _y1 + _h), cv2.FONT_HERSHEY_SIMPLEX, 0.75, (255, 255, 255), 2) + + if args.show: + cv2.imshow('result', draw) + cv2.waitKey(0) + else: + cv2.imwrite(str(save_image), draw) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--engine', type=str, help='Engine file') + parser.add_argument('--imgs', type=str, help='Images file') + parser.add_argument('--show', + action='store_true', + help='Show the detection results') + parser.add_argument('--out-dir', + type=str, + default='./output', + help='Path to output file') + parser.add_argument('--method', + type=str, + default='cudart', + help='CUDART pipeline') + args = parser.parse_args() + return args + + +if __name__ == '__main__': + args = parse_args() + main(args) diff --git a/infer-cls.py b/infer-cls.py new file mode 100644 index 0000000..188d127 --- /dev/null +++ b/infer-cls.py @@ -0,0 +1,74 @@ +from models import TRTModule # isort:skip +import argparse +from pathlib import Path + +import cv2 +import torch + +from config import CLASSES_CLS +from models.utils import blob, path_to_list + + +def main(args: argparse.Namespace) -> None: + device = torch.device(args.device) + Engine = TRTModule(args.engine, device) + H, W = Engine.inp_info[0].shape[-2:] + + images = path_to_list(args.imgs) + save_path = Path(args.out_dir) + + if not args.show and not save_path.exists(): + save_path.mkdir(parents=True, exist_ok=True) + + for image in images: + save_image = save_path / image.name + bgr = cv2.imread(str(image)) + draw = bgr.copy() + bgr = cv2.resize(bgr, (W, H)) + rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB) + tensor = blob(rgb, return_seg=False) + tensor = torch.asarray(tensor, device=device) + # inference + data = Engine(tensor) + + score, cls_id = data[0].max(0) + score = float(score) + cls_id = int(cls_id) + cls = CLASSES_CLS[cls_id] + + text = f'{cls}:{score:.3f}' + (_w, _h), _bl = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.8, 1) + _y1 = min(10, draw.shape[0]) + + cv2.rectangle(draw, (10, _y1), (10 + _w, _y1 + _h + _bl), (0, 0, 255), -1) + cv2.putText(draw, text, (10, _y1 + _h), cv2.FONT_HERSHEY_SIMPLEX, 0.75, (255, 255, 255), 2) + + if args.show: + cv2.imshow('result', draw) + cv2.waitKey(0) + else: + cv2.imwrite(str(save_image), draw) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser() + parser.add_argument('--engine', type=str, help='Engine file') + parser.add_argument('--imgs', type=str, help='Images file') + parser.add_argument('--show', + action='store_true', + help='Show the detection results') + parser.add_argument('--out-dir', + type=str, + default='./output', + help='Path to output file') + parser.add_argument('--device', + type=str, + default='cuda:0', + help='TensorRT infer device') + args = parser.parse_args() + return args + + +if __name__ == '__main__': + args = parse_args() + main(args) diff --git a/infer-det-without-torch.py b/infer-det-without-torch.py index 81055ce..fe4eaf2 100644 --- a/infer-det-without-torch.py +++ b/infer-det-without-torch.py @@ -4,7 +4,7 @@ from pathlib import Path import cv2 import numpy as np -from config import CLASSES, COLORS +from config import CLASSES_DET, COLORS from models.utils import blob, det_postprocess, letterbox, path_to_list @@ -48,14 +48,19 @@ def main(args: argparse.Namespace) -> None: for (bbox, score, label) in zip(bboxes, scores, labels): bbox = bbox.round().astype(np.int32).tolist() cls_id = int(label) - cls = CLASSES[cls_id] + cls = CLASSES_DET[cls_id] color = COLORS[cls] - cv2.rectangle(draw, bbox[:2], bbox[2:], color, 2) - cv2.putText(draw, - f'{cls}:{score:.3f}', (bbox[0], bbox[1] - 2), - cv2.FONT_HERSHEY_SIMPLEX, - 0.75, [225, 255, 255], - thickness=2) + + text = f'{cls}:{score:.3f}' + x1, y1, x2, y2 = bbox + + (_w, _h), _bl = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.8, 1) + _y1 = min(y1 + 1, draw.shape[0]) + + cv2.rectangle(draw, (x1, y1), (x2, y2), color, 2) + cv2.rectangle(draw, (x1, _y1), (x1 + _w, _y1 + _h + _bl), (0, 0, 255), -1) + cv2.putText(draw, text, (x1, _y1 + _h), cv2.FONT_HERSHEY_SIMPLEX, 0.75, (255, 255, 255), 2) + if args.show: cv2.imshow('result', draw) cv2.waitKey(0) diff --git a/infer-det.py b/infer-det.py index ab9586d..cb54def 100644 --- a/infer-det.py +++ b/infer-det.py @@ -5,7 +5,7 @@ from pathlib import Path import cv2 import torch -from config import CLASSES, COLORS +from config import CLASSES_DET, COLORS from models.torch_utils import det_postprocess from models.utils import blob, letterbox, path_to_list @@ -47,14 +47,19 @@ def main(args: argparse.Namespace) -> None: for (bbox, score, label) in zip(bboxes, scores, labels): bbox = bbox.round().int().tolist() cls_id = int(label) - cls = CLASSES[cls_id] + cls = CLASSES_DET[cls_id] color = COLORS[cls] - cv2.rectangle(draw, bbox[:2], bbox[2:], color, 2) - cv2.putText(draw, - f'{cls}:{score:.3f}', (bbox[0], bbox[1] - 2), - cv2.FONT_HERSHEY_SIMPLEX, - 0.75, [225, 255, 255], - thickness=2) + + text = f'{cls}:{score:.3f}' + x1, y1, x2, y2 = bbox + + (_w, _h), _bl = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.8, 1) + _y1 = min(y1 + 1, draw.shape[0]) + + cv2.rectangle(draw, (x1, y1), (x2, y2), color, 2) + cv2.rectangle(draw, (x1, _y1), (x1 + _w, _y1 + _h + _bl), (0, 0, 255), -1) + cv2.putText(draw, text, (x1, _y1 + _h), cv2.FONT_HERSHEY_SIMPLEX, 0.75, (255, 255, 255), 2) + if args.show: cv2.imshow('result', draw) cv2.waitKey(0) diff --git a/models/torch_utils.py b/models/torch_utils.py index b25ede9..45541c7 100644 --- a/models/torch_utils.py +++ b/models/torch_utils.py @@ -70,9 +70,6 @@ def det_postprocess(data: Tuple[Tensor, Tensor, Tensor, Tensor]): (0, )), labels.new_zeros((0, )) # check score negative scores[scores < 0] = 1 + scores[scores < 0] - # add nms - idx = nms(bboxes, scores, iou_thres) - bboxes, scores, labels = bboxes[idx], scores[idx], labels[idx] bboxes = bboxes[:nums] scores = scores[:nums] labels = labels[:nums] diff --git a/models/utils.py b/models/utils.py index 7a545cf..50b5a70 100644 --- a/models/utils.py +++ b/models/utils.py @@ -199,7 +199,6 @@ def crop_mask(masks: ndarray, bboxes: ndarray) -> ndarray: def det_postprocess(data: Tuple[ndarray, ndarray, ndarray, ndarray]): assert len(data) == 4 - iou_thres: float = 0.65 num_dets, bboxes, scores, labels = (i[0] for i in data) nums = num_dets.item() if nums == 0: @@ -207,10 +206,6 @@ def det_postprocess(data: Tuple[ndarray, ndarray, ndarray, ndarray]): (0, ), dtype=np.float32), np.empty((0, ), dtype=np.int32) # check score negative scores[scores < 0] = 1 + scores[scores < 0] - # add nms - idx = nms(bboxes, scores, iou_thres) - bboxes, scores, labels = bboxes[idx], scores[idx], labels[idx] - bboxes = bboxes[:nums] scores = scores[:nums] labels = labels[:nums]