# Owner(s): ["module: inductor"] """ Test the FX IR backend. """ import itertools import operator import unittest from collections.abc import Callable from typing import Optional import sympy import torch import torch._inductor.codegen.common as common import torch.utils._pytree as pytree from torch._dynamo.exc import BackendCompilerFailed from torch._dynamo.utils import same from torch._higher_order_ops.triton_kernel_wrap import triton_kernel_wrapper_mutation from torch._inductor import config from torch._inductor.codegen.cpp import CppScheduling from torch._inductor.codegen.triton import TritonScheduling from torch._inductor.codegen.wrapper import PythonWrapperCodegen from torch._inductor.codegen.wrapper_fxir import FxConverter, WrapperFxCodegen from torch._inductor.test_case import TestCase as InductorTestCase from torch.export import Dim from torch.testing._internal.common_utils import ( DeterministicGuard, instantiate_parametrized_tests, parametrize, ) from torch.testing._internal.inductor_utils import ( GPU_TYPE, HAS_GPU, requires_gpu, TRITON_HAS_CPU, ) if HAS_GPU: import triton import triton.language as tl from torch.testing._internal.triton_utils import add_kernel_2d_autotuned test_config = { "compile_threads": 1, "alignment_asserts": False, "size_asserts": False, "scalar_asserts": False, "nan_asserts": False, } @requires_gpu() @config.patch(test_config) @instantiate_parametrized_tests class FxirTestCase(InductorTestCase): device = GPU_TYPE def _count_ops(self, gm: torch.fx.GraphModule, target: Callable) -> int: return len(gm.graph.find_nodes(op="call_function", target=target)) def _run_and_capture_graphs(self, opt, args) -> torch.fx.GraphModule: gms = [] orig_generate = FxConverter.generate def generate(self) -> torch.fx.GraphModule: nonlocal gms gm = orig_generate(self) gms.append(gm) return gm with unittest.mock.patch.object( torch._inductor.codegen.wrapper_fxir.FxConverter, "generate", generate ): opt(*args) return gms def _compile_and_check( self, func, args, expected_num_triton_kernels: int = 1, metadata_only: bool = False, compile_kwargs: Optional[dict] = None, ): if compile_kwargs is None: compile_kwargs = {} opt = torch.compile(func, **compile_kwargs) # Get the FX graph from the backend. gms = self._run_and_capture_graphs(opt, args) # Check the code for triton kernels. num_kernels = sum( self._count_ops(gm, triton_kernel_wrapper_mutation) for gm in gms ) self.assertEqual(num_kernels, expected_num_triton_kernels) # Check accuracy. result = opt(*args) ref = func(*args) if metadata_only: # When we only want to check metadata, fill in zeros for tensor data. ref, result = tuple( pytree.tree_map(torch.zeros_like, x) for x in (ref, result) ) self.assertTrue(same(ref, result)) return gms @classmethod def setUpClass(cls): super().setUpClass() # Register the FX backend, storing the default for later. common.init_backend_registration() cls._default_backend = common.device_codegens[cls.device] common.register_backend_for_device( cls.device, TritonScheduling, WrapperFxCodegen ) @classmethod def tearDownClass(cls): super().tearDownClass() # Restore the default backend. common.device_codegens[cls.device] = cls._default_backend def test_basic(self): args = [torch.randn(8, device=self.device) for _ in range(2)] self._compile_and_check(torch.add, args) def test_multiple_kernels(self): def foo(x, y): return x.sum() + y.sum() args = [torch.randn(length, device=self.device) for length in [517, 1029]] self._compile_and_check(foo, args, expected_num_triton_kernels=2) def test_free(self): """ Test a program that frees a buffer which is no longer in use. """ def foo(x, y, z): w = x.sum() + y return z.sum() + w.sum() args = [torch.randn(length, device=self.device) for length in [517, 1029, 123]] (gm,) = self._compile_and_check(foo, args, expected_num_triton_kernels=3) # Check the generated code for frees. num_frees = gm.code.count("= None") self.assertGreater(num_frees, 0) def test_extern(self): """ Test a program that calls an extern kernel. """ def foo(x, y): return x @ y + y.sum() args = [ torch.randn(size, device=self.device) for size in [(129, 129), (129, 1)] ] (gm,) = self._compile_and_check(foo, args, expected_num_triton_kernels=1) # Check for the extern kernel num_extern = self._count_ops(gm, torch.ops.aten.addmm.out) self.assertEqual(num_extern, 1) def test_fallback(self): """ Test a program that calls aten fallbacks. """ def foo(x): batch1 = torch.randn(2, 3, 5, device=self.device) batch2 = torch.randn(2, 5, 4, device=self.device) return torch.addbmm(x, batch1, batch2) args = (torch.randn(3, 4, device=self.device),) # Since the program has a random output, just check metadata. # Don't check for an exact value. (gm,) = self._compile_and_check( foo, args, expected_num_triton_kernels=2, metadata_only=True ) # Check for the fallback kernel. num_fallback = self._count_ops( gm, torch.ops.aten.randint.low_out ) + self._count_ops(gm, torch.ops.aten.addbmm.default) self.assertEqual(num_fallback, 2) def test_cat_inputs(self): """ Test concatenation of graph inputs. """ def foo(x, y): return torch.cat((x, y)) + 1 args = [torch.randn(8, device=self.device) for _ in range(2)] self._compile_and_check(foo, args, expected_num_triton_kernels=1) def test_cat_views(self): """ Test concatenation with multiple kernels writing to the same buffer. """ def foo(x, y): a = x - 2 b = y.sum(0, keepdim=True) c = torch.cat((a, b)).clone() return a, b, c args = [torch.randn(8, device=self.device) for _ in range(2)] (gm,) = self._compile_and_check(foo, args, expected_num_triton_kernels=2) def get_offset(node: torch.fx.Node) -> int: (input_, shape, stride, offset) = node.args assert isinstance(offset, int) return offset # Check for 2 views, one of which is offset. as_strided_nodes = list( gm.graph.find_nodes(op="call_function", target=torch.as_strided) ) self.assertEqual(len(as_strided_nodes), 2) num_offset_views = sum(get_offset(node) > 0 for node in as_strided_nodes) self.assertEqual(num_offset_views, 1) def test_cat_to_alloc(self): """ Test concatenation that's optimized out to an allocation. """ length = 8 def foo(x): y, z = tuple( torch.arange(length // 2, device=self.device) for _ in range(2) ) return x + torch.cat((y, z)) args = [torch.randn(length, device=self.device)] (gm,) = self._compile_and_check(foo, args, expected_num_triton_kernels=1) # Expect a single allocation, even though eager mode would use 2. num_allocs = self._count_ops(gm, torch.empty_strided) self.assertEqual(num_allocs, 1) def test_cat_reinterpret_view(self): """ Test torch.cat using ReinterpretView. """ length = 8 def foo(x): y, z = tuple(torch.randn(length // 2, device=self.device) for _ in range(2)) return x + torch.cat((y, z)) args = [torch.randn(length, device=self.device)] # Since this test generates random numbers, check metadata only. (gm,) = self._compile_and_check( foo, args, expected_num_triton_kernels=3, metadata_only=True ) # Check for as_strided. We map ReinterpretView to this. num_as_strided = self._count_ops(gm, torch.as_strided) self.assertEqual(num_as_strided, 2) def test_reshape_output(self): """ Test reshaping the output, which maps to a ReinterpretView. """ def foo(x, y): return torch.reshape(x + y, (8,)) args = [torch.randn((2, 4), device=self.device) for _ in range(2)] (gm,) = self._compile_and_check(foo, args, expected_num_triton_kernels=1) # Check for as_strided. We map ReinterpretView to this. num_as_strided = self._count_ops(gm, torch.as_strided) self.assertEqual(num_as_strided, 1) def test_extern_multi_output(self): """ Test an extern kernel with multiple outputs. Also test a graph with multiple outputs. """ def foo(x): top, idx = torch.topk(x, 2) return top + 1, idx * 2 args = [torch.randn(8, device=self.device)] (gm,) = self._compile_and_check(foo, args, expected_num_triton_kernels=2) # Check for multiple kernel outputs via getitems. num_getitems = self._count_ops(gm, operator.getitem) self.assertEqual(num_getitems, 2) # Check for multiple graph outputs. output_node = gm.graph.find_nodes(op="output")[0] self.assertEqual(len(output_node.args[0]), 2) def test_duplicate_input(self): """ Test duplicated inputs. This will collapse into a single input in the GM. """ args = [torch.randn(4, device=self.device)] * 2 (gm,) = self._compile_and_check(torch.add, args, expected_num_triton_kernels=1) num_placeholders = len(gm.graph.find_nodes(op="placeholder")) self.assertEqual(num_placeholders, 1) def test_backward(self): """ Test a program with a backward pass. """ x = torch.ones(5, device=self.device) # input tensor y = torch.zeros(3, device=self.device) # expected output w = torch.randn(5, 3, requires_grad=True, device=self.device) b = torch.randn(3, requires_grad=True, device=self.device) def foo(x, y): z = torch.matmul(x, w) + b loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y) loss.backward() return w.grad, b.grad # Expect separate forward and backward graphs. (forward_gm, backward_gm) = self._compile_and_check( foo, (x, y), expected_num_triton_kernels=3 ) def test_custom_compiler(self): """ Test a derived backend with a custom compiler. """ offset = 1 class CustomWrapperCodegen(WrapperFxCodegen): def compile_graph(self, gm): def compiled_fn(*args): # Adds an offset to the program's outputs. outputs = gm(*args) return pytree.tree_map(lambda x: x + 1, outputs) return compiled_fn args = [torch.randn(8, device=self.device) for _ in range(2)] custom_backend = common.DeviceCodegen( TritonScheduling, CustomWrapperCodegen, None ) with unittest.mock.patch.dict( common.device_codegens, {self.device: custom_backend} ): func = torch.add opt = torch.compile(func) result = opt(*args) # Check the output is offset from eager mode. ref = func(*args) self.assertFalse(same(result, ref)) self.assertNotEqual(offset, 0) self.assertTrue(same(result - offset, ref)) def test_dynamic_shapes_and_strides(self): """ Test a graph with dynamic shapes and strides. """ static_dims = (8, 8) def get_input(): full_size = (16, 8) full = torch.randn(full_size, device=self.device) view = torch.as_strided(full, static_dims, full.stride()) return view func = torch.add args = [get_input() for _ in range(2)] (gm,) = self._compile_and_check(func, args, compile_kwargs={"dynamic": True}) # Check for a symbolic output shape. (empty_strided,) = gm.graph.find_nodes( op="call_function", target=torch.empty_strided ) example_tensor = empty_strided.meta["val"] symbolic_dims = example_tensor.shape self.assertEqual(len(symbolic_dims), len(static_dims)) # Check for symbolic output strides. (stride, one) = example_tensor.stride() self.assertEqual(one, sympy.S.One) # Find the size symbols, and check for a corresponding placeholders defining them. for symbol in itertools.chain(symbolic_dims, [stride]): self.assertTrue(isinstance(symbol, torch.SymInt)) (placeholder,) = [ node for node in gm.graph.find_nodes(op="placeholder") if node.name == str(symbol) ] self.assertEqual(placeholder.meta["val"], symbol) @parametrize( "shape", [ (20,), (50, 30), (50, 30, 40), ], ) @torch._inductor.config.patch( { "pad_dynamic_shapes": True, "comprehensive_padding": True, "padding_alignment_bytes": 32, "pad_outputs": True, } ) def test_dynamic_shapes_with_padding(self, shape): """ Test a graph with dynamic shapes with padding. """ def get_input(shape): pad_size = list(shape) pad_size[-1] = ((shape[-1] + 7) // 8) * 8 pad = torch.randn(pad_size, dtype=torch.float32, device=self.device) view = torch.as_strided(pad, shape, pad.stride()) return view args = [get_input(shape) for _ in range(2)] (gm,) = self._compile_and_check( torch.add, args, compile_kwargs={"dynamic": True} ) # Check for a symbolic output shape. (empty_strided,) = gm.graph.find_nodes( op="call_function", target=torch.empty_strided ) example_tensor = empty_strided.meta["val"] symbolic_dims = example_tensor.shape symbolic_strides = example_tensor.stride() align_elems = 32 // args[0].dtype.itemsize expected_strides = [1 for _ in range(len(shape))] for i in range(len(shape) - 1, 0, -1): expected_strides[i - 1] = align_elems * ( ((expected_strides[i] * symbolic_dims[i]) + align_elems - 1) // align_elems ) for i, j in zip(symbolic_strides, expected_strides): self.assertEqual(i, j) def test_dynamic_shapes_precomputed_size(self): """ Test dynamic shapes where a kernel's size arg is precomputed. """ func = torch.add args = [ torch.randn(shape, device=self.device) for shape in [(7, 12, 9), (7, 1, 1)] ] (gm,) = self._compile_and_check(func, args, compile_kwargs={"dynamic": True}) # Check for the precomputed size arg. (triton_node,) = gm.graph.find_nodes( op="call_function", target=triton_kernel_wrapper_mutation ) self.assertIn("ks0", triton_node.kwargs["kwargs"]) def test_dynamic_launch_grid_calc_python(self): """ Test the dyanmic launch grid calculation for Triton kernel wrapper using python mode """ func = torch.add args = [torch.randn(shape, device=self.device) for shape in [(7, 12), (7, 1)]] (gm,) = self._compile_and_check(func, args, compile_kwargs={"dynamic": True}) # Check for the precomputed size arg. (triton_node,) = gm.graph.find_nodes( op="call_function", target=triton_kernel_wrapper_mutation ) self.assertIn("grid", triton_node.kwargs) self.assertIn("xnumel", triton_node.kwargs["kwargs"]) self.assertIn("XBLOCK", triton_node.kwargs["kwargs"]) grid = triton_node.kwargs["grid"][0] xnumel = triton_node.kwargs["kwargs"]["xnumel"].meta["val"] xblock = triton_node.kwargs["kwargs"]["XBLOCK"] self.assertEqual(grid[0].meta["val"], -(-xnumel // xblock)) self.assertEqual(grid[1], 1) self.assertEqual(grid[2], 1) def test_dynamic_launch_grid_calc_python_slow(self): """ Test the dyanmic launch grid calculation for Triton kernel wrapper using python_slow mode """ from torch._inductor.runtime.triton_heuristics import GridExpr # Mock GridExpr.from_meta to use "python_slow" mode explicitly original_from_meta = GridExpr.from_meta def mocked_from_meta(inductor_meta, cfg, mode="python"): return original_from_meta(inductor_meta, cfg, mode="python_slow") with unittest.mock.patch.object(GridExpr, "from_meta", mocked_from_meta): func = torch.add args = [ torch.randn(shape, device=self.device) for shape in [(7, 12), (7, 1)] ] (gm,) = self._compile_and_check( func, args, compile_kwargs={"dynamic": True} ) # Check for the precomputed size arg. (triton_node,) = gm.graph.find_nodes( op="call_function", target=triton_kernel_wrapper_mutation ) self.assertIn("grid", triton_node.kwargs) self.assertIn("xnumel", triton_node.kwargs["kwargs"]) self.assertIn("XBLOCK", triton_node.kwargs["kwargs"]) grid = triton_node.kwargs["grid"][0] xnumel = triton_node.kwargs["kwargs"]["xnumel"].meta["val"] xblock = triton_node.kwargs["kwargs"]["XBLOCK"] self.assertEqual(grid[0].meta["val"], ((xnumel + xblock - 1) // xblock)) self.assertEqual(grid[1], 1) self.assertEqual(grid[2], 1) @config.patch({"trace.enabled": True}) @unittest.mock.patch("torch._inductor.debug.DebugFormatter.output_code") def test_debug(self, mock_output_code): # Compile in debug mode. args = [torch.randn(11, device=self.device) for _ in range(2)] self._compile_and_check(torch.sub, args) # Check the output code for a Triton kernel call. mock_output_code.assert_called_once() (output_filename,) = mock_output_code.call_args.args with open(output_filename) as f: output_code = f.read() self.assertIn("triton_kernel_wrapper_mutation", output_code) @parametrize( "const", (1, 1.5), ) def test_export_const_placeholder(self, const): """ Test that we can compile a graph coming from torch.export with a constant input. """ class TestModule(torch.nn.Module): def forward(self, x, y): return x - y args = (torch.randn(8, device=self.device), const) mod = TestModule() export_gm = torch.export.export(mod, args).module() def compile_module(*inps): torch._inductor.compile(export_gm, inps) (inductor_gm,) = self._run_and_capture_graphs(compile_module, args) result = inductor_gm(*args) ref = mod(*args) self.assertTrue(same(ref, result)) def test_scatter_fallback_scalar_src(self): """ Test a special case where ScatterFallback takes a scalar 'src' argument. """ def foo(input_): dim = 0 src = 1.5 return torch.ops.aten.scatter(input_, dim, index, src) length = 8 index = torch.randint(length, (length,), device=self.device) input_ = torch.randn(length, device=self.device) with DeterministicGuard(True): (gm,) = self._compile_and_check( foo, (input_,), ) # Check for the fallback op. num_fallback = self._count_ops(gm, torch.ops.aten.scatter_.value) self.assertEqual(num_fallback, 1) def test_index_put_fallback(self): """ Test the deterministic fallback for index_put. """ length = 8 out, values = [torch.randn(length, device=self.device) for _ in range(2)] indices = (torch.randint(length, (length,), device=self.device),) accumulate = True with DeterministicGuard(True): (gm,) = self._compile_and_check( torch.index_put, (out, indices, values, accumulate), expected_num_triton_kernels=1, ) # Check for the fallback op. self.assertEqual(self._count_ops(gm, torch.ops.aten.index_put_.default), 1) def test_scatter_reduce_fallback(self): """ Test the customized wrapper codegen for ScatterFallback ops. """ fallback_op = torch.ops.aten.scatter_reduce_.two def foo(out, index, src): dim = 0 out = fallback_op(out, dim, index, src, reduce="amax", include_self=False) return out + 1 length = 8 out, src = [torch.randn(length, device=self.device) for _ in range(2)] index = torch.randint(length, (length,), device=self.device) (gm,) = self._compile_and_check( foo, (out, index, src), expected_num_triton_kernels=2 ) # Check for the fallback. self.assertEqual(self._count_ops(gm, fallback_op), 1) @parametrize("pred", (False, True)) def test_cond_subgraph(self, pred: bool): """ Test a model with subgraphs. """ def foo(pred, x): return torch.cond(pred, torch.cos, torch.sin, [x]) + 1 x = torch.randn((2, 3), device=self.device) pred_tensor = torch.tensor([pred], device=self.device) gm = self._compile_and_check( foo, [pred_tensor, x], expected_num_triton_kernels=3 )[-1] # Check for subgraphs. subgm_getattrs = list(gm.graph.find_nodes(op="get_attr")) self.assertEqual(len(subgm_getattrs), 2) for subgm_getattr in subgm_getattrs: target = subgm_getattr.name self.assertTrue(isinstance(getattr(gm, target), torch.fx.GraphModule)) @parametrize("pred", (False, True)) def test_cond_no_operands(self, pred: bool): """ Test torch.cond when the subgraphs take no inputs. """ length = 8 def true_fn(): return torch.zeros(length, device=self.device) def false_fn(): return true_fn() + 5 def foo(pred): return torch.cond(pred, true_fn, false_fn, ()) pred_tensor = torch.tensor([pred], device=self.device) self._compile_and_check(foo, [pred_tensor], expected_num_triton_kernels=2) def test_cpp_raises(self): """ Test the C++ CPU backend. C++ kernels are not yet supported, so for now check that we get the expected exception. """ def foo(x, y): return x + y * 5 device = torch.device("cpu") args = [torch.randn(5, device=device) for _ in range(2)] cpp_backend = common.DeviceCodegen(CppScheduling, WrapperFxCodegen, None) with ( unittest.mock.patch.dict( common.device_codegens, {device.type: cpp_backend} ), self.assertRaisesRegex(BackendCompilerFailed, "Triton"), ): self._compile_and_check(foo, args) @parametrize("enable_tuning", (False, True)) @parametrize("use_dynamic_shapes", (False, True)) def test_autotune(self, use_dynamic_shapes: bool, enable_tuning: bool): orig_run = torch._inductor.runtime.triton_heuristics.CachingAutotuner.run called = False def run(*args, **kwargs): nonlocal called called = True return orig_run(*args, **kwargs) args = [torch.randn(8, device=self.device) for _ in range(2)] with ( config.patch("triton.autotune_at_compile_time", enable_tuning), unittest.mock.patch.object( torch._inductor.runtime.triton_heuristics.CachingAutotuner, "run", run ), ): # Compile and check that the tuner was called. self.assertFalse(called) (gm,) = self._compile_and_check( torch.mul, args, compile_kwargs={"dynamic": use_dynamic_shapes} ) self.assertEqual(called, enable_tuning) # Check for a symbolic output shape. (empty_strided,) = gm.graph.find_nodes( op="call_function", target=torch.empty_strided ) (shape, stride) = empty_strided.args if use_dynamic_shapes: self.assertEqual(type(shape[0]), torch.fx.Node) def test_custom_triton(self): @triton.jit def add_kernel( in_ptr0, in_ptr1, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr, ): pid = tl.program_id(axis=0) block_start = pid * BLOCK_SIZE offsets = block_start + tl.arange(0, BLOCK_SIZE) mask = offsets < n_elements x = tl.load(in_ptr0 + offsets, mask=mask) y = tl.load(in_ptr1 + offsets, mask=mask) output = x + y tl.store(out_ptr + offsets, output, mask=mask) def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: output = torch.empty_like(x) n_elements = output.numel() def grid(meta): return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16) return output args = [torch.randn(32, device=self.device) for _ in range(2)] self._compile_and_check(add, args) def test_output_slice_view(self): """ Test when the output is a view of the input. The sliced strides create a TensorBox in the output IR. """ def foo(x): return x[0:2:2].T[3:].squeeze(0) args = [torch.rand([4, 4, 4, 4], device=self.device)] self._compile_and_check(foo, args, expected_num_triton_kernels=0) @instantiate_parametrized_tests class AOTFxirTestCase(InductorTestCase): device = GPU_TYPE def check( self, model, inp, dynamic_shapes=None, strict=False ) -> torch.fx.GraphModule: if self.device == "xpu": raise unittest.SkipTest("The feature AOTFxir not currently ready for XPU") with torch.no_grad(): ep = torch.export.export( model, inp, dynamic_shapes=dynamic_shapes, strict=strict ) gm = torch._inductor.aot_compile( ep.module(), inp, options={"fx_wrapper": True, **test_config} ) self.assertTrue(same(model(*inp), gm(*inp))) for node in gm.graph.nodes: if ( node.op == "call_function" and node.target != triton_kernel_wrapper_mutation ): self.assertTrue(node.meta.get("val", None) is not None) return gm def test_aoti_fx_add(self): class M(torch.nn.Module): def forward(self, x, y): return x + y inp = (torch.ones(3, device=self.device), torch.ones(3, device=self.device)) self.check(M(), inp) def test_aoti_fx_const(self): class M(torch.nn.Module): def __init__(self, device): super().__init__() self.device = device self.a = torch.nn.Parameter(torch.ones(3, device=self.device)) self.b = torch.ones(3, device=self.device) def forward(self, x, y): return x + y + self.a + self.b + torch.tensor(3, device=self.device) inp = (torch.ones(3, device=self.device), torch.ones(3, device=self.device)) self.check(M(self.device), inp) def test_aoti_fx_linear(self): class M(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(3, 3) def forward(self, x): return self.linear(x) inp = (torch.ones(3, 3, device=self.device),) self.check(M().to(self.device), inp) def test_aoti_fx_dynamic(self): class M(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x, y): return x + y inp = (torch.ones(3, device=self.device), torch.ones(3, device=self.device)) self.check( M().to(device=self.device), inp, dynamic_shapes=({0: Dim.DYNAMIC}, {0: Dim.DYNAMIC}), ) def test_custom_triton_autotune_dynamic(self): class Model(torch.nn.Module): def forward(self, x, y): output = torch.zeros_like(x) x_elements = output.size()[0] y_elements = output.size()[1] def grid(meta): return ( triton.cdiv(x_elements, meta["BLOCK_SIZE_X"]), triton.cdiv(y_elements, meta["BLOCK_SIZE_Y"]), ) add_kernel_2d_autotuned[grid](x, y, output, x_elements, y_elements) return output num_dims = 2 dims = [10] * num_dims x = torch.randn(*dims, device=self.device) y = torch.randn(*dims, device=self.device) dim0_x = Dim("dim0_x", min=1, max=10) dim0_y = Dim("dim0_y", min=1, max=10) dynamic_shapes = {"x": {0: dim0_x}, "y": {0: dim0_y}} self.check( Model().to(device=self.device), (x, y), dynamic_shapes=dynamic_shapes, strict=True, ) def test_custom_backend(self): """ Test registering a custom FX backend. """ called = False class CustomWrapperCodegen(WrapperFxCodegen): def compile_graph(self, gm): """ Simply records whether this override was called. """ nonlocal called called = True return super().compile_graph(gm) class M(torch.nn.Module): def forward(self, x): return x + 1 # Register a custom FX backend. custom_backend = common.DeviceCodegen( TritonScheduling, PythonWrapperCodegen, fx_wrapper_codegen=CustomWrapperCodegen, ) with unittest.mock.patch.dict( common.device_codegens, {self.device: custom_backend} ): # The backend should not have been called yet. self.assertFalse(called) inp = (torch.randn(8, device=self.device),) self.check(M().to(self.device), inp) # Now the backend should have been called. self.assertTrue(called) @parametrize( "expr", [ (2 * Dim("x") + 1), (Dim("x", min=3) - 3), ], ) def test_dynamic_input_expr(self, expr: sympy.Expr): """ Test dynamic shapes with a nontrivial input expression. """ class M(torch.nn.Module): def forward(self, x): return x.reshape(x.shape[0] * x.shape[1]) + x.shape[1] dynamic_shapes = {"x": {0: expr}} inp = (torch.randn((5, 4), device=self.device),) gm = self.check(M().to(self.device), inp, dynamic_shapes=dynamic_shapes) # Check for dynamic size ops. self.assertEqual( len( gm.graph.find_nodes( op="call_function", target=torch.ops.aten.sym_size.int ) ), 1, ) @parametrize("pred", (False, True)) def test_cond_multi_inputs_and_outputs(self, pred): """ Test torch.cond and check the output graphs. """ class M(torch.nn.Module): def forward(self, pred, x, y): def true_fn(x, y): return torch.tanh(x), torch.relu(y) def false_fn(x, y): return tuple(t / 2 for t in true_fn(x, y)) return torch.cond(pred, true_fn, false_fn, (x, y)) pred = torch.tensor([True], device=self.device) (x, y) = [torch.randn(8, device=self.device) for _ in range(2)] gm = self.check(M(), (pred, x, y)) # Check the graph. self.assertExpectedInline( gm.code.strip(), """\ def forward(self, arg0_1, arg1_1, arg2_1): true_graph_0 = self.true_graph_0 false_graph_0 = self.false_graph_0 cond = torch.ops.higher_order.cond(arg0_1, true_graph_0, false_graph_0, (arg1_1, arg2_1)); arg0_1 = true_graph_0 = false_graph_0 = arg1_1 = arg2_1 = None buf1 = cond[0] buf2 = cond[1]; cond = None return [buf1, buf2]""", # noqa: B950 ) @parametrize("length", (4, 8)) def test_cond_dynamic_shape_pred_scalar_closure(self, length: int): """ Test cond using a predicate computed from dynamic shapes. Also test a dynamic scalar computed outside the branches. """ class M(torch.nn.Module): def forward(self, x, y): z = x.reshape(-1) a = y.shape[0] def true_fn(x): return x + a def false_fn(x): return true_fn(x) / 2 return torch.cond(x.shape[0] > 5, true_fn, false_fn, (z,)) (x, y) = [ torch.randn(shape, device=self.device) for shape in [(length // 2,) * 2, (length,)] ] dynamic_shapes = { "x": {0: Dim.DYNAMIC}, "y": {0: Dim.DYNAMIC}, } self.check(M(), (x, y), dynamic_shapes=dynamic_shapes) def test_dynamic_scalar_output(self): """ Test an output scalar from dynamic shapes. """ class M(torch.nn.Module): def forward(self, x): return x.shape[0] * 3 x = torch.randn(7, device=self.device) self.check(M(), (x,), dynamic_shapes=({0: Dim.DYNAMIC},)) if __name__ == "__main__": from torch._inductor.test_case import run_tests if HAS_GPU or TRITON_HAS_CPU: run_tests(needs="filelock")