added sony post-processing

mct-2.1.1
Francesco Mattioli 6 months ago
parent 7deebfecea
commit adaee64f8f
  1. 44
      ultralytics/engine/exporter.py
  2. 1
      ultralytics/nn/autobackend.py

@ -1023,9 +1023,43 @@ class Exporter:
@try_export
def export_mct(self, prefix=colorstr("Sony MCT:")):
# pip install --upgrade -force-reinstall git+https://github.com/ambitious-octopus/model_optimization.git@get-output-fix
import model_compression_toolkit as mct
from torch import nn
# pip install sony-custom-layers[torch]
from sony_custom_layers.pytorch.object_detection.nms import multiclass_nms
# pip install --upgrade -force-reinstall git+https://github.com/ambitious-octopus/model_optimization.git@get-output-fix
class PostProcessWrapper(nn.Module):
def __init__(self,
model: nn.Module,
score_threshold: float = 0.001,
iou_threshold: float = 0.7,
max_detections: int = 300):
"""
Wrapping PyTorch Module with multiclass_nms layer from sony_custom_layers.
Args:
model (nn.Module): Model instance.
score_threshold (float): Score threshold for non-maximum suppression.
iou_threshold (float): Intersection over union threshold for non-maximum suppression.
max_detections (float): The number of detections to return.
"""
super(PostProcessWrapper, self).__init__()
self.model = model
self.score_threshold = score_threshold
self.iou_threshold = iou_threshold
self.max_detections = max_detections
def forward(self, images):
# model inference
outputs = self.model(images)
boxes = outputs[0]
scores = outputs[1]
nms = multiclass_nms(boxes=boxes, scores=scores, score_threshold=self.score_threshold,
iou_threshold=self.iou_threshold, max_detections=self.max_detections)
return nms
def representative_dataset_gen(dataloader=self.get_int8_calibration_dataloader(prefix)):
for batch in dataloader:
img = batch["img"]
@ -1054,11 +1088,15 @@ class Exporter:
target_resource_utilization=resource_utilization,
core_config=config,
target_platform_capabilities=tpc)
# Get working device
device = mct.core.pytorch.pytorch_device_config.get_working_device()
quant_model_pp = PostProcessWrapper(model=quant_model).to(device=device)
f = str(self.file).replace(self.file.suffix, "_mct_model.onnx")
mct.exporter.pytorch_export_model(model=quant_model,
mct.exporter.pytorch_export_model(model=quant_model_pp,
save_model_path=f,
repr_dataset=representative_dataset_gen)
return f, None
def _add_tflite_metadata(self, file):

@ -121,6 +121,7 @@ class AutoBackend(nn.Module):
paddle,
ncnn,
triton,
mct
) = self._model_type(w)
fp16 &= pt or jit or onnx or xml or engine or nn_module or triton # FP16
nhwc = coreml or saved_model or pb or tflite or edgetpu # BHWC formats (vs torch BCWH)

Loading…
Cancel
Save