diff --git a/tests/test_cli.py b/tests/test_cli.py index b4a09fcf3..f90d0e9e0 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -40,14 +40,6 @@ def test_train(task, model, data): @pytest.mark.parametrize('task,model,data', TASK_ARGS) def test_val(task, model, data): - # Download annotations to run pycocotools eval - # from ultralytics.utils import SETTINGS, Path - # from ultralytics.utils.downloads import download - # url = 'https://github.com/ultralytics/assets/releases/download/v0.0.0/' - # download(f'{url}instances_val2017.json', dir=Path(SETTINGS['datasets_dir']) / 'coco8/annotations') - # download(f'{url}person_keypoints_val2017.json', dir=Path(SETTINGS['datasets_dir']) / 'coco8-pose/annotations') - - # Validate run(f'yolo val {task} model={WEIGHTS_DIR / model}.pt data={data} imgsz=32 save_txt save_json') diff --git a/tests/test_cuda.py b/tests/test_cuda.py index 7f8936726..da9cc084e 100644 --- a/tests/test_cuda.py +++ b/tests/test_cuda.py @@ -1,16 +1,18 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license +import contextlib import subprocess from pathlib import Path import pytest import torch -from ultralytics import YOLO +from ultralytics import YOLO, download from ultralytics.utils import ASSETS, SETTINGS CUDA_IS_AVAILABLE = torch.cuda.is_available() CUDA_DEVICE_COUNT = torch.cuda.device_count() +DATASETS_DIR = Path(SETTINGS['datasets_dir']) WEIGHTS_DIR = Path(SETTINGS['weights_dir']) MODEL = WEIGHTS_DIR / 'path with spaces' / 'yolov8n.pt' # test spaces in path DATA = 'coco8.yaml' @@ -37,13 +39,15 @@ def test_train_ddp(): def test_utils_benchmarks(): from ultralytics.utils.benchmarks import ProfileModels - YOLO(MODEL).export(format='engine', imgsz=32, dynamic=True, batch=1) # pre-export engine model, auto-device + # Pre-export a dynamic engine model to use dynamic inference + YOLO(MODEL).export(format='engine', imgsz=32, dynamic=True, batch=1) ProfileModels([MODEL], imgsz=32, half=False, min_time=1, num_timed_runs=3, num_warmup_runs=1).profile() @pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason='CUDA is not available') def test_predict_sam(): from ultralytics import SAM + from ultralytics.models.sam import Predictor as SAMPredictor # Load a model model = SAM(WEIGHTS_DIR / 'sam_b.pt') @@ -60,14 +64,63 @@ def test_predict_sam(): # Run inference with points prompt model(ASSETS / 'zidane.jpg', points=[900, 370], labels=[1], device=0) + # Create SAMPredictor + overrides = dict(conf=0.25, task='segment', mode='predict', imgsz=1024, model='mobile_sam.pt') + predictor = SAMPredictor(overrides=overrides) + + # Set image + predictor.set_image('ultralytics/assets/zidane.jpg') # set with image file + # predictor(bboxes=[439, 437, 524, 709]) + # predictor(points=[900, 370], labels=[1]) + + # Reset image + predictor.reset_image() + @pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason='CUDA is not available') def test_model_tune(): subprocess.run('pip install ray[tune]'.split(), check=True) - YOLO('yolov8n-cls.yaml').tune(data='imagenet10', - grace_period=1, - max_samples=1, - imgsz=32, - epochs=1, - plots=False, - device='cpu') + with contextlib.suppress(RuntimeError): # RuntimeError may be caused by out-of-memory + YOLO('yolov8n-cls.yaml').tune(data='imagenet10', + grace_period=1, + max_samples=1, + imgsz=32, + epochs=1, + plots=False, + device='cpu') + + +@pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason='CUDA is not available') +def test_pycocotools(): + from ultralytics.models.yolo.detect import DetectionValidator + from ultralytics.models.yolo.pose import PoseValidator + from ultralytics.models.yolo.segment import SegmentationValidator + + # Download annotations after each dataset downloads first + url = 'https://github.com/ultralytics/assets/releases/download/v0.0.0/' + + validator = DetectionValidator(args={'model': 'yolov8n.pt', 'data': 'coco8.yaml', 'save_json': True, 'imgsz': 64}) + validator() + validator.is_coco = True + download(f'{url}instances_val2017.json', dir=DATASETS_DIR / 'coco8/annotations') + _ = validator.eval_json(validator.stats) + + validator = SegmentationValidator(args={ + 'model': 'yolov8n-seg.pt', + 'data': 'coco8-seg.yaml', + 'save_json': True, + 'imgsz': 64}) + validator() + validator.is_coco = True + download(f'{url}instances_val2017.json', dir=DATASETS_DIR / 'coco8-seg/annotations') + _ = validator.eval_json(validator.stats) + + validator = PoseValidator(args={ + 'model': 'yolov8n-pose.pt', + 'data': 'coco8-pose.yaml', + 'save_json': True, + 'imgsz': 64}) + validator() + validator.is_coco = True + download(f'{url}person_keypoints_val2017.json', dir=DATASETS_DIR / 'coco8-pose/annotations') + _ = validator.eval_json(validator.stats) diff --git a/tests/test_python.py b/tests/test_python.py index 8236af35d..be7f4f869 100644 --- a/tests/test_python.py +++ b/tests/test_python.py @@ -1,7 +1,6 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license import contextlib -import shutil from copy import copy from pathlib import Path @@ -15,7 +14,7 @@ from torchvision.transforms import ToTensor from ultralytics import RTDETR, YOLO from ultralytics.cfg import TASK2DATA from ultralytics.data.build import load_inference_source -from ultralytics.utils import ASSETS, DEFAULT_CFG, LINUX, ONLINE, ROOT, SETTINGS, WINDOWS +from ultralytics.utils import ASSETS, DEFAULT_CFG, LINUX, MACOS, ONLINE, ROOT, SETTINGS, WINDOWS from ultralytics.utils.downloads import download from ultralytics.utils.torch_utils import TORCH_1_9 @@ -50,14 +49,22 @@ def test_model_methods(): _ = model.task_map +def test_model_profile(): + # Test profile=True model argument + from ultralytics.nn.tasks import DetectionModel + + model = DetectionModel() # build model + im = torch.randn(1, 3, 64, 64) # requires min imgsz=64 + _ = model.predict(im, profile=True) + + def test_predict_txt(): # Write a list of sources (file, dir, glob, recursive glob) to a txt file txt_file = TMP / 'sources.txt' with open(txt_file, 'w') as f: for x in [ASSETS / 'bus.jpg', ASSETS, ASSETS / '*', ASSETS / '**/*.jpg']: f.write(f'{x}\n') - model = YOLO(MODEL) - model(source=txt_file, imgsz=32) + _ = YOLO(MODEL)(source=txt_file, imgsz=32) def test_predict_img(): @@ -143,8 +150,7 @@ def test_track_stream(): def test_val(): - model = YOLO(MODEL) - model.val(data='coco8.yaml', imgsz=32, save_hybrid=True) + YOLO(MODEL).val(data='coco8.yaml', imgsz=32, save_hybrid=True) def test_train_scratch(): @@ -160,29 +166,25 @@ def test_train_pretrained(): def test_export_torchscript(): - model = YOLO(MODEL) - f = model.export(format='torchscript', optimize=True) + f = YOLO(MODEL).export(format='torchscript', optimize=True) YOLO(f)(SOURCE) # exported model inference def test_export_onnx(): - model = YOLO(MODEL) - f = model.export(format='onnx', dynamic=True) + f = YOLO(MODEL).export(format='onnx', dynamic=True) YOLO(f)(SOURCE) # exported model inference def test_export_openvino(): - model = YOLO(MODEL) - f = model.export(format='openvino') + f = YOLO(MODEL).export(format='openvino') YOLO(f)(SOURCE) # exported model inference def test_export_coreml(): if not WINDOWS: # RuntimeError: BlobWriter not loaded with coremltools 7.0 on windows - model = YOLO(MODEL) - model.export(format='coreml', nms=True) - # if MACOS: - # YOLO(f)(SOURCE) # model prediction only supported on macOS + f = YOLO(MODEL).export(format='coreml', nms=True) + if MACOS: + YOLO(f)(SOURCE) # model prediction only supported on macOS def test_export_tflite(enabled=False): @@ -204,13 +206,11 @@ def test_export_pb(enabled=False): def test_export_paddle(enabled=False): # Paddle protobuf requirements conflicting with onnx protobuf requirements if enabled: - model = YOLO(MODEL) - model.export(format='paddle') + YOLO(MODEL).export(format='paddle') def test_export_ncnn(): - model = YOLO(MODEL) - f = model.export(format='ncnn') + f = YOLO(MODEL).export(format='ncnn') YOLO(f)(SOURCE) # exported model inference @@ -218,14 +218,14 @@ def test_all_model_yamls(): for m in (ROOT / 'cfg' / 'models').rglob('*.yaml'): if 'rtdetr' in m.name: if TORCH_1_9: # torch<=1.8 issue - TypeError: __init__() got an unexpected keyword argument 'batch_first' - RTDETR(m.name)(SOURCE, imgsz=640) # must be 640 + _ = RTDETR(m.name)(SOURCE, imgsz=640) # must be 640 else: YOLO(m.name) def test_workflow(): model = YOLO(MODEL) - model.train(data='coco8.yaml', epochs=1, imgsz=32) + model.train(data='coco8.yaml', epochs=1, imgsz=32, optimizer='SGD') model.val(imgsz=32) model.predict(SOURCE, imgsz=32) model.export(format='onnx') # export a model to ONNX format @@ -254,8 +254,7 @@ def test_predict_callback_and_setup(): def test_results(): for m in 'yolov8n-pose.pt', 'yolov8n-seg.pt', 'yolov8n.pt', 'yolov8n-cls.pt': - model = YOLO(m) - results = model([SOURCE, SOURCE], imgsz=160) + results = YOLO(m)([SOURCE, SOURCE], imgsz=160) for r in results: r = r.cpu().numpy() r = r.to(device='cpu', dtype=torch.float32) @@ -278,8 +277,7 @@ def test_data_utils(): for task in 'detect', 'segment', 'pose': file = Path(TASK2DATA[task]).with_suffix('.zip') # i.e. coco8.zip - download(f'https://github.com/ultralytics/hub/raw/main/example_datasets/{file}', unzip=False) - shutil.move(str(file), TMP) # Python 3.8 requires string input to shutil.move() + download(f'https://github.com/ultralytics/hub/raw/main/example_datasets/{file}', unzip=False, dir=TMP) stats = HUBDatasetStats(TMP / file, task=task) stats.get_json(save=True) stats.process_images() @@ -294,8 +292,7 @@ def test_data_converter(): from ultralytics.data.converter import coco80_to_coco91_class, convert_coco file = 'instances_val2017.json' - download(f'https://github.com/ultralytics/yolov5/releases/download/v1.0/{file}') - shutil.move(file, TMP) + download(f'https://github.com/ultralytics/yolov5/releases/download/v1.0/{file}', dir=TMP) convert_coco(labels_dir=TMP, use_segments=True, use_keypoints=False, cls91to80=True) coco80_to_coco91_class() diff --git a/ultralytics/engine/model.py b/ultralytics/engine/model.py index 22b1294e5..a2424d4b6 100644 --- a/ultralytics/engine/model.py +++ b/ultralytics/engine/model.py @@ -339,6 +339,8 @@ class Model: overrides['batch'] = 1 # default to 1 if not modified if 'data' not in kwargs: overrides['data'] = None # default to None if not modified (avoid int8 calibration with coco.yaml) + if 'verbose' not in kwargs: + overrides['verbose'] = False args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides) args.task = self.task return Exporter(overrides=args, _callbacks=self.callbacks)(model=self.model) diff --git a/ultralytics/models/fastsam/prompt.py b/ultralytics/models/fastsam/prompt.py index 2ccdec607..3bdc9448e 100644 --- a/ultralytics/models/fastsam/prompt.py +++ b/ultralytics/models/fastsam/prompt.py @@ -51,32 +51,16 @@ class FastSAMPrompt: n = len(result.masks.data) for i in range(n): mask = result.masks.data[i] == 1.0 - - if torch.sum(mask) < filter: - continue - annotation = { - 'id': i, - 'segmentation': mask.cpu().numpy(), - 'bbox': result.boxes.data[i], - 'score': result.boxes.conf[i]} - annotation['area'] = annotation['segmentation'].sum() - annotations.append(annotation) + if torch.sum(mask) >= filter: + annotation = { + 'id': i, + 'segmentation': mask.cpu().numpy(), + 'bbox': result.boxes.data[i], + 'score': result.boxes.conf[i]} + annotation['area'] = annotation['segmentation'].sum() + annotations.append(annotation) return annotations - @staticmethod - def filter_masks(annotations): # filter the overlap mask - annotations.sort(key=lambda x: x['area'], reverse=True) - to_remove = set() - for i in range(len(annotations)): - a = annotations[i] - for j in range(i + 1, len(annotations)): - b = annotations[j] - if i != j and j not in to_remove and b['area'] < a['area'] and \ - (a['segmentation'] & b['segmentation']).sum() / b['segmentation'].sum() > 0.8: - to_remove.add(j) - - return [a for i, a in enumerate(annotations) if i not in to_remove], to_remove - @staticmethod def _get_bbox_from_mask(mask): mask = mask.astype(np.uint8) @@ -242,15 +226,12 @@ class FastSAMPrompt: cropped_images = [] not_crop = [] filter_id = [] - # annotations, _ = filter_masks(annotations) - # filter_id = list(_) for _, mask in enumerate(annotations): if np.sum(mask['segmentation']) <= 100: filter_id.append(_) continue bbox = self._get_bbox_from_mask(mask['segmentation']) # mask 的 bbox cropped_boxes.append(self._segment_image(image, bbox)) # 保存裁剪的图片 - # cropped_boxes.append(segment_image(image,mask["segmentation"])) cropped_images.append(bbox) # 保存裁剪的图片的bbox return cropped_boxes, cropped_images, not_crop, filter_id, annotations diff --git a/ultralytics/models/sam/modules/encoders.py b/ultralytics/models/sam/modules/encoders.py index 91938af54..eb9352f97 100644 --- a/ultralytics/models/sam/modules/encoders.py +++ b/ultralytics/models/sam/modules/encoders.py @@ -267,10 +267,11 @@ class PositionEmbeddingRandom(nn.Module): super().__init__() if scale is None or scale <= 0.0: scale = 1.0 - self.register_buffer( - 'positional_encoding_gaussian_matrix', - scale * torch.randn((2, num_pos_feats)), - ) + self.register_buffer('positional_encoding_gaussian_matrix', scale * torch.randn((2, num_pos_feats))) + + # Set non-deterministic for forward() error 'cumsum_cuda_kernel does not have a deterministic implementation' + torch.use_deterministic_algorithms(False) + torch.backends.cudnn.deterministic = False def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: """Positionally encode points that are normalized to [0,1].""" diff --git a/ultralytics/models/sam/modules/sam.py b/ultralytics/models/sam/modules/sam.py index 6fb71b33c..d20b5a931 100644 --- a/ultralytics/models/sam/modules/sam.py +++ b/ultralytics/models/sam/modules/sam.py @@ -20,12 +20,14 @@ class Sam(nn.Module): mask_threshold: float = 0.0 image_format: str = 'RGB' - def __init__(self, - image_encoder: ImageEncoderViT, - prompt_encoder: PromptEncoder, - mask_decoder: MaskDecoder, - pixel_mean: List[float] = None, - pixel_std: List[float] = None) -> None: + def __init__( + self, + image_encoder: ImageEncoderViT, + prompt_encoder: PromptEncoder, + mask_decoder: MaskDecoder, + pixel_mean: List[float] = (123.675, 116.28, 103.53), + pixel_std: List[float] = (58.395, 57.12, 57.375) + ) -> None: """ SAM predicts object masks from an image and input prompts. @@ -37,10 +39,6 @@ class Sam(nn.Module): pixel_mean (list(float)): Mean values for normalizing pixels in the input image. pixel_std (list(float)): Std values for normalizing pixels in the input image. """ - if pixel_mean is None: - pixel_mean = [123.675, 116.28, 103.53] - if pixel_std is None: - pixel_std = [58.395, 57.12, 57.375] super().__init__() self.image_encoder = image_encoder self.prompt_encoder = prompt_encoder diff --git a/ultralytics/models/sam/modules/tiny_encoder.py b/ultralytics/models/sam/modules/tiny_encoder.py index 0a1f8ac8d..ca8de50b7 100644 --- a/ultralytics/models/sam/modules/tiny_encoder.py +++ b/ultralytics/models/sam/modules/tiny_encoder.py @@ -30,40 +30,6 @@ class Conv2d_BN(torch.nn.Sequential): torch.nn.init.constant_(bn.bias, 0) self.add_module('bn', bn) - @torch.no_grad() - def fuse(self): - c, bn = self._modules.values() - w = bn.weight / (bn.running_var + bn.eps) ** 0.5 - w = c.weight * w[:, None, None, None] - b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5 - m = torch.nn.Conv2d(w.size(1) * self.c.groups, - w.size(0), - w.shape[2:], - stride=self.c.stride, - padding=self.c.padding, - dilation=self.c.dilation, - groups=self.c.groups) - m.weight.data.copy_(w) - m.bias.data.copy_(b) - return m - - -# NOTE: This module and timm package is needed only for training. -# from ultralytics.utils.checks import check_requirements -# check_requirements('timm') -# from timm.models.layers import DropPath as TimmDropPath -# from timm.models.layers import trunc_normal_ -# class DropPath(TimmDropPath): -# -# def __init__(self, drop_prob=None): -# super().__init__(drop_prob=drop_prob) -# self.drop_prob = drop_prob -# -# def __repr__(self): -# msg = super().__repr__() -# msg += f'(drop_prob={self.drop_prob})' -# return msg - class PatchEmbed(nn.Module): diff --git a/ultralytics/models/sam/predict.py b/ultralytics/models/sam/predict.py index 5f0a97894..e1c3b481d 100644 --- a/ultralytics/models/sam/predict.py +++ b/ultralytics/models/sam/predict.py @@ -153,8 +153,7 @@ class Predictor(BasePredictor): bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes bboxes *= r if masks is not None: - masks = torch.as_tensor(masks, dtype=torch.float32, device=self.device) - masks = masks[:, None, :, :] + masks = torch.as_tensor(masks, dtype=torch.float32, device=self.device).unsqueeze(1) points = (points, labels) if points is not None else None # Embed prompts @@ -257,9 +256,7 @@ class Predictor(BasePredictor): pred_bbox = batched_mask_to_box(pred_mask).float() keep_mask = ~is_box_near_crop_edge(pred_bbox, crop_region, [0, 0, iw, ih]) if not torch.all(keep_mask): - pred_bbox = pred_bbox[keep_mask] - pred_mask = pred_mask[keep_mask] - pred_score = pred_score[keep_mask] + pred_bbox, pred_mask, pred_score = pred_bbox[keep_mask], pred_mask[keep_mask], pred_score[keep_mask] crop_masks.append(pred_mask) crop_bboxes.append(pred_bbox) @@ -288,9 +285,7 @@ class Predictor(BasePredictor): if len(crop_regions) > 1: scores = 1 / region_areas keep = torchvision.ops.nms(pred_bboxes, scores, crop_nms_thresh) - pred_masks = pred_masks[keep] - pred_bboxes = pred_bboxes[keep] - pred_scores = pred_scores[keep] + pred_masks, pred_bboxes, pred_scores = pred_masks[keep], pred_bboxes[keep], pred_scores[keep] return pred_masks, pred_scores, pred_bboxes diff --git a/ultralytics/models/utils/loss.py b/ultralytics/models/utils/loss.py index db6fd6315..9f95e5f0b 100644 --- a/ultralytics/models/utils/loss.py +++ b/ultralytics/models/utils/loss.py @@ -82,8 +82,7 @@ class DETRLoss(nn.Module): loss[name_giou] = 1.0 - bbox_iou(pred_bboxes, gt_bboxes, xywh=True, GIoU=True) loss[name_giou] = loss[name_giou].sum() / len(gt_bboxes) loss[name_giou] = self.loss_gain['giou'] * loss[name_giou] - loss = {k: v.squeeze() for k, v in loss.items()} - return loss + return {k: v.squeeze() for k, v in loss.items()} def _get_loss_mask(self, masks, gt_mask, match_indices, postfix=''): # masks: [b, query, h, w], gt_mask: list[[n, H, W]] @@ -105,7 +104,8 @@ class DETRLoss(nn.Module): loss[name_dice] = self.loss_gain['dice'] * self._dice_loss(src_masks, target_masks, num_gts) return loss - def _dice_loss(self, inputs, targets, num_gts): + @staticmethod + def _dice_loss(inputs, targets, num_gts): inputs = F.sigmoid(inputs) inputs = inputs.flatten(1) targets = targets.flatten(1) @@ -163,7 +163,8 @@ class DETRLoss(nn.Module): # loss[f'loss_dice_aux{postfix}'] = loss[4] return loss - def _get_index(self, match_indices): + @staticmethod + def _get_index(match_indices): batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(match_indices)]) src_idx = torch.cat([src for (src, _) in match_indices]) dst_idx = torch.cat([dst for (_, dst) in match_indices]) @@ -257,10 +258,10 @@ class RTDETRDetectionLoss(DETRLoss): dn_pos_idx, dn_num_group = dn_meta['dn_pos_idx'], dn_meta['dn_num_group'] assert len(batch['gt_groups']) == len(dn_pos_idx) - # denoising match indices + # Denoising match indices match_indices = self.get_dn_match_indices(dn_pos_idx, dn_num_group, batch['gt_groups']) - # compute denoising training loss + # Compute denoising training loss dn_loss = super().forward(dn_bboxes, dn_scores, batch, postfix='_dn', match_indices=match_indices) total_loss.update(dn_loss) else: @@ -270,7 +271,8 @@ class RTDETRDetectionLoss(DETRLoss): @staticmethod def get_dn_match_indices(dn_pos_idx, dn_num_group, gt_groups): - """Get the match indices for denoising. + """ + Get the match indices for denoising. Args: dn_pos_idx (List[torch.Tensor]): A list includes positive indices of denoising. @@ -279,7 +281,6 @@ class RTDETRDetectionLoss(DETRLoss): Returns: dn_match_indices (List(tuple)): Matched indices. - """ dn_match_indices = [] idx_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0) diff --git a/ultralytics/nn/modules/transformer.py b/ultralytics/nn/modules/transformer.py index ecbc4c0bf..7a4602bab 100644 --- a/ultralytics/nn/modules/transformer.py +++ b/ultralytics/nn/modules/transformer.py @@ -51,8 +51,7 @@ class TransformerEncoderLayer(nn.Module): src = self.norm1(src) src2 = self.fc2(self.dropout(self.act(self.fc1(src)))) src = src + self.dropout2(src2) - src = self.norm2(src) - return src + return self.norm2(src) def forward_pre(self, src, src_mask=None, src_key_padding_mask=None, pos=None): src2 = self.norm1(src) @@ -61,8 +60,7 @@ class TransformerEncoderLayer(nn.Module): src = src + self.dropout1(src2) src2 = self.norm2(src) src2 = self.fc2(self.dropout(self.act(self.fc1(src2)))) - src = src + self.dropout2(src2) - return src + return src + self.dropout2(src2) def forward(self, src, src_mask=None, src_key_padding_mask=None, pos=None): """Forward propagates the input through the encoder module.""" @@ -116,8 +114,7 @@ class TransformerLayer(nn.Module): def forward(self, x): """Apply a transformer block to the input x and return the output.""" x = self.ma(self.q(x), self.k(x), self.v(x))[0] + x - x = self.fc2(self.fc1(x)) + x - return x + return self.fc2(self.fc1(x)) + x class TransformerBlock(nn.Module): @@ -185,8 +182,7 @@ class LayerNorm2d(nn.Module): u = x.mean(1, keepdim=True) s = (x - u).pow(2).mean(1, keepdim=True) x = (x - u) / torch.sqrt(s + self.eps) - x = self.weight[:, None, None] * x + self.bias[:, None, None] - return x + return self.weight[:, None, None] * x + self.bias[:, None, None] class MSDeformAttn(nn.Module): @@ -271,8 +267,7 @@ class MSDeformAttn(nn.Module): else: raise ValueError(f'Last dim of reference_points must be 2 or 4, but got {num_points}.') output = multi_scale_deformable_attn_pytorch(value, value_shapes, sampling_locations, attention_weights) - output = self.output_proj(output) - return output + return self.output_proj(output) class DeformableTransformerDecoderLayer(nn.Module): @@ -309,8 +304,7 @@ class DeformableTransformerDecoderLayer(nn.Module): def forward_ffn(self, tgt): tgt2 = self.linear2(self.dropout3(self.act(self.linear1(tgt)))) tgt = tgt + self.dropout4(tgt2) - tgt = self.norm3(tgt) - return tgt + return self.norm3(tgt) def forward(self, embed, refer_bbox, feats, shapes, padding_mask=None, attn_mask=None, query_pos=None): # self attention @@ -327,9 +321,7 @@ class DeformableTransformerDecoderLayer(nn.Module): embed = self.norm2(embed) # ffn - embed = self.forward_ffn(embed) - - return embed + return self.forward_ffn(embed) class DeformableTransformerDecoder(nn.Module): diff --git a/ultralytics/nn/tasks.py b/ultralytics/nn/tasks.py index 25b69979e..24153d244 100644 --- a/ultralytics/nn/tasks.py +++ b/ultralytics/nn/tasks.py @@ -322,31 +322,10 @@ class PoseModel(DetectionModel): class ClassificationModel(BaseModel): """YOLOv8 classification model.""" - def __init__(self, - cfg='yolov8n-cls.yaml', - model=None, - ch=3, - nc=None, - cutoff=10, - verbose=True): # YAML, model, channels, number of classes, cutoff index, verbose flag + def __init__(self, cfg='yolov8n-cls.yaml', ch=3, nc=None, verbose=True): + """Init ClassificationModel with YAML, channels, number of classes, verbose flag.""" super().__init__() - self._from_detection_model(model, nc, cutoff) if model is not None else self._from_yaml(cfg, ch, nc, verbose) - - def _from_detection_model(self, model, nc=1000, cutoff=10): - """Create a YOLOv5 classification model from a YOLOv5 detection model.""" - from ultralytics.nn.autobackend import AutoBackend - if isinstance(model, AutoBackend): - model = model.model # unwrap DetectMultiBackend - model.model = model.model[:cutoff] # backbone - m = model.model[-1] # last layer - ch = m.conv.in_channels if hasattr(m, 'conv') else m.cv1.conv.in_channels # ch into module - c = Classify(ch, nc) # Classify() - c.i, c.f, c.type = m.i, m.f, 'models.common.Classify' # index, from, type - model.model[-1] = c # replace - self.model = model.model - self.stride = model.stride - self.save = [] - self.nc = nc + self._from_yaml(cfg, ch, nc, verbose) def _from_yaml(self, cfg, ch, nc, verbose): """Set YOLOv8 model configurations and define the model architecture."""