mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo] Support nonstrict_trace on class method (#147571)
As title, also see 1. new test `test_nonstrict_trace_on_method` for example. 2. newly added comments for why we need special treatment on methods. Pull Request resolved: https://github.com/pytorch/pytorch/pull/147571 Approved by: https://github.com/zou3519 ghstack dependencies: #146714, #146367, #146950
This commit is contained in:
parent
7e0ef2c844
commit
73e963459e
|
|
@ -490,6 +490,33 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
|||
res = opt_fn(x)
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
def test_nonstrict_trace_on_method(self):
|
||||
class Num:
|
||||
def __init__(self, n):
|
||||
self.n = n
|
||||
|
||||
@torch._dynamo.nonstrict_trace
|
||||
def trace_me(self, t):
|
||||
torch._dynamo.graph_break()
|
||||
return t + self.n
|
||||
|
||||
torch.utils._pytree.register_pytree_node(
|
||||
Num,
|
||||
lambda num: ((num.n,), ()),
|
||||
lambda n, _: Num(n[0]),
|
||||
)
|
||||
|
||||
def fn(x, n):
|
||||
num = Num(n)
|
||||
return num.trace_me(x)
|
||||
|
||||
x, n = torch.randn(10), 42
|
||||
opt_fn = torch.compile(fn, fullgraph=True, backend="aot_eager")
|
||||
|
||||
ref = fn(x, n)
|
||||
res = opt_fn(x, n)
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
def test_nonstrict_trace_no_action_at_a_distance(self):
|
||||
def trace_me(x):
|
||||
torch._dynamo.graph_break()
|
||||
|
|
|
|||
|
|
@ -879,6 +879,23 @@ class UserMethodVariable(UserFunctionVariable):
|
|||
args: "list[VariableTracker]",
|
||||
kwargs: "dict[str, VariableTracker]",
|
||||
) -> "VariableTracker":
|
||||
# NOTE this is to handle methods annotated by `nonstrict_trace`. Usually
|
||||
# a `nonstrict_trace`-ed function will be wrapped by
|
||||
# `VariableTracker.build` and route to `TorchInGraphFunctionVariable`,
|
||||
# but in the case of method, we manually wrap it with `UserMethodVariable`
|
||||
# inside `UserDefinedObjectVariable.var_getattr`.
|
||||
#
|
||||
# We might be able to simplify this away by canonicalizing the
|
||||
# function/method wrapping code paths.
|
||||
from ..trace_rules import is_nonstrict_trace_callable
|
||||
|
||||
if is_nonstrict_trace_callable(self.fn):
|
||||
call_args = [*self.self_args(), *args]
|
||||
var = variables.TorchInGraphFunctionVariable(
|
||||
self.fn, nonstrict_traceable=True
|
||||
)
|
||||
return var.call_function(tx, call_args, kwargs)
|
||||
|
||||
# For nn.Module methods, redirecting to NNModuleVariable.call_method for optimized solution
|
||||
# rather than simple inlining. E.g, putting `call_method` op in FX graph for `forward` method
|
||||
# since we ensure `forward` of allowed modules can be traced by AOT safely.
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user