diff --git a/ultralytics/utils/torch_utils.py b/ultralytics/utils/torch_utils.py index c9c9093722..bb6bf9f78b 100644 --- a/ultralytics/utils/torch_utils.py +++ b/ultralytics/utils/torch_utils.py @@ -694,7 +694,7 @@ def torchfx(): which are not supported by torch.fx. Returns: - Callable: A decorator function that wraps the original function to handle torch.fx tracing. + (Callable): A decorator function that wraps the original function to handle torch.fx tracing. """ try: from torch.fx._symbolic_trace import is_fx_tracing @@ -709,7 +709,7 @@ def torchfx(): func (Callable): The function to be decorated. Returns: - Callable: The wrapped function that handles torch.fx tracing. + (Callable): The wrapped function that handles torch.fx tracing. """ def wrapper(self, x, *args, **kwargs): @@ -723,7 +723,7 @@ def torchfx(): **kwargs: Additional keyword arguments. Returns: - Any: The result of the original function call. + (Any): The result of the original function call. """ if is_fx_tracing is not None and is_fx_tracing(): result = func(self, x=x) # torch.fx does not work with `*args` and `**kwargs` function arguments