|
|
|
@ -685,6 +685,57 @@ def profile(input, ops, n=10, device=None): |
|
|
|
|
return results |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def torchfx(): |
|
|
|
|
""" |
|
|
|
|
Decorator factory to handle torch.fx tracing in a function. |
|
|
|
|
|
|
|
|
|
This decorator checks if torch.fx tracing is active and modifies the function call accordingly. |
|
|
|
|
If torch.fx tracing is active, it avoids passing `*args` and `**kwargs` to the function, |
|
|
|
|
which are not supported by torch.fx. |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
Callable: A decorator function that wraps the original function to handle torch.fx tracing. |
|
|
|
|
""" |
|
|
|
|
try: |
|
|
|
|
from torch.fx._symbolic_trace import is_fx_tracing |
|
|
|
|
except ModuleNotFoundError: # 1.x torch versions do not have this module. |
|
|
|
|
is_fx_tracing = None |
|
|
|
|
|
|
|
|
|
def decorator(func): |
|
|
|
|
""" |
|
|
|
|
Decorator to handle torch.fx tracing within a function. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
func (Callable): The function to be decorated. |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
Callable: The wrapped function that handles torch.fx tracing. |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
def wrapper(self, x, *args, **kwargs): |
|
|
|
|
""" |
|
|
|
|
Wrapper function that checks for torch.fx tracing and calls the original function accordingly. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
self (object): The instance of the class. |
|
|
|
|
x (Any): The main input to the function. |
|
|
|
|
*args: Additional positional arguments. |
|
|
|
|
**kwargs: Additional keyword arguments. |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
Any: The result of the original function call. |
|
|
|
|
""" |
|
|
|
|
if is_fx_tracing is not None and is_fx_tracing(): |
|
|
|
|
result = func(self, x=x) # torch.fx does not work with `*args` and `**kwargs` function arguments |
|
|
|
|
else: |
|
|
|
|
result = func(self, x=x, *args, **kwargs) |
|
|
|
|
return result |
|
|
|
|
|
|
|
|
|
return wrapper |
|
|
|
|
|
|
|
|
|
return decorator |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class EarlyStopping: |
|
|
|
|
"""Early stopping class that stops training when a specified number of epochs have passed without improvement.""" |
|
|
|
|
|
|
|
|
|