diff --git a/.gitignore b/.gitignore index 4e0f0845b2..0d4b744d3f 100644 --- a/.gitignore +++ b/.gitignore @@ -163,6 +163,7 @@ weights/ *_openvino_model/ *_paddle_model/ *_ncnn_model/ +*_imx_model/ pnnx* # Autogenerated files for tests diff --git a/docs/en/integrations/index.md b/docs/en/integrations/index.md index 05af439936..8ed822bda7 100644 --- a/docs/en/integrations/index.md +++ b/docs/en/integrations/index.md @@ -61,6 +61,8 @@ Welcome to the Ultralytics Integrations page! This page provides an overview of - [Albumentations](albumentations.md): Enhance your Ultralytics models with powerful image augmentations to improve model robustness and generalization. +- [SONY IMX500](sony-imx500.md): Optimize and deploy [Ultralytics YOLOv8](https://docs.ultralytics.com/models/yolov8/) models on Raspberry Pi AI Cameras with the IMX500 sensor for fast, low-power performance. + ## Deployment Integrations - [CoreML](coreml.md): CoreML, developed by [Apple](https://www.apple.com/), is a framework designed for efficiently integrating machine learning models into applications across iOS, macOS, watchOS, and tvOS, using Apple's hardware for effective and secure [model deployment](https://www.ultralytics.com/glossary/model-deployment). diff --git a/docs/en/integrations/sony-imx500.md b/docs/en/integrations/sony-imx500.md index 43dbc133f8..335daf51fc 100644 --- a/docs/en/integrations/sony-imx500.md +++ b/docs/en/integrations/sony-imx500.md @@ -4,7 +4,7 @@ description: Learn to export Ultralytics YOLOv8 models to Sony's IMX500 format t keywords: Sony, IMX500, IMX 500, Atrios, MCT, model export, quantization, pruning, deep learning optimization, Raspberry Pi AI Camera, edge AI, PyTorch, IMX --- -# IMX500 Export for Ultralytics YOLOv8 +# Sony IMX500 Export for Ultralytics YOLOv8 This guide covers exporting and deploying Ultralytics YOLOv8 models to Raspberry Pi AI Cameras that feature the Sony IMX500 sensor. diff --git a/docs/en/macros/export-table.md b/docs/en/macros/export-table.md index b7134f42b8..ac9b352a26 100644 --- a/docs/en/macros/export-table.md +++ b/docs/en/macros/export-table.md @@ -14,3 +14,4 @@ | [PaddlePaddle](../integrations/paddlepaddle.md) | `paddle` | `{{ model_name or "yolo11n" }}_paddle_model/` | ✅ | `imgsz`, `batch` | | [MNN](../integrations/mnn.md) | `mnn` | `{{ model_name or "yolo11n" }}.mnn` | ✅ | `imgsz`, `batch`, `int8`, `half` | | [NCNN](../integrations/ncnn.md) | `ncnn` | `{{ model_name or "yolo11n" }}_ncnn_model/` | ✅ | `imgsz`, `half`, `batch` | +| [IMX500](../integrations/sony-imx500.md) | `imx` | `{{ model_name or "yolo11n" }}_imx_model/` | ✅ | `imgsz`, `int8` | diff --git a/docs/en/reference/utils/torch_utils.md b/docs/en/reference/utils/torch_utils.md index ac31ec2c33..8ec53d8269 100644 --- a/docs/en/reference/utils/torch_utils.md +++ b/docs/en/reference/utils/torch_utils.md @@ -19,6 +19,10 @@ keywords: Ultralytics, torch utils, model optimization, device selection, infere



+## ::: ultralytics.utils.torch_utils.FXModel + +



+ ## ::: ultralytics.utils.torch_utils.torch_distributed_zero_first



diff --git a/docs/mkdocs_github_authors.yaml b/docs/mkdocs_github_authors.yaml index 6d91127d59..49360cf687 100644 --- a/docs/mkdocs_github_authors.yaml +++ b/docs/mkdocs_github_authors.yaml @@ -109,6 +109,9 @@ chr043416@gmail.com: davis.justin@mssm.org: avatar: https://avatars.githubusercontent.com/u/23462437?v=4 username: justincdavis +francesco.mttl@gmail.com: + avatar: https://avatars.githubusercontent.com/u/3855193?v=4 + username: ambitious-octopus glenn.jocher@ultralytics.com: avatar: https://avatars.githubusercontent.com/u/26833433?v=4 username: glenn-jocher diff --git a/mkdocs.yml b/mkdocs.yml index 20d8ec3bf1..04d734430d 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -412,12 +412,14 @@ nav: - TF.js: integrations/tfjs.md - TFLite: integrations/tflite.md - TFLite Edge TPU: integrations/edge-tpu.md + - Sony IMX500: integrations/sony-imx500.md - TensorBoard: integrations/tensorboard.md - TensorRT: integrations/tensorrt.md - TorchScript: integrations/torchscript.md - VS Code: integrations/vscode.md - Weights & Biases: integrations/weights-biases.md - Albumentations: integrations/albumentations.md + - SONY IMX500: integrations/sony-imx500.md - HUB: - hub/index.md - Web: @@ -559,7 +561,6 @@ nav: - utils: reference/nn/modules/utils.md - tasks: reference/nn/tasks.md - solutions: - - solutions: reference/solutions/solutions.md - ai_gym: reference/solutions/ai_gym.md - analytics: reference/solutions/analytics.md - distance_calculation: reference/solutions/distance_calculation.md @@ -567,6 +568,7 @@ nav: - object_counter: reference/solutions/object_counter.md - parking_management: reference/solutions/parking_management.md - queue_management: reference/solutions/queue_management.md + - solutions: reference/solutions/solutions.md - speed_estimation: reference/solutions/speed_estimation.md - streamlit_inference: reference/solutions/streamlit_inference.md - trackers: diff --git a/tests/test_exports.py b/tests/test_exports.py index 5a54b1afa6..e540e7d757 100644 --- a/tests/test_exports.py +++ b/tests/test_exports.py @@ -205,3 +205,12 @@ def test_export_ncnn(): """Test YOLO exports to NCNN format.""" file = YOLO(MODEL).export(format="ncnn", imgsz=32) YOLO(file)(SOURCE, imgsz=32) # exported model inference + + +@pytest.mark.skipif(True, reason="Test disabled as keras and tensorflow version conflicts with tflite export.") +@pytest.mark.skipif(not LINUX or MACOS, reason="Skipping test on Windows and Macos") +def test_export_imx(): + """Test YOLOv8n exports to IMX format.""" + model = YOLO("yolov8n.pt") + file = model.export(format="imx", imgsz=32) + YOLO(file)(SOURCE, imgsz=32) diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index f6b1d2e783..2ff53681a3 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -__version__ = "8.3.28" +__version__ = "8.3.29" import os diff --git a/ultralytics/engine/exporter.py b/ultralytics/engine/exporter.py index 9fca6c28d9..c618e794b5 100644 --- a/ultralytics/engine/exporter.py +++ b/ultralytics/engine/exporter.py @@ -18,6 +18,7 @@ TensorFlow.js | `tfjs` | yolo11n_web_model/ PaddlePaddle | `paddle` | yolo11n_paddle_model/ MNN | `mnn` | yolo11n.mnn NCNN | `ncnn` | yolo11n_ncnn_model/ +IMX | `imx` | yolo11n_imx_model/ Requirements: $ pip install "ultralytics[export]" @@ -44,6 +45,7 @@ Inference: yolo11n_paddle_model # PaddlePaddle yolo11n.mnn # MNN yolo11n_ncnn_model # NCNN + yolo11n_imx_model # IMX TensorFlow.js: $ cd .. && git clone https://github.com/zldrobit/tfjs-yolov5-example.git && cd tfjs-yolov5-example @@ -94,7 +96,7 @@ from ultralytics.utils.checks import check_imgsz, check_is_path_safe, check_requ from ultralytics.utils.downloads import attempt_download_asset, get_github_assets, safe_download from ultralytics.utils.files import file_size, spaces_in_path from ultralytics.utils.ops import Profile -from ultralytics.utils.torch_utils import TORCH_1_13, get_latest_opset, select_device, smart_inference_mode +from ultralytics.utils.torch_utils import TORCH_1_13, get_latest_opset, select_device def export_formats(): @@ -114,6 +116,7 @@ def export_formats(): ["PaddlePaddle", "paddle", "_paddle_model", True, True], ["MNN", "mnn", ".mnn", True, True], ["NCNN", "ncnn", "_ncnn_model", True, True], + ["IMX", "imx", "_imx_model", True, True], ] return dict(zip(["Format", "Argument", "Suffix", "CPU", "GPU"], zip(*x))) @@ -171,7 +174,6 @@ class Exporter: self.callbacks = _callbacks or callbacks.get_default_callbacks() callbacks.add_integration_callbacks(self) - @smart_inference_mode() def __call__(self, model=None) -> str: """Returns list of exported files/dirs after running callbacks.""" self.run_callbacks("on_export_start") @@ -194,9 +196,22 @@ class Exporter: flags = [x == fmt for x in fmts] if sum(flags) != 1: raise ValueError(f"Invalid export format='{fmt}'. Valid formats are {fmts}") - jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, mnn, ncnn = ( - flags # export booleans - ) + ( + jit, + onnx, + xml, + engine, + coreml, + saved_model, + pb, + tflite, + edgetpu, + tfjs, + paddle, + mnn, + ncnn, + imx, + ) = flags # export booleans is_tf_format = any((saved_model, pb, tflite, edgetpu, tfjs)) # Device @@ -210,6 +225,9 @@ class Exporter: self.device = select_device("cpu" if self.args.device is None else self.args.device) # Checks + if imx and not self.args.int8: + LOGGER.warning("WARNING ⚠️ IMX only supports int8 export, setting int8=True.") + self.args.int8 = True if not hasattr(model, "names"): model.names = default_class_names() model.names = check_class_names(model.names) @@ -249,6 +267,7 @@ class Exporter: ) if mnn and (IS_RASPBERRYPI or IS_JETSON): raise SystemError("MNN export not supported on Raspberry Pi and NVIDIA Jetson") + # Input im = torch.zeros(self.args.batch, 3, *self.imgsz).to(self.device) file = Path( @@ -264,6 +283,11 @@ class Exporter: model.eval() model.float() model = model.fuse() + + if imx: + from ultralytics.utils.torch_utils import FXModel + + model = FXModel(model) for m in model.modules(): if isinstance(m, (Detect, RTDETRDecoder)): # includes all Detect subclasses like Segment, Pose, OBB m.dynamic = self.args.dynamic @@ -273,6 +297,15 @@ class Exporter: elif isinstance(m, C2f) and not is_tf_format: # EdgeTPU does not support FlexSplitV while split provides cleaner ONNX graph m.forward = m.forward_split + if isinstance(m, Detect) and imx: + from ultralytics.utils.tal import make_anchors + + m.anchors, m.strides = ( + x.transpose(0, 1) + for x in make_anchors( + torch.cat([s / m.stride.unsqueeze(-1) for s in self.imgsz], dim=1), m.stride, 0.5 + ) + ) y = None for _ in range(2): @@ -347,6 +380,8 @@ class Exporter: f[11], _ = self.export_mnn() if ncnn: # NCNN f[12], _ = self.export_ncnn() + if imx: + f[13], _ = self.export_imx() # Finish f = [str(x) for x in f if x] # filter out '' and None @@ -1068,6 +1103,137 @@ class Exporter: yaml_save(Path(f) / "metadata.yaml", self.metadata) # add metadata.yaml return f, None + @try_export + def export_imx(self, prefix=colorstr("IMX:")): + """YOLO IMX export.""" + gptq = False + assert LINUX, "export only supported on Linux. See https://developer.aitrios.sony-semicon.com/en/raspberrypi-ai-camera/documentation/imx500-converter" + if getattr(self.model, "end2end", False): + raise ValueError("IMX export is not supported for end2end models.") + if "C2f" not in self.model.__str__(): + raise ValueError("IMX export is only supported for YOLOv8 detection models") + check_requirements(("model-compression-toolkit==2.1.1", "sony-custom-layers==0.2.0", "tensorflow==2.12.0")) + check_requirements("imx500-converter[pt]==3.14.3") # Separate requirements for imx500-converter + + import model_compression_toolkit as mct + import onnx + from sony_custom_layers.pytorch.object_detection.nms import multiclass_nms + + try: + out = subprocess.run( + ["java", "--version"], check=True, capture_output=True + ) # Java 17 is required for imx500-converter + if "openjdk 17" not in str(out.stdout): + raise FileNotFoundError + except FileNotFoundError: + subprocess.run(["sudo", "apt", "install", "-y", "openjdk-17-jdk", "openjdk-17-jre"], check=True) + + def representative_dataset_gen(dataloader=self.get_int8_calibration_dataloader(prefix)): + for batch in dataloader: + img = batch["img"] + img = img / 255.0 + yield [img] + + tpc = mct.get_target_platform_capabilities( + fw_name="pytorch", target_platform_name="imx500", target_platform_version="v1" + ) + + config = mct.core.CoreConfig( + mixed_precision_config=mct.core.MixedPrecisionQuantizationConfig(num_of_images=10), + quantization_config=mct.core.QuantizationConfig(concat_threshold_update=True), + ) + + resource_utilization = mct.core.ResourceUtilization(weights_memory=3146176 * 0.76) + + quant_model = ( + mct.gptq.pytorch_gradient_post_training_quantization( # Perform Gradient-Based Post Training Quantization + model=self.model, + representative_data_gen=representative_dataset_gen, + target_resource_utilization=resource_utilization, + gptq_config=mct.gptq.get_pytorch_gptq_config(n_epochs=1000, use_hessian_based_weights=False), + core_config=config, + target_platform_capabilities=tpc, + )[0] + if gptq + else mct.ptq.pytorch_post_training_quantization( # Perform post training quantization + in_module=self.model, + representative_data_gen=representative_dataset_gen, + target_resource_utilization=resource_utilization, + core_config=config, + target_platform_capabilities=tpc, + )[0] + ) + + class NMSWrapper(torch.nn.Module): + def __init__( + self, + model: torch.nn.Module, + score_threshold: float = 0.001, + iou_threshold: float = 0.7, + max_detections: int = 300, + ): + """ + Wrapping PyTorch Module with multiclass_nms layer from sony_custom_layers. + + Args: + model (nn.Module): Model instance. + score_threshold (float): Score threshold for non-maximum suppression. + iou_threshold (float): Intersection over union threshold for non-maximum suppression. + max_detections (float): The number of detections to return. + """ + super().__init__() + self.model = model + self.score_threshold = score_threshold + self.iou_threshold = iou_threshold + self.max_detections = max_detections + + def forward(self, images): + # model inference + outputs = self.model(images) + + boxes = outputs[0] + scores = outputs[1] + nms = multiclass_nms( + boxes=boxes, + scores=scores, + score_threshold=self.score_threshold, + iou_threshold=self.iou_threshold, + max_detections=self.max_detections, + ) + return nms + + quant_model = NMSWrapper( + model=quant_model, + score_threshold=self.args.conf or 0.001, + iou_threshold=self.args.iou, + max_detections=self.args.max_det, + ).to(self.device) + + f = Path(str(self.file).replace(self.file.suffix, "_imx_model")) + f.mkdir(exist_ok=True) + onnx_model = f / Path(str(self.file).replace(self.file.suffix, "_imx.onnx")) # js dir + mct.exporter.pytorch_export_model( + model=quant_model, save_model_path=onnx_model, repr_dataset=representative_dataset_gen + ) + + model_onnx = onnx.load(onnx_model) # load onnx model + for k, v in self.metadata.items(): + meta = model_onnx.metadata_props.add() + meta.key, meta.value = k, str(v) + + onnx.save(model_onnx, onnx_model) + + subprocess.run( + ["imxconv-pt", "-i", str(onnx_model), "-o", str(f), "--no-input-persistency", "--overwrite-output"], + check=True, + ) + + # Needed for imx models. + with open(f / "labels.txt", "w") as file: + file.writelines([f"{name}\n" for _, name in self.model.names.items()]) + + return f, None + def _add_tflite_metadata(self, file): """Add metadata to *.tflite models per https://www.tensorflow.org/lite/models/convert/metadata.""" import flatbuffers diff --git a/ultralytics/nn/autobackend.py b/ultralytics/nn/autobackend.py index cef05a3571..60b9f6389a 100644 --- a/ultralytics/nn/autobackend.py +++ b/ultralytics/nn/autobackend.py @@ -123,6 +123,7 @@ class AutoBackend(nn.Module): paddle, mnn, ncnn, + imx, triton, ) = self._model_type(w) fp16 &= pt or jit or onnx or xml or engine or nn_module or triton # FP16 @@ -182,8 +183,8 @@ class AutoBackend(nn.Module): check_requirements("opencv-python>=4.5.4") net = cv2.dnn.readNetFromONNX(w) - # ONNX Runtime - elif onnx: + # ONNX Runtime and IMX + elif onnx or imx: LOGGER.info(f"Loading {w} for ONNX Runtime inference...") check_requirements(("onnx", "onnxruntime-gpu" if cuda else "onnxruntime")) if IS_RASPBERRYPI or IS_JETSON: @@ -199,7 +200,22 @@ class AutoBackend(nn.Module): device = torch.device("cpu") cuda = False LOGGER.info(f"Preferring ONNX Runtime {providers[0]}") - session = onnxruntime.InferenceSession(w, providers=providers) + if onnx: + session = onnxruntime.InferenceSession(w, providers=providers) + else: + check_requirements( + ["model-compression-toolkit==2.1.1", "sony-custom-layers[torch]==0.2.0", "onnxruntime-extensions"] + ) + w = next(Path(w).glob("*.onnx")) + LOGGER.info(f"Loading {w} for ONNX IMX inference...") + import mct_quantizers as mctq + from sony_custom_layers.pytorch.object_detection import nms_ort # noqa + + session = onnxruntime.InferenceSession( + w, mctq.get_ort_session_options(), providers=["CPUExecutionProvider"] + ) + task = "detect" + output_names = [x.name for x in session.get_outputs()] metadata = session.get_modelmeta().custom_metadata_map dynamic = isinstance(session.get_outputs()[0].shape[0], str) @@ -520,7 +536,7 @@ class AutoBackend(nn.Module): y = self.net.forward() # ONNX Runtime - elif self.onnx: + elif self.onnx or self.imx: if self.dynamic: im = im.cpu().numpy() # torch to numpy y = self.session.run(self.output_names, {self.session.get_inputs()[0].name: im}) @@ -537,6 +553,9 @@ class AutoBackend(nn.Module): ) self.session.run_with_iobinding(self.io) y = self.bindings + if self.imx: + # boxes, conf, cls + y = np.concatenate([y[0], y[1][:, :, None], y[2][:, :, None]], axis=-1) # OpenVINO elif self.xml: diff --git a/ultralytics/nn/modules/block.py b/ultralytics/nn/modules/block.py index 7208ea639b..08188b6e7a 100644 --- a/ultralytics/nn/modules/block.py +++ b/ultralytics/nn/modules/block.py @@ -240,7 +240,8 @@ class C2f(nn.Module): def forward_split(self, x): """Forward pass using split() instead of chunk().""" - y = list(self.cv1(x).split((self.c, self.c), 1)) + y = self.cv1(x).split((self.c, self.c), 1) + y = [y[0], y[1]] y.extend(m(y[-1]) for m in self.m) return self.cv2(torch.cat(y, 1)) diff --git a/ultralytics/nn/modules/head.py b/ultralytics/nn/modules/head.py index 84c31709ca..29a1953e47 100644 --- a/ultralytics/nn/modules/head.py +++ b/ultralytics/nn/modules/head.py @@ -23,6 +23,7 @@ class Detect(nn.Module): dynamic = False # force grid reconstruction export = False # export mode + format = None # export format end2end = False # end2end max_det = 300 # max_det shape = None @@ -101,7 +102,7 @@ class Detect(nn.Module): # Inference path shape = x[0].shape # BCHW x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2) - if self.dynamic or self.shape != shape: + if self.format != "imx" and (self.dynamic or self.shape != shape): self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5)) self.shape = shape @@ -119,6 +120,11 @@ class Detect(nn.Module): grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=box.device).reshape(1, 4, 1) norm = self.strides / (self.stride[0] * grid_size) dbox = self.decode_bboxes(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2]) + elif self.export and self.format == "imx": + dbox = self.decode_bboxes( + self.dfl(box) * self.strides, self.anchors.unsqueeze(0) * self.strides, xywh=False + ) + return dbox.transpose(1, 2), cls.sigmoid().permute(0, 2, 1) else: dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides @@ -137,9 +143,9 @@ class Detect(nn.Module): a[-1].bias.data[:] = 1.0 # box b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img) - def decode_bboxes(self, bboxes, anchors): + def decode_bboxes(self, bboxes, anchors, xywh=True): """Decode bounding boxes.""" - return dist2bbox(bboxes, anchors, xywh=not self.end2end, dim=1) + return dist2bbox(bboxes, anchors, xywh=xywh and (not self.end2end), dim=1) @staticmethod def postprocess(preds: torch.Tensor, max_det: int, nc: int = 80): diff --git a/ultralytics/utils/benchmarks.py b/ultralytics/utils/benchmarks.py index 13d940780f..24e8ea9a1d 100644 --- a/ultralytics/utils/benchmarks.py +++ b/ultralytics/utils/benchmarks.py @@ -118,6 +118,11 @@ def benchmark( assert not IS_JETSON, "MNN export not supported on NVIDIA Jetson" if i == 13: # NCNN assert not isinstance(model, YOLOWorld), "YOLOWorldv2 NCNN exports not supported yet" + if i == 14: # IMX + assert not is_end2end + assert not isinstance(model, YOLOWorld), "YOLOWorldv2 IMX exports not supported" + assert model.task == "detect", "IMX only supported for detection task" + assert "C2f" in model.__str__(), "IMX only supported for YOLOv8" if "cpu" in device.type: assert cpu, "inference not supported on CPU" if "cuda" in device.type: diff --git a/ultralytics/utils/tal.py b/ultralytics/utils/tal.py index 74604eda23..9fb5020923 100644 --- a/ultralytics/utils/tal.py +++ b/ultralytics/utils/tal.py @@ -306,7 +306,7 @@ def make_anchors(feats, strides, grid_cell_offset=0.5): assert feats is not None dtype, device = feats[0].dtype, feats[0].device for i, stride in enumerate(strides): - _, _, h, w = feats[i].shape + h, w = feats[i].shape[2:] if isinstance(feats, list) else (int(feats[i][0]), int(feats[i][1])) sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset # shift x sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset # shift y sy, sx = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_10 else torch.meshgrid(sy, sx) diff --git a/ultralytics/utils/torch_utils.py b/ultralytics/utils/torch_utils.py index 0dbc728e23..966e980f1b 100644 --- a/ultralytics/utils/torch_utils.py +++ b/ultralytics/utils/torch_utils.py @@ -729,3 +729,48 @@ class EarlyStopping: f"i.e. `patience=300` or use `patience=0` to disable EarlyStopping." ) return stop + + +class FXModel(nn.Module): + """ + A custom model class for torch.fx compatibility. + + This class extends `torch.nn.Module` and is designed to ensure compatibility with torch.fx for tracing and graph manipulation. + It copies attributes from an existing model and explicitly sets the model attribute to ensure proper copying. + + Args: + model (torch.nn.Module): The original model to wrap for torch.fx compatibility. + """ + + def __init__(self, model): + """ + Initialize the FXModel. + + Args: + model (torch.nn.Module): The original model to wrap for torch.fx compatibility. + """ + super().__init__() + copy_attr(self, model) + # Explicitly set `model` since `copy_attr` somehow does not copy it. + self.model = model.model + + def forward(self, x): + """ + Forward pass through the model. + + This method performs the forward pass through the model, handling the dependencies between layers and saving intermediate outputs. + + Args: + x (torch.Tensor): The input tensor to the model. + + Returns: + (torch.Tensor): The output tensor from the model. + """ + y = [] # outputs + for m in self.model: + if m.f != -1: # if not from previous layer + # from earlier layers + x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] + x = m(x) # run + y.append(x) # save output + return x