mct export refactoring

mct-2.1.1
Francesco Mattioli 3 months ago
parent b17c77f3da
commit dac52b9be3
  1. 2
      ultralytics/cfg/default.yaml
  2. 186
      ultralytics/engine/exporter.py
  3. 1
      ultralytics/nn/autobackend.py

@ -85,7 +85,7 @@ simplify: False # (bool) ONNX: simplify model using `onnxslim`
opset: # (int, optional) ONNX: opset version
workspace: 4 # (int) TensorRT: workspace size (GB)
nms: False # (bool) CoreML: add NMS
gptq: False
# Hyperparameters ------------------------------------------------------------------------------------------------------
lr0: 0.01 # (float) initial learning rate (i.e. SGD=1E-2, Adam=1E-3)
lrf: 0.01 # (float) final learning rate (lr0 * lrf)

@ -1037,55 +1037,57 @@ class Exporter:
def export_mct(self, prefix=colorstr("Sony MCT:")):
# pip install --upgrade -force-reinstall git+https://github.com/ambitious-octopus/model_optimization.git@get-output-fix
import model_compression_toolkit as mct
# from model_compression_toolkit.core.pytorch.pytorch_device_config import get_working_device
# pip install sony-custom-layers[torch]
# from sony_custom_layers.pytorch.object_detection.nms import multiclass_nms
#
# class PostProcessWrapper(torch.nn.Module):
# def __init__(
# self,
# model: torch.nn.Module,
# score_threshold: float = 0.001,
# iou_threshold: float = 0.7,
# max_detections: int = 300,
# ):
# """
# Wrapping PyTorch Module with multiclass_nms layer from sony_custom_layers.
#
# Args:
# model (nn.Module): Model instance.
# score_threshold (float): Score threshold for non-maximum suppression.
# iou_threshold (float): Intersection over union threshold for non-maximum suppression.
# max_detections (float): The number of detections to return.
# """
# super(PostProcessWrapper, self).__init__()
# self.model = model
# self.score_threshold = score_threshold
# self.iou_threshold = iou_threshold
# self.max_detections = max_detections
#
# def forward(self, images):
# # model inference
# outputs = self.model(images)
#
# boxes = outputs[0]
# scores = outputs[1]
# nms = multiclass_nms(
# boxes=boxes,
# scores=scores,
# score_threshold=self.score_threshold,
# iou_threshold=self.iou_threshold,
# max_detections=self.max_detections,
# )
# return nms
from model_compression_toolkit.core.pytorch.pytorch_device_config import get_working_device, set_working_device
from sony_custom_layers.pytorch.object_detection.nms import multiclass_nms
import onnx
set_working_device(str(self.device))
class PostProcessWrapper(torch.nn.Module):
def __init__(
self,
model: torch.nn.Module,
score_threshold: float = 0.001,
iou_threshold: float = 0.7,
max_detections: int = 300,
):
"""
Wrapping PyTorch Module with multiclass_nms layer from sony_custom_layers.
Args:
model (nn.Module): Model instance.
score_threshold (float): Score threshold for non-maximum suppression.
iou_threshold (float): Intersection over union threshold for non-maximum suppression.
max_detections (float): The number of detections to return.
"""
super(PostProcessWrapper, self).__init__()
self.model = model
self.score_threshold = score_threshold
self.iou_threshold = iou_threshold
self.max_detections = max_detections
def forward(self, images):
# model inference
outputs = self.model(images)
boxes = outputs[0]
scores = outputs[1]
nms = multiclass_nms(
boxes=boxes,
scores=scores,
score_threshold=self.score_threshold,
iou_threshold=self.iou_threshold,
max_detections=self.max_detections,
)
return nms
def representative_dataset_gen(dataloader=self.get_int8_calibration_dataloader(prefix)):
for batch in dataloader:
img = batch["img"]
img = img / 255.0
yield [img]
tpc = mct.get_target_platform_capabilities(
fw_name="pytorch", target_platform_name="imx500", target_platform_version="v3"
)
@ -1095,69 +1097,51 @@ class Exporter:
)
resource_utilization = mct.core.ResourceUtilization(weights_memory=3146176 * 0.76)
if not self.args.gptq:
# Perform post training quantization
quant_model, _ = mct.ptq.pytorch_post_training_quantization(
in_module=self.model,
representative_data_gen=representative_dataset_gen,
target_resource_utilization=resource_utilization,
core_config=config,
target_platform_capabilities=tpc,
)
print("Quantized model is ready")
# Perform post training quantization
quant_model, _ = mct.ptq.pytorch_post_training_quantization(
in_module=self.model,
representative_data_gen=representative_dataset_gen,
target_resource_utilization=resource_utilization,
core_config=config,
target_platform_capabilities=tpc,
)
print("Quantized model is ready")
# Define PostProcess params
# score_threshold = 0.001
# iou_threshold = 0.7
# max_detections = 300
# Get working device
# device = get_working_device()
# quant_model_pp = PostProcessWrapper(
# model=quant_model,
# score_threshold=score_threshold,
# iou_threshold=iou_threshold,
# max_detections=max_detections,
# ).to(device=device)
f = Path(str(self.file).replace(self.file.suffix, "_ptq_mct_model.onnx")) # js dir
mct.exporter.pytorch_export_model(model=quant_model, save_model_path=f, repr_dataset=representative_dataset_gen)
# add metadata
import onnx
model_onnx = onnx.load(f) # load onnx model
for k, v in self.metadata.items():
meta = model_onnx.metadata_props.add()
meta.key, meta.value = k, str(v)
onnx.save(model_onnx, f)
gptq_config = mct.gptq.get_pytorch_gptq_config(n_epochs=1000, use_hessian_based_weights=False)
else:
# Perform Gradient-Based Post Training Quantization
gptq_config = mct.gptq.get_pytorch_gptq_config(n_epochs=1000, use_hessian_based_weights=False)
gptq_quant_model, _ = mct.gptq.pytorch_gradient_post_training_quantization(
model=self.model,
representative_data_gen=representative_dataset_gen,
target_resource_utilization=resource_utilization,
gptq_config=gptq_config,
core_config=config,
target_platform_capabilities=tpc,
)
# Perform Gradient-Based Post Training Quantization
print("Quantized-PTQ model is ready")
quant_model, _ = mct.gptq.pytorch_gradient_post_training_quantization(
model=self.model,
representative_data_gen=representative_dataset_gen,
target_resource_utilization=resource_utilization,
gptq_config=gptq_config,
core_config=config,
target_platform_capabilities=tpc,
)
# gptq_quant_model_pp = PostProcessWrapper(
# model=gptq_quant_model,
# score_threshold=score_threshold,
# iou_threshold=iou_threshold,
# max_detections=max_detections,
# ).to(device=device)
f = Path(str(self.file).replace(self.file.suffix, "_gptq_mct_model.onnx")) # js dir
print("Quantized-PTQ model is ready")
if self.args.nms:
# Define PostProcess params
score_threshold = 0.001
iou_threshold = 0.7
max_detections = 300
quant_model = PostProcessWrapper(
model=quant_model,
score_threshold=score_threshold,
iou_threshold=iou_threshold,
max_detections=max_detections,
).to(device=get_working_device())
f = Path(str(self.file).replace(self.file.suffix, "_mct_model.onnx")) # js dir
mct.exporter.pytorch_export_model(
model=gptq_quant_model, save_model_path=f, repr_dataset=representative_dataset_gen
model=quant_model, save_model_path=f, repr_dataset=representative_dataset_gen
)
model_onnx = onnx.load(f) # load onnx model
@ -1166,7 +1150,7 @@ class Exporter:
meta.key, meta.value = k, str(v)
onnx.save(model_onnx, f)
return f, model_onnx
return f, None
def _add_tflite_metadata(self, file):
"""Add metadata to *.tflite models per https://www.tensorflow.org/lite/models/convert/metadata."""

@ -177,6 +177,7 @@ class AutoBackend(nn.Module):
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] if cuda else ["CPUExecutionProvider"]
if mct:
LOGGER.info(f"Loading {w} for ONNX MCT quantization inference...")
import mct_quantizers as mctq
from sony_custom_layers.pytorch.object_detection import nms_ort # noqa

Loading…
Cancel
Save