mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Unified symbolic shape variables between AOTAutograd and Inductor (#86659)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86659 Approved by: https://github.com/wconstab
This commit is contained in:
parent
c7c09722ad
commit
b3b9786fdd
|
|
@ -206,7 +206,7 @@ const char* toString(DispatchKey t) {
|
|||
switch (bc) { \
|
||||
C10_FORALL_BACKEND_COMPONENTS(ENTRY, prefix) \
|
||||
default: \
|
||||
return #prefix "Unknown"; \
|
||||
return #prefix "Undefined"; \
|
||||
}
|
||||
|
||||
C10_FORALL_FUNCTIONALITY_KEYS(FORALL_BC)
|
||||
|
|
|
|||
|
|
@ -278,6 +278,7 @@ class AOTConfig:
|
|||
bw_compiler: Callable
|
||||
partition_fn: Callable
|
||||
decompositions: Dict[Callable, Callable]
|
||||
num_params_buffers: int
|
||||
|
||||
|
||||
def aot_dispatch_base(flat_fn, flat_args: List[Tensor], aot_config: AOTConfig):
|
||||
|
|
@ -491,6 +492,11 @@ def create_aot_dispatcher_function(
|
|||
|
||||
The resulting compiled forward and backward graphs are then wrapped up in a
|
||||
``torch.autograd.Function`` object.
|
||||
|
||||
The calling convention here is that the first aot_config.num_params_buffers
|
||||
inputs in flat_args are parameters and buffers, and the rest are inputs.
|
||||
|
||||
We use this to assume that parameters/buffer's shapes don't change.
|
||||
"""
|
||||
|
||||
# This is the main entry point.
|
||||
|
|
@ -514,19 +520,26 @@ def create_aot_dispatcher_function(
|
|||
# coordinate flags
|
||||
config.use_fake_tensor = False
|
||||
|
||||
fake_mode = FakeTensorMode() if config.use_fake_tensor else nullcontext()
|
||||
if config.use_dynamic_shapes:
|
||||
assert config.use_fake_tensor, "Dynamic shapes only works with fake tensor"
|
||||
|
||||
shape_env = ShapeEnv() if config.use_dynamic_shapes else None
|
||||
fake_mode = FakeTensorMode(shape_env=shape_env) if config.use_fake_tensor else nullcontext()
|
||||
cross_ref = CrossRefFakeMode() if config.debug_fake_cross_ref else nullcontext()
|
||||
python_dispatcher_mode = enable_python_dispatcher() if config.use_dynamic_shapes else nullcontext()
|
||||
shape_env = ShapeEnv() if config.use_dynamic_shapes else None
|
||||
|
||||
with torch.autograd.set_multithreading_enabled(False), preserve_rng_state(), cross_ref, fake_mode, python_dispatcher_mode:
|
||||
|
||||
def process_inputs(flat_args):
|
||||
if config.use_fake_tensor:
|
||||
def convert(x):
|
||||
return fake_mode.from_tensor(x, shape_env=shape_env)
|
||||
def convert(idx, x):
|
||||
if not isinstance(x, torch.Tensor):
|
||||
return x
|
||||
if idx < aot_config.num_params_buffers and config.static_weight_shapes:
|
||||
return fake_mode.from_tensor(x, static_shapes=True)
|
||||
return fake_mode.from_tensor(x, static_shapes=False)
|
||||
|
||||
return pytree.tree_map_only(Tensor, convert, flat_args)
|
||||
return [convert(idx, x) for idx, x in enumerate(flat_args)]
|
||||
else:
|
||||
return flat_args
|
||||
|
||||
|
|
@ -587,6 +600,7 @@ def aot_function(
|
|||
bw_compiler: Optional[Callable] = None,
|
||||
partition_fn: Callable = default_partition,
|
||||
decompositions: Optional[Dict] = None,
|
||||
num_params_buffers: int = 0,
|
||||
hasher_type=None, # deprecated
|
||||
static_argnums: Optional[Tuple[int]] = None, # deprecated
|
||||
) -> Callable:
|
||||
|
|
@ -650,6 +664,7 @@ def aot_function(
|
|||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
decompositions=decompositions,
|
||||
num_params_buffers=num_params_buffers,
|
||||
)
|
||||
cached_res = None
|
||||
|
||||
|
|
@ -734,7 +749,10 @@ def aot_module(mod: nn.Module, *args, **kwargs) -> nn.Module:
|
|||
params_and_buffers = {**named_params, **named_buffers}
|
||||
return stateless.functional_call(mod, params_and_buffers, args, kwargs)
|
||||
|
||||
compiled_f = aot_function(functional_call, *args, **kwargs)
|
||||
named_params = dict(_named_parameters(mod, remove_duplicate=False))
|
||||
named_buffers = dict(_named_buffers(mod, remove_duplicate=False))
|
||||
num_params_buffers = len(named_params) + len(named_buffers)
|
||||
compiled_f = aot_function(functional_call, num_params_buffers=num_params_buffers, *args, **kwargs)
|
||||
|
||||
class AOTModule(nn.Module):
|
||||
def __init__(self):
|
||||
|
|
@ -743,8 +761,8 @@ def aot_module(mod: nn.Module, *args, **kwargs) -> nn.Module:
|
|||
|
||||
def forward(self, *args, **kwargs):
|
||||
return compiled_f(
|
||||
dict(_named_parameters(mod, remove_duplicate=False)),
|
||||
dict(_named_buffers(mod, remove_duplicate=False)),
|
||||
named_params,
|
||||
named_buffers,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
|
@ -812,6 +830,7 @@ def aot_module_simplified(mod: nn.Module, *top_args, **top_kwargs) -> nn.Module:
|
|||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
decompositions=decompositions,
|
||||
num_params_buffers=params_len,
|
||||
)
|
||||
|
||||
compiled_fn = None
|
||||
|
|
|
|||
|
|
@ -23,3 +23,5 @@ debug_graphs = os.environ.get('AOT_FX_GRAPHS', False)
|
|||
debug_joint = os.environ.get('AOT_FX_GRAPHS_JOINT', False)
|
||||
|
||||
use_dynamic_shapes = os.getenv('AOT_DYNAMIC_SHAPES', False)
|
||||
|
||||
static_weight_shapes = True
|
||||
|
|
|
|||
|
|
@ -1060,6 +1060,7 @@ symbolic_aot_autograd_failures = {
|
|||
xfail('nn.functional.avg_pool1d', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('nn.functional.avg_pool2d', ''), # aten.avg_pool2d.default - couldn't find symbolic meta function/...
|
||||
xfail('nn.functional.avg_pool3d', ''), # aten.avg_pool3d.default - couldn't find symbolic meta function/...
|
||||
skip('nn.functional.batch_norm', ''), # '0 is not tracked with proxy for <torch.fx.experimental.proxy_te..
|
||||
xfail('nn.functional.bilinear', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('nn.functional.binary_cross_entropy', ''), # aten.fill_.Scalar - couldn't find symbolic meta funct...
|
||||
xfail('nn.functional.conv1d', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
|
|
|
|||
|
|
@ -113,13 +113,14 @@ class FakeSymbolicTensor(torch.Tensor):
|
|||
|
||||
|
||||
def create_symbolic_tensor(name, arg, shape_env, storage_offset=0):
|
||||
sym_shapes = tuple([shape_env.create_symint(f"{name}_{idx}", val) for idx, val in enumerate(arg.size())])
|
||||
sym_strides = tuple([shape_env.create_symint(f"{name}_{idx}_stride", val) for idx, val in enumerate(arg.stride())])
|
||||
sym_shapes, sym_strides = shape_env.create_symbolic_sizes_strides(arg)
|
||||
return FakeSymbolicTensor(sym_shapes, sym_strides, arg.dtype, arg.layout, arg.requires_grad, arg.device, storage_offset)
|
||||
|
||||
|
||||
CPP_SYMINT_CLASS = type(torch.SymIntNode.new_symint(1))
|
||||
|
||||
def create_symint(shape_env, i):
|
||||
return shape_env.create_symintnode(shape_env.create_symbol(i))
|
||||
|
||||
@skipIfTorchDynamo("Creating ShapeEnv fails for confusing reasons (also we never expect dynamo to see code like this)")
|
||||
class TestPySymInt(TestCase):
|
||||
|
|
@ -128,8 +129,8 @@ class TestPySymInt(TestCase):
|
|||
def test_arith_ops(self):
|
||||
shape_env = ShapeEnv()
|
||||
symints = []
|
||||
for i in range(5):
|
||||
symints.append((i, shape_env.create_symint(f"s{i}", i)))
|
||||
for i in range(2, 5):
|
||||
symints.append((i, create_symint(shape_env, i)))
|
||||
|
||||
ops = [operator.add, operator.sub, operator.floordiv, operator.mul, operator.mod]
|
||||
|
||||
|
|
@ -143,10 +144,10 @@ class TestPySymInt(TestCase):
|
|||
def test_reverse_arith_ops(self):
|
||||
shape_env = ShapeEnv()
|
||||
|
||||
a = shape_env.create_symint("s1", 2)
|
||||
a = create_symint(shape_env, 2)
|
||||
self.assertTrue(5 // a == 5 // 2)
|
||||
|
||||
a = shape_env.create_symint("s1", 2)
|
||||
a = create_symint(shape_env, 2)
|
||||
self.assertTrue(5 * a == 5 * 2)
|
||||
|
||||
|
||||
|
|
@ -172,7 +173,7 @@ class TestPySymInt(TestCase):
|
|||
self.assertTrue(x.size(2) == 3)
|
||||
self.assertTrue(isinstance(x.size(2), CPP_SYMINT_CLASS))
|
||||
|
||||
offset = shape_env.create_symint("offset", 2)
|
||||
offset = create_symint(shape_env, 2)
|
||||
y = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env, offset)
|
||||
self.assertTrue(isinstance(y.storage_offset(), CPP_SYMINT_CLASS))
|
||||
self.assertTrue(y.storage_offset() == 2)
|
||||
|
|
@ -207,7 +208,7 @@ class TestPySymInt(TestCase):
|
|||
y = create_symbolic_tensor("y", torch.randn(5, 4, 1), shape_env)
|
||||
LAST_DIM = 2
|
||||
z = x.narrow_copy(LAST_DIM, 0, y.shape[LAST_DIM])
|
||||
self.assertTrue(z.shape[2] == int(y.shape[2]))
|
||||
self.assertTrue(z.shape[2] == y.shape[2])
|
||||
|
||||
# arithmetic expr with two symints
|
||||
z = x.narrow_copy(LAST_DIM, 0, x.shape[LAST_DIM] - y.shape[LAST_DIM])
|
||||
|
|
@ -317,28 +318,28 @@ class TestPySymInt(TestCase):
|
|||
@skipIfNoSympy
|
||||
def test_meta_symint(self):
|
||||
shape_env = ShapeEnv()
|
||||
a0 = shape_env.create_symint("a0", 2)
|
||||
a0 = create_symint(shape_env, 2)
|
||||
r = torch.empty(a0, device='meta')
|
||||
self.assertIsInstance(r.shape[0], CPP_SYMINT_CLASS)
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_guard_int(self):
|
||||
shape_env = ShapeEnv()
|
||||
a0 = shape_env.create_symint("a0", 2)
|
||||
a0 = create_symint(shape_env, 2)
|
||||
self.assertEqual(a0.guard_int(), 2)
|
||||
self.assertEqual(str(shape_env.guards[0][0]), "a0")
|
||||
self.assertEqual(str(shape_env.guards[0][0]), "s0")
|
||||
self.assertEqual(shape_env.guards[0][1], 2)
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_int_conversion(self):
|
||||
shape_env = ShapeEnv()
|
||||
a0 = shape_env.create_symint("a0", 2)
|
||||
a0 = create_symint(shape_env, 2)
|
||||
self.assertRaisesRegex(RuntimeError, "Trying to extract", lambda: int(a0))
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_symint_as_scalar(self):
|
||||
shape_env = ShapeEnv()
|
||||
a0 = shape_env.create_symint("a0", 2)
|
||||
a0 = create_symint(shape_env, 2)
|
||||
|
||||
sym_int_encountered = False
|
||||
|
||||
|
|
@ -372,18 +373,18 @@ class TestPySymInt(TestCase):
|
|||
|
||||
self.assertExpectedInline(mock_stdout.getvalue().strip(), """\
|
||||
class f(torch.nn.Module):
|
||||
def forward(self, a_1: f32[t0.size(0),t0.size(1)], b_1: f32[t1.size(0),t0.size(1)]):
|
||||
def forward(self, a_1: f32[s0, s1], b_1: f32[s2, s1]):
|
||||
# No stacktrace found for following nodes
|
||||
sym_size: Sym(t0.size(0)) = torch.ops.aten.sym_size(a_1, 0)
|
||||
sym_size_1: Sym(t1.size(0)) = torch.ops.aten.sym_size(b_1, 0)
|
||||
add: Sym(t0.size(0) + t1.size(0)) = sym_size + sym_size_1; sym_size = sym_size_1 = None
|
||||
sym_size_2: Sym(t0.size(1)) = torch.ops.aten.sym_size(a_1, 1)
|
||||
sym_size_3: Sym(t0.size(1)) = torch.ops.aten.sym_size(b_1, 1); b_1 = None
|
||||
add_1: Sym(2*t0.size(1)) = sym_size_2 + sym_size_3; sym_size_2 = sym_size_3 = None
|
||||
new_empty: f32[t0.size(0) + t1.size(0),2*t0.size(1)] = torch.ops.aten.new_empty.default(a_1, [add, add_1], dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False); a_1 = add = add_1 = None
|
||||
sym_size: Sym(s0) = torch.ops.aten.sym_size(a_1, 0)
|
||||
sym_size_1: Sym(s2) = torch.ops.aten.sym_size(b_1, 0)
|
||||
add: Sym(s0 + s2) = sym_size + sym_size_1; sym_size = sym_size_1 = None
|
||||
sym_size_2: Sym(s1) = torch.ops.aten.sym_size(a_1, 1)
|
||||
sym_size_3: Sym(s1) = torch.ops.aten.sym_size(b_1, 1); b_1 = None
|
||||
add_1: Sym(2*s1) = sym_size_2 + sym_size_3; sym_size_2 = sym_size_3 = None
|
||||
new_empty: f32[s0 + s2, 2*s1] = torch.ops.aten.new_empty.default(a_1, [add, add_1], dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False); a_1 = add = add_1 = None
|
||||
native_dropout = torch.ops.aten.native_dropout.default(new_empty, 0.5, True); new_empty = None
|
||||
getitem: f32[t0.size(0) + t1.size(0),2*t0.size(1)] = native_dropout[0]
|
||||
getitem_1: b8[t0.size(0) + t1.size(0),2*t0.size(1)] = native_dropout[1]; native_dropout = None
|
||||
getitem: f32[s0 + s2, 2*s1] = native_dropout[0]
|
||||
getitem_1: b8[s0 + s2, 2*s1] = native_dropout[1]; native_dropout = None
|
||||
return (getitem, getitem_1)""") # noqa: B950
|
||||
|
||||
|
||||
|
|
|
|||
391
test/test_dynamic_shapes.py.bak
Normal file
391
test/test_dynamic_shapes.py.bak
Normal file
|
|
@ -0,0 +1,391 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
from torch._C import _disabled_torch_function_impl
|
||||
import torch.fx
|
||||
import torch.nn.functional as F
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase, skipIfTorchDynamo
|
||||
import unittest
|
||||
import torch
|
||||
import operator
|
||||
import itertools
|
||||
import io
|
||||
from torch.utils._pytree import tree_map
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv, PySymInt, sym_float
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
try:
|
||||
import sympy
|
||||
HAS_SYMPY = True
|
||||
except ImportError:
|
||||
HAS_SYMPY = False
|
||||
skipIfNoSympy = unittest.skipIf(not HAS_SYMPY, "no sympy")
|
||||
|
||||
|
||||
meta_funcs = {}
|
||||
|
||||
|
||||
def register_meta(op):
|
||||
def decorator(f):
|
||||
def add_func(op):
|
||||
meta_funcs[op] = f
|
||||
tree_map(add_func, op)
|
||||
return f
|
||||
return decorator
|
||||
|
||||
|
||||
@register_meta([aten.add.Tensor, aten.sub.Tensor])
|
||||
def binary_meta(a, b):
|
||||
return a.new_empty(a.shape)
|
||||
|
||||
|
||||
@register_meta(aten.cat.default)
|
||||
def cat_meta(tensors, dim=0):
|
||||
concat_length = 0
|
||||
shape = tensors[0].shape
|
||||
for tensor in tensors:
|
||||
for idx, (common_length, length) in enumerate(zip(shape, tensor.shape)):
|
||||
if idx == dim:
|
||||
concat_length = concat_length + length
|
||||
else:
|
||||
assert length == common_length
|
||||
new_shape = list(shape)
|
||||
new_shape[dim] = concat_length
|
||||
return tensors[0].new_empty(new_shape)
|
||||
|
||||
|
||||
@register_meta([aten.narrow_copy.default])
|
||||
def narrow_copy_symint_meta(a, dim, start, length, **kwargs):
|
||||
shape = []
|
||||
for i, x in enumerate(a.shape):
|
||||
if i == dim:
|
||||
shape.append(length)
|
||||
else:
|
||||
shape.append(x)
|
||||
return a.new_empty(tuple(shape))
|
||||
|
||||
|
||||
@register_meta([aten.expand.default])
|
||||
def expand_symint_meta(a, size, implicit=False):
|
||||
return a.new_empty(size)
|
||||
|
||||
|
||||
def create_contiguous(shape):
|
||||
strides = [1]
|
||||
for dim in reversed(shape[:-1]):
|
||||
strides.append(dim * strides[-1])
|
||||
return list(reversed(strides))
|
||||
|
||||
|
||||
class FakeSymbolicTensor(torch.Tensor):
|
||||
@staticmethod
|
||||
def __new__(cls, sym_shape, sym_strides, dtype, layout, requires_grad, device, storage_offset=0):
|
||||
# TODO: this is wrong in general
|
||||
sym_stride = create_contiguous(sym_shape)
|
||||
r = torch.Tensor._make_wrapper_subclass(
|
||||
cls, sym_shape,
|
||||
sym_stride, storage_offset,
|
||||
dtype=dtype, layout=layout, requires_grad=requires_grad,
|
||||
device=device,
|
||||
)
|
||||
return r
|
||||
|
||||
__torch_function__ = _disabled_torch_function_impl
|
||||
|
||||
def new_empty(self, shape):
|
||||
return FakeSymbolicTensor(shape, None, self.dtype, self.layout, self.requires_grad, self.device)
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func_overload, types, args=(), kwargs=None):
|
||||
if func_overload in meta_funcs:
|
||||
return meta_funcs[func_overload](*args, **kwargs)
|
||||
|
||||
if func_overload == torch.ops.aten.new_empty.default:
|
||||
self = args[0]
|
||||
shape = args[1]
|
||||
return FakeSymbolicTensor(shape, self.stride(), self.dtype, self.layout, self.requires_grad, self.device)
|
||||
|
||||
raise RuntimeError(f"operator {func_overload} not supported")
|
||||
|
||||
|
||||
def create_symbolic_tensor(name, arg, shape_env, storage_offset=0):
|
||||
sym_shapes, sym_strides = shape_env.create_symbolic_sizes_strides(arg)
|
||||
return FakeSymbolicTensor(sym_shapes, sym_strides, arg.dtype, arg.layout, arg.requires_grad, arg.device, storage_offset)
|
||||
|
||||
|
||||
CPP_SYMINT_CLASS = type(torch.SymIntNode.new_symint(1))
|
||||
|
||||
def create_symint(shape_env, i):
|
||||
return shape_env.create_symintnode(shape_env.create_symbol(i))
|
||||
|
||||
@skipIfTorchDynamo("Creating ShapeEnv fails for confusing reasons (also we never expect dynamo to see code like this)")
|
||||
class TestPySymInt(TestCase):
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_arith_ops(self):
|
||||
shape_env = ShapeEnv()
|
||||
symints = []
|
||||
for i in range(2, 5):
|
||||
symints.append((i, create_symint(shape_env, i)))
|
||||
|
||||
ops = [operator.add, operator.sub, operator.floordiv, operator.mul, operator.mod]
|
||||
|
||||
for op in ops:
|
||||
for args in itertools.permutations(symints, 2):
|
||||
if not isinstance(args[0][1], int) and ((op != operator.mod or op != operator.floordiv) and args[1][0] != 0):
|
||||
self.assertTrue(op(args[0][1], args[1][1]) == op(args[0][0], args[1][0]))
|
||||
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_reverse_arith_ops(self):
|
||||
shape_env = ShapeEnv()
|
||||
|
||||
a = create_symint(shape_env, 2)
|
||||
self.assertTrue(5 // a == 5 // 2)
|
||||
|
||||
a = create_symint(shape_env, 2)
|
||||
self.assertTrue(5 * a == 5 * 2)
|
||||
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_roundtrip(self):
|
||||
shape_env = ShapeEnv()
|
||||
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
|
||||
|
||||
self.assertTrue(not isinstance(x.shape[0], PySymInt))
|
||||
self.assertTrue(isinstance(x.shape[0], CPP_SYMINT_CLASS))
|
||||
|
||||
self.assertTrue(x.shape[0] == 5)
|
||||
self.assertTrue(x.shape[1] == 4)
|
||||
self.assertTrue(x.shape[2], 3)
|
||||
|
||||
self.assertTrue(x.size()[0], 5)
|
||||
self.assertTrue(x.size()[1], 4)
|
||||
self.assertTrue(isinstance(x.size()[1], CPP_SYMINT_CLASS))
|
||||
self.assertTrue(x.size()[2] == 3)
|
||||
|
||||
self.assertTrue(x.size(0) == 5)
|
||||
self.assertTrue(x.size(1) == 4)
|
||||
self.assertTrue(x.size(2) == 3)
|
||||
self.assertTrue(isinstance(x.size(2), CPP_SYMINT_CLASS))
|
||||
|
||||
offset = create_symint(shape_env, 2)
|
||||
y = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env, offset)
|
||||
self.assertTrue(isinstance(y.storage_offset(), CPP_SYMINT_CLASS))
|
||||
self.assertTrue(y.storage_offset() == 2)
|
||||
|
||||
offset = 2
|
||||
z = create_symbolic_tensor("z", torch.randn(5, 4, 3), shape_env, offset)
|
||||
self.assertTrue(isinstance(z.storage_offset(), int))
|
||||
self.assertTrue(z.storage_offset() == 2)
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_binary(self):
|
||||
shape_env = ShapeEnv()
|
||||
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
|
||||
y = create_symbolic_tensor("y", torch.randn(5, 4, 3), shape_env)
|
||||
|
||||
z = x + y
|
||||
self.assertTrue(z.shape[0] == 5)
|
||||
self.assertTrue(z.shape[1] == 4)
|
||||
self.assertTrue(z.shape[2] == 3)
|
||||
|
||||
# broadcasting
|
||||
y = create_symbolic_tensor("y", torch.randn(1, 4, 1), shape_env)
|
||||
z = x + y
|
||||
self.assertTrue(z.shape[0] == 5)
|
||||
self.assertTrue(z.shape[1] == 4)
|
||||
self.assertTrue(z.shape[2] == 3)
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_symint_args(self):
|
||||
shape_env = ShapeEnv()
|
||||
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
|
||||
y = create_symbolic_tensor("y", torch.randn(5, 4, 1), shape_env)
|
||||
LAST_DIM = 2
|
||||
z = x.narrow_copy(LAST_DIM, 0, y.shape[LAST_DIM])
|
||||
self.assertTrue(z.shape[2] == y.shape[2])
|
||||
|
||||
# arithmetic expr with two symints
|
||||
z = x.narrow_copy(LAST_DIM, 0, x.shape[LAST_DIM] - y.shape[LAST_DIM])
|
||||
self.assertTrue(z.shape[2] == 2)
|
||||
|
||||
# arithmetic expr with a symint and python int
|
||||
z = x.narrow_copy(LAST_DIM, 0, x.shape[LAST_DIM] - 1)
|
||||
self.assertTrue(z.shape[2] == 2)
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_symint_vargs(self):
|
||||
shape_env = ShapeEnv()
|
||||
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
|
||||
y = create_symbolic_tensor("y", torch.randn(1, 4, 1), shape_env)
|
||||
|
||||
# varargs
|
||||
z = y.expand(x.shape[0], y.shape[1], x.shape[2])
|
||||
self.assertTrue(z.shape[0] == 5)
|
||||
self.assertTrue(z.shape[1] == 4)
|
||||
self.assertTrue(z.shape[2] == 3)
|
||||
|
||||
# shape list
|
||||
z = y.expand((x.shape[0], y.shape[1], x.shape[2]))
|
||||
self.assertTrue(z.shape[0] == 5)
|
||||
self.assertTrue(z.shape[1] == 4)
|
||||
self.assertTrue(z.shape[2] == 3)
|
||||
|
||||
# mixed python symints and ints
|
||||
z = y.expand(x.shape[0], y.shape[1], 3)
|
||||
self.assertTrue(z.shape[0] == 5)
|
||||
self.assertTrue(z.shape[1] == 4)
|
||||
self.assertTrue(z.shape[2] == 3)
|
||||
|
||||
# mixed python symints and ints in a list
|
||||
z = y.expand((x.shape[0], y.shape[1], 3))
|
||||
self.assertTrue(z.shape[0] == 5)
|
||||
self.assertTrue(z.shape[1] == 4)
|
||||
self.assertTrue(z.shape[2] == 3)
|
||||
|
||||
# mixed python symints and ints
|
||||
z = y.expand(5, y.shape[1], x.shape[2])
|
||||
self.assertTrue(z.shape[0] == 5)
|
||||
self.assertTrue(z.shape[1] == 4)
|
||||
self.assertTrue(z.shape[2] == 3)
|
||||
|
||||
# mixed python ints and symints in a list
|
||||
z = y.expand((5, y.shape[1], x.shape[2]))
|
||||
self.assertTrue(z.shape[0] == 5)
|
||||
self.assertTrue(z.shape[1] == 4)
|
||||
self.assertTrue(z.shape[2] == 3)
|
||||
|
||||
z = y.expand((y.shape[1],))
|
||||
z = y.expand(y.shape[1])
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_stride(self):
|
||||
shape_env = ShapeEnv()
|
||||
x = create_symbolic_tensor("x", torch.randn(5, 5), shape_env)
|
||||
self.assertIsInstance(x.stride()[0], CPP_SYMINT_CLASS)
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_size_expressions(self):
|
||||
shape_env = ShapeEnv()
|
||||
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
|
||||
expand_x = x.expand(x.shape[0], x.shape[0])
|
||||
if expand_x.shape[0] > 3:
|
||||
result = expand_x + expand_x
|
||||
else:
|
||||
result = expand_x + expand_x
|
||||
|
||||
gt_op = shape_env.guards[0][0]
|
||||
self.assertTrue(isinstance(gt_op, sympy.core.relational.StrictGreaterThan))
|
||||
self.assertTrue(str(x.shape[0]), str(gt_op.args[0]))
|
||||
self.assertTrue(str(expand_x.shape[1]), str(x.shape[0]))
|
||||
self.assertTrue(str(expand_x.shape[1]), str(result.shape[0]))
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_int_to_float(self):
|
||||
shape_env = ShapeEnv()
|
||||
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
|
||||
r = sym_float(x.shape[0])
|
||||
self.assertTrue(isinstance(r, torch.SymFloatNode))
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_aten_ops(self):
|
||||
|
||||
shape_env = ShapeEnv()
|
||||
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
|
||||
torch.ops.aten.narrow_copy.default(x, 0, 0, x.shape[0])
|
||||
|
||||
shape_env = ShapeEnv()
|
||||
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
|
||||
torch.ops.aten.expand.default(x, [x.shape[0], x.shape[1], x.shape[2]])
|
||||
|
||||
def test_fx_trace_intlist(self):
|
||||
class CustomModule(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
bs, c, h, w = x.shape
|
||||
return F.pad(x, (0, w % 2, 0, h % 2, 0, 0))
|
||||
|
||||
m = CustomModule()
|
||||
x = torch.rand(1, 3, 4, 4)
|
||||
# should not TypeError: pad(): argument 'pad' (position 2) must be
|
||||
# tuple of ints, not tuple
|
||||
torch.fx.symbolic_trace(m)
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_meta_symint(self):
|
||||
shape_env = ShapeEnv()
|
||||
a0 = create_symint(shape_env, 2)
|
||||
r = torch.empty(a0, device='meta')
|
||||
self.assertIsInstance(r.shape[0], CPP_SYMINT_CLASS)
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_guard_int(self):
|
||||
shape_env = ShapeEnv()
|
||||
a0 = create_symint(shape_env, 2)
|
||||
self.assertEqual(a0.guard_int(), 2)
|
||||
self.assertEqual(str(shape_env.guards[0][0]), "s0")
|
||||
self.assertEqual(shape_env.guards[0][1], 2)
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_int_conversion(self):
|
||||
shape_env = ShapeEnv()
|
||||
a0 = create_symint(shape_env, 2)
|
||||
self.assertRaisesRegex(RuntimeError, "Trying to extract", lambda: int(a0))
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_symint_as_scalar(self):
|
||||
shape_env = ShapeEnv()
|
||||
a0 = create_symint(shape_env, 2)
|
||||
|
||||
sym_int_encountered = False
|
||||
|
||||
class TestSymInt(TorchDispatchMode):
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
assert func == torch.ops.aten.add.Tensor
|
||||
|
||||
nonlocal sym_int_encountered
|
||||
sym_int_encountered = kwargs["alpha"] is a0
|
||||
kwargs["alpha"] = 0
|
||||
return func(*args)
|
||||
|
||||
x = torch.rand([4, 4])
|
||||
with TestSymInt():
|
||||
y = torch.add(x, x, alpha=a0)
|
||||
|
||||
self.assertTrue(sym_int_encountered)
|
||||
|
||||
@skipIfNoSympy
|
||||
@unittest.mock.patch('sys.stdout', new_callable=io.StringIO)
|
||||
def test_print_readable_with_symints(self, mock_stdout):
|
||||
def f(a, b):
|
||||
dim0 = a.shape[0] + b.shape[0]
|
||||
dim1 = a.shape[1] + b.shape[1]
|
||||
d = a.new_empty(dim0, dim1)
|
||||
d = torch.ops.aten.native_dropout(d, 0.5, train=True)
|
||||
return d
|
||||
|
||||
fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(5, 3), torch.randn(4, 3))
|
||||
fx_g.print_readable()
|
||||
|
||||
self.assertExpectedInline(mock_stdout.getvalue().strip(), """\
|
||||
class f(torch.nn.Module):
|
||||
def forward(self, a_1: f32[t0.size(0),t0.size(1)], b_1: f32[t1.size(0),t0.size(1)]):
|
||||
# No stacktrace found for following nodes
|
||||
sym_size: Sym(t0.size(0)) = torch.ops.aten.sym_size(a_1, 0)
|
||||
sym_size_1: Sym(t1.size(0)) = torch.ops.aten.sym_size(b_1, 0)
|
||||
add: Sym(t0.size(0) + t1.size(0)) = sym_size + sym_size_1; sym_size = sym_size_1 = None
|
||||
sym_size_2: Sym(t0.size(1)) = torch.ops.aten.sym_size(a_1, 1)
|
||||
sym_size_3: Sym(t0.size(1)) = torch.ops.aten.sym_size(b_1, 1); b_1 = None
|
||||
add_1: Sym(2*t0.size(1)) = sym_size_2 + sym_size_3; sym_size_2 = sym_size_3 = None
|
||||
new_empty: f32[t0.size(0) + t1.size(0),2*t0.size(1)] = torch.ops.aten.new_empty.default(a_1, [add, add_1], dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False); a_1 = add = add_1 = None
|
||||
native_dropout = torch.ops.aten.native_dropout.default(new_empty, 0.5, True); new_empty = None
|
||||
getitem: f32[t0.size(0) + t1.size(0),2*t0.size(1)] = native_dropout[0]
|
||||
getitem_1: b8[t0.size(0) + t1.size(0),2*t0.size(1)] = native_dropout[1]; native_dropout = None
|
||||
return (getitem, getitem_1)""") # noqa: B950
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
|
@ -1065,6 +1065,7 @@ symbolic_tensor_failures = {
|
|||
xfail('cholesky_solve', ''), # Could not run 'aten::_cholesky_solve_helper' with arguments from the 'Meta' back...
|
||||
xfail('chunk', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('column_stack', ''), # Tensors of type TensorImpl do not have numel
|
||||
xfail('combinations', ''),
|
||||
xfail('count_nonzero', ''), # Could not run 'aten::count_nonzero.dim_IntList' with arguments from the 'Meta' ba...
|
||||
xfail('cross', ''), # aten.linalg_cross.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('cummax', ''), # aten.cummax.default - couldn't find symbolic meta function/decomposition
|
||||
|
|
@ -1290,6 +1291,7 @@ symbolic_tensor_failures = {
|
|||
xfail('unbind', ''), # aten.unbind.int - couldn't find symbolic meta function/decomposition
|
||||
}
|
||||
symbolic_tensor_segfaults = {
|
||||
skip('nn.functional.batch_norm') # Segfault??
|
||||
}
|
||||
|
||||
symbolic_tensor_failures.update(symbolic_tensor_segfaults)
|
||||
|
|
|
|||
|
|
@ -622,6 +622,7 @@ class FakeTensorMode(TorchDispatchMode):
|
|||
allow_fallback_kernels=True,
|
||||
allow_meta=False,
|
||||
throw_on_data_dependent_ops=True,
|
||||
shape_env=None,
|
||||
):
|
||||
self.allow_fallback_kernels = allow_fallback_kernels
|
||||
self.fake_tensor_converter = FakeTensorConverter()
|
||||
|
|
@ -642,6 +643,8 @@ class FakeTensorMode(TorchDispatchMode):
|
|||
# the device property
|
||||
self.in_kernel_invocation = False
|
||||
|
||||
self.shape_env = shape_env
|
||||
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
kwargs = kwargs if kwargs else {}
|
||||
|
||||
|
|
@ -920,8 +923,10 @@ class FakeTensorMode(TorchDispatchMode):
|
|||
):
|
||||
self.fake_tensor_converter.invalidate_constant_aliases(v.constant)
|
||||
|
||||
def from_tensor(self, tensor, shape_env=None):
|
||||
return self.fake_tensor_converter(self, tensor, shape_env=shape_env)
|
||||
def from_tensor(self, tensor, static_shapes=False):
|
||||
if static_shapes:
|
||||
return self.fake_tensor_converter(self, tensor)
|
||||
return self.fake_tensor_converter(self, tensor, shape_env=self.shape_env)
|
||||
|
||||
|
||||
# NB: returns fake tensors
|
||||
|
|
|
|||
|
|
@ -142,39 +142,18 @@ class MetaConverter:
|
|||
arg_cnt = self.arg_cnt
|
||||
self.arg_cnt += 1
|
||||
|
||||
# Don't make parameters have symbolic shapes; they are assumed to stay
|
||||
# constant size across training runs
|
||||
make_symbolic = shape_env is not None and not isinstance(t, torch.nn.Parameter)
|
||||
make_symbolic = shape_env is not None
|
||||
|
||||
def sym(name, x):
|
||||
def sym(x):
|
||||
if make_symbolic:
|
||||
return shape_env.create_symint(f"t{arg_cnt}.{name}()", x)
|
||||
return shape_env.create_symbol(x)
|
||||
else:
|
||||
return x
|
||||
|
||||
def sym_list(name, xs):
|
||||
def sym_sizes_strides(t):
|
||||
if make_symbolic:
|
||||
return [
|
||||
shape_env.create_symint(f"t{arg_cnt}.{name}({i})", x)
|
||||
for i, x in enumerate(xs)
|
||||
]
|
||||
else:
|
||||
return xs
|
||||
|
||||
def sym_size(t):
|
||||
return sym_list("size", t.size())
|
||||
|
||||
def sym_stride(t):
|
||||
return sym_list("stride", t.stride())
|
||||
|
||||
# NB: Although sym_stride variables initially have no correlation
|
||||
# with size, we will immediately introduce guards based on contiguity.
|
||||
# Thus, if the input tensor is contiguous, the stride variables
|
||||
# will typically immediately get reexpressed in terms of the size
|
||||
# variables.
|
||||
|
||||
def sym_storage_offset(t):
|
||||
return sym("storage_offset", t.storage_offset())
|
||||
return shape_env.create_symbolic_sizes_strides(t)
|
||||
return (t.size(), t.stride())
|
||||
|
||||
# see expired-storages
|
||||
self.check_expired_count += 1
|
||||
|
|
@ -231,9 +210,8 @@ class MetaConverter:
|
|||
base = base.view(t.dtype)
|
||||
|
||||
with torch.enable_grad():
|
||||
r = base.as_strided(
|
||||
sym_size(t), sym_stride(t), sym_storage_offset(t)
|
||||
)
|
||||
sizes, strides = sym_sizes_strides(t)
|
||||
r = base.as_strided(sizes, strides, sym(t.storage_offset()))
|
||||
else:
|
||||
is_leaf = safe_is_leaf(t)
|
||||
# Fake up some autograd history.
|
||||
|
|
@ -257,8 +235,9 @@ class MetaConverter:
|
|||
# meta storage
|
||||
s = self.meta_storage(t.storage())
|
||||
with no_dispatch():
|
||||
sizes, strides = sym_sizes_strides(t)
|
||||
with torch.no_grad():
|
||||
r.set_(s, sym_storage_offset(t), sym_size(t), sym_stride(t))
|
||||
r.set_(s, sym(t.storage_offset()), sizes, strides)
|
||||
|
||||
torch._C._set_conj(r, t.is_conj())
|
||||
torch._C._set_neg(r, t.is_neg())
|
||||
|
|
|
|||
|
|
@ -612,7 +612,8 @@ def make_fx(f, decomposition_table=None, tracing_mode="real"):
|
|||
elif tracing_mode == "fake":
|
||||
fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=True)
|
||||
elif tracing_mode == "symbolic":
|
||||
fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=False)
|
||||
shape_env = ShapeEnv()
|
||||
fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=False, shape_env=shape_env)
|
||||
else:
|
||||
raise AssertionError(f"Unexpected tracing type: {tracing_mode}")
|
||||
|
||||
|
|
@ -628,15 +629,12 @@ def make_fx(f, decomposition_table=None, tracing_mode="real"):
|
|||
|
||||
return x
|
||||
|
||||
shape_env = None
|
||||
if tracing_mode == "symbolic":
|
||||
shape_env = ShapeEnv()
|
||||
sym_mode = proxy_mode.sym_mode
|
||||
|
||||
# todo: Figure out a more informative name for symints
|
||||
def wrap_fake_symbolic(x):
|
||||
if isinstance(x, torch.Tensor):
|
||||
return fake_tensor_mode.from_tensor(x, shape_env=shape_env)
|
||||
return fake_tensor_mode.from_tensor(x)
|
||||
return x
|
||||
|
||||
wrap_fn_map = {
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from typing import Set, Dict, List, Type, Optional, cast
|
||||
from typing import Set, Dict, List, Type, Optional, cast, Union
|
||||
import operator
|
||||
import math
|
||||
import functools
|
||||
|
|
@ -331,7 +331,7 @@ class ShapeEnv(object):
|
|||
self.divisible: Set["sympy.Expr"] = set()
|
||||
# Duck-shaping says that if two input tensors have the same size,
|
||||
# they get assigned the same symbolic variable
|
||||
self.val_to_symint: Dict[int, torch.SymIntNode] = {}
|
||||
self.val_to_var: Dict[int, "sympy.Expr"] = {0: sympy.Integer(0), 1: sympy.Integer(1)}
|
||||
|
||||
def _get_key(self):
|
||||
"""
|
||||
|
|
@ -340,28 +340,68 @@ class ShapeEnv(object):
|
|||
"""
|
||||
return (len(self.replacements), len(self.divisible))
|
||||
|
||||
# NB: This is only called for input symbolic sizes; intermediate symbolic
|
||||
# sizes are allocated via a different mechanism
|
||||
def create_symint(self, name, val):
|
||||
assert val >= 0
|
||||
def create_symbolic_sizes_strides(self, ex: torch.Tensor):
|
||||
"""
|
||||
Returns a list of symbolic sizes and strides for the given tensor.
|
||||
We try our best to express stride in terms of the sizes, so as to not
|
||||
introduce new symbolic variables.
|
||||
"""
|
||||
|
||||
size = [self.create_symbol(i) for i in ex.size()]
|
||||
stride: List[Optional[sympy.Expr]] = [None] * len(size)
|
||||
for i, val in enumerate(ex.stride()):
|
||||
if val in (0, 1):
|
||||
stride[i] = sympy.Integer(val)
|
||||
while any(x is None for x in stride):
|
||||
candidates = {
|
||||
ex.size(i) * ex.stride()[i]: size[i] * stride[i]
|
||||
for i in range(len(size))
|
||||
if stride[i] is not None and ex.stride()[i] >= 0
|
||||
}
|
||||
# iterate over unbound strides in sorted order
|
||||
val_list = sorted(
|
||||
[(ex.stride()[i], i) for i in range(len(stride)) if stride[i] is None]
|
||||
)
|
||||
for _, i in val_list:
|
||||
if stride[i] is None and ex.stride()[i] in candidates:
|
||||
stride[i] = candidates[ex.stride()[i]]
|
||||
candidates[ex.size(i) * ex.stride()[i]] = size[i] * stride[i]
|
||||
if any(x is None for x in stride):
|
||||
# bind the smallest unbound stride to a new variable
|
||||
val, i = sorted(
|
||||
[
|
||||
(ex.stride()[i], i)
|
||||
for i in range(len(stride))
|
||||
if stride[i] is None
|
||||
]
|
||||
)[0]
|
||||
stride[i] = self.create_symbol(val)
|
||||
assert all(x is not None for x in stride)
|
||||
return [self.create_symintnode(i) for i in size], [self.create_symintnode(i) for i in stride] # type: ignore[arg-type]
|
||||
|
||||
def create_symintnode(self, expr: Union["sympy.Expr", int]):
|
||||
py_sym_int = PySymInt(expr, self)
|
||||
cpp_sym_int = torch.SymIntNode.new_symint(py_sym_int) # type: ignore[attr-defined]
|
||||
return cpp_sym_int
|
||||
|
||||
def create_symbol(self, val: int) -> "sympy.Expr":
|
||||
if not HAS_SYMPY:
|
||||
raise RuntimeError("Need sympy installed to create symbolic shapes")
|
||||
|
||||
# TODO: Put 0/1 specialization in guards
|
||||
if val == 0 or val == 1:
|
||||
return val
|
||||
if val < 0:
|
||||
# all sympy base variables must be positive and > 1
|
||||
return -self.create_symbol(-val)
|
||||
# This implements duck-shaping: input sizes that match are assigned
|
||||
# the same symint
|
||||
# TODO: Create a guard whenever this happens
|
||||
# TODO: But how do I represent the guard in this case?
|
||||
if val in self.val_to_symint:
|
||||
return self.val_to_symint[val]
|
||||
sympy_expr = sympy.Symbol(name, positive=True, integer=True)
|
||||
py_sym_int = PySymInt(sympy_expr, self)
|
||||
cpp_sym_int = torch.SymIntNode.new_symint(py_sym_int) # type: ignore[attr-defined]
|
||||
# Note: val_to_var is also initialized with 0/1 mapping to constants, so
|
||||
# this also ensures that all symbols are > 1
|
||||
if val in self.val_to_var:
|
||||
return self.val_to_var[val]
|
||||
sympy_expr = sympy.Symbol(f"s{len(self.var_to_val)}", positive=True, integer=True)
|
||||
self.var_to_val[sympy_expr] = sympy.Integer(val)
|
||||
self.val_to_symint[val] = cpp_sym_int
|
||||
return cpp_sym_int
|
||||
self.val_to_var[val] = sympy_expr
|
||||
return sympy_expr
|
||||
|
||||
def evaluate_guards_for_args(self, *args):
|
||||
new_env = ShapeEnv()
|
||||
|
|
|
|||
|
|
@ -475,7 +475,7 @@ class CodeGen(object):
|
|||
body.append('\n# No stacktrace found for following nodes\n')
|
||||
|
||||
def stringify_shape(shape : torch.Size) -> str:
|
||||
return f"[{','.join(str(x) for x in shape)}]"
|
||||
return f"[{', '.join(str(x) for x in shape)}]"
|
||||
|
||||
def emit_node(node : Node):
|
||||
maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}'
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user