Unified symbolic shape variables between AOTAutograd and Inductor (#86659)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/86659
Approved by: https://github.com/wconstab
This commit is contained in:
Horace He 2022-10-13 20:19:16 +00:00 committed by PyTorch MergeBot
parent c7c09722ad
commit b3b9786fdd
12 changed files with 526 additions and 88 deletions

View File

@ -206,7 +206,7 @@ const char* toString(DispatchKey t) {
switch (bc) { \
C10_FORALL_BACKEND_COMPONENTS(ENTRY, prefix) \
default: \
return #prefix "Unknown"; \
return #prefix "Undefined"; \
}
C10_FORALL_FUNCTIONALITY_KEYS(FORALL_BC)

View File

@ -278,6 +278,7 @@ class AOTConfig:
bw_compiler: Callable
partition_fn: Callable
decompositions: Dict[Callable, Callable]
num_params_buffers: int
def aot_dispatch_base(flat_fn, flat_args: List[Tensor], aot_config: AOTConfig):
@ -491,6 +492,11 @@ def create_aot_dispatcher_function(
The resulting compiled forward and backward graphs are then wrapped up in a
``torch.autograd.Function`` object.
The calling convention here is that the first aot_config.num_params_buffers
inputs in flat_args are parameters and buffers, and the rest are inputs.
We use this to assume that parameters/buffer's shapes don't change.
"""
# This is the main entry point.
@ -514,19 +520,26 @@ def create_aot_dispatcher_function(
# coordinate flags
config.use_fake_tensor = False
fake_mode = FakeTensorMode() if config.use_fake_tensor else nullcontext()
if config.use_dynamic_shapes:
assert config.use_fake_tensor, "Dynamic shapes only works with fake tensor"
shape_env = ShapeEnv() if config.use_dynamic_shapes else None
fake_mode = FakeTensorMode(shape_env=shape_env) if config.use_fake_tensor else nullcontext()
cross_ref = CrossRefFakeMode() if config.debug_fake_cross_ref else nullcontext()
python_dispatcher_mode = enable_python_dispatcher() if config.use_dynamic_shapes else nullcontext()
shape_env = ShapeEnv() if config.use_dynamic_shapes else None
with torch.autograd.set_multithreading_enabled(False), preserve_rng_state(), cross_ref, fake_mode, python_dispatcher_mode:
def process_inputs(flat_args):
if config.use_fake_tensor:
def convert(x):
return fake_mode.from_tensor(x, shape_env=shape_env)
def convert(idx, x):
if not isinstance(x, torch.Tensor):
return x
if idx < aot_config.num_params_buffers and config.static_weight_shapes:
return fake_mode.from_tensor(x, static_shapes=True)
return fake_mode.from_tensor(x, static_shapes=False)
return pytree.tree_map_only(Tensor, convert, flat_args)
return [convert(idx, x) for idx, x in enumerate(flat_args)]
else:
return flat_args
@ -587,6 +600,7 @@ def aot_function(
bw_compiler: Optional[Callable] = None,
partition_fn: Callable = default_partition,
decompositions: Optional[Dict] = None,
num_params_buffers: int = 0,
hasher_type=None, # deprecated
static_argnums: Optional[Tuple[int]] = None, # deprecated
) -> Callable:
@ -650,6 +664,7 @@ def aot_function(
bw_compiler=bw_compiler,
partition_fn=partition_fn,
decompositions=decompositions,
num_params_buffers=num_params_buffers,
)
cached_res = None
@ -734,7 +749,10 @@ def aot_module(mod: nn.Module, *args, **kwargs) -> nn.Module:
params_and_buffers = {**named_params, **named_buffers}
return stateless.functional_call(mod, params_and_buffers, args, kwargs)
compiled_f = aot_function(functional_call, *args, **kwargs)
named_params = dict(_named_parameters(mod, remove_duplicate=False))
named_buffers = dict(_named_buffers(mod, remove_duplicate=False))
num_params_buffers = len(named_params) + len(named_buffers)
compiled_f = aot_function(functional_call, num_params_buffers=num_params_buffers, *args, **kwargs)
class AOTModule(nn.Module):
def __init__(self):
@ -743,8 +761,8 @@ def aot_module(mod: nn.Module, *args, **kwargs) -> nn.Module:
def forward(self, *args, **kwargs):
return compiled_f(
dict(_named_parameters(mod, remove_duplicate=False)),
dict(_named_buffers(mod, remove_duplicate=False)),
named_params,
named_buffers,
*args,
**kwargs,
)
@ -812,6 +830,7 @@ def aot_module_simplified(mod: nn.Module, *top_args, **top_kwargs) -> nn.Module:
bw_compiler=bw_compiler,
partition_fn=partition_fn,
decompositions=decompositions,
num_params_buffers=params_len,
)
compiled_fn = None

View File

@ -23,3 +23,5 @@ debug_graphs = os.environ.get('AOT_FX_GRAPHS', False)
debug_joint = os.environ.get('AOT_FX_GRAPHS_JOINT', False)
use_dynamic_shapes = os.getenv('AOT_DYNAMIC_SHAPES', False)
static_weight_shapes = True

View File

@ -1060,6 +1060,7 @@ symbolic_aot_autograd_failures = {
xfail('nn.functional.avg_pool1d', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('nn.functional.avg_pool2d', ''), # aten.avg_pool2d.default - couldn't find symbolic meta function/...
xfail('nn.functional.avg_pool3d', ''), # aten.avg_pool3d.default - couldn't find symbolic meta function/...
skip('nn.functional.batch_norm', ''), # '0 is not tracked with proxy for <torch.fx.experimental.proxy_te..
xfail('nn.functional.bilinear', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('nn.functional.binary_cross_entropy', ''), # aten.fill_.Scalar - couldn't find symbolic meta funct...
xfail('nn.functional.conv1d', ''), # Cannot call sizes() on tensor with symbolic sizes/strides

View File

@ -113,13 +113,14 @@ class FakeSymbolicTensor(torch.Tensor):
def create_symbolic_tensor(name, arg, shape_env, storage_offset=0):
sym_shapes = tuple([shape_env.create_symint(f"{name}_{idx}", val) for idx, val in enumerate(arg.size())])
sym_strides = tuple([shape_env.create_symint(f"{name}_{idx}_stride", val) for idx, val in enumerate(arg.stride())])
sym_shapes, sym_strides = shape_env.create_symbolic_sizes_strides(arg)
return FakeSymbolicTensor(sym_shapes, sym_strides, arg.dtype, arg.layout, arg.requires_grad, arg.device, storage_offset)
CPP_SYMINT_CLASS = type(torch.SymIntNode.new_symint(1))
def create_symint(shape_env, i):
return shape_env.create_symintnode(shape_env.create_symbol(i))
@skipIfTorchDynamo("Creating ShapeEnv fails for confusing reasons (also we never expect dynamo to see code like this)")
class TestPySymInt(TestCase):
@ -128,8 +129,8 @@ class TestPySymInt(TestCase):
def test_arith_ops(self):
shape_env = ShapeEnv()
symints = []
for i in range(5):
symints.append((i, shape_env.create_symint(f"s{i}", i)))
for i in range(2, 5):
symints.append((i, create_symint(shape_env, i)))
ops = [operator.add, operator.sub, operator.floordiv, operator.mul, operator.mod]
@ -143,10 +144,10 @@ class TestPySymInt(TestCase):
def test_reverse_arith_ops(self):
shape_env = ShapeEnv()
a = shape_env.create_symint("s1", 2)
a = create_symint(shape_env, 2)
self.assertTrue(5 // a == 5 // 2)
a = shape_env.create_symint("s1", 2)
a = create_symint(shape_env, 2)
self.assertTrue(5 * a == 5 * 2)
@ -172,7 +173,7 @@ class TestPySymInt(TestCase):
self.assertTrue(x.size(2) == 3)
self.assertTrue(isinstance(x.size(2), CPP_SYMINT_CLASS))
offset = shape_env.create_symint("offset", 2)
offset = create_symint(shape_env, 2)
y = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env, offset)
self.assertTrue(isinstance(y.storage_offset(), CPP_SYMINT_CLASS))
self.assertTrue(y.storage_offset() == 2)
@ -207,7 +208,7 @@ class TestPySymInt(TestCase):
y = create_symbolic_tensor("y", torch.randn(5, 4, 1), shape_env)
LAST_DIM = 2
z = x.narrow_copy(LAST_DIM, 0, y.shape[LAST_DIM])
self.assertTrue(z.shape[2] == int(y.shape[2]))
self.assertTrue(z.shape[2] == y.shape[2])
# arithmetic expr with two symints
z = x.narrow_copy(LAST_DIM, 0, x.shape[LAST_DIM] - y.shape[LAST_DIM])
@ -317,28 +318,28 @@ class TestPySymInt(TestCase):
@skipIfNoSympy
def test_meta_symint(self):
shape_env = ShapeEnv()
a0 = shape_env.create_symint("a0", 2)
a0 = create_symint(shape_env, 2)
r = torch.empty(a0, device='meta')
self.assertIsInstance(r.shape[0], CPP_SYMINT_CLASS)
@skipIfNoSympy
def test_guard_int(self):
shape_env = ShapeEnv()
a0 = shape_env.create_symint("a0", 2)
a0 = create_symint(shape_env, 2)
self.assertEqual(a0.guard_int(), 2)
self.assertEqual(str(shape_env.guards[0][0]), "a0")
self.assertEqual(str(shape_env.guards[0][0]), "s0")
self.assertEqual(shape_env.guards[0][1], 2)
@skipIfNoSympy
def test_int_conversion(self):
shape_env = ShapeEnv()
a0 = shape_env.create_symint("a0", 2)
a0 = create_symint(shape_env, 2)
self.assertRaisesRegex(RuntimeError, "Trying to extract", lambda: int(a0))
@skipIfNoSympy
def test_symint_as_scalar(self):
shape_env = ShapeEnv()
a0 = shape_env.create_symint("a0", 2)
a0 = create_symint(shape_env, 2)
sym_int_encountered = False
@ -372,18 +373,18 @@ class TestPySymInt(TestCase):
self.assertExpectedInline(mock_stdout.getvalue().strip(), """\
class f(torch.nn.Module):
def forward(self, a_1: f32[t0.size(0),t0.size(1)], b_1: f32[t1.size(0),t0.size(1)]):
def forward(self, a_1: f32[s0, s1], b_1: f32[s2, s1]):
# No stacktrace found for following nodes
sym_size: Sym(t0.size(0)) = torch.ops.aten.sym_size(a_1, 0)
sym_size_1: Sym(t1.size(0)) = torch.ops.aten.sym_size(b_1, 0)
add: Sym(t0.size(0) + t1.size(0)) = sym_size + sym_size_1; sym_size = sym_size_1 = None
sym_size_2: Sym(t0.size(1)) = torch.ops.aten.sym_size(a_1, 1)
sym_size_3: Sym(t0.size(1)) = torch.ops.aten.sym_size(b_1, 1); b_1 = None
add_1: Sym(2*t0.size(1)) = sym_size_2 + sym_size_3; sym_size_2 = sym_size_3 = None
new_empty: f32[t0.size(0) + t1.size(0),2*t0.size(1)] = torch.ops.aten.new_empty.default(a_1, [add, add_1], dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False); a_1 = add = add_1 = None
sym_size: Sym(s0) = torch.ops.aten.sym_size(a_1, 0)
sym_size_1: Sym(s2) = torch.ops.aten.sym_size(b_1, 0)
add: Sym(s0 + s2) = sym_size + sym_size_1; sym_size = sym_size_1 = None
sym_size_2: Sym(s1) = torch.ops.aten.sym_size(a_1, 1)
sym_size_3: Sym(s1) = torch.ops.aten.sym_size(b_1, 1); b_1 = None
add_1: Sym(2*s1) = sym_size_2 + sym_size_3; sym_size_2 = sym_size_3 = None
new_empty: f32[s0 + s2, 2*s1] = torch.ops.aten.new_empty.default(a_1, [add, add_1], dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False); a_1 = add = add_1 = None
native_dropout = torch.ops.aten.native_dropout.default(new_empty, 0.5, True); new_empty = None
getitem: f32[t0.size(0) + t1.size(0),2*t0.size(1)] = native_dropout[0]
getitem_1: b8[t0.size(0) + t1.size(0),2*t0.size(1)] = native_dropout[1]; native_dropout = None
getitem: f32[s0 + s2, 2*s1] = native_dropout[0]
getitem_1: b8[s0 + s2, 2*s1] = native_dropout[1]; native_dropout = None
return (getitem, getitem_1)""") # noqa: B950

View File

@ -0,0 +1,391 @@
# -*- coding: utf-8 -*-
# Owner(s): ["oncall: jit"]
from torch._C import _disabled_torch_function_impl
import torch.fx
import torch.nn.functional as F
from torch.testing._internal.common_utils import run_tests, TestCase, skipIfTorchDynamo
import unittest
import torch
import operator
import itertools
import io
from torch.utils._pytree import tree_map
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.experimental.symbolic_shapes import ShapeEnv, PySymInt, sym_float
from torch.utils._python_dispatch import TorchDispatchMode
aten = torch.ops.aten
try:
import sympy
HAS_SYMPY = True
except ImportError:
HAS_SYMPY = False
skipIfNoSympy = unittest.skipIf(not HAS_SYMPY, "no sympy")
meta_funcs = {}
def register_meta(op):
def decorator(f):
def add_func(op):
meta_funcs[op] = f
tree_map(add_func, op)
return f
return decorator
@register_meta([aten.add.Tensor, aten.sub.Tensor])
def binary_meta(a, b):
return a.new_empty(a.shape)
@register_meta(aten.cat.default)
def cat_meta(tensors, dim=0):
concat_length = 0
shape = tensors[0].shape
for tensor in tensors:
for idx, (common_length, length) in enumerate(zip(shape, tensor.shape)):
if idx == dim:
concat_length = concat_length + length
else:
assert length == common_length
new_shape = list(shape)
new_shape[dim] = concat_length
return tensors[0].new_empty(new_shape)
@register_meta([aten.narrow_copy.default])
def narrow_copy_symint_meta(a, dim, start, length, **kwargs):
shape = []
for i, x in enumerate(a.shape):
if i == dim:
shape.append(length)
else:
shape.append(x)
return a.new_empty(tuple(shape))
@register_meta([aten.expand.default])
def expand_symint_meta(a, size, implicit=False):
return a.new_empty(size)
def create_contiguous(shape):
strides = [1]
for dim in reversed(shape[:-1]):
strides.append(dim * strides[-1])
return list(reversed(strides))
class FakeSymbolicTensor(torch.Tensor):
@staticmethod
def __new__(cls, sym_shape, sym_strides, dtype, layout, requires_grad, device, storage_offset=0):
# TODO: this is wrong in general
sym_stride = create_contiguous(sym_shape)
r = torch.Tensor._make_wrapper_subclass(
cls, sym_shape,
sym_stride, storage_offset,
dtype=dtype, layout=layout, requires_grad=requires_grad,
device=device,
)
return r
__torch_function__ = _disabled_torch_function_impl
def new_empty(self, shape):
return FakeSymbolicTensor(shape, None, self.dtype, self.layout, self.requires_grad, self.device)
@classmethod
def __torch_dispatch__(cls, func_overload, types, args=(), kwargs=None):
if func_overload in meta_funcs:
return meta_funcs[func_overload](*args, **kwargs)
if func_overload == torch.ops.aten.new_empty.default:
self = args[0]
shape = args[1]
return FakeSymbolicTensor(shape, self.stride(), self.dtype, self.layout, self.requires_grad, self.device)
raise RuntimeError(f"operator {func_overload} not supported")
def create_symbolic_tensor(name, arg, shape_env, storage_offset=0):
sym_shapes, sym_strides = shape_env.create_symbolic_sizes_strides(arg)
return FakeSymbolicTensor(sym_shapes, sym_strides, arg.dtype, arg.layout, arg.requires_grad, arg.device, storage_offset)
CPP_SYMINT_CLASS = type(torch.SymIntNode.new_symint(1))
def create_symint(shape_env, i):
return shape_env.create_symintnode(shape_env.create_symbol(i))
@skipIfTorchDynamo("Creating ShapeEnv fails for confusing reasons (also we never expect dynamo to see code like this)")
class TestPySymInt(TestCase):
@skipIfNoSympy
def test_arith_ops(self):
shape_env = ShapeEnv()
symints = []
for i in range(2, 5):
symints.append((i, create_symint(shape_env, i)))
ops = [operator.add, operator.sub, operator.floordiv, operator.mul, operator.mod]
for op in ops:
for args in itertools.permutations(symints, 2):
if not isinstance(args[0][1], int) and ((op != operator.mod or op != operator.floordiv) and args[1][0] != 0):
self.assertTrue(op(args[0][1], args[1][1]) == op(args[0][0], args[1][0]))
@skipIfNoSympy
def test_reverse_arith_ops(self):
shape_env = ShapeEnv()
a = create_symint(shape_env, 2)
self.assertTrue(5 // a == 5 // 2)
a = create_symint(shape_env, 2)
self.assertTrue(5 * a == 5 * 2)
@skipIfNoSympy
def test_roundtrip(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
self.assertTrue(not isinstance(x.shape[0], PySymInt))
self.assertTrue(isinstance(x.shape[0], CPP_SYMINT_CLASS))
self.assertTrue(x.shape[0] == 5)
self.assertTrue(x.shape[1] == 4)
self.assertTrue(x.shape[2], 3)
self.assertTrue(x.size()[0], 5)
self.assertTrue(x.size()[1], 4)
self.assertTrue(isinstance(x.size()[1], CPP_SYMINT_CLASS))
self.assertTrue(x.size()[2] == 3)
self.assertTrue(x.size(0) == 5)
self.assertTrue(x.size(1) == 4)
self.assertTrue(x.size(2) == 3)
self.assertTrue(isinstance(x.size(2), CPP_SYMINT_CLASS))
offset = create_symint(shape_env, 2)
y = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env, offset)
self.assertTrue(isinstance(y.storage_offset(), CPP_SYMINT_CLASS))
self.assertTrue(y.storage_offset() == 2)
offset = 2
z = create_symbolic_tensor("z", torch.randn(5, 4, 3), shape_env, offset)
self.assertTrue(isinstance(z.storage_offset(), int))
self.assertTrue(z.storage_offset() == 2)
@skipIfNoSympy
def test_binary(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
y = create_symbolic_tensor("y", torch.randn(5, 4, 3), shape_env)
z = x + y
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
# broadcasting
y = create_symbolic_tensor("y", torch.randn(1, 4, 1), shape_env)
z = x + y
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
@skipIfNoSympy
def test_symint_args(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
y = create_symbolic_tensor("y", torch.randn(5, 4, 1), shape_env)
LAST_DIM = 2
z = x.narrow_copy(LAST_DIM, 0, y.shape[LAST_DIM])
self.assertTrue(z.shape[2] == y.shape[2])
# arithmetic expr with two symints
z = x.narrow_copy(LAST_DIM, 0, x.shape[LAST_DIM] - y.shape[LAST_DIM])
self.assertTrue(z.shape[2] == 2)
# arithmetic expr with a symint and python int
z = x.narrow_copy(LAST_DIM, 0, x.shape[LAST_DIM] - 1)
self.assertTrue(z.shape[2] == 2)
@skipIfNoSympy
def test_symint_vargs(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
y = create_symbolic_tensor("y", torch.randn(1, 4, 1), shape_env)
# varargs
z = y.expand(x.shape[0], y.shape[1], x.shape[2])
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
# shape list
z = y.expand((x.shape[0], y.shape[1], x.shape[2]))
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
# mixed python symints and ints
z = y.expand(x.shape[0], y.shape[1], 3)
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
# mixed python symints and ints in a list
z = y.expand((x.shape[0], y.shape[1], 3))
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
# mixed python symints and ints
z = y.expand(5, y.shape[1], x.shape[2])
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
# mixed python ints and symints in a list
z = y.expand((5, y.shape[1], x.shape[2]))
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
z = y.expand((y.shape[1],))
z = y.expand(y.shape[1])
@skipIfNoSympy
def test_stride(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5, 5), shape_env)
self.assertIsInstance(x.stride()[0], CPP_SYMINT_CLASS)
@skipIfNoSympy
def test_size_expressions(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
expand_x = x.expand(x.shape[0], x.shape[0])
if expand_x.shape[0] > 3:
result = expand_x + expand_x
else:
result = expand_x + expand_x
gt_op = shape_env.guards[0][0]
self.assertTrue(isinstance(gt_op, sympy.core.relational.StrictGreaterThan))
self.assertTrue(str(x.shape[0]), str(gt_op.args[0]))
self.assertTrue(str(expand_x.shape[1]), str(x.shape[0]))
self.assertTrue(str(expand_x.shape[1]), str(result.shape[0]))
@skipIfNoSympy
def test_int_to_float(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
r = sym_float(x.shape[0])
self.assertTrue(isinstance(r, torch.SymFloatNode))
@skipIfNoSympy
def test_aten_ops(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
torch.ops.aten.narrow_copy.default(x, 0, 0, x.shape[0])
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
torch.ops.aten.expand.default(x, [x.shape[0], x.shape[1], x.shape[2]])
def test_fx_trace_intlist(self):
class CustomModule(torch.nn.Module):
def forward(self, x):
bs, c, h, w = x.shape
return F.pad(x, (0, w % 2, 0, h % 2, 0, 0))
m = CustomModule()
x = torch.rand(1, 3, 4, 4)
# should not TypeError: pad(): argument 'pad' (position 2) must be
# tuple of ints, not tuple
torch.fx.symbolic_trace(m)
@skipIfNoSympy
def test_meta_symint(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 2)
r = torch.empty(a0, device='meta')
self.assertIsInstance(r.shape[0], CPP_SYMINT_CLASS)
@skipIfNoSympy
def test_guard_int(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 2)
self.assertEqual(a0.guard_int(), 2)
self.assertEqual(str(shape_env.guards[0][0]), "s0")
self.assertEqual(shape_env.guards[0][1], 2)
@skipIfNoSympy
def test_int_conversion(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 2)
self.assertRaisesRegex(RuntimeError, "Trying to extract", lambda: int(a0))
@skipIfNoSympy
def test_symint_as_scalar(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 2)
sym_int_encountered = False
class TestSymInt(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
assert func == torch.ops.aten.add.Tensor
nonlocal sym_int_encountered
sym_int_encountered = kwargs["alpha"] is a0
kwargs["alpha"] = 0
return func(*args)
x = torch.rand([4, 4])
with TestSymInt():
y = torch.add(x, x, alpha=a0)
self.assertTrue(sym_int_encountered)
@skipIfNoSympy
@unittest.mock.patch('sys.stdout', new_callable=io.StringIO)
def test_print_readable_with_symints(self, mock_stdout):
def f(a, b):
dim0 = a.shape[0] + b.shape[0]
dim1 = a.shape[1] + b.shape[1]
d = a.new_empty(dim0, dim1)
d = torch.ops.aten.native_dropout(d, 0.5, train=True)
return d
fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(5, 3), torch.randn(4, 3))
fx_g.print_readable()
self.assertExpectedInline(mock_stdout.getvalue().strip(), """\
class f(torch.nn.Module):
def forward(self, a_1: f32[t0.size(0),t0.size(1)], b_1: f32[t1.size(0),t0.size(1)]):
# No stacktrace found for following nodes
sym_size: Sym(t0.size(0)) = torch.ops.aten.sym_size(a_1, 0)
sym_size_1: Sym(t1.size(0)) = torch.ops.aten.sym_size(b_1, 0)
add: Sym(t0.size(0) + t1.size(0)) = sym_size + sym_size_1; sym_size = sym_size_1 = None
sym_size_2: Sym(t0.size(1)) = torch.ops.aten.sym_size(a_1, 1)
sym_size_3: Sym(t0.size(1)) = torch.ops.aten.sym_size(b_1, 1); b_1 = None
add_1: Sym(2*t0.size(1)) = sym_size_2 + sym_size_3; sym_size_2 = sym_size_3 = None
new_empty: f32[t0.size(0) + t1.size(0),2*t0.size(1)] = torch.ops.aten.new_empty.default(a_1, [add, add_1], dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False); a_1 = add = add_1 = None
native_dropout = torch.ops.aten.native_dropout.default(new_empty, 0.5, True); new_empty = None
getitem: f32[t0.size(0) + t1.size(0),2*t0.size(1)] = native_dropout[0]
getitem_1: b8[t0.size(0) + t1.size(0),2*t0.size(1)] = native_dropout[1]; native_dropout = None
return (getitem, getitem_1)""") # noqa: B950
if __name__ == '__main__':
run_tests()

View File

@ -1065,6 +1065,7 @@ symbolic_tensor_failures = {
xfail('cholesky_solve', ''), # Could not run 'aten::_cholesky_solve_helper' with arguments from the 'Meta' back...
xfail('chunk', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('column_stack', ''), # Tensors of type TensorImpl do not have numel
xfail('combinations', ''),
xfail('count_nonzero', ''), # Could not run 'aten::count_nonzero.dim_IntList' with arguments from the 'Meta' ba...
xfail('cross', ''), # aten.linalg_cross.default - couldn't find symbolic meta function/decomposition
xfail('cummax', ''), # aten.cummax.default - couldn't find symbolic meta function/decomposition
@ -1290,6 +1291,7 @@ symbolic_tensor_failures = {
xfail('unbind', ''), # aten.unbind.int - couldn't find symbolic meta function/decomposition
}
symbolic_tensor_segfaults = {
skip('nn.functional.batch_norm') # Segfault??
}
symbolic_tensor_failures.update(symbolic_tensor_segfaults)

View File

@ -622,6 +622,7 @@ class FakeTensorMode(TorchDispatchMode):
allow_fallback_kernels=True,
allow_meta=False,
throw_on_data_dependent_ops=True,
shape_env=None,
):
self.allow_fallback_kernels = allow_fallback_kernels
self.fake_tensor_converter = FakeTensorConverter()
@ -642,6 +643,8 @@ class FakeTensorMode(TorchDispatchMode):
# the device property
self.in_kernel_invocation = False
self.shape_env = shape_env
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
kwargs = kwargs if kwargs else {}
@ -920,8 +923,10 @@ class FakeTensorMode(TorchDispatchMode):
):
self.fake_tensor_converter.invalidate_constant_aliases(v.constant)
def from_tensor(self, tensor, shape_env=None):
return self.fake_tensor_converter(self, tensor, shape_env=shape_env)
def from_tensor(self, tensor, static_shapes=False):
if static_shapes:
return self.fake_tensor_converter(self, tensor)
return self.fake_tensor_converter(self, tensor, shape_env=self.shape_env)
# NB: returns fake tensors

View File

@ -142,39 +142,18 @@ class MetaConverter:
arg_cnt = self.arg_cnt
self.arg_cnt += 1
# Don't make parameters have symbolic shapes; they are assumed to stay
# constant size across training runs
make_symbolic = shape_env is not None and not isinstance(t, torch.nn.Parameter)
make_symbolic = shape_env is not None
def sym(name, x):
def sym(x):
if make_symbolic:
return shape_env.create_symint(f"t{arg_cnt}.{name}()", x)
return shape_env.create_symbol(x)
else:
return x
def sym_list(name, xs):
def sym_sizes_strides(t):
if make_symbolic:
return [
shape_env.create_symint(f"t{arg_cnt}.{name}({i})", x)
for i, x in enumerate(xs)
]
else:
return xs
def sym_size(t):
return sym_list("size", t.size())
def sym_stride(t):
return sym_list("stride", t.stride())
# NB: Although sym_stride variables initially have no correlation
# with size, we will immediately introduce guards based on contiguity.
# Thus, if the input tensor is contiguous, the stride variables
# will typically immediately get reexpressed in terms of the size
# variables.
def sym_storage_offset(t):
return sym("storage_offset", t.storage_offset())
return shape_env.create_symbolic_sizes_strides(t)
return (t.size(), t.stride())
# see expired-storages
self.check_expired_count += 1
@ -231,9 +210,8 @@ class MetaConverter:
base = base.view(t.dtype)
with torch.enable_grad():
r = base.as_strided(
sym_size(t), sym_stride(t), sym_storage_offset(t)
)
sizes, strides = sym_sizes_strides(t)
r = base.as_strided(sizes, strides, sym(t.storage_offset()))
else:
is_leaf = safe_is_leaf(t)
# Fake up some autograd history.
@ -257,8 +235,9 @@ class MetaConverter:
# meta storage
s = self.meta_storage(t.storage())
with no_dispatch():
sizes, strides = sym_sizes_strides(t)
with torch.no_grad():
r.set_(s, sym_storage_offset(t), sym_size(t), sym_stride(t))
r.set_(s, sym(t.storage_offset()), sizes, strides)
torch._C._set_conj(r, t.is_conj())
torch._C._set_neg(r, t.is_neg())

View File

@ -612,7 +612,8 @@ def make_fx(f, decomposition_table=None, tracing_mode="real"):
elif tracing_mode == "fake":
fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=True)
elif tracing_mode == "symbolic":
fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=False)
shape_env = ShapeEnv()
fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=False, shape_env=shape_env)
else:
raise AssertionError(f"Unexpected tracing type: {tracing_mode}")
@ -628,15 +629,12 @@ def make_fx(f, decomposition_table=None, tracing_mode="real"):
return x
shape_env = None
if tracing_mode == "symbolic":
shape_env = ShapeEnv()
sym_mode = proxy_mode.sym_mode
# todo: Figure out a more informative name for symints
def wrap_fake_symbolic(x):
if isinstance(x, torch.Tensor):
return fake_tensor_mode.from_tensor(x, shape_env=shape_env)
return fake_tensor_mode.from_tensor(x)
return x
wrap_fn_map = {

View File

@ -1,6 +1,6 @@
import torch
import torch.utils._pytree as pytree
from typing import Set, Dict, List, Type, Optional, cast
from typing import Set, Dict, List, Type, Optional, cast, Union
import operator
import math
import functools
@ -331,7 +331,7 @@ class ShapeEnv(object):
self.divisible: Set["sympy.Expr"] = set()
# Duck-shaping says that if two input tensors have the same size,
# they get assigned the same symbolic variable
self.val_to_symint: Dict[int, torch.SymIntNode] = {}
self.val_to_var: Dict[int, "sympy.Expr"] = {0: sympy.Integer(0), 1: sympy.Integer(1)}
def _get_key(self):
"""
@ -340,28 +340,68 @@ class ShapeEnv(object):
"""
return (len(self.replacements), len(self.divisible))
# NB: This is only called for input symbolic sizes; intermediate symbolic
# sizes are allocated via a different mechanism
def create_symint(self, name, val):
assert val >= 0
def create_symbolic_sizes_strides(self, ex: torch.Tensor):
"""
Returns a list of symbolic sizes and strides for the given tensor.
We try our best to express stride in terms of the sizes, so as to not
introduce new symbolic variables.
"""
size = [self.create_symbol(i) for i in ex.size()]
stride: List[Optional[sympy.Expr]] = [None] * len(size)
for i, val in enumerate(ex.stride()):
if val in (0, 1):
stride[i] = sympy.Integer(val)
while any(x is None for x in stride):
candidates = {
ex.size(i) * ex.stride()[i]: size[i] * stride[i]
for i in range(len(size))
if stride[i] is not None and ex.stride()[i] >= 0
}
# iterate over unbound strides in sorted order
val_list = sorted(
[(ex.stride()[i], i) for i in range(len(stride)) if stride[i] is None]
)
for _, i in val_list:
if stride[i] is None and ex.stride()[i] in candidates:
stride[i] = candidates[ex.stride()[i]]
candidates[ex.size(i) * ex.stride()[i]] = size[i] * stride[i]
if any(x is None for x in stride):
# bind the smallest unbound stride to a new variable
val, i = sorted(
[
(ex.stride()[i], i)
for i in range(len(stride))
if stride[i] is None
]
)[0]
stride[i] = self.create_symbol(val)
assert all(x is not None for x in stride)
return [self.create_symintnode(i) for i in size], [self.create_symintnode(i) for i in stride] # type: ignore[arg-type]
def create_symintnode(self, expr: Union["sympy.Expr", int]):
py_sym_int = PySymInt(expr, self)
cpp_sym_int = torch.SymIntNode.new_symint(py_sym_int) # type: ignore[attr-defined]
return cpp_sym_int
def create_symbol(self, val: int) -> "sympy.Expr":
if not HAS_SYMPY:
raise RuntimeError("Need sympy installed to create symbolic shapes")
# TODO: Put 0/1 specialization in guards
if val == 0 or val == 1:
return val
if val < 0:
# all sympy base variables must be positive and > 1
return -self.create_symbol(-val)
# This implements duck-shaping: input sizes that match are assigned
# the same symint
# TODO: Create a guard whenever this happens
# TODO: But how do I represent the guard in this case?
if val in self.val_to_symint:
return self.val_to_symint[val]
sympy_expr = sympy.Symbol(name, positive=True, integer=True)
py_sym_int = PySymInt(sympy_expr, self)
cpp_sym_int = torch.SymIntNode.new_symint(py_sym_int) # type: ignore[attr-defined]
# Note: val_to_var is also initialized with 0/1 mapping to constants, so
# this also ensures that all symbols are > 1
if val in self.val_to_var:
return self.val_to_var[val]
sympy_expr = sympy.Symbol(f"s{len(self.var_to_val)}", positive=True, integer=True)
self.var_to_val[sympy_expr] = sympy.Integer(val)
self.val_to_symint[val] = cpp_sym_int
return cpp_sym_int
self.val_to_var[val] = sympy_expr
return sympy_expr
def evaluate_guards_for_args(self, *args):
new_env = ShapeEnv()

View File

@ -475,7 +475,7 @@ class CodeGen(object):
body.append('\n# No stacktrace found for following nodes\n')
def stringify_shape(shape : torch.Size) -> str:
return f"[{','.join(str(x) for x in shape)}]"
return f"[{', '.join(str(x) for x in shape)}]"
def emit_node(node : Node):
maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}'