|
|
@ -694,7 +694,7 @@ def torchfx(): |
|
|
|
which are not supported by torch.fx. |
|
|
|
which are not supported by torch.fx. |
|
|
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
Returns: |
|
|
|
Callable: A decorator function that wraps the original function to handle torch.fx tracing. |
|
|
|
(Callable): A decorator function that wraps the original function to handle torch.fx tracing. |
|
|
|
""" |
|
|
|
""" |
|
|
|
try: |
|
|
|
try: |
|
|
|
from torch.fx._symbolic_trace import is_fx_tracing |
|
|
|
from torch.fx._symbolic_trace import is_fx_tracing |
|
|
@ -709,7 +709,7 @@ def torchfx(): |
|
|
|
func (Callable): The function to be decorated. |
|
|
|
func (Callable): The function to be decorated. |
|
|
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
Returns: |
|
|
|
Callable: The wrapped function that handles torch.fx tracing. |
|
|
|
(Callable): The wrapped function that handles torch.fx tracing. |
|
|
|
""" |
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
def wrapper(self, x, *args, **kwargs): |
|
|
|
def wrapper(self, x, *args, **kwargs): |
|
|
@ -723,7 +723,7 @@ def torchfx(): |
|
|
|
**kwargs: Additional keyword arguments. |
|
|
|
**kwargs: Additional keyword arguments. |
|
|
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
Returns: |
|
|
|
Any: The result of the original function call. |
|
|
|
(Any): The result of the original function call. |
|
|
|
""" |
|
|
|
""" |
|
|
|
if is_fx_tracing is not None and is_fx_tracing(): |
|
|
|
if is_fx_tracing is not None and is_fx_tracing(): |
|
|
|
result = func(self, x=x) # torch.fx does not work with `*args` and `**kwargs` function arguments |
|
|
|
result = func(self, x=x) # torch.fx does not work with `*args` and `**kwargs` function arguments |
|
|
|