[inductor] Insert triton barrier before storing to inplace buffers (#100769)

The linked issue demonstrates a triton bug where a load broadcasted
over multiple warps may see the result of a store that happens later
in the triton program. The workaround is to add a barrier before
storing, which enforces that all warps have already read the data.

e.g. in `test_embedding_var_mean` we now generate:
```python
    tl.debug_barrier()
    tl.store(in_out_ptr1 + (tl.broadcast_to(x0, [XBLOCK, 1])), tmp17, None)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100769
Approved by: https://github.com/jansel, https://github.com/ngimel
This commit is contained in:
Peter Bell 2023-05-12 18:19:00 +00:00 committed by PyTorch MergeBot
parent 05077f2ac3
commit 5fe834afc1
2 changed files with 14 additions and 8 deletions

View File

@ -1172,6 +1172,19 @@ class TritonKernel(Kernel):
):
self.gen_assert_indirect_indexing(self.stores, original_index, mask)
# Guard against write-after-read corruption in triton.
# See # https://github.com/openai/triton/issues/1615
# This triton bug means that a load which is broadcasted over multiple
# warps may see the result of a store that happens later in the triton
# program. The workaround is to add a barrier before storing, which
# enforces that all warps have already read the data.
is_inplace = name in self.args.inplace_buffers
is_broadcasted = not set.issubset(
set(self.range_tree_nodes.keys()), original_index.free_symbols
)
if is_inplace and is_broadcasted:
self.stores.writeline(DeferredLine(name, "tl.debug_barrier()"))
if mode is None:
line = f"tl.store({var} + ({index}), {value}, {mask})"
elif mode == "atomic_add":

View File

@ -1235,14 +1235,7 @@ class Scheduler:
elif node.is_extern():
self.codegen_extern_call(node)
elif isinstance(node, (FusedSchedulerNode, SchedulerNode)):
with config.patch(
inplace_buffers=(
config.inplace_buffers
# workaround https://github.com/openai/triton/issues/1615
and not (ir.is_triton(device) and node.is_reduction())
)
):
self.get_backend(device).codegen_nodes(node.get_nodes())
self.get_backend(device).codegen_nodes(node.get_nodes())
else:
assert isinstance(node, NopKernelSchedulerNode)
node.allocate()