removed nms=False support

test-quan
Francesco Mattioli 4 weeks ago
parent 4182ddb63d
commit 354ea825e8
  1. 95
      ultralytics/engine/exporter.py

@ -1106,7 +1106,8 @@ class Exporter:
raise ValueError("MCT export is not supported for end2end models.") raise ValueError("MCT export is not supported for end2end models.")
if "C2f" not in self.model.__str__(): if "C2f" not in self.model.__str__():
raise ValueError("MCT export is only supported for YOLOv8 detection models") raise ValueError("MCT export is only supported for YOLOv8 detection models")
check_requirements("model-compression-toolkit==2.1.1") check_requirements(("model-compression-toolkit==2.1.1", "sony-custom-layers[torch]"))
from sony_custom_layers.pytorch.object_detection.nms import multiclass_nms
import model_compression_toolkit as mct import model_compression_toolkit as mct
def representative_dataset_gen(dataloader=self.get_int8_calibration_dataloader(prefix)): def representative_dataset_gen(dataloader=self.get_int8_calibration_dataloader(prefix)):
@ -1145,54 +1146,50 @@ class Exporter:
)[0] )[0]
) )
if self.args.nms: class NMSWrapper(torch.nn.Module):
check_requirements("sony-custom-layers[torch]") def __init__(
from sony_custom_layers.pytorch.object_detection.nms import multiclass_nms self,
model: torch.nn.Module,
class NMSWrapper(torch.nn.Module): score_threshold: float = 0.001,
def __init__( iou_threshold: float = 0.7,
self, max_detections: int = 300,
model: torch.nn.Module, ):
score_threshold: float = 0.001, """
iou_threshold: float = 0.7, Wrapping PyTorch Module with multiclass_nms layer from sony_custom_layers.
max_detections: int = 300,
): Args:
""" model (nn.Module): Model instance.
Wrapping PyTorch Module with multiclass_nms layer from sony_custom_layers. score_threshold (float): Score threshold for non-maximum suppression.
iou_threshold (float): Intersection over union threshold for non-maximum suppression.
Args: max_detections (float): The number of detections to return.
model (nn.Module): Model instance. """
score_threshold (float): Score threshold for non-maximum suppression. super().__init__()
iou_threshold (float): Intersection over union threshold for non-maximum suppression. self.model = model
max_detections (float): The number of detections to return. self.score_threshold = score_threshold
""" self.iou_threshold = iou_threshold
super().__init__() self.max_detections = max_detections
self.model = model
self.score_threshold = score_threshold def forward(self, images):
self.iou_threshold = iou_threshold # model inference
self.max_detections = max_detections outputs = self.model(images)
def forward(self, images): boxes = outputs[0]
# model inference scores = outputs[1]
outputs = self.model(images) nms = multiclass_nms(
boxes=boxes,
boxes = outputs[0] scores=scores,
scores = outputs[1] score_threshold=self.score_threshold,
nms = multiclass_nms( iou_threshold=self.iou_threshold,
boxes=boxes, max_detections=self.max_detections,
scores=scores, )
score_threshold=self.score_threshold, return nms
iou_threshold=self.iou_threshold,
max_detections=self.max_detections, quant_model = NMSWrapper(
) model=quant_model,
return nms score_threshold=self.args.conf or 0.001,
iou_threshold=self.args.iou,
quant_model = NMSWrapper( max_detections=self.args.max_det,
model=quant_model, ).to(self.device)
score_threshold=self.args.conf or 0.001,
iou_threshold=self.args.iou,
max_detections=self.args.max_det,
).to(self.device)
f = Path(str(self.file).replace(self.file.suffix, "_mct_model.onnx")) # js dir f = Path(str(self.file).replace(self.file.suffix, "_mct_model.onnx")) # js dir
mct.exporter.pytorch_export_model(model=quant_model, save_model_path=f, repr_dataset=representative_dataset_gen) mct.exporter.pytorch_export_model(model=quant_model, save_model_path=f, repr_dataset=representative_dataset_gen)

Loading…
Cancel
Save