mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo] Initial support for nonstrict_trace (#146367)
## Context > **Note:** `mark_traceable` got renamed to `nonstrict_trace` after > offline discussion. The reasons are (1) it aligns with `torch.export`'s > `nonstrict` notion, and (2) it's more definitive in behavior suggestion. 1. [Overall Design](https://docs.google.com/document/d/1O-dR2ZQaJQVt_v67AVcDCw2yJLtqgkZFwoXK0buEWRg/edit?tab=t.0) 2. [Dynamo graph representation with `torch._higher_order_ops.flat_apply`](https://docs.google.com/document/d/1YHl5nPTJvYeCPE5TO9uA18DPWNgUYGE4gCn6bFvXcBM/edit?tab=t.0#heading=h.xtw3hhbro4gn) ## Summary This patch adds a `torch._dynamo.nonstrict_trace` decorator, which currently is an enhanced version of `torch._dynamo.allow_in_graph` (see docstring for their differences). Specifically, this patch focuses on the UI and functionality prototyping/plumbing. The main enhancement is supporting more input types, and the implementation challenge lies in reconstructing the input objects from Dynamo `VariableTracker` (while accounting for buffered side-effects and guards). This patch takes a middle-ground (simple implementation with a bit of user labor), by 1. asking the user to provide pytree registration for non-proxy-able input types, 2. letting Dynamo trace through `pytree_flatten` (which accounts for buffered side-effects and guards automatically), 3. and passing in the TreeSpec as a graph attribute constant into `torch._higher_order_ops.flat_apply` (which unflattens the inputs and invokes the underlying function). ## Next Steps In subsequent patches, we will try to support the following: - annotating on class method - reads to global tensors - inputs that contains `pytree.register_constant`-ed instances. - function as input - more output types (e.g., any pytree-registered type) - `torch.nn.Module` as inputs Pull Request resolved: https://github.com/pytorch/pytorch/pull/146367 Approved by: https://github.com/zou3519 ghstack dependencies: #146714
This commit is contained in:
parent
bab84f0bd9
commit
f46f0e465c
|
|
@ -211,6 +211,439 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
|||
def fn1(x):
|
||||
return x.cos()
|
||||
|
||||
def test_nonstrict_trace_tensor_args(self):
|
||||
@torch._dynamo.nonstrict_trace
|
||||
def trace_me(x, y, z):
|
||||
torch._dynamo.graph_break()
|
||||
return x * y + z
|
||||
|
||||
def fn(x, y):
|
||||
t0 = x + 1
|
||||
t1 = trace_me(x, y, t0)
|
||||
t2 = t1 + y
|
||||
return t0 * t2
|
||||
|
||||
x, y = torch.randn(10), torch.randn(10)
|
||||
opt_fn = torch.compile(fn, fullgraph=True, backend="aot_eager")
|
||||
|
||||
ref = fn(x, y)
|
||||
res = opt_fn(x, y)
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
def test_nonstrict_trace_pre_existing_dict(self):
|
||||
@torch._dynamo.nonstrict_trace
|
||||
def trace_me(x, d):
|
||||
torch._dynamo.graph_break()
|
||||
return x * d["a"]
|
||||
|
||||
def fn(x, d):
|
||||
t0 = trace_me(x, d)
|
||||
return t0 + 1
|
||||
|
||||
x = torch.randn(10)
|
||||
d = {"a": 2}
|
||||
opt_fn = torch.compile(fn, fullgraph=True, backend="aot_eager")
|
||||
|
||||
ref = fn(x, d)
|
||||
res = opt_fn(x, d)
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
def test_nonstrict_trace_newly_constructed_dict_with_side_effects(self):
|
||||
@torch._dynamo.nonstrict_trace
|
||||
def trace_me(x, d):
|
||||
torch._dynamo.graph_break()
|
||||
return x * d["a"]
|
||||
|
||||
def fn(x):
|
||||
d = {}
|
||||
d["a"] = 2
|
||||
t0 = trace_me(x, d)
|
||||
return t0 + 1
|
||||
|
||||
x = torch.randn(10)
|
||||
opt_fn = torch.compile(fn, fullgraph=True, backend="aot_eager")
|
||||
|
||||
ref = fn(x)
|
||||
res = opt_fn(x)
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
def test_nonstrict_trace_pre_existing_dict_with_side_effects(self):
|
||||
@torch._dynamo.nonstrict_trace
|
||||
def trace_me(x, d):
|
||||
torch._dynamo.graph_break()
|
||||
return x * d["a"]
|
||||
|
||||
def fn(x, d):
|
||||
d["a"] = x + 1
|
||||
t0 = trace_me(x, d)
|
||||
return t0 + 2
|
||||
|
||||
x = torch.randn(10)
|
||||
d0 = {"a": 0}
|
||||
d1 = dict(d0)
|
||||
opt_fn = torch.compile(fn, fullgraph=True, backend="aot_eager")
|
||||
|
||||
ref = fn(x, d0)
|
||||
res = opt_fn(x, d1)
|
||||
self.assertEqual(ref, res)
|
||||
self.assertEqual(d0, d1)
|
||||
|
||||
def test_nonstrict_trace_pre_existing_custom_class(self):
|
||||
class Point:
|
||||
x: torch.Tensor
|
||||
y: torch.Tensor
|
||||
|
||||
def __init__(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
torch.utils._pytree.register_pytree_node(
|
||||
Point,
|
||||
lambda p: ((p.x, p.y), ()),
|
||||
lambda xy, _: Point(xy[0], xy[1]),
|
||||
)
|
||||
|
||||
@torch._dynamo.nonstrict_trace
|
||||
def trace_me(p):
|
||||
torch._dynamo.graph_break()
|
||||
return p.x * p.y
|
||||
|
||||
def fn(p):
|
||||
res = trace_me(p)
|
||||
return res, p.x, p.y
|
||||
|
||||
p = Point(torch.ones(10), torch.ones(1))
|
||||
opt_fn = torch.compile(fn, fullgraph=True, backend="aot_eager")
|
||||
|
||||
ref = fn(p)
|
||||
res = opt_fn(p)
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
def test_nonstrict_trace_pre_existing_custom_class_with_side_effects(self):
|
||||
class Point:
|
||||
x: torch.Tensor
|
||||
y: torch.Tensor
|
||||
|
||||
def __init__(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
torch.utils._pytree.register_pytree_node(
|
||||
Point,
|
||||
lambda p: ((p.x, p.y), ()),
|
||||
lambda xy, _: Point(xy[0], xy[1]),
|
||||
)
|
||||
|
||||
@torch._dynamo.nonstrict_trace
|
||||
def trace_me(p):
|
||||
torch._dynamo.graph_break()
|
||||
return p.x * p.y
|
||||
|
||||
def fn(p):
|
||||
p.x = p.x + 1
|
||||
p.y = p.y + 2
|
||||
res = trace_me(p)
|
||||
return res, p.x, p.y
|
||||
|
||||
p1 = Point(torch.ones(10), torch.ones(1))
|
||||
p2 = Point(torch.ones(10), torch.ones(1))
|
||||
opt_fn = torch.compile(fn, fullgraph=True, backend="aot_eager")
|
||||
|
||||
ref = fn(p1)
|
||||
res = opt_fn(p2)
|
||||
self.assertEqual(ref, res)
|
||||
self.assertEqual(p1.x, p2.x)
|
||||
self.assertEqual(p1.y, p2.y)
|
||||
|
||||
def test_nonstrict_trace_newly_constructed_custom_class_with_side_effects(self):
|
||||
class Point:
|
||||
x: torch.Tensor
|
||||
y: torch.Tensor
|
||||
|
||||
def __init__(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
torch.utils._pytree.register_pytree_node(
|
||||
Point,
|
||||
lambda p: ((p.x, p.y), ()),
|
||||
lambda xy, _: Point(xy[0], xy[1]),
|
||||
)
|
||||
|
||||
@torch._dynamo.nonstrict_trace
|
||||
def trace_me(p):
|
||||
torch._dynamo.graph_break()
|
||||
return p.x * p.y
|
||||
|
||||
def fn(x, y):
|
||||
p = Point(x, y)
|
||||
p.x = p.x + 1
|
||||
p.y = p.y + 2
|
||||
res = trace_me(p)
|
||||
return res, p.x, p.y
|
||||
|
||||
x, y = torch.ones(10), torch.ones(1)
|
||||
opt_fn = torch.compile(fn, fullgraph=True, backend="aot_eager")
|
||||
|
||||
ref = fn(x, y)
|
||||
res = opt_fn(x, y)
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
def test_nonstrict_trace_nested_custom_class(self):
|
||||
class Point:
|
||||
x: torch.Tensor
|
||||
y: torch.Tensor
|
||||
|
||||
def __init__(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
class PointTensor:
|
||||
p: Point
|
||||
t: torch.Tensor
|
||||
|
||||
def __init__(self, p, t):
|
||||
self.p = p
|
||||
self.t = t
|
||||
|
||||
torch.utils._pytree.register_pytree_node(
|
||||
PointTensor,
|
||||
lambda pt: ((pt.p, pt.t), ()),
|
||||
lambda pt, _: PointTensor(pt[0], pt[1]),
|
||||
)
|
||||
|
||||
torch.utils._pytree.register_pytree_node(
|
||||
Point,
|
||||
lambda p: ((p.x, p.y), ()),
|
||||
lambda xy, _: Point(xy[0], xy[1]),
|
||||
)
|
||||
|
||||
def trace_point(p):
|
||||
torch._dynamo.graph_break()
|
||||
return p.x * p.y
|
||||
|
||||
@torch._dynamo.nonstrict_trace
|
||||
def trace_point_tensor(pt):
|
||||
torch._dynamo.graph_break()
|
||||
return pt.t + trace_point(pt.p)
|
||||
|
||||
def fn(x, y):
|
||||
p = Point(x, y)
|
||||
t = x + y
|
||||
pt = PointTensor(p, t)
|
||||
res = trace_point_tensor(pt)
|
||||
return res
|
||||
|
||||
x, y = torch.ones(10), torch.ones(1)
|
||||
opt_fn = torch.compile(fn, fullgraph=True, backend="aot_eager")
|
||||
|
||||
ref = fn(x, y)
|
||||
res = opt_fn(x, y)
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
def test_nonstrict_trace_tuple_and_sym_int_output(self):
|
||||
@torch._dynamo.nonstrict_trace
|
||||
def trace_me(x):
|
||||
torch._dynamo.graph_break()
|
||||
return x + 1, x.size(0)
|
||||
|
||||
def fn(x):
|
||||
t0, n = trace_me(x)
|
||||
return t0 * n
|
||||
|
||||
x = torch.randn(10)
|
||||
opt_fn = torch.compile(fn, dynamic=True, fullgraph=True, backend="aot_eager")
|
||||
|
||||
ref = fn(x)
|
||||
res = opt_fn(x)
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
def test_nonstrict_trace_inside_compiled_function(self):
|
||||
def trace_me(x):
|
||||
torch._dynamo.graph_break()
|
||||
return x + 42
|
||||
|
||||
def fn(x):
|
||||
res = torch._dynamo.nonstrict_trace(trace_me)(x)
|
||||
return res + 1
|
||||
|
||||
x = torch.randn(10)
|
||||
opt_fn = torch.compile(fn, fullgraph=True, backend="aot_eager")
|
||||
|
||||
ref = fn(x)
|
||||
res = opt_fn(x)
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
def test_nonstrict_trace_inside_compiled_function_kwarg(self):
|
||||
def trace_me(x):
|
||||
torch._dynamo.graph_break()
|
||||
return x + 42
|
||||
|
||||
def fn(x):
|
||||
res = torch._dynamo.nonstrict_trace(traceable_fn=trace_me)(x)
|
||||
return res + 1
|
||||
|
||||
x = torch.randn(10)
|
||||
opt_fn = torch.compile(fn, fullgraph=True, backend="aot_eager")
|
||||
|
||||
ref = fn(x)
|
||||
res = opt_fn(x)
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
def test_nonstrict_trace_no_action_at_a_distance(self):
|
||||
def trace_me(x):
|
||||
torch._dynamo.graph_break()
|
||||
return x + 42
|
||||
|
||||
# No effect on traceability of `trace_me`
|
||||
torch._dynamo.nonstrict_trace(trace_me)
|
||||
|
||||
def fn(x):
|
||||
res = trace_me(x)
|
||||
return res + 1
|
||||
|
||||
x = torch.randn(10)
|
||||
cnts = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
|
||||
opt_fn = torch.compile(fn, backend=cnts)
|
||||
|
||||
ref = fn(x)
|
||||
res = opt_fn(x)
|
||||
self.assertEqual(ref, res)
|
||||
# There should be 1 graph break
|
||||
self.assertEqual(cnts.frame_count, 2)
|
||||
|
||||
def test_nonstrict_trace_inside_compiled_function_error(self):
|
||||
@torch.compile(fullgraph=True, backend="aot_eager")
|
||||
def fn(x, y):
|
||||
def trace_me(x, y):
|
||||
torch._dynamo.graph_break()
|
||||
return x * y
|
||||
|
||||
res = torch._dynamo.nonstrict_trace(trace_me)(x, y)
|
||||
return res + 1
|
||||
|
||||
try:
|
||||
fn(torch.ones(10), torch.ones(1))
|
||||
self.assertFalse(True) # must raise error before this
|
||||
except torch._dynamo.exc.Unsupported as e:
|
||||
msg = """
|
||||
Applying `nonstrict_trace` to function <trace_me>; however, `nonstrict_trace` currently requires the function to be defined outside `torch.compile` region.
|
||||
""" # NOQA: B950
|
||||
self.assertIn(msg, str(e))
|
||||
|
||||
def test_nonstrict_trace_custom_class_error(self):
|
||||
class Point:
|
||||
x: torch.Tensor
|
||||
y: torch.Tensor
|
||||
|
||||
def __init__(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
@torch._dynamo.nonstrict_trace
|
||||
def trace_me(p):
|
||||
torch._dynamo.graph_break()
|
||||
return p.x * p.y
|
||||
|
||||
@torch.compile(fullgraph=True, backend="aot_eager")
|
||||
def fn(p):
|
||||
res = trace_me(p)
|
||||
return res + 1
|
||||
|
||||
try:
|
||||
p = Point(torch.ones(10), torch.ones(1))
|
||||
fn(p)
|
||||
self.assertFalse(True) # must raise error before this
|
||||
except torch._dynamo.exc.Unsupported as e:
|
||||
msg = """
|
||||
For `nonstrict_trace`-ed function, the only allowed input types are basic types (e.g., torch.Tensor, int, float) or pytree containers of those. Here you are calling the function with arguments that contain a value of type <DecoratorTests.test_nonstrict_trace_custom_class_error.<locals>.Point>, please use one of the following to register the type with pytree:
|
||||
* `torch.utils._pytree.register_dataclass`
|
||||
* `torch.utils._pytree.register_pytree_node`
|
||||
""" # NOQA: B950
|
||||
self.assertIn(msg, str(e))
|
||||
|
||||
def test_nonstrict_trace_nested_custom_class_error(self):
|
||||
class Point:
|
||||
x: torch.Tensor
|
||||
y: torch.Tensor
|
||||
|
||||
def __init__(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
class PointTensor:
|
||||
p: Point
|
||||
t: torch.Tensor
|
||||
|
||||
def __init__(self, p, t):
|
||||
self.p = p
|
||||
self.t = t
|
||||
|
||||
torch.utils._pytree.register_pytree_node(
|
||||
PointTensor,
|
||||
lambda pt: ((pt.p, pt.t), ()),
|
||||
lambda pt, _: PointTensor(pt[0], pt[1]),
|
||||
)
|
||||
|
||||
def trace_point(p):
|
||||
torch._dynamo.graph_break()
|
||||
return p.x * p.y
|
||||
|
||||
@torch._dynamo.nonstrict_trace
|
||||
def trace_point_tensor(pt):
|
||||
torch._dynamo.graph_break()
|
||||
return pt.t + trace_point(pt.p)
|
||||
|
||||
@torch.compile(fullgraph=True, backend="aot_eager")
|
||||
def fn(x, y):
|
||||
p = Point(x, y)
|
||||
t = x + y
|
||||
pt = PointTensor(p, t)
|
||||
res = trace_point_tensor(pt)
|
||||
return res
|
||||
|
||||
try:
|
||||
fn(torch.ones(10), torch.ones(1))
|
||||
self.assertFalse(True) # must raise error before this
|
||||
except torch._dynamo.exc.Unsupported as e:
|
||||
msg = """
|
||||
For `nonstrict_trace`-ed function, the only allowed input types are basic types (e.g., torch.Tensor, int, float) or pytree containers of those. Here you are calling the function with arguments that contain a value of type <DecoratorTests.test_nonstrict_trace_nested_custom_class_error.<locals>.Point>, please use one of the following to register the type with pytree:
|
||||
* `torch.utils._pytree.register_dataclass`
|
||||
* `torch.utils._pytree.register_pytree_node`
|
||||
""" # NOQA: B950
|
||||
self.assertIn(msg, str(e))
|
||||
|
||||
def test_nonstrict_trace_pytree_register_constant_error(self):
|
||||
class Point:
|
||||
x: int
|
||||
y: int
|
||||
|
||||
def __init__(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
torch.utils._pytree.register_constant(Point)
|
||||
|
||||
@torch._dynamo.nonstrict_trace
|
||||
def trace_me(x, p):
|
||||
torch._dynamo.graph_break()
|
||||
return x * p.x + p.y
|
||||
|
||||
@torch.compile(fullgraph=True, backend="aot_eager")
|
||||
def fn(x, p):
|
||||
res = trace_me(x, p)
|
||||
return res + 1
|
||||
|
||||
try:
|
||||
p = Point(3, 4)
|
||||
fn(torch.ones(10), p)
|
||||
self.assertFalse(True) # must raise error before this
|
||||
except torch._dynamo.exc.Unsupported as e:
|
||||
msg = """
|
||||
This error is most likely due to a call to `nonstrict_trace`-ed function, where one of the argument contains object of a type that has been (or needs to be) `torch.utils._pytree.register_constant`-ed. We currently don't support that.
|
||||
""" # NOQA: B950
|
||||
self.assertIn(msg, str(e))
|
||||
|
||||
def test_graph_break(self):
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from dataclasses import dataclass
|
|||
import torch
|
||||
import torch._dynamo.test_case
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._dynamo.testing import EagerAndRecordGraphs, normalize_gm
|
||||
from torch._higher_order_ops.flat_apply import (
|
||||
flat_apply,
|
||||
func_to_graphable,
|
||||
|
|
@ -83,6 +84,72 @@ class FlatApplyTests(torch._dynamo.test_case.TestCase):
|
|||
result = flat_apply(func_spec, in_spec, *flat_args)
|
||||
self.assertEqual(result, f(*args, **kwargs))
|
||||
|
||||
def test_nonstrict_trace_dynamo_graph(self):
|
||||
class Point:
|
||||
x: torch.Tensor
|
||||
y: torch.Tensor
|
||||
|
||||
def __init__(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
class PointTensor:
|
||||
p: Point
|
||||
t: torch.Tensor
|
||||
|
||||
def __init__(self, p, t):
|
||||
self.p = p
|
||||
self.t = t
|
||||
|
||||
torch.utils._pytree.register_pytree_node(
|
||||
PointTensor,
|
||||
lambda pt: ((pt.p, pt.t), ()),
|
||||
lambda pt, _: PointTensor(pt[0], pt[1]),
|
||||
)
|
||||
|
||||
torch.utils._pytree.register_pytree_node(
|
||||
Point,
|
||||
lambda p: ((p.x, p.y), ()),
|
||||
lambda xy, _: Point(xy[0], xy[1]),
|
||||
)
|
||||
|
||||
def trace_point(p):
|
||||
torch._dynamo.graph_break()
|
||||
return p.x * p.y
|
||||
|
||||
@torch._dynamo.nonstrict_trace
|
||||
def trace_point_tensor(pt):
|
||||
torch._dynamo.graph_break()
|
||||
return pt.t + trace_point(pt.p)
|
||||
|
||||
backend = EagerAndRecordGraphs()
|
||||
|
||||
@torch.compile(fullgraph=True, backend=backend)
|
||||
def fn(x, y):
|
||||
p = Point(x, y)
|
||||
t = x + y
|
||||
pt = PointTensor(p, t)
|
||||
res = trace_point_tensor(pt)
|
||||
return res
|
||||
|
||||
fn(torch.randn(10), torch.randn(10))
|
||||
self.assertExpectedInline(
|
||||
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, L_x_: "f32[10]", L_y_: "f32[10]"):
|
||||
l_x_ = L_x_
|
||||
l_y_ = L_y_
|
||||
|
||||
t: "f32[10]" = l_x_ + l_y_
|
||||
|
||||
trace_point_tensor_spec : torch.utils._pytree.TreeSpec = self.trace_point_tensor_spec
|
||||
trace_point_tensor_input_spec : torch.utils._pytree.TreeSpec = self.trace_point_tensor_input_spec
|
||||
res: "f32[10]" = torch.ops.higher_order.flat_apply(trace_point_tensor_spec, trace_point_tensor_input_spec, l_x_, l_y_, t); trace_point_tensor_spec = trace_point_tensor_input_spec = l_x_ = l_y_ = t = None
|
||||
return (res,)
|
||||
""", # NOQA: B950
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ from .decorators import (
|
|||
mark_static,
|
||||
mark_static_address,
|
||||
maybe_mark_dynamic,
|
||||
nonstrict_trace,
|
||||
run,
|
||||
set_stance,
|
||||
substitute_in_graph,
|
||||
|
|
@ -63,6 +64,7 @@ __all__ = [
|
|||
"maybe_mark_dynamic",
|
||||
"mark_static",
|
||||
"mark_static_address",
|
||||
"nonstrict_trace",
|
||||
"optimize",
|
||||
"optimize_assert",
|
||||
"export",
|
||||
|
|
|
|||
|
|
@ -160,6 +160,39 @@ def allow_in_graph(fn):
|
|||
return fn
|
||||
|
||||
|
||||
def nonstrict_trace(traceable_fn):
|
||||
# Like `allow_in_graph`, but with the following enhancements/differences:
|
||||
#
|
||||
# 1. Supports user-defined class as inputs, as long as the class has been
|
||||
# registered with pytree.
|
||||
# 2. Reads to global/captured tensors forces the underlying graph to treat
|
||||
# those tensors as constant, and we _assume_ they will not be updated. This
|
||||
# is similar to FX tracing.
|
||||
# 3. In the resulting Dynamo graph, the call to a `nonstrict_trace`-ed function
|
||||
# will be represented as a call to `torch._higher_order_ops.flat_apply`,
|
||||
# which takes in the `nonstrict_trace`-ed function and pytree-flattened
|
||||
# inputs.
|
||||
# 4. Only the returned function is traceable, and the original function will
|
||||
# not be. Moreover, `nonstrict_trace` can be used inside a `torch.compile`
|
||||
# region.
|
||||
#
|
||||
# NOTE: like `allow_in_graph`, aliasing information is neither preserved
|
||||
# between inputs themselves, nor between inputs and outputs.
|
||||
assert callable(traceable_fn), "nonstrict_trace expects a callable"
|
||||
|
||||
@functools.wraps(traceable_fn)
|
||||
def wrapped(*args, **kwargs):
|
||||
return traceable_fn(*args, **kwargs)
|
||||
|
||||
# This line allows us to reuse much of the `allow_in_graph` impl.
|
||||
trace_rules._allowed_callable_ids.add(id(wrapped))
|
||||
|
||||
# This line allows us to diverge the impl from `allow_in_graph`.
|
||||
trace_rules._nonstrict_trace_callable_ids.add(id(wrapped))
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
def _disallow_in_graph_helper(throw_if_not_allowed):
|
||||
def inner(fn):
|
||||
if isinstance(fn, (list, tuple)):
|
||||
|
|
@ -176,6 +209,7 @@ def _disallow_in_graph_helper(throw_if_not_allowed):
|
|||
"Allowed callables means callables that TorchDynamo puts as-is in the extracted graph."
|
||||
)
|
||||
trace_rules._allowed_callable_ids.remove(id(fn))
|
||||
trace_rules._nonstrict_trace_callable_ids.remove(id(fn))
|
||||
trace_rules._disallowed_callable_ids.add(id(fn))
|
||||
return fn
|
||||
|
||||
|
|
|
|||
|
|
@ -115,6 +115,7 @@ from .utils import (
|
|||
get_instruction_source_311,
|
||||
get_locals_to_steal,
|
||||
get_static_address_type,
|
||||
get_unique_name_wrt,
|
||||
graph_break_reasons,
|
||||
increment_op_count,
|
||||
lazy_format_graph_code,
|
||||
|
|
@ -753,6 +754,17 @@ class OutputGraph:
|
|||
|
||||
return name
|
||||
|
||||
def register_static_attr_and_return_proxy(
|
||||
self, attr_prefix: str, attr_value: Any
|
||||
) -> fx.Proxy:
|
||||
attr_name = get_unique_name_wrt(attr_prefix, self.nn_modules)
|
||||
# TODO `nn_modules` has been historically overloaded to store a lot more
|
||||
# than just nn module objects, fix that.
|
||||
self.nn_modules[attr_name] = attr_value
|
||||
proxy = self.create_proxy("get_attr", attr_name, (), {})
|
||||
set_example_value(proxy.node, attr_value)
|
||||
return proxy
|
||||
|
||||
def register_attr_or_module(
|
||||
self,
|
||||
target: Union[torch.nn.Module, torch.Tensor, Any],
|
||||
|
|
@ -864,36 +876,30 @@ class OutputGraph:
|
|||
return wrap_name(k)
|
||||
|
||||
name = OutputGraph.module_key_name(*names)
|
||||
name = get_unique_name_wrt(name, self.nn_modules, self.global_scope)
|
||||
self.nn_modules[name] = target
|
||||
if isinstance(target, torch.nn.Module):
|
||||
|
||||
base = name
|
||||
for i in itertools.count():
|
||||
if name not in self.nn_modules and name not in self.global_scope:
|
||||
self.nn_modules[name] = target
|
||||
if isinstance(target, torch.nn.Module):
|
||||
def register_leaf_name(leaf_name):
|
||||
assert self.param_name_to_source is not None
|
||||
new_source = ParamBufferSource(source, leaf_name)
|
||||
new_name = f"{name}.{leaf_name}"
|
||||
self.param_name_to_source[new_name] = new_source
|
||||
if isinstance(source, LocalSource):
|
||||
self.dynamo_flat_name_to_original_fqn[
|
||||
OutputGraph.module_key_name(new_source.name())
|
||||
] = leaf_name
|
||||
|
||||
def register_leaf_name(leaf_name):
|
||||
assert self.param_name_to_source is not None
|
||||
new_source = ParamBufferSource(source, leaf_name)
|
||||
new_name = f"{name}.{leaf_name}"
|
||||
self.param_name_to_source[new_name] = new_source
|
||||
if isinstance(source, LocalSource):
|
||||
self.dynamo_flat_name_to_original_fqn[
|
||||
OutputGraph.module_key_name(new_source.name())
|
||||
] = leaf_name
|
||||
# annoying, but there are cases when we do not have parameters
|
||||
# see test_nn_moduledict_contains
|
||||
if hasattr(target, "_parameters"):
|
||||
for leaf_name, _ in target.named_parameters():
|
||||
register_leaf_name(leaf_name)
|
||||
if hasattr(target, "_buffers"):
|
||||
for leaf_name, _ in target.named_buffers():
|
||||
register_leaf_name(leaf_name)
|
||||
|
||||
# annoying, but there are cases when we do not have parameters
|
||||
# see test_nn_moduledict_contains
|
||||
if hasattr(target, "_parameters"):
|
||||
for leaf_name, _ in target.named_parameters():
|
||||
register_leaf_name(leaf_name)
|
||||
if hasattr(target, "_buffers"):
|
||||
for leaf_name, _ in target.named_buffers():
|
||||
register_leaf_name(leaf_name)
|
||||
|
||||
return wrap_name(name)
|
||||
name = f"{base}_{i}"
|
||||
|
||||
raise AssertionError("unreachable")
|
||||
return wrap_name(name)
|
||||
|
||||
def handle_aliases_for_stolen_lists(self, tx):
|
||||
# If list inputs are stolen, but still needed after the function call, create aliases to keep them alive
|
||||
|
|
|
|||
|
|
@ -122,7 +122,7 @@ If you are removing an existing torch level API:
|
|||
|
||||
|
||||
"""
|
||||
manual_torch_name_rule_map = {
|
||||
manual_torch_name_rule_map: dict[str, Any] = {
|
||||
"torch.onnx.is_in_onnx_export": TorchInGraphFunctionVariable,
|
||||
"torch.onnx.operators.shape_as_tensor": TorchInGraphFunctionVariable,
|
||||
"torch.overrides.is_tensor_like": TorchInGraphFunctionVariable,
|
||||
|
|
@ -306,6 +306,7 @@ manual_torch_name_rule_map = {
|
|||
"torch.jit._unwrap_optional": UserFunctionVariable,
|
||||
"torch.backends.mha.get_fastpath_enabled": UserFunctionVariable,
|
||||
"torch._dynamo.mark_static": UserFunctionVariable,
|
||||
"torch._dynamo.nonstrict_trace": UserFunctionVariable,
|
||||
"torch.fx.experimental.symbolic_shapes.guard_size_oblivious": TorchInGraphFunctionVariable,
|
||||
"torch.cuda._get_device_properties": TorchInGraphFunctionVariable,
|
||||
"torch.utils.hooks.BackwardHook": TorchInGraphFunctionVariable,
|
||||
|
|
@ -2998,6 +2999,12 @@ def _disallowed_callable_ids() -> dict[int, str]:
|
|||
return rv
|
||||
|
||||
|
||||
@FunctionIdSet
|
||||
def _nonstrict_trace_callable_ids() -> dict[int, str]:
|
||||
rv: dict[int, str] = {}
|
||||
return rv
|
||||
|
||||
|
||||
@FunctionIdSet
|
||||
def _builtin_function_ids() -> dict[int, str]:
|
||||
# See also torch/_dynamo/polyfills/loader.py, which removes items in _builtin_function_ids
|
||||
|
|
@ -3103,6 +3110,11 @@ def is_callable_allowed(obj) -> bool:
|
|||
return id(obj) in _allowed_callable_ids
|
||||
|
||||
|
||||
def is_nonstrict_trace_callable(obj) -> bool:
|
||||
_maybe_init_lazy_module(obj)
|
||||
return id(obj) in _nonstrict_trace_callable_ids
|
||||
|
||||
|
||||
def is_callable_disallowed(obj) -> bool:
|
||||
_maybe_init_lazy_module(obj)
|
||||
return id(obj) in _disallowed_callable_ids
|
||||
|
|
|
|||
|
|
@ -2583,6 +2583,27 @@ def get_safe_global_name(tx, root, obj):
|
|||
return f"{root}_{id(obj)}_c{tx.output.compile_id}"
|
||||
|
||||
|
||||
def get_unique_name_wrt(prefix: str, *containers) -> str:
|
||||
"""
|
||||
Return a name that starts with `prefix` and is not in any of the
|
||||
`containers` (e.g., map, set).
|
||||
"""
|
||||
name = prefix
|
||||
for i in itertools.count():
|
||||
found = False
|
||||
for container in containers:
|
||||
if name in container:
|
||||
found = True
|
||||
break
|
||||
|
||||
if not found:
|
||||
return name
|
||||
# else update and retry
|
||||
name = f"{prefix}_{i}"
|
||||
|
||||
raise AssertionError("unreachable")
|
||||
|
||||
|
||||
def wrap_fake_exception(fn):
|
||||
try:
|
||||
return fn()
|
||||
|
|
|
|||
|
|
@ -363,6 +363,27 @@ class UserFunctionVariable(BaseUserFunctionVariable):
|
|||
args: "list[VariableTracker]",
|
||||
kwargs: "dict[str, VariableTracker]",
|
||||
) -> "VariableTracker":
|
||||
# Handle a `nonstrict_trace(fn)` call
|
||||
if self.fn is torch._dynamo.nonstrict_trace:
|
||||
bound = inspect.signature(self.fn).bind(*args, **kwargs)
|
||||
fn_var = bound.args[0]
|
||||
if not isinstance(fn_var, BaseUserFunctionVariable):
|
||||
typ = fn_var.python_type()
|
||||
unimplemented(
|
||||
f"`nonstrict_trace` expects a callable, but got value of type <{typ.__name__}>"
|
||||
)
|
||||
|
||||
if not isinstance(fn_var, UserFunctionVariable):
|
||||
fn_name = fn_var.get_name()
|
||||
unimplemented(
|
||||
f"""
|
||||
Applying `nonstrict_trace` to function <{fn_name}>; however, `nonstrict_trace` currently requires the function to be defined outside `torch.compile` region.
|
||||
""" # NOQA: B950
|
||||
)
|
||||
|
||||
fn = fn_var.fn
|
||||
return variables.TorchInGraphFunctionVariable(fn, nonstrict_traceable=True)
|
||||
|
||||
if self.is_constant:
|
||||
return invoke_and_store_as_constant(
|
||||
tx, self.fn, self.get_name(), args, kwargs
|
||||
|
|
|
|||
|
|
@ -70,6 +70,7 @@ from .ctx_manager import (
|
|||
ProfilerContextVariable,
|
||||
TorchFunctionDisableVariable,
|
||||
)
|
||||
from .dicts import ConstDictVariable
|
||||
from .distributed import DistributedVariable, ProcessGroupVariable
|
||||
from .lists import ListVariable, TupleVariable
|
||||
from .torch_function import (
|
||||
|
|
@ -401,8 +402,16 @@ class TorchCtxManagerClassVariable(BaseTorchVariable):
|
|||
class TorchInGraphFunctionVariable(BaseTorchVariable):
|
||||
"""Points to a torch function/method that should be put in FX graph"""
|
||||
|
||||
def __init__(self, value, nonstrict_traceable=None, **kwargs) -> None:
|
||||
super().__init__(value, **kwargs)
|
||||
from ..trace_rules import is_nonstrict_trace_callable
|
||||
|
||||
if nonstrict_traceable is None:
|
||||
nonstrict_traceable = is_nonstrict_trace_callable(value)
|
||||
self.nonstrict_traceable = nonstrict_traceable
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"TorchInGraphFunctionVariable({self.value})"
|
||||
return f"TorchInGraphFunctionVariable({self.value}, nonstrict_traceable={self.nonstrict_traceable})"
|
||||
|
||||
def get_function(self):
|
||||
return self.value
|
||||
|
|
@ -976,6 +985,92 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
|||
from . import ConstantVariable, SymNodeVariable, TensorVariable
|
||||
from .builder import wrap_fx_proxy
|
||||
|
||||
if self.nonstrict_traceable:
|
||||
import torch._higher_order_ops.flat_apply as flat_apply
|
||||
from torch._higher_order_ops.flat_apply import (
|
||||
func_to_graphable,
|
||||
is_graphable_type,
|
||||
)
|
||||
from torch.utils._pytree import tree_flatten
|
||||
|
||||
fn = self.value
|
||||
# 1. Convert `args, kwargs` into pytree-flattened proxy forms.
|
||||
#
|
||||
# Rather than reconstructing `args, kwargs` into python objects and
|
||||
# then tree_flatten them, we just let Dynamo symbolically interpret
|
||||
# `tree_flatten((args, kwargs))`. This saves us from having to
|
||||
# worry about the reconstruction logic, side effects, and guards.
|
||||
packed_input_vt = TupleVariable.build(
|
||||
tx, (TupleVariable.build(tx, args), ConstDictVariable.build(tx, kwargs))
|
||||
)
|
||||
out_vt = variables.UserFunctionVariable(tree_flatten).call_function(
|
||||
tx, [packed_input_vt], {}
|
||||
)
|
||||
assert isinstance(out_vt, TupleVariable) and len(out_vt.items) == 2
|
||||
flat_args_vts, input_spec_vt = out_vt.items
|
||||
assert isinstance(flat_args_vts, ListVariable)
|
||||
|
||||
# Handle the case when the input contains a non-graphable type.
|
||||
for flat_arg_vt in flat_args_vts.items:
|
||||
arg_type = flat_arg_vt.python_type()
|
||||
if not is_graphable_type(arg_type):
|
||||
type_name = flat_arg_vt.python_type().__qualname__
|
||||
unimplemented(
|
||||
f"""
|
||||
For `nonstrict_trace`-ed function, the only allowed input types are basic types (e.g., torch.Tensor, int, float) or pytree containers of those. Here you are calling the function with arguments that contain a value of type <{type_name}>, please use one of the following to register the type with pytree:
|
||||
* `torch.utils._pytree.register_dataclass`
|
||||
* `torch.utils._pytree.register_pytree_node`
|
||||
""" # NOQA: B950
|
||||
)
|
||||
|
||||
# Since we checked with `is_graphable` above, `as_proxy` on the
|
||||
# flat_arg VT should always work.
|
||||
proxified_flat_args = [
|
||||
flat_arg_vt.as_proxy() for flat_arg_vt in flat_args_vts.items
|
||||
]
|
||||
|
||||
# The downstream `flat_apply` call requires the input spec; however,
|
||||
# the spec not a graphable type, so we still have to reconstruct it
|
||||
# into a python object, and store it as a constant attribute on the
|
||||
# fx graph.
|
||||
#
|
||||
# TODO handle `pytree._register_constant`-ed values.
|
||||
try:
|
||||
input_spec = input_spec_vt.as_python_constant()
|
||||
except NotImplementedError:
|
||||
unimplemented(
|
||||
"""
|
||||
This error is most likely due to a call to `nonstrict_trace`-ed function, where one of the argument contains object of a type that has been (or needs to be) `torch.utils._pytree.register_constant`-ed. We currently don't support that.
|
||||
""" # NOQA: B950
|
||||
)
|
||||
|
||||
# `flat_apply` wants a TreeSpec for the function input.
|
||||
_, f_spec = func_to_graphable(fn)
|
||||
|
||||
# TreeSpec isn't graphable, so we register the function and input
|
||||
# specs as attributes on the graph module.
|
||||
f_spec_proxy = tx.output.register_static_attr_and_return_proxy(
|
||||
f"{fn.__name__}_spec", f_spec
|
||||
)
|
||||
input_spec_proxy = tx.output.register_static_attr_and_return_proxy(
|
||||
fn.__name__ + "_input_spec", input_spec
|
||||
)
|
||||
f_spec_proxy.node.type = type(f_spec)
|
||||
input_spec_proxy.node.type = type(input_spec)
|
||||
all_args = (f_spec_proxy, input_spec_proxy, *proxified_flat_args)
|
||||
|
||||
# 2. Create a proxy call to `flat_apply`, then fake-tensor propagate
|
||||
# the call and wrap output into a VariableTracker.
|
||||
proxy = tx.output.create_proxy("call_function", flat_apply, all_args, {})
|
||||
out_vt = wrap_fx_proxy(tx, proxy)
|
||||
# TODO support more output types
|
||||
# Q: flat_apply will likely pytree_flatten the output for this, then
|
||||
# how do we intercept the output before flatten, and wrap those?
|
||||
# - Maybe we can have `flat_apply` return the output spec, so that
|
||||
# Dynamo can unflatten and wrap the result.
|
||||
|
||||
return out_vt
|
||||
|
||||
if self.torch_function_override_enabled(tx, args, kwargs):
|
||||
return dispatch_torch_function(tx, self, args, kwargs)
|
||||
|
||||
|
|
@ -1034,6 +1129,10 @@ For now, dynamo will explicitly graph break when it encounters user code with th
|
|||
):
|
||||
fn_ = getattr(torch, torch_sym_op)
|
||||
|
||||
# TODO for each of the following check on `out=` or `requires_grad=`
|
||||
# variant torch ops, the original function could come from a user
|
||||
# defined `@allow_in_graph` function as well, which doesn't have the
|
||||
# same semantics as the torch ops.
|
||||
fake_out_shape = None
|
||||
if "out" in kwargs and isinstance(kwargs["out"], variables.TensorVariable):
|
||||
# Calling fake tensor propagation can mutate the out= tensor in
|
||||
|
|
@ -1053,6 +1152,7 @@ For now, dynamo will explicitly graph break when it encounters user code with th
|
|||
),
|
||||
)
|
||||
|
||||
# Handle e.g., `torch.ones(10, requires_grad=True)`
|
||||
if (
|
||||
isinstance(tensor_variable, TensorVariable)
|
||||
and "requires_grad" in kwargs
|
||||
|
|
@ -1063,6 +1163,7 @@ For now, dynamo will explicitly graph break when it encounters user code with th
|
|||
Either create the tensor outside the compiled region, or do not set the tensor to require_grad"""
|
||||
)
|
||||
|
||||
# Handle e.g., `torch.add(a, b, out=result)`
|
||||
if "out" in kwargs and not (
|
||||
isinstance(kwargs["out"], variables.ConstantVariable)
|
||||
and kwargs["out"].as_python_constant() is None
|
||||
|
|
|
|||
|
|
@ -1345,6 +1345,34 @@ class FrozenDataClassVariable(UserDefinedObjectVariable):
|
|||
fields = {}
|
||||
self.fields = fields
|
||||
|
||||
def as_python_constant(self):
|
||||
# NOTE: this is an intentionally limited version of
|
||||
# `as_python_constant` for `nonstrict_trace` implementation.
|
||||
from dataclasses import fields
|
||||
|
||||
import torch.utils._pytree as pytree
|
||||
|
||||
if not istype(self.value, (pytree.TreeSpec, pytree.LeafSpec)):
|
||||
# TODO loosen this restriction and fix `as_proxy`.
|
||||
raise NotImplementedError(
|
||||
"currently can't reconstruct arbitrary frozen dataclass instances"
|
||||
)
|
||||
|
||||
args = []
|
||||
kwargs = {}
|
||||
for field in fields(self.value):
|
||||
if field.init:
|
||||
data = self.fields[field.name].as_python_constant()
|
||||
if getattr(field, "kw_only", False):
|
||||
kwargs[field.name] = data
|
||||
else:
|
||||
args.append(data)
|
||||
|
||||
# This is safe because we know the TreeSpec classes constructors don't
|
||||
# have external side effects.
|
||||
ctor = self.python_type()
|
||||
return ctor(*args, **kwargs)
|
||||
|
||||
def as_proxy(self):
|
||||
from dataclasses import fields
|
||||
|
||||
|
|
@ -1357,7 +1385,13 @@ class FrozenDataClassVariable(UserDefinedObjectVariable):
|
|||
else:
|
||||
args.append(proxy)
|
||||
|
||||
return self.python_type()(*args, **kwargs)
|
||||
# TODO this isn't really safe, because
|
||||
# 1. it could invoke a user defined `__post_init__`.
|
||||
# 2. it could invoke a user defined `__init__` if the class _subclasses_
|
||||
# a frozen dataclass.
|
||||
# Either of the above could end up mutating external state.
|
||||
ctor = self.python_type()
|
||||
return ctor(*args, **kwargs)
|
||||
|
||||
# NB: This is called during __init__ for a frozen dataclass
|
||||
# use this to accumulate the most up-to-date field values
|
||||
|
|
|
|||
|
|
@ -13,6 +13,11 @@ def is_graphable(val) -> bool:
|
|||
return isinstance(val, torch.fx.node.base_types)
|
||||
|
||||
|
||||
def is_graphable_type(typ) -> bool:
|
||||
"""Return whether the given type is graphable"""
|
||||
return issubclass(typ, torch.fx.node.base_types)
|
||||
|
||||
|
||||
def to_graphable(stuff):
|
||||
"""Flattens stuff into a flat list of graphable types."""
|
||||
# We can consider preserving things like List[int] to improve
|
||||
|
|
|
|||
|
|
@ -968,11 +968,20 @@ class TreeSpec:
|
|||
return unflatten_fn(child_pytrees, self.context)
|
||||
|
||||
|
||||
# NOTE: subclassing a dataclass is subtle. In order to enable reasoning about
|
||||
# this class with `dataclasses.fields`, etc., while having a simplified
|
||||
# constructor that takes no argument, we wrap with `dataclass(init=True, ...)`
|
||||
# again, with fields that have `init=False`.
|
||||
@dataclasses.dataclass(init=True, frozen=True, eq=False, repr=False)
|
||||
class LeafSpec(TreeSpec):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(None, None, [])
|
||||
type: Any = dataclasses.field(default=None, init=False)
|
||||
context: Context = dataclasses.field(default=None, init=False)
|
||||
children_specs: list["TreeSpec"] = dataclasses.field(
|
||||
default_factory=list, init=False
|
||||
)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Override `__post_init__` for `num_leaves` derivation.
|
||||
object.__setattr__(self, "num_nodes", 1)
|
||||
object.__setattr__(self, "num_leaves", 1)
|
||||
object.__setattr__(self, "num_children", 0)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user