diff --git a/ultralytics/cfg/default.yaml b/ultralytics/cfg/default.yaml index 205dd6559..16276f9e2 100644 --- a/ultralytics/cfg/default.yaml +++ b/ultralytics/cfg/default.yaml @@ -85,7 +85,7 @@ simplify: False # (bool) ONNX: simplify model using `onnxslim` opset: # (int, optional) ONNX: opset version workspace: 4 # (int) TensorRT: workspace size (GB) nms: False # (bool) CoreML: add NMS - +gptq: False # Hyperparameters ------------------------------------------------------------------------------------------------------ lr0: 0.01 # (float) initial learning rate (i.e. SGD=1E-2, Adam=1E-3) lrf: 0.01 # (float) final learning rate (lr0 * lrf) diff --git a/ultralytics/engine/exporter.py b/ultralytics/engine/exporter.py index e7663a7b3..68747c447 100644 --- a/ultralytics/engine/exporter.py +++ b/ultralytics/engine/exporter.py @@ -1037,55 +1037,57 @@ class Exporter: def export_mct(self, prefix=colorstr("Sony MCT:")): # pip install --upgrade -force-reinstall git+https://github.com/ambitious-octopus/model_optimization.git@get-output-fix import model_compression_toolkit as mct - # from model_compression_toolkit.core.pytorch.pytorch_device_config import get_working_device - - # pip install sony-custom-layers[torch] - # from sony_custom_layers.pytorch.object_detection.nms import multiclass_nms - # - # class PostProcessWrapper(torch.nn.Module): - # def __init__( - # self, - # model: torch.nn.Module, - # score_threshold: float = 0.001, - # iou_threshold: float = 0.7, - # max_detections: int = 300, - # ): - # """ - # Wrapping PyTorch Module with multiclass_nms layer from sony_custom_layers. - # - # Args: - # model (nn.Module): Model instance. - # score_threshold (float): Score threshold for non-maximum suppression. - # iou_threshold (float): Intersection over union threshold for non-maximum suppression. - # max_detections (float): The number of detections to return. - # """ - # super(PostProcessWrapper, self).__init__() - # self.model = model - # self.score_threshold = score_threshold - # self.iou_threshold = iou_threshold - # self.max_detections = max_detections - # - # def forward(self, images): - # # model inference - # outputs = self.model(images) - # - # boxes = outputs[0] - # scores = outputs[1] - # nms = multiclass_nms( - # boxes=boxes, - # scores=scores, - # score_threshold=self.score_threshold, - # iou_threshold=self.iou_threshold, - # max_detections=self.max_detections, - # ) - # return nms + from model_compression_toolkit.core.pytorch.pytorch_device_config import get_working_device, set_working_device + from sony_custom_layers.pytorch.object_detection.nms import multiclass_nms + import onnx + + set_working_device(str(self.device)) + + class PostProcessWrapper(torch.nn.Module): + def __init__( + self, + model: torch.nn.Module, + score_threshold: float = 0.001, + iou_threshold: float = 0.7, + max_detections: int = 300, + ): + """ + Wrapping PyTorch Module with multiclass_nms layer from sony_custom_layers. + + Args: + model (nn.Module): Model instance. + score_threshold (float): Score threshold for non-maximum suppression. + iou_threshold (float): Intersection over union threshold for non-maximum suppression. + max_detections (float): The number of detections to return. + """ + super(PostProcessWrapper, self).__init__() + self.model = model + self.score_threshold = score_threshold + self.iou_threshold = iou_threshold + self.max_detections = max_detections + + def forward(self, images): + # model inference + outputs = self.model(images) + + boxes = outputs[0] + scores = outputs[1] + nms = multiclass_nms( + boxes=boxes, + scores=scores, + score_threshold=self.score_threshold, + iou_threshold=self.iou_threshold, + max_detections=self.max_detections, + ) + return nms def representative_dataset_gen(dataloader=self.get_int8_calibration_dataloader(prefix)): for batch in dataloader: img = batch["img"] img = img / 255.0 yield [img] - + + tpc = mct.get_target_platform_capabilities( fw_name="pytorch", target_platform_name="imx500", target_platform_version="v3" ) @@ -1095,69 +1097,51 @@ class Exporter: ) resource_utilization = mct.core.ResourceUtilization(weights_memory=3146176 * 0.76) + + if not self.args.gptq: + # Perform post training quantization + quant_model, _ = mct.ptq.pytorch_post_training_quantization( + in_module=self.model, + representative_data_gen=representative_dataset_gen, + target_resource_utilization=resource_utilization, + core_config=config, + target_platform_capabilities=tpc, + ) + print("Quantized model is ready") - # Perform post training quantization - quant_model, _ = mct.ptq.pytorch_post_training_quantization( - in_module=self.model, - representative_data_gen=representative_dataset_gen, - target_resource_utilization=resource_utilization, - core_config=config, - target_platform_capabilities=tpc, - ) - print("Quantized model is ready") - - # Define PostProcess params - # score_threshold = 0.001 - # iou_threshold = 0.7 - # max_detections = 300 - - # Get working device - # device = get_working_device() - - # quant_model_pp = PostProcessWrapper( - # model=quant_model, - # score_threshold=score_threshold, - # iou_threshold=iou_threshold, - # max_detections=max_detections, - # ).to(device=device) - - f = Path(str(self.file).replace(self.file.suffix, "_ptq_mct_model.onnx")) # js dir - mct.exporter.pytorch_export_model(model=quant_model, save_model_path=f, repr_dataset=representative_dataset_gen) - - # add metadata - import onnx - - model_onnx = onnx.load(f) # load onnx model - for k, v in self.metadata.items(): - meta = model_onnx.metadata_props.add() - meta.key, meta.value = k, str(v) - - onnx.save(model_onnx, f) - - gptq_config = mct.gptq.get_pytorch_gptq_config(n_epochs=1000, use_hessian_based_weights=False) + else: - # Perform Gradient-Based Post Training Quantization + gptq_config = mct.gptq.get_pytorch_gptq_config(n_epochs=1000, use_hessian_based_weights=False) - gptq_quant_model, _ = mct.gptq.pytorch_gradient_post_training_quantization( - model=self.model, - representative_data_gen=representative_dataset_gen, - target_resource_utilization=resource_utilization, - gptq_config=gptq_config, - core_config=config, - target_platform_capabilities=tpc, - ) + # Perform Gradient-Based Post Training Quantization - print("Quantized-PTQ model is ready") + quant_model, _ = mct.gptq.pytorch_gradient_post_training_quantization( + model=self.model, + representative_data_gen=representative_dataset_gen, + target_resource_utilization=resource_utilization, + gptq_config=gptq_config, + core_config=config, + target_platform_capabilities=tpc, + ) - # gptq_quant_model_pp = PostProcessWrapper( - # model=gptq_quant_model, - # score_threshold=score_threshold, - # iou_threshold=iou_threshold, - # max_detections=max_detections, - # ).to(device=device) - f = Path(str(self.file).replace(self.file.suffix, "_gptq_mct_model.onnx")) # js dir + print("Quantized-PTQ model is ready") + + if self.args.nms: + # Define PostProcess params + score_threshold = 0.001 + iou_threshold = 0.7 + max_detections = 300 + + quant_model = PostProcessWrapper( + model=quant_model, + score_threshold=score_threshold, + iou_threshold=iou_threshold, + max_detections=max_detections, + ).to(device=get_working_device()) + + f = Path(str(self.file).replace(self.file.suffix, "_mct_model.onnx")) # js dir mct.exporter.pytorch_export_model( - model=gptq_quant_model, save_model_path=f, repr_dataset=representative_dataset_gen + model=quant_model, save_model_path=f, repr_dataset=representative_dataset_gen ) model_onnx = onnx.load(f) # load onnx model @@ -1166,7 +1150,7 @@ class Exporter: meta.key, meta.value = k, str(v) onnx.save(model_onnx, f) - return f, model_onnx + return f, None def _add_tflite_metadata(self, file): """Add metadata to *.tflite models per https://www.tensorflow.org/lite/models/convert/metadata.""" diff --git a/ultralytics/nn/autobackend.py b/ultralytics/nn/autobackend.py index 20c2fec9d..795e76a83 100644 --- a/ultralytics/nn/autobackend.py +++ b/ultralytics/nn/autobackend.py @@ -177,6 +177,7 @@ class AutoBackend(nn.Module): providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] if cuda else ["CPUExecutionProvider"] if mct: + LOGGER.info(f"Loading {w} for ONNX MCT quantization inference...") import mct_quantizers as mctq from sony_custom_layers.pytorch.object_detection import nms_ort # noqa