update torch.fx

test-quan
Laughing-q 1 month ago
parent a93248e89c
commit 9bd4de7431
  1. 23
      ultralytics/nn/tasks.py
  2. 51
      ultralytics/utils/torch_utils.py

@ -82,6 +82,7 @@ from ultralytics.utils.torch_utils import (
model_info,
scale_img,
time_sync,
torchfx,
)
try:
@ -90,28 +91,6 @@ except ImportError:
thop = None
def torchfx():
try:
from torch.fx._symbolic_trace import is_fx_tracing
except ModuleNotFoundError: # 1.x torch versions does not have this module.
is_fx_tracing = None
def decorator(func):
"""Decorator to apply temporary rc parameters and backend to a function."""
def wrapper(self, x, *args, **kwargs):
"""Sets rc parameters and backend, calls the original function, and restores the settings."""
if is_fx_tracing is not None and is_fx_tracing():
result = func(self, x=x) # torch.fx does not work with `*args` and `**kwargs` function argument
else:
result = func(self, x=x, *args, **kwargs)
return result
return wrapper
return decorator
class BaseModel(nn.Module):
"""The BaseModel class serves as a base class for all the models in the Ultralytics YOLO family."""

@ -685,6 +685,57 @@ def profile(input, ops, n=10, device=None):
return results
def torchfx():
"""
Decorator factory to handle torch.fx tracing in a function.
This decorator checks if torch.fx tracing is active and modifies the function call accordingly.
If torch.fx tracing is active, it avoids passing `*args` and `**kwargs` to the function,
which are not supported by torch.fx.
Returns:
Callable: A decorator function that wraps the original function to handle torch.fx tracing.
"""
try:
from torch.fx._symbolic_trace import is_fx_tracing
except ModuleNotFoundError: # 1.x torch versions do not have this module.
is_fx_tracing = None
def decorator(func):
"""
Decorator to handle torch.fx tracing within a function.
Args:
func (Callable): The function to be decorated.
Returns:
Callable: The wrapped function that handles torch.fx tracing.
"""
def wrapper(self, x, *args, **kwargs):
"""
Wrapper function that checks for torch.fx tracing and calls the original function accordingly.
Args:
self (object): The instance of the class.
x (Any): The main input to the function.
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
Any: The result of the original function call.
"""
if is_fx_tracing is not None and is_fx_tracing():
result = func(self, x=x) # torch.fx does not work with `*args` and `**kwargs` function arguments
else:
result = func(self, x=x, *args, **kwargs)
return result
return wrapper
return decorator
class EarlyStopping:
"""Early stopping class that stops training when a specified number of epochs have passed without improvement."""

Loading…
Cancel
Save