mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[inductor] Insert triton barrier before storing to inplace buffers (#100769)
The linked issue demonstrates a triton bug where a load broadcasted
over multiple warps may see the result of a store that happens later
in the triton program. The workaround is to add a barrier before
storing, which enforces that all warps have already read the data.
e.g. in `test_embedding_var_mean` we now generate:
```python
tl.debug_barrier()
tl.store(in_out_ptr1 + (tl.broadcast_to(x0, [XBLOCK, 1])), tmp17, None)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100769
Approved by: https://github.com/jansel, https://github.com/ngimel
This commit is contained in:
parent
05077f2ac3
commit
5fe834afc1
|
|
@ -1172,6 +1172,19 @@ class TritonKernel(Kernel):
|
|||
):
|
||||
self.gen_assert_indirect_indexing(self.stores, original_index, mask)
|
||||
|
||||
# Guard against write-after-read corruption in triton.
|
||||
# See # https://github.com/openai/triton/issues/1615
|
||||
# This triton bug means that a load which is broadcasted over multiple
|
||||
# warps may see the result of a store that happens later in the triton
|
||||
# program. The workaround is to add a barrier before storing, which
|
||||
# enforces that all warps have already read the data.
|
||||
is_inplace = name in self.args.inplace_buffers
|
||||
is_broadcasted = not set.issubset(
|
||||
set(self.range_tree_nodes.keys()), original_index.free_symbols
|
||||
)
|
||||
if is_inplace and is_broadcasted:
|
||||
self.stores.writeline(DeferredLine(name, "tl.debug_barrier()"))
|
||||
|
||||
if mode is None:
|
||||
line = f"tl.store({var} + ({index}), {value}, {mask})"
|
||||
elif mode == "atomic_add":
|
||||
|
|
|
|||
|
|
@ -1235,14 +1235,7 @@ class Scheduler:
|
|||
elif node.is_extern():
|
||||
self.codegen_extern_call(node)
|
||||
elif isinstance(node, (FusedSchedulerNode, SchedulerNode)):
|
||||
with config.patch(
|
||||
inplace_buffers=(
|
||||
config.inplace_buffers
|
||||
# workaround https://github.com/openai/triton/issues/1615
|
||||
and not (ir.is_triton(device) and node.is_reduction())
|
||||
)
|
||||
):
|
||||
self.get_backend(device).codegen_nodes(node.get_nodes())
|
||||
self.get_backend(device).codegen_nodes(node.get_nodes())
|
||||
else:
|
||||
assert isinstance(node, NopKernelSchedulerNode)
|
||||
node.allocate()
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user