mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Added two new utils to help with turning python functionalization on in AOTAutograd (next PR): (1) updated `torch._sync()`. Previously, this API could only handle `torch.Tensor` instances that had a `FunctionalTensorWrapper` TensorImpl. It now needs to handle python `FunctionalTensor`'s. In theory I can probably break BC and change this API (since it's private?), but I decided not to do it in this PR stack do minimize the chance of reverts. Instead of updating that API directly (which is in C++), I just added a python shim that first tries to unwrap the python `FunctionalTensor` if there is one, then calls the existing C++ logic (2) `mirror_autograd_meta` is now a standalone API that tries to mirror the `requires_grad` and `is_leaf` autograd metadata from one tensor to another. Previously this was hardcoded into `torch._to_functional_tensor()`. But I now need to use it in a more standalone way: later in AOTAutograd when we unwrap and re-wrap a tensor subclasses, we need to manually mirror the autograd metadata from the original to the updated version of the subclass. Pull Request resolved: https://github.com/pytorch/pytorch/pull/107917 Approved by: https://github.com/ezyang ghstack dependencies: #106404
540 lines
19 KiB
Python
540 lines
19 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
import contextlib
|
|
import functools
|
|
|
|
import torch
|
|
|
|
import torch._dynamo.test_case
|
|
import torch._dynamo.testing
|
|
import torch._functorch.config
|
|
import torch.utils._pytree as pytree
|
|
import torch.utils.checkpoint
|
|
from torch._dynamo.testing import normalize_gm
|
|
from torch._higher_order_ops.wrap import wrap
|
|
|
|
from torch.fx.experimental.symbolic_shapes import DimDynamic, ShapeEnv
|
|
|
|
|
|
class MockSubclass(torch.Tensor):
|
|
@classmethod
|
|
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
return func(*args, **kwargs)
|
|
|
|
|
|
class EagerRecordGraphAndInputs:
|
|
def __init__(self):
|
|
self.graphs = []
|
|
self.example_inputs = []
|
|
|
|
def __call__(self, gm: torch.fx.GraphModule, example_inputs):
|
|
self.graphs.append(gm)
|
|
self.example_inputs.append(example_inputs)
|
|
return gm
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def preserve_subclass_config():
|
|
old_subclass_set = set(torch._dynamo.config.traceable_tensor_subclasses)
|
|
try:
|
|
torch._dynamo.config.traceable_tensor_subclasses.add(MockSubclass)
|
|
yield
|
|
finally:
|
|
torch._dynamo.config.traceable_tensor_subclasses.clear()
|
|
torch._dynamo.config.traceable_tensor_subclasses.update(old_subclass_set)
|
|
|
|
|
|
class SubclassTests(torch._dynamo.test_case.TestCase):
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
super().setUpClass()
|
|
cls._exit_stack.enter_context(preserve_subclass_config())
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
cls._exit_stack.close()
|
|
|
|
def test_torch_function_state_graph_break(self):
|
|
@torch.compile(backend="eager")
|
|
def fn(x):
|
|
with torch._C.DisableTorchFunctionSubclass():
|
|
torch._dynamo.graph_break()
|
|
return torch._C._is_torch_function_enabled(), torch.add(x, 1.0)
|
|
|
|
input = torch.ones(2, 2)
|
|
res, _ = fn(input)
|
|
self.assertFalse(res)
|
|
|
|
def test_torch_function_state_tracing(self):
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(x):
|
|
with torch._C.DisableTorchFunctionSubclass():
|
|
torch.add(x, 1.0)
|
|
|
|
input = torch.ones(2, 2)
|
|
|
|
res = fn(input)
|
|
|
|
def test_torch_function_state_guards(self):
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
|
|
@torch.compile(backend=cnt, fullgraph=True)
|
|
def fn(x):
|
|
torch.add(x, 1.0)
|
|
|
|
input = torch.ones(2, 2)
|
|
|
|
with torch._C.DisableTorchFunctionSubclass():
|
|
res = fn(input)
|
|
|
|
res = fn(input)
|
|
|
|
self.assertEqual(cnt.frame_count, 2)
|
|
|
|
def test_return_subclass(self):
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(x):
|
|
return MockSubclass(torch.add(x, 1.0))
|
|
|
|
input = torch.ones(2, 2)
|
|
|
|
res = fn(input)
|
|
self.assertIsInstance(res, MockSubclass)
|
|
|
|
def test_return_local_subclass(self):
|
|
class LocalSubclass(torch.Tensor):
|
|
@classmethod
|
|
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
return func(*args, **kwargs)
|
|
|
|
torch._dynamo.config.traceable_tensor_subclasses.add(LocalSubclass)
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(x):
|
|
return LocalSubclass(torch.add(x, 1.0))
|
|
|
|
input = torch.ones(2, 2)
|
|
|
|
res = fn(input)
|
|
self.assertIsInstance(res, LocalSubclass)
|
|
|
|
def test_compile_with_fake_tensor_dynamic_dim(self):
|
|
x = torch.randn([3, 4])
|
|
|
|
def f(x):
|
|
return torch.sin(x)
|
|
|
|
def test_dynamic_dim(f, x, dim_dynamic, exp_frame_count, exp_op_count):
|
|
torch._dynamo.reset()
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
|
|
opt_f = torch.compile(f, backend=cnt, fullgraph=True)
|
|
|
|
x1 = torch.rand_like(x)
|
|
f(x)
|
|
f(torch.randn([4, 3]))
|
|
shape_env = ShapeEnv()
|
|
with torch._subclasses.fake_tensor.FakeTensorMode(
|
|
shape_env=shape_env
|
|
) as fake_mode:
|
|
x_fake = fake_mode.from_tensor(
|
|
x, dynamic_dims=[dim_dynamic for i in range(x.dim())]
|
|
)
|
|
x1_fake = fake_mode.from_tensor(
|
|
x1, dynamic_dims=[dim_dynamic for i in range(x.dim())]
|
|
)
|
|
opt_f(x_fake)
|
|
opt_f(x1_fake)
|
|
|
|
self.assertEqual(cnt.frame_count, exp_frame_count)
|
|
self.assertEqual(cnt.op_count, exp_op_count)
|
|
|
|
test_dynamic_dim(f, x, DimDynamic.DYNAMIC, 1, 1)
|
|
test_dynamic_dim(f, x, DimDynamic.DUCK, 1, 1)
|
|
test_dynamic_dim(f, x, DimDynamic.STATIC, 1, 1)
|
|
|
|
def test_compile_with_fake_tensor_automatic_dynamic(self):
|
|
def f(x):
|
|
return torch.sin(x)
|
|
|
|
def test_automatic_dynamic(f, inps, dim_dynamic, exp_frame_count, exp_op_count):
|
|
torch._dynamo.reset()
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_f = torch.compile(f, backend=cnt, fullgraph=True)
|
|
|
|
shape_env = ShapeEnv()
|
|
with torch._subclasses.fake_tensor.FakeTensorMode(
|
|
shape_env=shape_env
|
|
) as fake_mode:
|
|
for inp in inps:
|
|
fake_inp = fake_mode.from_tensor(
|
|
inp, dynamic_dims=[dim_dynamic for i in range(x.dim())]
|
|
)
|
|
opt_f(fake_inp)
|
|
self.assertEqual(cnt.frame_count, exp_frame_count)
|
|
self.assertEqual(cnt.op_count, exp_op_count)
|
|
|
|
x = torch.randn([3, 4])
|
|
y = torch.randn([4, 5])
|
|
z = torch.randn([5, 6])
|
|
a = torch.randn([3, 5])
|
|
b = torch.randn([4, 4])
|
|
for dim_dynamic in [DimDynamic.DYNAMIC, DimDynamic.DUCK, DimDynamic.STATIC]:
|
|
# Recompile once, first with dim 0 and 1 become Dynamic
|
|
test_automatic_dynamic(f, [x, y, z], dim_dynamic, 2, 2)
|
|
# Recompile 2 times, first with dim 1 become Dynamic, second with dim 0 becomes Dynamic.
|
|
test_automatic_dynamic(f, [x, a, z], dim_dynamic, 3, 3)
|
|
# Recompile 2 times, first with dim 0 become Dynamic, second with dim 1 becomes Dynamic.
|
|
test_automatic_dynamic(f, [x, b, z], dim_dynamic, 3, 3)
|
|
|
|
def test_compile_with_functionalization(self):
|
|
x = torch.randn([3, 4])
|
|
x_clone = x.clone()
|
|
x_clone2 = x.clone()
|
|
backend = EagerRecordGraphAndInputs()
|
|
cnt = torch._dynamo.testing.CompileCounterWithBackend(backend)
|
|
|
|
@torch.compile(backend=cnt, fullgraph=True)
|
|
def f(x):
|
|
return x.add_(1.0) + torch.nn.functional.relu_(x)
|
|
|
|
f_out = f(x)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertEqual(cnt.op_count, 3)
|
|
self.assertEqual(len(backend.graphs), 1)
|
|
self.assertEqual(len(backend.example_inputs), 1)
|
|
|
|
expected = """\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_ : torch.Tensor):
|
|
l_x_ = L_x_
|
|
|
|
add_ = l_x_.add_(1.0)
|
|
relu_ = torch.relu_(l_x_); l_x_ = None
|
|
add = add_ + relu_; add_ = relu_ = None
|
|
return (add,)
|
|
"""
|
|
actual = normalize_gm(backend.graphs[0].print_readable(print_output=False))
|
|
self.assertEqual(actual, expected)
|
|
|
|
ff = torch.func.functionalize(f)
|
|
ff_out = ff(x_clone)
|
|
|
|
self.assertEqual(cnt.frame_count, 2)
|
|
self.assertEqual(cnt.op_count, 6)
|
|
self.assertEqual(len(backend.graphs), 2)
|
|
self.assertEqual(len(backend.example_inputs), 2)
|
|
actual = normalize_gm(backend.graphs[1].print_readable(print_output=False))
|
|
self.assertEqual(actual, expected)
|
|
self.assertTrue(torch._is_functional_tensor(backend.example_inputs[1][0]))
|
|
|
|
# Cannot re-use the version from AOTAutograd, since that uses python functional tensors.
|
|
def to_fun(x):
|
|
x_functional = torch._to_functional_tensor(x)
|
|
torch._mirror_autograd_meta_to(x, x_functional)
|
|
return x_functional
|
|
|
|
def aot_f_wrapper(func):
|
|
@functools.wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
torch._enable_functionalization(reapply_views=False)
|
|
try:
|
|
func_args = pytree.tree_map(to_fun, args)
|
|
func_kwargs = pytree.tree_map(to_fun, kwargs)
|
|
return func(*func_args, **func_kwargs)
|
|
finally:
|
|
torch._disable_functionalization()
|
|
|
|
return wrapper
|
|
|
|
aot_ff = aot_f_wrapper(f)
|
|
aot_ff_out = aot_ff(x_clone2)
|
|
|
|
self.assertEqual(cnt.frame_count, 3)
|
|
self.assertEqual(cnt.op_count, 9)
|
|
self.assertEqual(len(backend.graphs), 3)
|
|
self.assertEqual(len(backend.example_inputs), 3)
|
|
actual = normalize_gm(backend.graphs[2].print_readable(print_output=False))
|
|
self.assertEqual(actual, expected)
|
|
self.assertTrue(torch._is_functional_tensor(backend.example_inputs[1][0]))
|
|
|
|
self.assertEqual(f_out, ff_out)
|
|
self.assertEqual(f_out, aot_ff_out)
|
|
|
|
try:
|
|
torch._enable_functionalization(reapply_views=False)
|
|
xf = pytree.tree_map(to_fun, x)
|
|
x_view = xf.t()
|
|
with self.assertRaisesRegex(RuntimeError, "Cannot safely fakify a view"):
|
|
f(x_view)
|
|
finally:
|
|
torch._disable_functionalization()
|
|
|
|
def test_compile_higher_order_with_functionalization(self):
|
|
backend = EagerRecordGraphAndInputs()
|
|
cnt = torch._dynamo.testing.CompileCounterWithBackend(backend)
|
|
|
|
@torch.compile(backend=cnt, fullgraph=True)
|
|
def f(x):
|
|
return wrap(lambda x: x.add_(1.0), x)
|
|
|
|
def check_count_and_graph(
|
|
exp_frame_count, exp_op_count, exp_n_graph, exp_graph
|
|
):
|
|
self.assertEqual(cnt.frame_count, exp_frame_count)
|
|
self.assertEqual(cnt.op_count, exp_op_count)
|
|
self.assertEqual(len(backend.graphs), exp_n_graph)
|
|
actual = normalize_gm(
|
|
backend.graphs[exp_n_graph - 1].print_readable(print_output=False)
|
|
)
|
|
self.assertExpectedInline(actual, exp_graph)
|
|
|
|
t = torch.randn([3, 4])
|
|
t_clone = t.clone()
|
|
t_clone2 = t.clone()
|
|
f(t)
|
|
|
|
expected_graph = """\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_ : torch.Tensor):
|
|
l_x_ = L_x_
|
|
|
|
wrap_body_0 = self.wrap_body_0
|
|
wrap = torch._higher_order_ops.wrap.wrap(wrap_body_0, l_x_); wrap_body_0 = l_x_ = None
|
|
return (wrap,)
|
|
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, l_x_):
|
|
add_ = l_x_.add_(1.0); l_x_ = None
|
|
return add_
|
|
"""
|
|
check_count_and_graph(1, 1, 1, expected_graph)
|
|
|
|
ff = torch.func.functionalize(f)
|
|
ff_out = ff(t_clone)
|
|
# frame count and op count are incremented due to re-compilation
|
|
check_count_and_graph(2, 2, 2, expected_graph)
|
|
|
|
try:
|
|
x = torch._to_functional_tensor(t_clone2)
|
|
torch._mirror_autograd_meta_to(t_clone2, x)
|
|
torch._enable_functionalization(reapply_views=False)
|
|
aot_f_out = f(x)
|
|
finally:
|
|
torch._disable_functionalization()
|
|
|
|
# frame count and op count are incremented due to re-compilation
|
|
check_count_and_graph(3, 3, 3, expected_graph)
|
|
|
|
def test_wrapper_subclass_guards_on_inner_tensor(self):
|
|
# Holds an inner tensor, that has a distinct shape from the outer wrapper tensor.
|
|
# Also adds additional guards on the inner tensor's sizes.
|
|
# When the first input to an op has x.shape[0] > 5, we insert an extra add node.
|
|
class DoubleSizeMaybeAddGeThreeTensor(torch.Tensor):
|
|
@staticmethod
|
|
def __new__(cls, inner):
|
|
# Double the outer-most dimension
|
|
outer_shape = (inner.shape[0] * 2,) + inner.shape[1:]
|
|
return torch.Tensor._make_wrapper_subclass(
|
|
# TODO: right now, _make_wrapper_subclass's dynamic shape interaction is not great.
|
|
# Calling the overload that has kwargs causes us to go down the first overload path,
|
|
# which will **always** specialize sizes.
|
|
# We should probably eventually fix this so that the first overload can just handle dynamic shapes.
|
|
cls,
|
|
outer_shape,
|
|
inner.stride(),
|
|
None,
|
|
None,
|
|
inner.dtype,
|
|
inner.layout,
|
|
inner.device,
|
|
False,
|
|
inner.requires_grad,
|
|
)
|
|
|
|
def __init__(self, inner):
|
|
self.inner_elem = inner
|
|
|
|
def __tensor_flatten__(self):
|
|
return ["inner_elem"], None
|
|
|
|
@staticmethod
|
|
def __tensor_unflatten__(inner_tensors, _):
|
|
return DoubleSizeMaybeAddGeThreeTensor(inner_tensors["inner_elem"])
|
|
|
|
def __repr__(self):
|
|
return f"DoubleSizeMayberAddGeThreeTensor({repr(self.inner_elem)})"
|
|
|
|
@classmethod
|
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
|
|
args_inner = torch.utils._pytree.tree_map_only(
|
|
DoubleSizeMaybeAddGeThreeTensor, lambda x: x.inner_elem, args
|
|
)
|
|
out_inner = func(*args_inner, **kwargs)
|
|
|
|
# Add guards on the inner tensor's sizes
|
|
if args_inner[0].shape[0] > 3:
|
|
out_inner += 2
|
|
|
|
return DoubleSizeMaybeAddGeThreeTensor(out_inner)
|
|
|
|
lower_bound_str = None
|
|
upper_bound_str = None
|
|
curr_var_to_val = None
|
|
curr_var_to_sources = None
|
|
|
|
def backend(gm, args):
|
|
print(gm.code)
|
|
context = torch._guards.TracingContext.get()
|
|
val_to_guards = list(context.fake_mode.shape_env.var_to_guards.values())
|
|
|
|
# Grab info on sources and guards from the shapenv
|
|
nonlocal lower_bound_str
|
|
nonlocal upper_bound_str
|
|
nonlocal curr_var_to_val
|
|
nonlocal curr_var_to_sources
|
|
|
|
lower_bound_str = str(val_to_guards[0][0].expr)
|
|
upper_bound_str = str(val_to_guards[0][1].expr)
|
|
curr_var_to_val = {
|
|
str(k): v for k, v in context.fake_mode.shape_env.var_to_val.items()
|
|
}
|
|
curr_var_to_sources = {
|
|
str(k): v[0].name()
|
|
for k, v in context.fake_mode.shape_env.var_to_sources.items()
|
|
}
|
|
return gm
|
|
|
|
@torch.compile(backend=backend)
|
|
def fn(x):
|
|
if x.shape[0] < 10:
|
|
return torch.mul(x, x)
|
|
else:
|
|
return torch.div(x, x)
|
|
|
|
inp = torch.ones(4, 4)
|
|
|
|
x = DoubleSizeMaybeAddGeThreeTensor(inp)
|
|
torch._dynamo.mark_dynamic(x, 0)
|
|
res = fn(x)
|
|
# During fakeifying, we end up allocating a separate symint
|
|
# for the outer and inner tensor (in this test, s0 is unused).
|
|
expected_var_to_val = {
|
|
"s0": 8,
|
|
"s1": 4,
|
|
}
|
|
expected_var_to_sources = {
|
|
"s0": "L['x'].size()[0]",
|
|
"s1": "L['x'].inner_elem.size()[0]",
|
|
}
|
|
# lower bound comes from code underneath torch_dispatch (operating on the inner tensor size)
|
|
expected_lower_bound = "s1 > 3"
|
|
# upper bound comes from user code (operating on the wrapper size)
|
|
expected_upper_bound = "2*s1 < 10"
|
|
self.assertEqual(curr_var_to_val, expected_var_to_val)
|
|
self.assertEqual(curr_var_to_sources, expected_var_to_sources)
|
|
self.assertEqual(lower_bound_str, expected_lower_bound)
|
|
self.assertEqual(upper_bound_str, expected_upper_bound)
|
|
|
|
def test_recompile_with_symbool_inputs(self):
|
|
def f(pred: bool):
|
|
if pred:
|
|
return torch.ones([3, 4])
|
|
else:
|
|
return torch.ones([4, 3])
|
|
|
|
def test_recompilation(
|
|
f, x, sizes, exp_graphs, exp_frame_count, exp_shape_env_guards
|
|
):
|
|
torch._dynamo.reset()
|
|
shape_env = ShapeEnv()
|
|
backend = torch._dynamo.testing.EagerAndRecordGraphs()
|
|
cnt = torch._dynamo.testing.CompileCounterWithBackend(backend)
|
|
f_cond = torch.compile(f, backend=cnt, fullgraph=True)
|
|
with torch._subclasses.fake_tensor.FakeTensorMode(
|
|
shape_env=shape_env
|
|
) as fake_mode:
|
|
fake_inp = fake_mode.from_tensor(
|
|
x, dynamic_dims=[DimDynamic.DYNAMIC for i in range(x.dim())]
|
|
)
|
|
for i, size in enumerate(sizes):
|
|
pred = fake_inp.size(0) == size
|
|
f_cond(pred)
|
|
actual = normalize_gm(
|
|
backend.graphs[exp_frame_count[i] - 1].print_readable(
|
|
print_output=False
|
|
)
|
|
)
|
|
actual_guard_str = [str(guard.expr) for guard in shape_env.guards]
|
|
self.assertExpectedInline(actual, exp_graphs[i])
|
|
self.assertEqual(cnt.frame_count, exp_frame_count[i])
|
|
self.assertEqual(actual_guard_str, exp_shape_env_guards[i])
|
|
|
|
true_graph = """\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self):
|
|
ones = torch.ones([3, 4])
|
|
return (ones,)
|
|
"""
|
|
false_graph = """\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self):
|
|
ones = torch.ones([4, 3])
|
|
return (ones,)
|
|
"""
|
|
test_recompilation(
|
|
f,
|
|
torch.randn([3, 4]),
|
|
[3, 3, 4, 5],
|
|
exp_graphs=[true_graph, true_graph, false_graph, false_graph],
|
|
exp_frame_count=[1, 1, 2, 2],
|
|
exp_shape_env_guards=[
|
|
[],
|
|
# s0 is specialized and guarded in outter shape_env when dynamo checks the guards
|
|
["Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)"],
|
|
[
|
|
"Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)",
|
|
"Ne(Piecewise((1, Eq(s0, 4)), (0, True)), 1)",
|
|
],
|
|
[
|
|
"Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)",
|
|
"Ne(Piecewise((1, Eq(s0, 4)), (0, True)), 1)",
|
|
"Ne(Piecewise((1, Eq(s0, 5)), (0, True)), 1)",
|
|
],
|
|
],
|
|
)
|
|
|
|
test_recompilation(
|
|
f,
|
|
torch.randn([3, 4]),
|
|
[4, 5, 3, 3],
|
|
exp_graphs=[false_graph, false_graph, true_graph, true_graph],
|
|
exp_frame_count=[1, 1, 2, 2],
|
|
exp_shape_env_guards=[
|
|
[],
|
|
# s0 is specialized and guarded in outter shape_env when dynamo checks the guards
|
|
["Ne(Piecewise((1, Eq(s0, 5)), (0, True)), 1)"],
|
|
[
|
|
"Ne(Piecewise((1, Eq(s0, 5)), (0, True)), 1)",
|
|
"Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)",
|
|
],
|
|
[
|
|
"Ne(Piecewise((1, Eq(s0, 5)), (0, True)), 1)",
|
|
"Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)",
|
|
"Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)",
|
|
],
|
|
],
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|