mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[Inductor] Fix conditional codegen (#129492)
Summary:
We have the cache to guarantee the `sym` is codegen only once, see the following code
```
def ensure_size_computed(self, sym: sympy.Symbol):
if isinstance(sym, sympy.Symbol) and symbol_is_type(sym, SymT.PRECOMPUTED_SIZE):
if sym in self.computed_sizes:
return
self.computed_sizes.add(sym)
expr = V.graph.sizevars.inv_precomputed_replacements[sym]
self.writeline(
f"{self.declare}{sym} = {self.expr_printer(expr)}{self.ending}"
)
```
However, we don't consider the case when same `sym`s need to be codegen in both conditions (true branch and false branch), which caused the issue of `undefined symbols`: P1441378833
To fix the issue, we use a stack to capture the state before doing the condition codegen and restore the state after doing the codegen
Test Plan:
TORCH_LOGS="+inductor" buck2 run mode/dev-nosan -c fbcode.nvcc_arch=h100 -c fbcode.enable_gpu_sections=true --config 'cxx.extra_cxxflags=-g1' -c fbcode.platform010_cuda_version=12 //scripts/hhh:repro_cond_torch_compile
PYTORCH_TEST_FBCODE=1 TORCH_COMPILE_DEBUG=1 buck2 run mode/opt -c=python.package_style=inplace -c fbcode.enable_gpu_sections=true -c fbcode.platform=platform010 -c fbcode.split-dwarf=true //caffe2/test/inductor:control_flow -- -r test_cond_control_flow_with_precomputed_size
Differential Revision: D58973730
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129492
Approved by: https://github.com/aakhundov
This commit is contained in:
parent
c5c9dbece1
commit
31bb65de19
|
|
@ -221,6 +221,42 @@ class CondTests(TestCase):
|
|||
dynamic=dynamic,
|
||||
)
|
||||
|
||||
@requires_gpu
|
||||
def test_cond_control_flow_with_precomputed_size(self):
|
||||
class TestModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
):
|
||||
super().__init__()
|
||||
self.conv2d = torch.nn.Conv2d(
|
||||
512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
|
||||
)
|
||||
self.threshold = 20
|
||||
|
||||
def forward(self, x: torch.Tensor, index) -> torch.Tensor:
|
||||
def true_fn(x: torch.Tensor):
|
||||
return self.conv2d(x)
|
||||
|
||||
def false_fn(x: torch.Tensor):
|
||||
return self.conv2d(x)
|
||||
|
||||
return torch.cond(
|
||||
index < self.threshold and index >= 0, true_fn, false_fn, (x,)
|
||||
)
|
||||
|
||||
main_model = TestModel().cuda()
|
||||
x1 = torch.rand(2, 512, 128, 72).cuda()
|
||||
x2 = torch.rand(2, 512, 96, 96).cuda()
|
||||
|
||||
opt_model = torch.compile(main_model)
|
||||
out1 = main_model(x1, 1)
|
||||
opt_out1 = opt_model(x1, 1)
|
||||
self.assertTrue(torch.allclose(out1, opt_out1, atol=1e-5))
|
||||
|
||||
out2 = main_model(x2, 30)
|
||||
opt_out2 = opt_model(x2, 30)
|
||||
self.assertTrue(torch.allclose(out2, opt_out2, atol=1e-5))
|
||||
|
||||
@requires_gpu
|
||||
@parametrize("device", ["cpu", GPU_TYPE])
|
||||
@parametrize("dynamic", [False, True])
|
||||
|
|
|
|||
|
|
@ -269,6 +269,9 @@ class EnterSubgraphLine(WrapperLine):
|
|||
wrapper: WrapperCodeGen
|
||||
graph: GraphLowering
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.wrapper.push_computed_sizes(self.wrapper.computed_sizes)
|
||||
|
||||
def codegen(self, code: IndentedBuffer) -> None:
|
||||
self.wrapper.push_codegened_graph(self.graph)
|
||||
code.do_indent()
|
||||
|
|
@ -278,6 +281,9 @@ class EnterSubgraphLine(WrapperLine):
|
|||
class ExitSubgraphLine(WrapperLine):
|
||||
wrapper: WrapperCodeGen
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.wrapper.computed_sizes = self.wrapper.pop_computed_sizes()
|
||||
|
||||
def codegen(self, code: IndentedBuffer) -> None:
|
||||
self.wrapper.pop_codegened_graph()
|
||||
code.do_unindent()
|
||||
|
|
@ -488,6 +494,7 @@ class WrapperCodeGen(CodeGen):
|
|||
# including the graph instance into a cache key to avoid cross-graph
|
||||
# caching during lowering of nested subgraphs
|
||||
self.codegened_graph_stack = []
|
||||
self.computed_sizes_stack = []
|
||||
|
||||
self.write_header()
|
||||
self.write_prefix()
|
||||
|
|
@ -680,6 +687,14 @@ class WrapperCodeGen(CodeGen):
|
|||
def pop_codegened_graph(self):
|
||||
return self.codegened_graph_stack.pop()
|
||||
|
||||
def push_computed_sizes(self, computed_sizes):
|
||||
from copy import deepcopy
|
||||
|
||||
return self.computed_sizes_stack.append(deepcopy(computed_sizes))
|
||||
|
||||
def pop_computed_sizes(self):
|
||||
return self.computed_sizes_stack.pop()
|
||||
|
||||
def next_kernel_suffix(self) -> str:
|
||||
return f"{next(self._names_iter)}"
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user