Update torch_utils.py

pull/17463/head^2
Laughing-q 3 months ago
parent 1704b9a9f4
commit b339394537
  1. 6
      ultralytics/utils/torch_utils.py

@ -694,7 +694,7 @@ def torchfx():
which are not supported by torch.fx.
Returns:
Callable: A decorator function that wraps the original function to handle torch.fx tracing.
(Callable): A decorator function that wraps the original function to handle torch.fx tracing.
"""
try:
from torch.fx._symbolic_trace import is_fx_tracing
@ -709,7 +709,7 @@ def torchfx():
func (Callable): The function to be decorated.
Returns:
Callable: The wrapped function that handles torch.fx tracing.
(Callable): The wrapped function that handles torch.fx tracing.
"""
def wrapper(self, x, *args, **kwargs):
@ -723,7 +723,7 @@ def torchfx():
**kwargs: Additional keyword arguments.
Returns:
Any: The result of the original function call.
(Any): The result of the original function call.
"""
if is_fx_tracing is not None and is_fx_tracing():
result = func(self, x=x) # torch.fx does not work with `*args` and `**kwargs` function arguments

Loading…
Cancel
Save