|
|
@ -1106,7 +1106,8 @@ class Exporter: |
|
|
|
raise ValueError("MCT export is not supported for end2end models.") |
|
|
|
raise ValueError("MCT export is not supported for end2end models.") |
|
|
|
if "C2f" not in self.model.__str__(): |
|
|
|
if "C2f" not in self.model.__str__(): |
|
|
|
raise ValueError("MCT export is only supported for YOLOv8 detection models") |
|
|
|
raise ValueError("MCT export is only supported for YOLOv8 detection models") |
|
|
|
check_requirements("model-compression-toolkit==2.1.1") |
|
|
|
check_requirements(("model-compression-toolkit==2.1.1", "sony-custom-layers[torch]")) |
|
|
|
|
|
|
|
from sony_custom_layers.pytorch.object_detection.nms import multiclass_nms |
|
|
|
import model_compression_toolkit as mct |
|
|
|
import model_compression_toolkit as mct |
|
|
|
|
|
|
|
|
|
|
|
def representative_dataset_gen(dataloader=self.get_int8_calibration_dataloader(prefix)): |
|
|
|
def representative_dataset_gen(dataloader=self.get_int8_calibration_dataloader(prefix)): |
|
|
@ -1145,54 +1146,50 @@ class Exporter: |
|
|
|
)[0] |
|
|
|
)[0] |
|
|
|
) |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
if self.args.nms: |
|
|
|
class NMSWrapper(torch.nn.Module): |
|
|
|
check_requirements("sony-custom-layers[torch]") |
|
|
|
def __init__( |
|
|
|
from sony_custom_layers.pytorch.object_detection.nms import multiclass_nms |
|
|
|
self, |
|
|
|
|
|
|
|
model: torch.nn.Module, |
|
|
|
class NMSWrapper(torch.nn.Module): |
|
|
|
score_threshold: float = 0.001, |
|
|
|
def __init__( |
|
|
|
iou_threshold: float = 0.7, |
|
|
|
self, |
|
|
|
max_detections: int = 300, |
|
|
|
model: torch.nn.Module, |
|
|
|
): |
|
|
|
score_threshold: float = 0.001, |
|
|
|
""" |
|
|
|
iou_threshold: float = 0.7, |
|
|
|
Wrapping PyTorch Module with multiclass_nms layer from sony_custom_layers. |
|
|
|
max_detections: int = 300, |
|
|
|
|
|
|
|
): |
|
|
|
Args: |
|
|
|
""" |
|
|
|
model (nn.Module): Model instance. |
|
|
|
Wrapping PyTorch Module with multiclass_nms layer from sony_custom_layers. |
|
|
|
score_threshold (float): Score threshold for non-maximum suppression. |
|
|
|
|
|
|
|
iou_threshold (float): Intersection over union threshold for non-maximum suppression. |
|
|
|
Args: |
|
|
|
max_detections (float): The number of detections to return. |
|
|
|
model (nn.Module): Model instance. |
|
|
|
""" |
|
|
|
score_threshold (float): Score threshold for non-maximum suppression. |
|
|
|
super().__init__() |
|
|
|
iou_threshold (float): Intersection over union threshold for non-maximum suppression. |
|
|
|
self.model = model |
|
|
|
max_detections (float): The number of detections to return. |
|
|
|
self.score_threshold = score_threshold |
|
|
|
""" |
|
|
|
self.iou_threshold = iou_threshold |
|
|
|
super().__init__() |
|
|
|
self.max_detections = max_detections |
|
|
|
self.model = model |
|
|
|
|
|
|
|
self.score_threshold = score_threshold |
|
|
|
def forward(self, images): |
|
|
|
self.iou_threshold = iou_threshold |
|
|
|
# model inference |
|
|
|
self.max_detections = max_detections |
|
|
|
outputs = self.model(images) |
|
|
|
|
|
|
|
|
|
|
|
def forward(self, images): |
|
|
|
boxes = outputs[0] |
|
|
|
# model inference |
|
|
|
scores = outputs[1] |
|
|
|
outputs = self.model(images) |
|
|
|
nms = multiclass_nms( |
|
|
|
|
|
|
|
boxes=boxes, |
|
|
|
boxes = outputs[0] |
|
|
|
scores=scores, |
|
|
|
scores = outputs[1] |
|
|
|
score_threshold=self.score_threshold, |
|
|
|
nms = multiclass_nms( |
|
|
|
iou_threshold=self.iou_threshold, |
|
|
|
boxes=boxes, |
|
|
|
max_detections=self.max_detections, |
|
|
|
scores=scores, |
|
|
|
) |
|
|
|
score_threshold=self.score_threshold, |
|
|
|
return nms |
|
|
|
iou_threshold=self.iou_threshold, |
|
|
|
|
|
|
|
max_detections=self.max_detections, |
|
|
|
quant_model = NMSWrapper( |
|
|
|
) |
|
|
|
model=quant_model, |
|
|
|
return nms |
|
|
|
score_threshold=self.args.conf or 0.001, |
|
|
|
|
|
|
|
iou_threshold=self.args.iou, |
|
|
|
quant_model = NMSWrapper( |
|
|
|
max_detections=self.args.max_det, |
|
|
|
model=quant_model, |
|
|
|
).to(self.device) |
|
|
|
score_threshold=self.args.conf or 0.001, |
|
|
|
|
|
|
|
iou_threshold=self.args.iou, |
|
|
|
|
|
|
|
max_detections=self.args.max_det, |
|
|
|
|
|
|
|
).to(self.device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
f = Path(str(self.file).replace(self.file.suffix, "_mct_model.onnx")) # js dir |
|
|
|
f = Path(str(self.file).replace(self.file.suffix, "_mct_model.onnx")) # js dir |
|
|
|
mct.exporter.pytorch_export_model(model=quant_model, save_model_path=f, repr_dataset=representative_dataset_gen) |
|
|
|
mct.exporter.pytorch_export_model(model=quant_model, save_model_path=f, repr_dataset=representative_dataset_gen) |
|
|
|