diff --git a/test/inductor/test_control_flow.py b/test/inductor/test_control_flow.py index cc3b211676f..60c6b1e9028 100644 --- a/test/inductor/test_control_flow.py +++ b/test/inductor/test_control_flow.py @@ -221,6 +221,42 @@ class CondTests(TestCase): dynamic=dynamic, ) + @requires_gpu + def test_cond_control_flow_with_precomputed_size(self): + class TestModel(torch.nn.Module): + def __init__( + self, + ): + super().__init__() + self.conv2d = torch.nn.Conv2d( + 512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1) + ) + self.threshold = 20 + + def forward(self, x: torch.Tensor, index) -> torch.Tensor: + def true_fn(x: torch.Tensor): + return self.conv2d(x) + + def false_fn(x: torch.Tensor): + return self.conv2d(x) + + return torch.cond( + index < self.threshold and index >= 0, true_fn, false_fn, (x,) + ) + + main_model = TestModel().cuda() + x1 = torch.rand(2, 512, 128, 72).cuda() + x2 = torch.rand(2, 512, 96, 96).cuda() + + opt_model = torch.compile(main_model) + out1 = main_model(x1, 1) + opt_out1 = opt_model(x1, 1) + self.assertTrue(torch.allclose(out1, opt_out1, atol=1e-5)) + + out2 = main_model(x2, 30) + opt_out2 = opt_model(x2, 30) + self.assertTrue(torch.allclose(out2, opt_out2, atol=1e-5)) + @requires_gpu @parametrize("device", ["cpu", GPU_TYPE]) @parametrize("dynamic", [False, True]) diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index a0aad9a2af8..868952dfba6 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -269,6 +269,9 @@ class EnterSubgraphLine(WrapperLine): wrapper: WrapperCodeGen graph: GraphLowering + def __post_init__(self) -> None: + self.wrapper.push_computed_sizes(self.wrapper.computed_sizes) + def codegen(self, code: IndentedBuffer) -> None: self.wrapper.push_codegened_graph(self.graph) code.do_indent() @@ -278,6 +281,9 @@ class EnterSubgraphLine(WrapperLine): class ExitSubgraphLine(WrapperLine): wrapper: WrapperCodeGen + def __post_init__(self) -> None: + self.wrapper.computed_sizes = self.wrapper.pop_computed_sizes() + def codegen(self, code: IndentedBuffer) -> None: self.wrapper.pop_codegened_graph() code.do_unindent() @@ -488,6 +494,7 @@ class WrapperCodeGen(CodeGen): # including the graph instance into a cache key to avoid cross-graph # caching during lowering of nested subgraphs self.codegened_graph_stack = [] + self.computed_sizes_stack = [] self.write_header() self.write_prefix() @@ -680,6 +687,14 @@ class WrapperCodeGen(CodeGen): def pop_codegened_graph(self): return self.codegened_graph_stack.pop() + def push_computed_sizes(self, computed_sizes): + from copy import deepcopy + + return self.computed_sizes_stack.append(deepcopy(computed_sizes)) + + def pop_computed_sizes(self): + return self.computed_sizes_stack.pop() + def next_kernel_suffix(self) -> str: return f"{next(self._names_iter)}"