From 7053169fd0802cba6e1984671f97678131619a4b Mon Sep 17 00:00:00 2001 From: Laughing <61612323+Laughing-q@users.noreply.github.com> Date: Fri, 30 Aug 2024 20:59:08 +0800 Subject: [PATCH] `ultralytics 8.2.84` new SAM flexible `imgsz` inference (#15882) Co-authored-by: UltralyticsAssistant Co-authored-by: Glenn Jocher --- ultralytics/__init__.py | 2 +- ultralytics/models/sam/model.py | 2 +- ultralytics/models/sam/modules/encoders.py | 7 +++++- ultralytics/models/sam/modules/sam.py | 24 +++++++++++++++++++ .../models/sam/modules/tiny_encoder.py | 23 +++++++++++++++++- ultralytics/models/sam/predict.py | 19 ++++++++++++--- 6 files changed, 70 insertions(+), 7 deletions(-) diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index 4645f8c8f..bfa52de39 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -__version__ = "8.2.83" +__version__ = "8.2.84" import os diff --git a/ultralytics/models/sam/model.py b/ultralytics/models/sam/model.py index 126e672a5..e685dc4e4 100644 --- a/ultralytics/models/sam/model.py +++ b/ultralytics/models/sam/model.py @@ -106,7 +106,7 @@ class SAM(Model): ... print(f"Detected {len(r.masks)} masks") """ overrides = dict(conf=0.25, task="segment", mode="predict", imgsz=1024) - kwargs.update(overrides) + kwargs = {**overrides, **kwargs} prompts = dict(bboxes=bboxes, points=points, labels=labels) return super().predict(source, stream, prompts=prompts, **kwargs) diff --git a/ultralytics/models/sam/modules/encoders.py b/ultralytics/models/sam/modules/encoders.py index 9432fec4b..22934222a 100644 --- a/ultralytics/models/sam/modules/encoders.py +++ b/ultralytics/models/sam/modules/encoders.py @@ -151,7 +151,12 @@ class ImageEncoderViT(nn.Module): """Processes input through patch embedding, positional embedding, transformer blocks, and neck module.""" x = self.patch_embed(x) if self.pos_embed is not None: - x = x + self.pos_embed + pos_embed = ( + F.interpolate(self.pos_embed.permute(0, 3, 1, 2), scale_factor=self.img_size / 1024).permute(0, 2, 3, 1) + if self.img_size != 1024 + else self.pos_embed + ) + x = x + pos_embed for blk in self.blocks: x = blk(x) return self.neck(x.permute(0, 3, 1, 2)) diff --git a/ultralytics/models/sam/modules/sam.py b/ultralytics/models/sam/modules/sam.py index d3191366f..b638ddc53 100644 --- a/ultralytics/models/sam/modules/sam.py +++ b/ultralytics/models/sam/modules/sam.py @@ -90,6 +90,19 @@ class SAMModel(nn.Module): self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) + def set_imgsz(self, imgsz): + """ + Set image size to make model compatible with different image sizes. + + Args: + imgsz (Tuple[int, int]): The size of the input image. + """ + if hasattr(self.image_encoder, "set_imgsz"): + self.image_encoder.set_imgsz(imgsz) + self.prompt_encoder.input_image_size = imgsz + self.prompt_encoder.image_embedding_size = [x // 16 for x in imgsz] # 16 is fixed as patch size of ViT model + self.image_encoder.img_size = imgsz[0] + class SAM2Model(torch.nn.Module): """ @@ -940,3 +953,14 @@ class SAM2Model(torch.nn.Module): # don't overlap (here sigmoid(-10.0)=4.5398e-05) pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0)) return pred_masks + + def set_imgsz(self, imgsz): + """ + Set image size to make model compatible with different image sizes. + + Args: + imgsz (Tuple[int, int]): The size of the input image. + """ + self.image_size = imgsz[0] + self.sam_prompt_encoder.input_image_size = imgsz + self.sam_prompt_encoder.image_embedding_size = [x // 16 for x in imgsz] # fixed ViT patch size of 16 diff --git a/ultralytics/models/sam/modules/tiny_encoder.py b/ultralytics/models/sam/modules/tiny_encoder.py index 6ce32824c..d036ab987 100644 --- a/ultralytics/models/sam/modules/tiny_encoder.py +++ b/ultralytics/models/sam/modules/tiny_encoder.py @@ -982,10 +982,31 @@ class TinyViT(nn.Module): layer = self.layers[i] x = layer(x) batch, _, channel = x.shape - x = x.view(batch, 64, 64, channel) + x = x.view(batch, self.patches_resolution[0] // 4, self.patches_resolution[1] // 4, channel) x = x.permute(0, 3, 1, 2) return self.neck(x) def forward(self, x): """Performs the forward pass through the TinyViT model, extracting features from the input image.""" return self.forward_features(x) + + def set_imgsz(self, imgsz=[1024, 1024]): + """ + Set image size to make model compatible with different image sizes. + + Args: + imgsz (Tuple[int, int]): The size of the input image. + """ + imgsz = [s // 4 for s in imgsz] + self.patches_resolution = imgsz + for i, layer in enumerate(self.layers): + input_resolution = ( + imgsz[0] // (2 ** (i - 1 if i == 3 else i)), + imgsz[1] // (2 ** (i - 1 if i == 3 else i)), + ) + layer.input_resolution = input_resolution + if layer.downsample is not None: + layer.downsample.input_resolution = input_resolution + if isinstance(layer, BasicLayer): + for b in layer.blocks: + b.input_resolution = input_resolution diff --git a/ultralytics/models/sam/predict.py b/ultralytics/models/sam/predict.py index 1118f82b1..8ecb069ee 100644 --- a/ultralytics/models/sam/predict.py +++ b/ultralytics/models/sam/predict.py @@ -95,7 +95,7 @@ class Predictor(BasePredictor): """ if overrides is None: overrides = {} - overrides.update(dict(task="segment", mode="predict", imgsz=1024)) + overrides.update(dict(task="segment", mode="predict")) super().__init__(cfg, overrides, _callbacks) self.args.retina_masks = True self.im = None @@ -455,8 +455,11 @@ class Predictor(BasePredictor): cls = torch.arange(len(pred_masks), dtype=torch.int32, device=pred_masks.device) pred_bboxes = torch.cat([pred_bboxes, pred_scores[:, None], cls[:, None]], dim=-1) - masks = ops.scale_masks(masks[None].float(), orig_img.shape[:2], padding=False)[0] - masks = masks > self.model.mask_threshold # to bool + if len(masks) == 0: + masks = None + else: + masks = ops.scale_masks(masks[None].float(), orig_img.shape[:2], padding=False)[0] + masks = masks > self.model.mask_threshold # to bool results.append(Results(orig_img, path=img_path, names=names, masks=masks, boxes=pred_bboxes)) # Reset segment-all mode. self.segment_all = False @@ -522,6 +525,10 @@ class Predictor(BasePredictor): def get_im_features(self, im): """Extracts image features using the SAM model's image encoder for subsequent mask prediction.""" + assert ( + isinstance(self.imgsz, (tuple, list)) and self.imgsz[0] == self.imgsz[1] + ), f"SAM models only support square image size, but got {self.imgsz}." + self.model.set_imgsz(self.imgsz) return self.model.image_encoder(im) def set_prompts(self, prompts): @@ -761,6 +768,12 @@ class SAM2Predictor(Predictor): def get_im_features(self, im): """Extracts image features from the SAM image encoder for subsequent processing.""" + assert ( + isinstance(self.imgsz, (tuple, list)) and self.imgsz[0] == self.imgsz[1] + ), f"SAM 2 models only support square image size, but got {self.imgsz}." + self.model.set_imgsz(self.imgsz) + self._bb_feat_sizes = [[x // (4 * i) for x in self.imgsz] for i in [1, 2, 4]] + backbone_out = self.model.forward_image(im) _, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out) if self.model.directly_add_no_mem_embed: