[dynamo] Initial support for nonstrict_trace (#146367)

## Context
> **Note:** `mark_traceable` got renamed to `nonstrict_trace` after
> offline discussion. The reasons are (1) it aligns with `torch.export`'s
> `nonstrict` notion, and (2) it's more definitive in behavior suggestion.

1. [Overall Design](https://docs.google.com/document/d/1O-dR2ZQaJQVt_v67AVcDCw2yJLtqgkZFwoXK0buEWRg/edit?tab=t.0)
2. [Dynamo graph representation with `torch._higher_order_ops.flat_apply`](https://docs.google.com/document/d/1YHl5nPTJvYeCPE5TO9uA18DPWNgUYGE4gCn6bFvXcBM/edit?tab=t.0#heading=h.xtw3hhbro4gn)

## Summary
This patch adds a `torch._dynamo.nonstrict_trace` decorator, which
currently is an enhanced version of `torch._dynamo.allow_in_graph` (see
docstring for their differences). Specifically, this patch focuses on
the UI and functionality prototyping/plumbing.

The main enhancement is supporting more input types, and the
implementation challenge lies in reconstructing the input objects from
Dynamo `VariableTracker` (while accounting for buffered side-effects and
guards).  This patch takes a middle-ground (simple implementation with a
bit of user labor), by
1. asking the user to provide pytree registration for non-proxy-able
   input types,
2. letting Dynamo trace through `pytree_flatten` (which accounts for
   buffered side-effects and guards automatically),
3. and passing in the TreeSpec as a graph attribute constant into
   `torch._higher_order_ops.flat_apply` (which unflattens the inputs and
   invokes the underlying function).

## Next Steps
In subsequent patches, we will try to support the following:
- annotating on class method
- reads to global tensors
- inputs that contains `pytree.register_constant`-ed instances.
- function as input
- more output types (e.g., any pytree-registered type)
- `torch.nn.Module` as inputs

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146367
Approved by: https://github.com/zou3519
ghstack dependencies: #146714
This commit is contained in:
Ryan Guo 2025-02-25 14:33:44 -08:00 committed by PyTorch MergeBot
parent bab84f0bd9
commit f46f0e465c
12 changed files with 777 additions and 32 deletions

View File

@ -211,6 +211,439 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
def fn1(x):
return x.cos()
def test_nonstrict_trace_tensor_args(self):
@torch._dynamo.nonstrict_trace
def trace_me(x, y, z):
torch._dynamo.graph_break()
return x * y + z
def fn(x, y):
t0 = x + 1
t1 = trace_me(x, y, t0)
t2 = t1 + y
return t0 * t2
x, y = torch.randn(10), torch.randn(10)
opt_fn = torch.compile(fn, fullgraph=True, backend="aot_eager")
ref = fn(x, y)
res = opt_fn(x, y)
self.assertEqual(ref, res)
def test_nonstrict_trace_pre_existing_dict(self):
@torch._dynamo.nonstrict_trace
def trace_me(x, d):
torch._dynamo.graph_break()
return x * d["a"]
def fn(x, d):
t0 = trace_me(x, d)
return t0 + 1
x = torch.randn(10)
d = {"a": 2}
opt_fn = torch.compile(fn, fullgraph=True, backend="aot_eager")
ref = fn(x, d)
res = opt_fn(x, d)
self.assertEqual(ref, res)
def test_nonstrict_trace_newly_constructed_dict_with_side_effects(self):
@torch._dynamo.nonstrict_trace
def trace_me(x, d):
torch._dynamo.graph_break()
return x * d["a"]
def fn(x):
d = {}
d["a"] = 2
t0 = trace_me(x, d)
return t0 + 1
x = torch.randn(10)
opt_fn = torch.compile(fn, fullgraph=True, backend="aot_eager")
ref = fn(x)
res = opt_fn(x)
self.assertEqual(ref, res)
def test_nonstrict_trace_pre_existing_dict_with_side_effects(self):
@torch._dynamo.nonstrict_trace
def trace_me(x, d):
torch._dynamo.graph_break()
return x * d["a"]
def fn(x, d):
d["a"] = x + 1
t0 = trace_me(x, d)
return t0 + 2
x = torch.randn(10)
d0 = {"a": 0}
d1 = dict(d0)
opt_fn = torch.compile(fn, fullgraph=True, backend="aot_eager")
ref = fn(x, d0)
res = opt_fn(x, d1)
self.assertEqual(ref, res)
self.assertEqual(d0, d1)
def test_nonstrict_trace_pre_existing_custom_class(self):
class Point:
x: torch.Tensor
y: torch.Tensor
def __init__(self, x, y):
self.x = x
self.y = y
torch.utils._pytree.register_pytree_node(
Point,
lambda p: ((p.x, p.y), ()),
lambda xy, _: Point(xy[0], xy[1]),
)
@torch._dynamo.nonstrict_trace
def trace_me(p):
torch._dynamo.graph_break()
return p.x * p.y
def fn(p):
res = trace_me(p)
return res, p.x, p.y
p = Point(torch.ones(10), torch.ones(1))
opt_fn = torch.compile(fn, fullgraph=True, backend="aot_eager")
ref = fn(p)
res = opt_fn(p)
self.assertEqual(ref, res)
def test_nonstrict_trace_pre_existing_custom_class_with_side_effects(self):
class Point:
x: torch.Tensor
y: torch.Tensor
def __init__(self, x, y):
self.x = x
self.y = y
torch.utils._pytree.register_pytree_node(
Point,
lambda p: ((p.x, p.y), ()),
lambda xy, _: Point(xy[0], xy[1]),
)
@torch._dynamo.nonstrict_trace
def trace_me(p):
torch._dynamo.graph_break()
return p.x * p.y
def fn(p):
p.x = p.x + 1
p.y = p.y + 2
res = trace_me(p)
return res, p.x, p.y
p1 = Point(torch.ones(10), torch.ones(1))
p2 = Point(torch.ones(10), torch.ones(1))
opt_fn = torch.compile(fn, fullgraph=True, backend="aot_eager")
ref = fn(p1)
res = opt_fn(p2)
self.assertEqual(ref, res)
self.assertEqual(p1.x, p2.x)
self.assertEqual(p1.y, p2.y)
def test_nonstrict_trace_newly_constructed_custom_class_with_side_effects(self):
class Point:
x: torch.Tensor
y: torch.Tensor
def __init__(self, x, y):
self.x = x
self.y = y
torch.utils._pytree.register_pytree_node(
Point,
lambda p: ((p.x, p.y), ()),
lambda xy, _: Point(xy[0], xy[1]),
)
@torch._dynamo.nonstrict_trace
def trace_me(p):
torch._dynamo.graph_break()
return p.x * p.y
def fn(x, y):
p = Point(x, y)
p.x = p.x + 1
p.y = p.y + 2
res = trace_me(p)
return res, p.x, p.y
x, y = torch.ones(10), torch.ones(1)
opt_fn = torch.compile(fn, fullgraph=True, backend="aot_eager")
ref = fn(x, y)
res = opt_fn(x, y)
self.assertEqual(ref, res)
def test_nonstrict_trace_nested_custom_class(self):
class Point:
x: torch.Tensor
y: torch.Tensor
def __init__(self, x, y):
self.x = x
self.y = y
class PointTensor:
p: Point
t: torch.Tensor
def __init__(self, p, t):
self.p = p
self.t = t
torch.utils._pytree.register_pytree_node(
PointTensor,
lambda pt: ((pt.p, pt.t), ()),
lambda pt, _: PointTensor(pt[0], pt[1]),
)
torch.utils._pytree.register_pytree_node(
Point,
lambda p: ((p.x, p.y), ()),
lambda xy, _: Point(xy[0], xy[1]),
)
def trace_point(p):
torch._dynamo.graph_break()
return p.x * p.y
@torch._dynamo.nonstrict_trace
def trace_point_tensor(pt):
torch._dynamo.graph_break()
return pt.t + trace_point(pt.p)
def fn(x, y):
p = Point(x, y)
t = x + y
pt = PointTensor(p, t)
res = trace_point_tensor(pt)
return res
x, y = torch.ones(10), torch.ones(1)
opt_fn = torch.compile(fn, fullgraph=True, backend="aot_eager")
ref = fn(x, y)
res = opt_fn(x, y)
self.assertEqual(ref, res)
def test_nonstrict_trace_tuple_and_sym_int_output(self):
@torch._dynamo.nonstrict_trace
def trace_me(x):
torch._dynamo.graph_break()
return x + 1, x.size(0)
def fn(x):
t0, n = trace_me(x)
return t0 * n
x = torch.randn(10)
opt_fn = torch.compile(fn, dynamic=True, fullgraph=True, backend="aot_eager")
ref = fn(x)
res = opt_fn(x)
self.assertEqual(ref, res)
def test_nonstrict_trace_inside_compiled_function(self):
def trace_me(x):
torch._dynamo.graph_break()
return x + 42
def fn(x):
res = torch._dynamo.nonstrict_trace(trace_me)(x)
return res + 1
x = torch.randn(10)
opt_fn = torch.compile(fn, fullgraph=True, backend="aot_eager")
ref = fn(x)
res = opt_fn(x)
self.assertEqual(ref, res)
def test_nonstrict_trace_inside_compiled_function_kwarg(self):
def trace_me(x):
torch._dynamo.graph_break()
return x + 42
def fn(x):
res = torch._dynamo.nonstrict_trace(traceable_fn=trace_me)(x)
return res + 1
x = torch.randn(10)
opt_fn = torch.compile(fn, fullgraph=True, backend="aot_eager")
ref = fn(x)
res = opt_fn(x)
self.assertEqual(ref, res)
def test_nonstrict_trace_no_action_at_a_distance(self):
def trace_me(x):
torch._dynamo.graph_break()
return x + 42
# No effect on traceability of `trace_me`
torch._dynamo.nonstrict_trace(trace_me)
def fn(x):
res = trace_me(x)
return res + 1
x = torch.randn(10)
cnts = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
opt_fn = torch.compile(fn, backend=cnts)
ref = fn(x)
res = opt_fn(x)
self.assertEqual(ref, res)
# There should be 1 graph break
self.assertEqual(cnts.frame_count, 2)
def test_nonstrict_trace_inside_compiled_function_error(self):
@torch.compile(fullgraph=True, backend="aot_eager")
def fn(x, y):
def trace_me(x, y):
torch._dynamo.graph_break()
return x * y
res = torch._dynamo.nonstrict_trace(trace_me)(x, y)
return res + 1
try:
fn(torch.ones(10), torch.ones(1))
self.assertFalse(True) # must raise error before this
except torch._dynamo.exc.Unsupported as e:
msg = """
Applying `nonstrict_trace` to function <trace_me>; however, `nonstrict_trace` currently requires the function to be defined outside `torch.compile` region.
""" # NOQA: B950
self.assertIn(msg, str(e))
def test_nonstrict_trace_custom_class_error(self):
class Point:
x: torch.Tensor
y: torch.Tensor
def __init__(self, x, y):
self.x = x
self.y = y
@torch._dynamo.nonstrict_trace
def trace_me(p):
torch._dynamo.graph_break()
return p.x * p.y
@torch.compile(fullgraph=True, backend="aot_eager")
def fn(p):
res = trace_me(p)
return res + 1
try:
p = Point(torch.ones(10), torch.ones(1))
fn(p)
self.assertFalse(True) # must raise error before this
except torch._dynamo.exc.Unsupported as e:
msg = """
For `nonstrict_trace`-ed function, the only allowed input types are basic types (e.g., torch.Tensor, int, float) or pytree containers of those. Here you are calling the function with arguments that contain a value of type <DecoratorTests.test_nonstrict_trace_custom_class_error.<locals>.Point>, please use one of the following to register the type with pytree:
* `torch.utils._pytree.register_dataclass`
* `torch.utils._pytree.register_pytree_node`
""" # NOQA: B950
self.assertIn(msg, str(e))
def test_nonstrict_trace_nested_custom_class_error(self):
class Point:
x: torch.Tensor
y: torch.Tensor
def __init__(self, x, y):
self.x = x
self.y = y
class PointTensor:
p: Point
t: torch.Tensor
def __init__(self, p, t):
self.p = p
self.t = t
torch.utils._pytree.register_pytree_node(
PointTensor,
lambda pt: ((pt.p, pt.t), ()),
lambda pt, _: PointTensor(pt[0], pt[1]),
)
def trace_point(p):
torch._dynamo.graph_break()
return p.x * p.y
@torch._dynamo.nonstrict_trace
def trace_point_tensor(pt):
torch._dynamo.graph_break()
return pt.t + trace_point(pt.p)
@torch.compile(fullgraph=True, backend="aot_eager")
def fn(x, y):
p = Point(x, y)
t = x + y
pt = PointTensor(p, t)
res = trace_point_tensor(pt)
return res
try:
fn(torch.ones(10), torch.ones(1))
self.assertFalse(True) # must raise error before this
except torch._dynamo.exc.Unsupported as e:
msg = """
For `nonstrict_trace`-ed function, the only allowed input types are basic types (e.g., torch.Tensor, int, float) or pytree containers of those. Here you are calling the function with arguments that contain a value of type <DecoratorTests.test_nonstrict_trace_nested_custom_class_error.<locals>.Point>, please use one of the following to register the type with pytree:
* `torch.utils._pytree.register_dataclass`
* `torch.utils._pytree.register_pytree_node`
""" # NOQA: B950
self.assertIn(msg, str(e))
def test_nonstrict_trace_pytree_register_constant_error(self):
class Point:
x: int
y: int
def __init__(self, x, y):
self.x = x
self.y = y
torch.utils._pytree.register_constant(Point)
@torch._dynamo.nonstrict_trace
def trace_me(x, p):
torch._dynamo.graph_break()
return x * p.x + p.y
@torch.compile(fullgraph=True, backend="aot_eager")
def fn(x, p):
res = trace_me(x, p)
return res + 1
try:
p = Point(3, 4)
fn(torch.ones(10), p)
self.assertFalse(True) # must raise error before this
except torch._dynamo.exc.Unsupported as e:
msg = """
This error is most likely due to a call to `nonstrict_trace`-ed function, where one of the argument contains object of a type that has been (or needs to be) `torch.utils._pytree.register_constant`-ed. We currently don't support that.
""" # NOQA: B950
self.assertIn(msg, str(e))
def test_graph_break(self):
cnts = torch._dynamo.testing.CompileCounter()

View File

@ -4,6 +4,7 @@ from dataclasses import dataclass
import torch
import torch._dynamo.test_case
import torch.utils._pytree as pytree
from torch._dynamo.testing import EagerAndRecordGraphs, normalize_gm
from torch._higher_order_ops.flat_apply import (
flat_apply,
func_to_graphable,
@ -83,6 +84,72 @@ class FlatApplyTests(torch._dynamo.test_case.TestCase):
result = flat_apply(func_spec, in_spec, *flat_args)
self.assertEqual(result, f(*args, **kwargs))
def test_nonstrict_trace_dynamo_graph(self):
class Point:
x: torch.Tensor
y: torch.Tensor
def __init__(self, x, y):
self.x = x
self.y = y
class PointTensor:
p: Point
t: torch.Tensor
def __init__(self, p, t):
self.p = p
self.t = t
torch.utils._pytree.register_pytree_node(
PointTensor,
lambda pt: ((pt.p, pt.t), ()),
lambda pt, _: PointTensor(pt[0], pt[1]),
)
torch.utils._pytree.register_pytree_node(
Point,
lambda p: ((p.x, p.y), ()),
lambda xy, _: Point(xy[0], xy[1]),
)
def trace_point(p):
torch._dynamo.graph_break()
return p.x * p.y
@torch._dynamo.nonstrict_trace
def trace_point_tensor(pt):
torch._dynamo.graph_break()
return pt.t + trace_point(pt.p)
backend = EagerAndRecordGraphs()
@torch.compile(fullgraph=True, backend=backend)
def fn(x, y):
p = Point(x, y)
t = x + y
pt = PointTensor(p, t)
res = trace_point_tensor(pt)
return res
fn(torch.randn(10), torch.randn(10))
self.assertExpectedInline(
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[10]", L_y_: "f32[10]"):
l_x_ = L_x_
l_y_ = L_y_
t: "f32[10]" = l_x_ + l_y_
trace_point_tensor_spec : torch.utils._pytree.TreeSpec = self.trace_point_tensor_spec
trace_point_tensor_input_spec : torch.utils._pytree.TreeSpec = self.trace_point_tensor_input_spec
res: "f32[10]" = torch.ops.higher_order.flat_apply(trace_point_tensor_spec, trace_point_tensor_input_spec, l_x_, l_y_, t); trace_point_tensor_spec = trace_point_tensor_input_spec = l_x_ = l_y_ = t = None
return (res,)
""", # NOQA: B950
)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -26,6 +26,7 @@ from .decorators import (
mark_static,
mark_static_address,
maybe_mark_dynamic,
nonstrict_trace,
run,
set_stance,
substitute_in_graph,
@ -63,6 +64,7 @@ __all__ = [
"maybe_mark_dynamic",
"mark_static",
"mark_static_address",
"nonstrict_trace",
"optimize",
"optimize_assert",
"export",

View File

@ -160,6 +160,39 @@ def allow_in_graph(fn):
return fn
def nonstrict_trace(traceable_fn):
# Like `allow_in_graph`, but with the following enhancements/differences:
#
# 1. Supports user-defined class as inputs, as long as the class has been
# registered with pytree.
# 2. Reads to global/captured tensors forces the underlying graph to treat
# those tensors as constant, and we _assume_ they will not be updated. This
# is similar to FX tracing.
# 3. In the resulting Dynamo graph, the call to a `nonstrict_trace`-ed function
# will be represented as a call to `torch._higher_order_ops.flat_apply`,
# which takes in the `nonstrict_trace`-ed function and pytree-flattened
# inputs.
# 4. Only the returned function is traceable, and the original function will
# not be. Moreover, `nonstrict_trace` can be used inside a `torch.compile`
# region.
#
# NOTE: like `allow_in_graph`, aliasing information is neither preserved
# between inputs themselves, nor between inputs and outputs.
assert callable(traceable_fn), "nonstrict_trace expects a callable"
@functools.wraps(traceable_fn)
def wrapped(*args, **kwargs):
return traceable_fn(*args, **kwargs)
# This line allows us to reuse much of the `allow_in_graph` impl.
trace_rules._allowed_callable_ids.add(id(wrapped))
# This line allows us to diverge the impl from `allow_in_graph`.
trace_rules._nonstrict_trace_callable_ids.add(id(wrapped))
return wrapped
def _disallow_in_graph_helper(throw_if_not_allowed):
def inner(fn):
if isinstance(fn, (list, tuple)):
@ -176,6 +209,7 @@ def _disallow_in_graph_helper(throw_if_not_allowed):
"Allowed callables means callables that TorchDynamo puts as-is in the extracted graph."
)
trace_rules._allowed_callable_ids.remove(id(fn))
trace_rules._nonstrict_trace_callable_ids.remove(id(fn))
trace_rules._disallowed_callable_ids.add(id(fn))
return fn

View File

@ -115,6 +115,7 @@ from .utils import (
get_instruction_source_311,
get_locals_to_steal,
get_static_address_type,
get_unique_name_wrt,
graph_break_reasons,
increment_op_count,
lazy_format_graph_code,
@ -753,6 +754,17 @@ class OutputGraph:
return name
def register_static_attr_and_return_proxy(
self, attr_prefix: str, attr_value: Any
) -> fx.Proxy:
attr_name = get_unique_name_wrt(attr_prefix, self.nn_modules)
# TODO `nn_modules` has been historically overloaded to store a lot more
# than just nn module objects, fix that.
self.nn_modules[attr_name] = attr_value
proxy = self.create_proxy("get_attr", attr_name, (), {})
set_example_value(proxy.node, attr_value)
return proxy
def register_attr_or_module(
self,
target: Union[torch.nn.Module, torch.Tensor, Any],
@ -864,36 +876,30 @@ class OutputGraph:
return wrap_name(k)
name = OutputGraph.module_key_name(*names)
name = get_unique_name_wrt(name, self.nn_modules, self.global_scope)
self.nn_modules[name] = target
if isinstance(target, torch.nn.Module):
base = name
for i in itertools.count():
if name not in self.nn_modules and name not in self.global_scope:
self.nn_modules[name] = target
if isinstance(target, torch.nn.Module):
def register_leaf_name(leaf_name):
assert self.param_name_to_source is not None
new_source = ParamBufferSource(source, leaf_name)
new_name = f"{name}.{leaf_name}"
self.param_name_to_source[new_name] = new_source
if isinstance(source, LocalSource):
self.dynamo_flat_name_to_original_fqn[
OutputGraph.module_key_name(new_source.name())
] = leaf_name
def register_leaf_name(leaf_name):
assert self.param_name_to_source is not None
new_source = ParamBufferSource(source, leaf_name)
new_name = f"{name}.{leaf_name}"
self.param_name_to_source[new_name] = new_source
if isinstance(source, LocalSource):
self.dynamo_flat_name_to_original_fqn[
OutputGraph.module_key_name(new_source.name())
] = leaf_name
# annoying, but there are cases when we do not have parameters
# see test_nn_moduledict_contains
if hasattr(target, "_parameters"):
for leaf_name, _ in target.named_parameters():
register_leaf_name(leaf_name)
if hasattr(target, "_buffers"):
for leaf_name, _ in target.named_buffers():
register_leaf_name(leaf_name)
# annoying, but there are cases when we do not have parameters
# see test_nn_moduledict_contains
if hasattr(target, "_parameters"):
for leaf_name, _ in target.named_parameters():
register_leaf_name(leaf_name)
if hasattr(target, "_buffers"):
for leaf_name, _ in target.named_buffers():
register_leaf_name(leaf_name)
return wrap_name(name)
name = f"{base}_{i}"
raise AssertionError("unreachable")
return wrap_name(name)
def handle_aliases_for_stolen_lists(self, tx):
# If list inputs are stolen, but still needed after the function call, create aliases to keep them alive

View File

@ -122,7 +122,7 @@ If you are removing an existing torch level API:
"""
manual_torch_name_rule_map = {
manual_torch_name_rule_map: dict[str, Any] = {
"torch.onnx.is_in_onnx_export": TorchInGraphFunctionVariable,
"torch.onnx.operators.shape_as_tensor": TorchInGraphFunctionVariable,
"torch.overrides.is_tensor_like": TorchInGraphFunctionVariable,
@ -306,6 +306,7 @@ manual_torch_name_rule_map = {
"torch.jit._unwrap_optional": UserFunctionVariable,
"torch.backends.mha.get_fastpath_enabled": UserFunctionVariable,
"torch._dynamo.mark_static": UserFunctionVariable,
"torch._dynamo.nonstrict_trace": UserFunctionVariable,
"torch.fx.experimental.symbolic_shapes.guard_size_oblivious": TorchInGraphFunctionVariable,
"torch.cuda._get_device_properties": TorchInGraphFunctionVariable,
"torch.utils.hooks.BackwardHook": TorchInGraphFunctionVariable,
@ -2998,6 +2999,12 @@ def _disallowed_callable_ids() -> dict[int, str]:
return rv
@FunctionIdSet
def _nonstrict_trace_callable_ids() -> dict[int, str]:
rv: dict[int, str] = {}
return rv
@FunctionIdSet
def _builtin_function_ids() -> dict[int, str]:
# See also torch/_dynamo/polyfills/loader.py, which removes items in _builtin_function_ids
@ -3103,6 +3110,11 @@ def is_callable_allowed(obj) -> bool:
return id(obj) in _allowed_callable_ids
def is_nonstrict_trace_callable(obj) -> bool:
_maybe_init_lazy_module(obj)
return id(obj) in _nonstrict_trace_callable_ids
def is_callable_disallowed(obj) -> bool:
_maybe_init_lazy_module(obj)
return id(obj) in _disallowed_callable_ids

View File

@ -2583,6 +2583,27 @@ def get_safe_global_name(tx, root, obj):
return f"{root}_{id(obj)}_c{tx.output.compile_id}"
def get_unique_name_wrt(prefix: str, *containers) -> str:
"""
Return a name that starts with `prefix` and is not in any of the
`containers` (e.g., map, set).
"""
name = prefix
for i in itertools.count():
found = False
for container in containers:
if name in container:
found = True
break
if not found:
return name
# else update and retry
name = f"{prefix}_{i}"
raise AssertionError("unreachable")
def wrap_fake_exception(fn):
try:
return fn()

View File

@ -363,6 +363,27 @@ class UserFunctionVariable(BaseUserFunctionVariable):
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
# Handle a `nonstrict_trace(fn)` call
if self.fn is torch._dynamo.nonstrict_trace:
bound = inspect.signature(self.fn).bind(*args, **kwargs)
fn_var = bound.args[0]
if not isinstance(fn_var, BaseUserFunctionVariable):
typ = fn_var.python_type()
unimplemented(
f"`nonstrict_trace` expects a callable, but got value of type <{typ.__name__}>"
)
if not isinstance(fn_var, UserFunctionVariable):
fn_name = fn_var.get_name()
unimplemented(
f"""
Applying `nonstrict_trace` to function <{fn_name}>; however, `nonstrict_trace` currently requires the function to be defined outside `torch.compile` region.
""" # NOQA: B950
)
fn = fn_var.fn
return variables.TorchInGraphFunctionVariable(fn, nonstrict_traceable=True)
if self.is_constant:
return invoke_and_store_as_constant(
tx, self.fn, self.get_name(), args, kwargs

View File

@ -70,6 +70,7 @@ from .ctx_manager import (
ProfilerContextVariable,
TorchFunctionDisableVariable,
)
from .dicts import ConstDictVariable
from .distributed import DistributedVariable, ProcessGroupVariable
from .lists import ListVariable, TupleVariable
from .torch_function import (
@ -401,8 +402,16 @@ class TorchCtxManagerClassVariable(BaseTorchVariable):
class TorchInGraphFunctionVariable(BaseTorchVariable):
"""Points to a torch function/method that should be put in FX graph"""
def __init__(self, value, nonstrict_traceable=None, **kwargs) -> None:
super().__init__(value, **kwargs)
from ..trace_rules import is_nonstrict_trace_callable
if nonstrict_traceable is None:
nonstrict_traceable = is_nonstrict_trace_callable(value)
self.nonstrict_traceable = nonstrict_traceable
def __repr__(self) -> str:
return f"TorchInGraphFunctionVariable({self.value})"
return f"TorchInGraphFunctionVariable({self.value}, nonstrict_traceable={self.nonstrict_traceable})"
def get_function(self):
return self.value
@ -976,6 +985,92 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
from . import ConstantVariable, SymNodeVariable, TensorVariable
from .builder import wrap_fx_proxy
if self.nonstrict_traceable:
import torch._higher_order_ops.flat_apply as flat_apply
from torch._higher_order_ops.flat_apply import (
func_to_graphable,
is_graphable_type,
)
from torch.utils._pytree import tree_flatten
fn = self.value
# 1. Convert `args, kwargs` into pytree-flattened proxy forms.
#
# Rather than reconstructing `args, kwargs` into python objects and
# then tree_flatten them, we just let Dynamo symbolically interpret
# `tree_flatten((args, kwargs))`. This saves us from having to
# worry about the reconstruction logic, side effects, and guards.
packed_input_vt = TupleVariable.build(
tx, (TupleVariable.build(tx, args), ConstDictVariable.build(tx, kwargs))
)
out_vt = variables.UserFunctionVariable(tree_flatten).call_function(
tx, [packed_input_vt], {}
)
assert isinstance(out_vt, TupleVariable) and len(out_vt.items) == 2
flat_args_vts, input_spec_vt = out_vt.items
assert isinstance(flat_args_vts, ListVariable)
# Handle the case when the input contains a non-graphable type.
for flat_arg_vt in flat_args_vts.items:
arg_type = flat_arg_vt.python_type()
if not is_graphable_type(arg_type):
type_name = flat_arg_vt.python_type().__qualname__
unimplemented(
f"""
For `nonstrict_trace`-ed function, the only allowed input types are basic types (e.g., torch.Tensor, int, float) or pytree containers of those. Here you are calling the function with arguments that contain a value of type <{type_name}>, please use one of the following to register the type with pytree:
* `torch.utils._pytree.register_dataclass`
* `torch.utils._pytree.register_pytree_node`
""" # NOQA: B950
)
# Since we checked with `is_graphable` above, `as_proxy` on the
# flat_arg VT should always work.
proxified_flat_args = [
flat_arg_vt.as_proxy() for flat_arg_vt in flat_args_vts.items
]
# The downstream `flat_apply` call requires the input spec; however,
# the spec not a graphable type, so we still have to reconstruct it
# into a python object, and store it as a constant attribute on the
# fx graph.
#
# TODO handle `pytree._register_constant`-ed values.
try:
input_spec = input_spec_vt.as_python_constant()
except NotImplementedError:
unimplemented(
"""
This error is most likely due to a call to `nonstrict_trace`-ed function, where one of the argument contains object of a type that has been (or needs to be) `torch.utils._pytree.register_constant`-ed. We currently don't support that.
""" # NOQA: B950
)
# `flat_apply` wants a TreeSpec for the function input.
_, f_spec = func_to_graphable(fn)
# TreeSpec isn't graphable, so we register the function and input
# specs as attributes on the graph module.
f_spec_proxy = tx.output.register_static_attr_and_return_proxy(
f"{fn.__name__}_spec", f_spec
)
input_spec_proxy = tx.output.register_static_attr_and_return_proxy(
fn.__name__ + "_input_spec", input_spec
)
f_spec_proxy.node.type = type(f_spec)
input_spec_proxy.node.type = type(input_spec)
all_args = (f_spec_proxy, input_spec_proxy, *proxified_flat_args)
# 2. Create a proxy call to `flat_apply`, then fake-tensor propagate
# the call and wrap output into a VariableTracker.
proxy = tx.output.create_proxy("call_function", flat_apply, all_args, {})
out_vt = wrap_fx_proxy(tx, proxy)
# TODO support more output types
# Q: flat_apply will likely pytree_flatten the output for this, then
# how do we intercept the output before flatten, and wrap those?
# - Maybe we can have `flat_apply` return the output spec, so that
# Dynamo can unflatten and wrap the result.
return out_vt
if self.torch_function_override_enabled(tx, args, kwargs):
return dispatch_torch_function(tx, self, args, kwargs)
@ -1034,6 +1129,10 @@ For now, dynamo will explicitly graph break when it encounters user code with th
):
fn_ = getattr(torch, torch_sym_op)
# TODO for each of the following check on `out=` or `requires_grad=`
# variant torch ops, the original function could come from a user
# defined `@allow_in_graph` function as well, which doesn't have the
# same semantics as the torch ops.
fake_out_shape = None
if "out" in kwargs and isinstance(kwargs["out"], variables.TensorVariable):
# Calling fake tensor propagation can mutate the out= tensor in
@ -1053,6 +1152,7 @@ For now, dynamo will explicitly graph break when it encounters user code with th
),
)
# Handle e.g., `torch.ones(10, requires_grad=True)`
if (
isinstance(tensor_variable, TensorVariable)
and "requires_grad" in kwargs
@ -1063,6 +1163,7 @@ For now, dynamo will explicitly graph break when it encounters user code with th
Either create the tensor outside the compiled region, or do not set the tensor to require_grad"""
)
# Handle e.g., `torch.add(a, b, out=result)`
if "out" in kwargs and not (
isinstance(kwargs["out"], variables.ConstantVariable)
and kwargs["out"].as_python_constant() is None

View File

@ -1345,6 +1345,34 @@ class FrozenDataClassVariable(UserDefinedObjectVariable):
fields = {}
self.fields = fields
def as_python_constant(self):
# NOTE: this is an intentionally limited version of
# `as_python_constant` for `nonstrict_trace` implementation.
from dataclasses import fields
import torch.utils._pytree as pytree
if not istype(self.value, (pytree.TreeSpec, pytree.LeafSpec)):
# TODO loosen this restriction and fix `as_proxy`.
raise NotImplementedError(
"currently can't reconstruct arbitrary frozen dataclass instances"
)
args = []
kwargs = {}
for field in fields(self.value):
if field.init:
data = self.fields[field.name].as_python_constant()
if getattr(field, "kw_only", False):
kwargs[field.name] = data
else:
args.append(data)
# This is safe because we know the TreeSpec classes constructors don't
# have external side effects.
ctor = self.python_type()
return ctor(*args, **kwargs)
def as_proxy(self):
from dataclasses import fields
@ -1357,7 +1385,13 @@ class FrozenDataClassVariable(UserDefinedObjectVariable):
else:
args.append(proxy)
return self.python_type()(*args, **kwargs)
# TODO this isn't really safe, because
# 1. it could invoke a user defined `__post_init__`.
# 2. it could invoke a user defined `__init__` if the class _subclasses_
# a frozen dataclass.
# Either of the above could end up mutating external state.
ctor = self.python_type()
return ctor(*args, **kwargs)
# NB: This is called during __init__ for a frozen dataclass
# use this to accumulate the most up-to-date field values

View File

@ -13,6 +13,11 @@ def is_graphable(val) -> bool:
return isinstance(val, torch.fx.node.base_types)
def is_graphable_type(typ) -> bool:
"""Return whether the given type is graphable"""
return issubclass(typ, torch.fx.node.base_types)
def to_graphable(stuff):
"""Flattens stuff into a flat list of graphable types."""
# We can consider preserving things like List[int] to improve

View File

@ -968,11 +968,20 @@ class TreeSpec:
return unflatten_fn(child_pytrees, self.context)
# NOTE: subclassing a dataclass is subtle. In order to enable reasoning about
# this class with `dataclasses.fields`, etc., while having a simplified
# constructor that takes no argument, we wrap with `dataclass(init=True, ...)`
# again, with fields that have `init=False`.
@dataclasses.dataclass(init=True, frozen=True, eq=False, repr=False)
class LeafSpec(TreeSpec):
def __init__(self) -> None:
super().__init__(None, None, [])
type: Any = dataclasses.field(default=None, init=False)
context: Context = dataclasses.field(default=None, init=False)
children_specs: list["TreeSpec"] = dataclasses.field(
default_factory=list, init=False
)
def __post_init__(self) -> None:
# Override `__post_init__` for `num_leaves` derivation.
object.__setattr__(self, "num_nodes", 1)
object.__setattr__(self, "num_leaves", 1)
object.__setattr__(self, "num_children", 0)