diff --git a/c10/core/DispatchKey.cpp b/c10/core/DispatchKey.cpp index 775e47ca99d..bffafc59168 100644 --- a/c10/core/DispatchKey.cpp +++ b/c10/core/DispatchKey.cpp @@ -206,7 +206,7 @@ const char* toString(DispatchKey t) { switch (bc) { \ C10_FORALL_BACKEND_COMPONENTS(ENTRY, prefix) \ default: \ - return #prefix "Unknown"; \ + return #prefix "Undefined"; \ } C10_FORALL_FUNCTIONALITY_KEYS(FORALL_BC) diff --git a/functorch/_src/aot_autograd.py b/functorch/_src/aot_autograd.py index 5294399dbca..de8a00c68f6 100644 --- a/functorch/_src/aot_autograd.py +++ b/functorch/_src/aot_autograd.py @@ -278,6 +278,7 @@ class AOTConfig: bw_compiler: Callable partition_fn: Callable decompositions: Dict[Callable, Callable] + num_params_buffers: int def aot_dispatch_base(flat_fn, flat_args: List[Tensor], aot_config: AOTConfig): @@ -491,6 +492,11 @@ def create_aot_dispatcher_function( The resulting compiled forward and backward graphs are then wrapped up in a ``torch.autograd.Function`` object. + + The calling convention here is that the first aot_config.num_params_buffers + inputs in flat_args are parameters and buffers, and the rest are inputs. + + We use this to assume that parameters/buffer's shapes don't change. """ # This is the main entry point. @@ -514,19 +520,26 @@ def create_aot_dispatcher_function( # coordinate flags config.use_fake_tensor = False - fake_mode = FakeTensorMode() if config.use_fake_tensor else nullcontext() + if config.use_dynamic_shapes: + assert config.use_fake_tensor, "Dynamic shapes only works with fake tensor" + + shape_env = ShapeEnv() if config.use_dynamic_shapes else None + fake_mode = FakeTensorMode(shape_env=shape_env) if config.use_fake_tensor else nullcontext() cross_ref = CrossRefFakeMode() if config.debug_fake_cross_ref else nullcontext() python_dispatcher_mode = enable_python_dispatcher() if config.use_dynamic_shapes else nullcontext() - shape_env = ShapeEnv() if config.use_dynamic_shapes else None with torch.autograd.set_multithreading_enabled(False), preserve_rng_state(), cross_ref, fake_mode, python_dispatcher_mode: def process_inputs(flat_args): if config.use_fake_tensor: - def convert(x): - return fake_mode.from_tensor(x, shape_env=shape_env) + def convert(idx, x): + if not isinstance(x, torch.Tensor): + return x + if idx < aot_config.num_params_buffers and config.static_weight_shapes: + return fake_mode.from_tensor(x, static_shapes=True) + return fake_mode.from_tensor(x, static_shapes=False) - return pytree.tree_map_only(Tensor, convert, flat_args) + return [convert(idx, x) for idx, x in enumerate(flat_args)] else: return flat_args @@ -587,6 +600,7 @@ def aot_function( bw_compiler: Optional[Callable] = None, partition_fn: Callable = default_partition, decompositions: Optional[Dict] = None, + num_params_buffers: int = 0, hasher_type=None, # deprecated static_argnums: Optional[Tuple[int]] = None, # deprecated ) -> Callable: @@ -650,6 +664,7 @@ def aot_function( bw_compiler=bw_compiler, partition_fn=partition_fn, decompositions=decompositions, + num_params_buffers=num_params_buffers, ) cached_res = None @@ -734,7 +749,10 @@ def aot_module(mod: nn.Module, *args, **kwargs) -> nn.Module: params_and_buffers = {**named_params, **named_buffers} return stateless.functional_call(mod, params_and_buffers, args, kwargs) - compiled_f = aot_function(functional_call, *args, **kwargs) + named_params = dict(_named_parameters(mod, remove_duplicate=False)) + named_buffers = dict(_named_buffers(mod, remove_duplicate=False)) + num_params_buffers = len(named_params) + len(named_buffers) + compiled_f = aot_function(functional_call, num_params_buffers=num_params_buffers, *args, **kwargs) class AOTModule(nn.Module): def __init__(self): @@ -743,8 +761,8 @@ def aot_module(mod: nn.Module, *args, **kwargs) -> nn.Module: def forward(self, *args, **kwargs): return compiled_f( - dict(_named_parameters(mod, remove_duplicate=False)), - dict(_named_buffers(mod, remove_duplicate=False)), + named_params, + named_buffers, *args, **kwargs, ) @@ -812,6 +830,7 @@ def aot_module_simplified(mod: nn.Module, *top_args, **top_kwargs) -> nn.Module: bw_compiler=bw_compiler, partition_fn=partition_fn, decompositions=decompositions, + num_params_buffers=params_len, ) compiled_fn = None diff --git a/functorch/_src/config.py b/functorch/_src/config.py index e473d1129ea..2dacdd38fa3 100644 --- a/functorch/_src/config.py +++ b/functorch/_src/config.py @@ -23,3 +23,5 @@ debug_graphs = os.environ.get('AOT_FX_GRAPHS', False) debug_joint = os.environ.get('AOT_FX_GRAPHS_JOINT', False) use_dynamic_shapes = os.getenv('AOT_DYNAMIC_SHAPES', False) + +static_weight_shapes = True diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index 8d1c0dba701..df25f90e55f 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -1060,6 +1060,7 @@ symbolic_aot_autograd_failures = { xfail('nn.functional.avg_pool1d', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('nn.functional.avg_pool2d', ''), # aten.avg_pool2d.default - couldn't find symbolic meta function/... xfail('nn.functional.avg_pool3d', ''), # aten.avg_pool3d.default - couldn't find symbolic meta function/... + skip('nn.functional.batch_norm', ''), # '0 is not tracked with proxy for 3: + result = expand_x + expand_x + else: + result = expand_x + expand_x + + gt_op = shape_env.guards[0][0] + self.assertTrue(isinstance(gt_op, sympy.core.relational.StrictGreaterThan)) + self.assertTrue(str(x.shape[0]), str(gt_op.args[0])) + self.assertTrue(str(expand_x.shape[1]), str(x.shape[0])) + self.assertTrue(str(expand_x.shape[1]), str(result.shape[0])) + + @skipIfNoSympy + def test_int_to_float(self): + shape_env = ShapeEnv() + x = create_symbolic_tensor("x", torch.randn(5), shape_env) + r = sym_float(x.shape[0]) + self.assertTrue(isinstance(r, torch.SymFloatNode)) + + @skipIfNoSympy + def test_aten_ops(self): + + shape_env = ShapeEnv() + x = create_symbolic_tensor("x", torch.randn(5), shape_env) + torch.ops.aten.narrow_copy.default(x, 0, 0, x.shape[0]) + + shape_env = ShapeEnv() + x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env) + torch.ops.aten.expand.default(x, [x.shape[0], x.shape[1], x.shape[2]]) + + def test_fx_trace_intlist(self): + class CustomModule(torch.nn.Module): + def forward(self, x): + bs, c, h, w = x.shape + return F.pad(x, (0, w % 2, 0, h % 2, 0, 0)) + + m = CustomModule() + x = torch.rand(1, 3, 4, 4) + # should not TypeError: pad(): argument 'pad' (position 2) must be + # tuple of ints, not tuple + torch.fx.symbolic_trace(m) + + @skipIfNoSympy + def test_meta_symint(self): + shape_env = ShapeEnv() + a0 = create_symint(shape_env, 2) + r = torch.empty(a0, device='meta') + self.assertIsInstance(r.shape[0], CPP_SYMINT_CLASS) + + @skipIfNoSympy + def test_guard_int(self): + shape_env = ShapeEnv() + a0 = create_symint(shape_env, 2) + self.assertEqual(a0.guard_int(), 2) + self.assertEqual(str(shape_env.guards[0][0]), "s0") + self.assertEqual(shape_env.guards[0][1], 2) + + @skipIfNoSympy + def test_int_conversion(self): + shape_env = ShapeEnv() + a0 = create_symint(shape_env, 2) + self.assertRaisesRegex(RuntimeError, "Trying to extract", lambda: int(a0)) + + @skipIfNoSympy + def test_symint_as_scalar(self): + shape_env = ShapeEnv() + a0 = create_symint(shape_env, 2) + + sym_int_encountered = False + + class TestSymInt(TorchDispatchMode): + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + assert func == torch.ops.aten.add.Tensor + + nonlocal sym_int_encountered + sym_int_encountered = kwargs["alpha"] is a0 + kwargs["alpha"] = 0 + return func(*args) + + x = torch.rand([4, 4]) + with TestSymInt(): + y = torch.add(x, x, alpha=a0) + + self.assertTrue(sym_int_encountered) + + @skipIfNoSympy + @unittest.mock.patch('sys.stdout', new_callable=io.StringIO) + def test_print_readable_with_symints(self, mock_stdout): + def f(a, b): + dim0 = a.shape[0] + b.shape[0] + dim1 = a.shape[1] + b.shape[1] + d = a.new_empty(dim0, dim1) + d = torch.ops.aten.native_dropout(d, 0.5, train=True) + return d + + fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(5, 3), torch.randn(4, 3)) + fx_g.print_readable() + + self.assertExpectedInline(mock_stdout.getvalue().strip(), """\ +class f(torch.nn.Module): + def forward(self, a_1: f32[t0.size(0),t0.size(1)], b_1: f32[t1.size(0),t0.size(1)]): + # No stacktrace found for following nodes + sym_size: Sym(t0.size(0)) = torch.ops.aten.sym_size(a_1, 0) + sym_size_1: Sym(t1.size(0)) = torch.ops.aten.sym_size(b_1, 0) + add: Sym(t0.size(0) + t1.size(0)) = sym_size + sym_size_1; sym_size = sym_size_1 = None + sym_size_2: Sym(t0.size(1)) = torch.ops.aten.sym_size(a_1, 1) + sym_size_3: Sym(t0.size(1)) = torch.ops.aten.sym_size(b_1, 1); b_1 = None + add_1: Sym(2*t0.size(1)) = sym_size_2 + sym_size_3; sym_size_2 = sym_size_3 = None + new_empty: f32[t0.size(0) + t1.size(0),2*t0.size(1)] = torch.ops.aten.new_empty.default(a_1, [add, add_1], dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False); a_1 = add = add_1 = None + native_dropout = torch.ops.aten.native_dropout.default(new_empty, 0.5, True); new_empty = None + getitem: f32[t0.size(0) + t1.size(0),2*t0.size(1)] = native_dropout[0] + getitem_1: b8[t0.size(0) + t1.size(0),2*t0.size(1)] = native_dropout[1]; native_dropout = None + return (getitem, getitem_1)""") # noqa: B950 + + +if __name__ == '__main__': + run_tests() diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index d736a2c453a..4544171f6ef 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -1065,6 +1065,7 @@ symbolic_tensor_failures = { xfail('cholesky_solve', ''), # Could not run 'aten::_cholesky_solve_helper' with arguments from the 'Meta' back... xfail('chunk', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('column_stack', ''), # Tensors of type TensorImpl do not have numel + xfail('combinations', ''), xfail('count_nonzero', ''), # Could not run 'aten::count_nonzero.dim_IntList' with arguments from the 'Meta' ba... xfail('cross', ''), # aten.linalg_cross.default - couldn't find symbolic meta function/decomposition xfail('cummax', ''), # aten.cummax.default - couldn't find symbolic meta function/decomposition @@ -1290,6 +1291,7 @@ symbolic_tensor_failures = { xfail('unbind', ''), # aten.unbind.int - couldn't find symbolic meta function/decomposition } symbolic_tensor_segfaults = { + skip('nn.functional.batch_norm') # Segfault?? } symbolic_tensor_failures.update(symbolic_tensor_segfaults) diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index b9e1c867612..e5f70da1f19 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -622,6 +622,7 @@ class FakeTensorMode(TorchDispatchMode): allow_fallback_kernels=True, allow_meta=False, throw_on_data_dependent_ops=True, + shape_env=None, ): self.allow_fallback_kernels = allow_fallback_kernels self.fake_tensor_converter = FakeTensorConverter() @@ -642,6 +643,8 @@ class FakeTensorMode(TorchDispatchMode): # the device property self.in_kernel_invocation = False + self.shape_env = shape_env + def __torch_dispatch__(self, func, types, args=(), kwargs=None): kwargs = kwargs if kwargs else {} @@ -920,8 +923,10 @@ class FakeTensorMode(TorchDispatchMode): ): self.fake_tensor_converter.invalidate_constant_aliases(v.constant) - def from_tensor(self, tensor, shape_env=None): - return self.fake_tensor_converter(self, tensor, shape_env=shape_env) + def from_tensor(self, tensor, static_shapes=False): + if static_shapes: + return self.fake_tensor_converter(self, tensor) + return self.fake_tensor_converter(self, tensor, shape_env=self.shape_env) # NB: returns fake tensors diff --git a/torch/_subclasses/meta_utils.py b/torch/_subclasses/meta_utils.py index 9d641ba458e..80723f12463 100644 --- a/torch/_subclasses/meta_utils.py +++ b/torch/_subclasses/meta_utils.py @@ -142,39 +142,18 @@ class MetaConverter: arg_cnt = self.arg_cnt self.arg_cnt += 1 - # Don't make parameters have symbolic shapes; they are assumed to stay - # constant size across training runs - make_symbolic = shape_env is not None and not isinstance(t, torch.nn.Parameter) + make_symbolic = shape_env is not None - def sym(name, x): + def sym(x): if make_symbolic: - return shape_env.create_symint(f"t{arg_cnt}.{name}()", x) + return shape_env.create_symbol(x) else: return x - def sym_list(name, xs): + def sym_sizes_strides(t): if make_symbolic: - return [ - shape_env.create_symint(f"t{arg_cnt}.{name}({i})", x) - for i, x in enumerate(xs) - ] - else: - return xs - - def sym_size(t): - return sym_list("size", t.size()) - - def sym_stride(t): - return sym_list("stride", t.stride()) - - # NB: Although sym_stride variables initially have no correlation - # with size, we will immediately introduce guards based on contiguity. - # Thus, if the input tensor is contiguous, the stride variables - # will typically immediately get reexpressed in terms of the size - # variables. - - def sym_storage_offset(t): - return sym("storage_offset", t.storage_offset()) + return shape_env.create_symbolic_sizes_strides(t) + return (t.size(), t.stride()) # see expired-storages self.check_expired_count += 1 @@ -231,9 +210,8 @@ class MetaConverter: base = base.view(t.dtype) with torch.enable_grad(): - r = base.as_strided( - sym_size(t), sym_stride(t), sym_storage_offset(t) - ) + sizes, strides = sym_sizes_strides(t) + r = base.as_strided(sizes, strides, sym(t.storage_offset())) else: is_leaf = safe_is_leaf(t) # Fake up some autograd history. @@ -257,8 +235,9 @@ class MetaConverter: # meta storage s = self.meta_storage(t.storage()) with no_dispatch(): + sizes, strides = sym_sizes_strides(t) with torch.no_grad(): - r.set_(s, sym_storage_offset(t), sym_size(t), sym_stride(t)) + r.set_(s, sym(t.storage_offset()), sizes, strides) torch._C._set_conj(r, t.is_conj()) torch._C._set_neg(r, t.is_neg()) diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index de2cdaeb5b6..f623c4cfd65 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -612,7 +612,8 @@ def make_fx(f, decomposition_table=None, tracing_mode="real"): elif tracing_mode == "fake": fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=True) elif tracing_mode == "symbolic": - fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=False) + shape_env = ShapeEnv() + fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=False, shape_env=shape_env) else: raise AssertionError(f"Unexpected tracing type: {tracing_mode}") @@ -628,15 +629,12 @@ def make_fx(f, decomposition_table=None, tracing_mode="real"): return x - shape_env = None - if tracing_mode == "symbolic": - shape_env = ShapeEnv() sym_mode = proxy_mode.sym_mode # todo: Figure out a more informative name for symints def wrap_fake_symbolic(x): if isinstance(x, torch.Tensor): - return fake_tensor_mode.from_tensor(x, shape_env=shape_env) + return fake_tensor_mode.from_tensor(x) return x wrap_fn_map = { diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 110831ddb9c..b42cfcfb109 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -1,6 +1,6 @@ import torch import torch.utils._pytree as pytree -from typing import Set, Dict, List, Type, Optional, cast +from typing import Set, Dict, List, Type, Optional, cast, Union import operator import math import functools @@ -331,7 +331,7 @@ class ShapeEnv(object): self.divisible: Set["sympy.Expr"] = set() # Duck-shaping says that if two input tensors have the same size, # they get assigned the same symbolic variable - self.val_to_symint: Dict[int, torch.SymIntNode] = {} + self.val_to_var: Dict[int, "sympy.Expr"] = {0: sympy.Integer(0), 1: sympy.Integer(1)} def _get_key(self): """ @@ -340,28 +340,68 @@ class ShapeEnv(object): """ return (len(self.replacements), len(self.divisible)) - # NB: This is only called for input symbolic sizes; intermediate symbolic - # sizes are allocated via a different mechanism - def create_symint(self, name, val): - assert val >= 0 + def create_symbolic_sizes_strides(self, ex: torch.Tensor): + """ + Returns a list of symbolic sizes and strides for the given tensor. + We try our best to express stride in terms of the sizes, so as to not + introduce new symbolic variables. + """ + + size = [self.create_symbol(i) for i in ex.size()] + stride: List[Optional[sympy.Expr]] = [None] * len(size) + for i, val in enumerate(ex.stride()): + if val in (0, 1): + stride[i] = sympy.Integer(val) + while any(x is None for x in stride): + candidates = { + ex.size(i) * ex.stride()[i]: size[i] * stride[i] + for i in range(len(size)) + if stride[i] is not None and ex.stride()[i] >= 0 + } + # iterate over unbound strides in sorted order + val_list = sorted( + [(ex.stride()[i], i) for i in range(len(stride)) if stride[i] is None] + ) + for _, i in val_list: + if stride[i] is None and ex.stride()[i] in candidates: + stride[i] = candidates[ex.stride()[i]] + candidates[ex.size(i) * ex.stride()[i]] = size[i] * stride[i] + if any(x is None for x in stride): + # bind the smallest unbound stride to a new variable + val, i = sorted( + [ + (ex.stride()[i], i) + for i in range(len(stride)) + if stride[i] is None + ] + )[0] + stride[i] = self.create_symbol(val) + assert all(x is not None for x in stride) + return [self.create_symintnode(i) for i in size], [self.create_symintnode(i) for i in stride] # type: ignore[arg-type] + + def create_symintnode(self, expr: Union["sympy.Expr", int]): + py_sym_int = PySymInt(expr, self) + cpp_sym_int = torch.SymIntNode.new_symint(py_sym_int) # type: ignore[attr-defined] + return cpp_sym_int + + def create_symbol(self, val: int) -> "sympy.Expr": if not HAS_SYMPY: raise RuntimeError("Need sympy installed to create symbolic shapes") - - # TODO: Put 0/1 specialization in guards - if val == 0 or val == 1: - return val + if val < 0: + # all sympy base variables must be positive and > 1 + return -self.create_symbol(-val) # This implements duck-shaping: input sizes that match are assigned # the same symint # TODO: Create a guard whenever this happens # TODO: But how do I represent the guard in this case? - if val in self.val_to_symint: - return self.val_to_symint[val] - sympy_expr = sympy.Symbol(name, positive=True, integer=True) - py_sym_int = PySymInt(sympy_expr, self) - cpp_sym_int = torch.SymIntNode.new_symint(py_sym_int) # type: ignore[attr-defined] + # Note: val_to_var is also initialized with 0/1 mapping to constants, so + # this also ensures that all symbols are > 1 + if val in self.val_to_var: + return self.val_to_var[val] + sympy_expr = sympy.Symbol(f"s{len(self.var_to_val)}", positive=True, integer=True) self.var_to_val[sympy_expr] = sympy.Integer(val) - self.val_to_symint[val] = cpp_sym_int - return cpp_sym_int + self.val_to_var[val] = sympy_expr + return sympy_expr def evaluate_guards_for_args(self, *args): new_env = ShapeEnv() diff --git a/torch/fx/graph.py b/torch/fx/graph.py index b5e21f1c3de..4aa6897196d 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -475,7 +475,7 @@ class CodeGen(object): body.append('\n# No stacktrace found for following nodes\n') def stringify_shape(shape : torch.Size) -> str: - return f"[{','.join(str(x) for x in shape)}]" + return f"[{', '.join(str(x) for x in shape)}]" def emit_node(node : Node): maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}'