Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
pull/16048/head
Ultralytics Assistant 2 months ago committed by GitHub
parent 95d54828bb
commit ac2c2be8f3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 11
      ultralytics/data/converter.py
  2. 4
      ultralytics/hub/google/__init__.py
  3. 2
      ultralytics/hub/session.py
  4. 2
      ultralytics/models/fastsam/predict.py
  5. 8
      ultralytics/models/sam/modules/blocks.py
  6. 9
      ultralytics/models/sam/modules/decoders.py
  7. 5
      ultralytics/models/sam/modules/encoders.py
  8. 36
      ultralytics/models/sam/modules/sam.py
  9. 8
      ultralytics/models/yolo/classify/predict.py
  10. 3
      ultralytics/nn/modules/activation.py
  11. 6
      ultralytics/utils/__init__.py
  12. 13
      ultralytics/utils/checks.py

@ -370,13 +370,10 @@ def convert_segment_masks_to_yolo_seg(masks_dir, output_dir, classes):
mask_yolo_03.txt mask_yolo_03.txt
mask_yolo_04.txt mask_yolo_04.txt
""" """
import os
pixel_to_class_mapping = {i + 1: i for i in range(classes)} pixel_to_class_mapping = {i + 1: i for i in range(classes)}
for mask_filename in os.listdir(masks_dir): for mask_path in Path(masks_dir).iterdir():
if mask_filename.endswith(".png"): if mask_path.suffix == ".png":
mask_path = os.path.join(masks_dir, mask_filename) mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE) # Read the mask image in grayscale
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) # Read the mask image in grayscale
img_height, img_width = mask.shape # Get image dimensions img_height, img_width = mask.shape # Get image dimensions
LOGGER.info(f"Processing {mask_path} imgsz = {img_height} x {img_width}") LOGGER.info(f"Processing {mask_path} imgsz = {img_height} x {img_width}")
@ -406,7 +403,7 @@ def convert_segment_masks_to_yolo_seg(masks_dir, output_dir, classes):
yolo_format.append(round(point[1] / img_height, 6)) yolo_format.append(round(point[1] / img_height, 6))
yolo_format_data.append(yolo_format) yolo_format_data.append(yolo_format)
# Save Ultralytics YOLO format data to file # Save Ultralytics YOLO format data to file
output_path = os.path.join(output_dir, os.path.splitext(mask_filename)[0] + ".txt") output_path = Path(output_dir) / f"{Path(mask_filename).stem}.txt"
with open(output_path, "w") as file: with open(output_path, "w") as file:
for item in yolo_format_data: for item in yolo_format_data:
line = " ".join(map(str, item)) line = " ".join(map(str, item))

@ -136,12 +136,12 @@ class GCPRegions:
sorted_results = sorted(results, key=lambda x: x[1]) sorted_results = sorted(results, key=lambda x: x[1])
if verbose: if verbose:
print(f"{'Region':<25} {'Location':<35} {'Tier':<5} {'Latency (ms)'}") print(f"{'Region':<25} {'Location':<35} {'Tier':<5} Latency (ms)")
for region, mean, std, min_, max_ in sorted_results: for region, mean, std, min_, max_ in sorted_results:
tier, city, country = self.regions[region] tier, city, country = self.regions[region]
location = f"{city}, {country}" location = f"{city}, {country}"
if mean == float("inf"): if mean == float("inf"):
print(f"{region:<25} {location:<35} {tier:<5} {'Timeout'}") print(f"{region:<25} {location:<35} {tier:<5} Timeout")
else: else:
print(f"{region:<25} {location:<35} {tier:<5} {mean:.0f} ± {std:.0f} ({min_:.0f} - {max_:.0f})") print(f"{region:<25} {location:<35} {tier:<5} {mean:.0f} ± {std:.0f} ({min_:.0f} - {max_:.0f})")
print(f"\nLowest latency region{'s' if top > 1 else ''}:") print(f"\nLowest latency region{'s' if top > 1 else ''}:")

@ -346,7 +346,7 @@ class HUBTrainingSession:
""" """
weights = Path(weights) weights = Path(weights)
if not weights.is_file(): if not weights.is_file():
last = weights.with_name("last" + weights.suffix) last = weights.with_name(f"last{weights.suffix}")
if final and last.is_file(): if final and last.is_file():
LOGGER.warning( LOGGER.warning(
f"{PREFIX} WARNING ⚠ Model 'best.pt' not found, copying 'last.pt' to 'best.pt' and uploading. " f"{PREFIX} WARNING ⚠ Model 'best.pt' not found, copying 'last.pt' to 'best.pt' and uploading. "

@ -93,7 +93,7 @@ class FastSAMPredictor(SegmentationPredictor):
else torch.zeros(len(result), dtype=torch.bool, device=self.device) else torch.zeros(len(result), dtype=torch.bool, device=self.device)
) )
for point, label in zip(points, labels): for point, label in zip(points, labels):
point_idx[torch.nonzero(masks[:, point[1], point[0]], as_tuple=True)[0]] = True if label else False point_idx[torch.nonzero(masks[:, point[1], point[0]], as_tuple=True)[0]] = bool(label)
idx |= point_idx idx |= point_idx
if texts is not None: if texts is not None:
if isinstance(texts, str): if isinstance(texts, str):

@ -736,7 +736,7 @@ class PositionEmbeddingSine(nn.Module):
self.num_pos_feats = num_pos_feats // 2 self.num_pos_feats = num_pos_feats // 2
self.temperature = temperature self.temperature = temperature
self.normalize = normalize self.normalize = normalize
if scale is not None and normalize is False: if scale is not None and not normalize:
raise ValueError("normalize should be True if scale is passed") raise ValueError("normalize should be True if scale is passed")
if scale is None: if scale is None:
scale = 2 * math.pi scale = 2 * math.pi
@ -763,8 +763,7 @@ class PositionEmbeddingSine(nn.Module):
def encode_boxes(self, x, y, w, h): def encode_boxes(self, x, y, w, h):
"""Encodes box coordinates and dimensions into positional embeddings for detection.""" """Encodes box coordinates and dimensions into positional embeddings for detection."""
pos_x, pos_y = self._encode_xy(x, y) pos_x, pos_y = self._encode_xy(x, y)
pos = torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1) return torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1)
return pos
encode = encode_boxes # Backwards compatibility encode = encode_boxes # Backwards compatibility
@ -775,8 +774,7 @@ class PositionEmbeddingSine(nn.Module):
assert bx == by and nx == ny and bx == bl and nx == nl assert bx == by and nx == ny and bx == bl and nx == nl
pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten()) pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten())
pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1) pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1)
pos = torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2) return torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2)
return pos
@torch.no_grad() @torch.no_grad()
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):

@ -435,9 +435,9 @@ class SAM2MaskDecoder(nn.Module):
upscaled_embedding = act1(ln1(dc1(src) + feat_s1)) upscaled_embedding = act1(ln1(dc1(src) + feat_s1))
upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0) upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0)
hyper_in_list: List[torch.Tensor] = [] hyper_in_list: List[torch.Tensor] = [
for i in range(self.num_mask_tokens): self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) for i in range(self.num_mask_tokens)
hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) ]
hyper_in = torch.stack(hyper_in_list, dim=1) hyper_in = torch.stack(hyper_in_list, dim=1)
b, c, h, w = upscaled_embedding.shape b, c, h, w = upscaled_embedding.shape
masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
@ -459,8 +459,7 @@ class SAM2MaskDecoder(nn.Module):
stability_delta = self.dynamic_multimask_stability_delta stability_delta = self.dynamic_multimask_stability_delta
area_i = torch.sum(mask_logits > stability_delta, dim=-1).float() area_i = torch.sum(mask_logits > stability_delta, dim=-1).float()
area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float() area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float()
stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0) return torch.where(area_u > 0, area_i / area_u, 1.0)
return stability_scores
def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores): def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
""" """

@ -491,12 +491,11 @@ class ImageEncoder(nn.Module):
features, pos = features[: -self.scalp], pos[: -self.scalp] features, pos = features[: -self.scalp], pos[: -self.scalp]
src = features[-1] src = features[-1]
output = { return {
"vision_features": src, "vision_features": src,
"vision_pos_enc": pos, "vision_pos_enc": pos,
"backbone_fpn": features, "backbone_fpn": features,
} }
return output
class FpnNeck(nn.Module): class FpnNeck(nn.Module):
@ -577,7 +576,7 @@ class FpnNeck(nn.Module):
self.convs.append(current) self.convs.append(current)
self.fpn_interp_model = fpn_interp_model self.fpn_interp_model = fpn_interp_model
assert fuse_type in ["sum", "avg"] assert fuse_type in {"sum", "avg"}
self.fuse_type = fuse_type self.fuse_type = fuse_type
# levels to have top-down features in its outputs # levels to have top-down features in its outputs

@ -671,26 +671,19 @@ class SAM2Model(torch.nn.Module):
t_rel = self.num_maskmem - t_pos # how many frames before current frame t_rel = self.num_maskmem - t_pos # how many frames before current frame
if t_rel == 1: if t_rel == 1:
# for t_rel == 1, we take the last frame (regardless of r) # for t_rel == 1, we take the last frame (regardless of r)
if not track_in_reverse: prev_frame_idx = frame_idx + t_rel if track_in_reverse else frame_idx - t_rel
# the frame immediately before this frame (i.e. frame_idx - 1) elif not track_in_reverse:
prev_frame_idx = frame_idx - t_rel # first find the nearest frame among every r-th frames before this frame
else: # for r=1, this would be (frame_idx - 2)
# the frame immediately after this frame (i.e. frame_idx + 1) prev_frame_idx = ((frame_idx - 2) // r) * r
prev_frame_idx = frame_idx + t_rel # then seek further among every r-th frames
prev_frame_idx = prev_frame_idx - (t_rel - 2) * r
else: else:
# for t_rel >= 2, we take the memory frame from every r-th frames # first find the nearest frame among every r-th frames after this frame
if not track_in_reverse: # for r=1, this would be (frame_idx + 2)
# first find the nearest frame among every r-th frames before this frame prev_frame_idx = -(-(frame_idx + 2) // r) * r
# for r=1, this would be (frame_idx - 2) # then seek further among every r-th frames
prev_frame_idx = ((frame_idx - 2) // r) * r prev_frame_idx = prev_frame_idx + (t_rel - 2) * r
# then seek further among every r-th frames
prev_frame_idx = prev_frame_idx - (t_rel - 2) * r
else:
# first find the nearest frame among every r-th frames after this frame
# for r=1, this would be (frame_idx + 2)
prev_frame_idx = -(-(frame_idx + 2) // r) * r
# then seek further among every r-th frames
prev_frame_idx = prev_frame_idx + (t_rel - 2) * r
out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None) out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None)
if out is None: if out is None:
# If an unselected conditioning frame is among the last (self.num_maskmem - 1) # If an unselected conditioning frame is among the last (self.num_maskmem - 1)
@ -739,7 +732,7 @@ class SAM2Model(torch.nn.Module):
if out is not None: if out is not None:
pos_and_ptrs.append((t_diff, out["obj_ptr"])) pos_and_ptrs.append((t_diff, out["obj_ptr"]))
# If we have at least one object pointer, add them to the across attention # If we have at least one object pointer, add them to the across attention
if len(pos_and_ptrs) > 0: if pos_and_ptrs:
pos_list, ptrs_list = zip(*pos_and_ptrs) pos_list, ptrs_list = zip(*pos_and_ptrs)
# stack object pointers along dim=0 into [ptr_seq_len, B, C] shape # stack object pointers along dim=0 into [ptr_seq_len, B, C] shape
obj_ptrs = torch.stack(ptrs_list, dim=0) obj_ptrs = torch.stack(ptrs_list, dim=0)
@ -930,12 +923,11 @@ class SAM2Model(torch.nn.Module):
def _use_multimask(self, is_init_cond_frame, point_inputs): def _use_multimask(self, is_init_cond_frame, point_inputs):
"""Determines whether to use multiple mask outputs in the SAM head based on configuration and inputs.""" """Determines whether to use multiple mask outputs in the SAM head based on configuration and inputs."""
num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1) num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1)
multimask_output = ( return (
self.multimask_output_in_sam self.multimask_output_in_sam
and (is_init_cond_frame or self.multimask_output_for_tracking) and (is_init_cond_frame or self.multimask_output_for_tracking)
and (self.multimask_min_pt_num <= num_pts <= self.multimask_max_pt_num) and (self.multimask_min_pt_num <= num_pts <= self.multimask_max_pt_num)
) )
return multimask_output
def _apply_non_overlapping_constraints(self, pred_masks): def _apply_non_overlapping_constraints(self, pred_masks):
"""Applies non-overlapping constraints to masks, keeping highest scoring object per location.""" """Applies non-overlapping constraints to masks, keeping highest scoring object per location."""

@ -53,7 +53,7 @@ class ClassificationPredictor(BasePredictor):
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs) orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
results = [] return [
for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0]): Results(orig_img, path=img_path, names=self.model.names, probs=pred)
results.append(Results(orig_img, path=img_path, names=self.model.names, probs=pred)) for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0])
return results ]

@ -18,5 +18,4 @@ class AGLU(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Compute the forward pass of the Unified activation function.""" """Compute the forward pass of the Unified activation function."""
lam = torch.clamp(self.lambd, min=0.0001) lam = torch.clamp(self.lambd, min=0.0001)
y = torch.exp((1 / lam) * self.act((self.kappa * x) - torch.log(lam))) return torch.exp((1 / lam) * self.act((self.kappa * x) - torch.log(lam)))
return y # for AGLU simply return y * input

@ -1160,9 +1160,9 @@ def vscode_msg(ext="ultralytics.ultralytics-snippets") -> str:
obs_file = path / ".obsolete" # file tracks uninstalled extensions, while source directory remains obs_file = path / ".obsolete" # file tracks uninstalled extensions, while source directory remains
installed = any(path.glob(f"{ext}*")) and ext not in (obs_file.read_text("utf-8") if obs_file.exists() else "") installed = any(path.glob(f"{ext}*")) and ext not in (obs_file.read_text("utf-8") if obs_file.exists() else "")
return ( return (
f"{colorstr('VS Code:')} view Ultralytics VS Code Extension ⚡ at https://docs.ultralytics.com/integrations/vscode" ""
if not installed if installed
else "" else f"{colorstr('VS Code:')} view Ultralytics VS Code Extension ⚡ at https://docs.ultralytics.com/integrations/vscode"
) )

@ -226,13 +226,12 @@ def check_version(
if not required: # if required is '' or None if not required: # if required is '' or None
return True return True
if "sys_platform" in required: # i.e. required='<2.4.0,>=1.8.0; sys_platform == "win32"' if "sys_platform" in required and ( # i.e. required='<2.4.0,>=1.8.0; sys_platform == "win32"'
if ( (WINDOWS and "win32" not in required)
(WINDOWS and "win32" not in required) or (LINUX and "linux" not in required)
or (LINUX and "linux" not in required) or (MACOS and "macos" not in required and "darwin" not in required)
or (MACOS and "macos" not in required and "darwin" not in required) ):
): return True
return True
op = "" op = ""
version = "" version = ""

Loading…
Cancel
Save