Cleanup redundant SAM `forward()` methods (#4591)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
pull/4595/head
Glenn Jocher 1 year ago committed by GitHub
parent 47ab96dab6
commit 2567b288c9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 8
      tests/test_cuda.py
  2. 2
      ultralytics/data/loaders.py
  3. 112
      ultralytics/models/sam/modules/sam.py
  4. 10
      ultralytics/utils/checks.py

@ -27,12 +27,8 @@ def test_checks():
@pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason='CUDA is not available')
def test_train():
YOLO(MODEL).train(data=DATA, imgsz=64, epochs=1, batch=-1, device=0) # also test AutoBatch, requires imgsz>=64
@pytest.mark.skipif(CUDA_DEVICE_COUNT < 2, reason=f'DDP is not available, {CUDA_DEVICE_COUNT} device(s) found')
def test_train_ddp():
YOLO(MODEL).train(data=DATA, imgsz=64, epochs=1, device=[0, 1]) # requires imgsz>=64
device = 0 if CUDA_DEVICE_COUNT < 2 else [0, 1]
YOLO(MODEL).train(data=DATA, imgsz=64, epochs=1, batch=-1, device=device) # also test AutoBatch, requires imgsz>=64
@pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason='CUDA is not available')

@ -119,7 +119,7 @@ class LoadStreams:
# Wait until a frame is available in each buffer
while not all(self.imgs):
if not all(x.is_alive() for x in self.threads) or cv2.waitKey(1) == ord('q'): # q to quit
cv2.destroyAllWindows()
self.close()
raise StopIteration
time.sleep(1 / min(self.fps))

@ -6,11 +6,10 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Dict, List, Tuple
from typing import List
import torch
from torch import nn
from torch.nn import functional as F
from .decoders import MaskDecoder
from .encoders import ImageEncoderViT, PromptEncoder
@ -31,6 +30,9 @@ class Sam(nn.Module):
"""
SAM predicts object masks from an image and input prompts.
Note:
All forward() operations moved to SAMPredictor.
Args:
image_encoder (ImageEncoderViT): The backbone used to encode the image into image embeddings that allow for
efficient mask prediction.
@ -45,109 +47,3 @@ class Sam(nn.Module):
self.mask_decoder = mask_decoder
self.register_buffer('pixel_mean', torch.Tensor(pixel_mean).view(-1, 1, 1), False)
self.register_buffer('pixel_std', torch.Tensor(pixel_std).view(-1, 1, 1), False)
@property
def device(self) -> Any:
return self.pixel_mean.device
@torch.no_grad()
def forward(
self,
batched_input: List[Dict[str, Any]],
multimask_output: bool,
) -> List[Dict[str, torch.Tensor]]:
"""
Predicts masks end-to-end from provided images and prompts. If prompts are not known in advance, using
SamPredictor is recommended over calling the model directly.
Args:
batched_input (list(dict)): A list over input images, each a dictionary with the following keys. A prompt
key can be excluded if it is not present.
'image': The image as a torch tensor in 3xHxW format, already transformed for input to the model.
'original_size': (tuple(int, int)) The original size of the image before transformation, as (H, W).
'point_coords': (torch.Tensor) Batched point prompts for this image, with shape BxNx2. Already
transformed to the input frame of the model.
'point_labels': (torch.Tensor) Batched labels for point prompts, with shape BxN.
'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. Already transformed to the input frame of
the model.
'mask_inputs': (torch.Tensor) Batched mask inputs to the model, in the form Bx1xHxW.
multimask_output (bool): Whether the model should predict multiple disambiguating masks, or return a single
mask.
Returns:
(list(dict)): A list over input images, where each element is as dictionary with the following keys.
'masks': (torch.Tensor) Batched binary mask predictions, with shape BxCxHxW, where B is the number of
input prompts, C is determined by multimask_output, and (H, W) is the original size of the image.
'iou_predictions': (torch.Tensor) The model's predictions of mask quality, in shape BxC.
'low_res_logits': (torch.Tensor) Low resolution logits with shape BxCxHxW, where H=W=256. Can be passed
as mask input to subsequent iterations of prediction.
"""
input_images = torch.stack([self.preprocess(x['image']) for x in batched_input], dim=0)
image_embeddings = self.image_encoder(input_images)
outputs = []
for image_record, curr_embedding in zip(batched_input, image_embeddings):
if 'point_coords' in image_record:
points = (image_record['point_coords'], image_record['point_labels'])
else:
points = None
sparse_embeddings, dense_embeddings = self.prompt_encoder(
points=points,
boxes=image_record.get('boxes', None),
masks=image_record.get('mask_inputs', None),
)
low_res_masks, iou_predictions = self.mask_decoder(
image_embeddings=curr_embedding.unsqueeze(0),
image_pe=self.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=multimask_output,
)
masks = self.postprocess_masks(
low_res_masks,
input_size=image_record['image'].shape[-2:],
original_size=image_record['original_size'],
)
masks = masks > self.mask_threshold
outputs.append({
'masks': masks,
'iou_predictions': iou_predictions,
'low_res_logits': low_res_masks, })
return outputs
def postprocess_masks(
self,
masks: torch.Tensor,
input_size: Tuple[int, ...],
original_size: Tuple[int, ...],
) -> torch.Tensor:
"""
Remove padding and upscale masks to the original image size.
Args:
masks (torch.Tensor): Batched masks from the mask_decoder, in BxCxHxW format.
input_size (tuple(int, int)): The size of the model input image, in (H, W) format. Used to remove padding.
original_size (tuple(int, int)): The original image size before resizing for input to the model, in (H, W).
Returns:
(torch.Tensor): Batched masks in BxCxHxW format, where (H, W) is given by original_size.
"""
masks = F.interpolate(
masks,
(self.image_encoder.img_size, self.image_encoder.img_size),
mode='bilinear',
align_corners=False,
)
masks = masks[..., :input_size[0], :input_size[1]]
return F.interpolate(masks, original_size, mode='bilinear', align_corners=False)
def preprocess(self, x: torch.Tensor) -> torch.Tensor:
"""Normalize pixel values and pad to a square input."""
# Normalize colors
x = (x - self.pixel_mean) / self.pixel_std
# Pad
h, w = x.shape[-2:]
padh = self.image_encoder.img_size - h
padw = self.image_encoder.img_size - w
return F.pad(x, (0, padw, 0, padh))

@ -519,9 +519,13 @@ def cuda_device_count() -> int:
# Run the nvidia-smi command and capture its output
output = subprocess.check_output(['nvidia-smi', '--query-gpu=count', '--format=csv,noheader,nounits'],
encoding='utf-8')
return int(output.strip())
except (subprocess.CalledProcessError, FileNotFoundError):
# If the command fails or nvidia-smi is not found, assume no GPUs are available
# Take the first line and strip any leading/trailing white space
first_line = output.strip().split('\n')[0]
return int(first_line)
except (subprocess.CalledProcessError, FileNotFoundError, ValueError):
# If the command fails, nvidia-smi is not found, or output is not an integer, assume no GPUs are available
return 0

Loading…
Cancel
Save