[dynamo] Support nonstrict_trace on class method (#147571)

As title, also see
1. new test `test_nonstrict_trace_on_method` for example.
2. newly added comments for why we need special treatment on methods.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147571
Approved by: https://github.com/zou3519
ghstack dependencies: #146714, #146367, #146950
This commit is contained in:
Ryan Guo 2025-02-25 14:33:47 -08:00 committed by PyTorch MergeBot
parent 7e0ef2c844
commit 73e963459e
2 changed files with 44 additions and 0 deletions

View File

@ -490,6 +490,33 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
res = opt_fn(x)
self.assertEqual(ref, res)
def test_nonstrict_trace_on_method(self):
class Num:
def __init__(self, n):
self.n = n
@torch._dynamo.nonstrict_trace
def trace_me(self, t):
torch._dynamo.graph_break()
return t + self.n
torch.utils._pytree.register_pytree_node(
Num,
lambda num: ((num.n,), ()),
lambda n, _: Num(n[0]),
)
def fn(x, n):
num = Num(n)
return num.trace_me(x)
x, n = torch.randn(10), 42
opt_fn = torch.compile(fn, fullgraph=True, backend="aot_eager")
ref = fn(x, n)
res = opt_fn(x, n)
self.assertEqual(ref, res)
def test_nonstrict_trace_no_action_at_a_distance(self):
def trace_me(x):
torch._dynamo.graph_break()

View File

@ -879,6 +879,23 @@ class UserMethodVariable(UserFunctionVariable):
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
# NOTE this is to handle methods annotated by `nonstrict_trace`. Usually
# a `nonstrict_trace`-ed function will be wrapped by
# `VariableTracker.build` and route to `TorchInGraphFunctionVariable`,
# but in the case of method, we manually wrap it with `UserMethodVariable`
# inside `UserDefinedObjectVariable.var_getattr`.
#
# We might be able to simplify this away by canonicalizing the
# function/method wrapping code paths.
from ..trace_rules import is_nonstrict_trace_callable
if is_nonstrict_trace_callable(self.fn):
call_args = [*self.self_args(), *args]
var = variables.TorchInGraphFunctionVariable(
self.fn, nonstrict_traceable=True
)
return var.call_function(tx, call_args, kwargs)
# For nn.Module methods, redirecting to NNModuleVariable.call_method for optimized solution
# rather than simple inlining. E.g, putting `call_method` op in FX graph for `forward` method
# since we ensure `forward` of allowed modules can be traced by AOT safely.