Skip L1 cache for single-use buffers (#143115)

### 1. Synopsis

Adds `cache_modifier='.cg'` optional argument into `tl.load` instructions in the inductor-generated triton code for selected buffers.

It makes the `tl.load` instruction to skip  the L1 cache for short-lived / non-reused data.

### 2. Using the feature

This feature is experimental and disabled by default.  It can be enabled by setting the environmental variable `TORCHINDUCTOR_SKIP_L1` equal to `1`.

### 3. Results

For a simple pointwise addition kernel:
```python
@torch.compile
def add_dummy(x: torch.Tensor, y: torch.Tensor):
    return x+y
```
we get (bandwith performance is in GB/s):

(a) feature DISABLED:
![image](https://github.com/user-attachments/assets/6caaf775-f083-4943-a61f-8a1bcb154387)

(b) feature ENABLED:
![image](https://github.com/user-attachments/assets/9286be7d-c6ff-4a33-a023-77cb5cc87ff6)

### 4. Caveats

The feature boost is only available when using
```python
torch._dynamo.config.cache_size_limit = 64 # or any other sufficiently big number..
torch._dynamo.config.automatic_dynamic_shapes = False   # use static shapes
```
When using (the default) dynamic shapes, only 1-2 triton kernels are generated with non-optimal block-sizes for
*all* the cases (vector sizes), hiding any perf benefit from skipping the L1 cache.

In the static case, as an optimal block size is generated for each vector size, the perf benefit of skipping the L1 cache becomes visible.

This block-size optimization issue is a larger problem in pytorch inductor and is outside the scope of this feature.

### 5. References

- [tl.load](https://triton-lang.org/main/python-api/generated/triton.language.load.html#triton.language.load)
- [cache operators](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#cache-operators)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143115
Approved by: https://github.com/jansel
This commit is contained in:
Sampsa 2025-01-07 19:35:37 +00:00 committed by PyTorch MergeBot
parent 355b0bc7e3
commit 0aa74d0ab9
5 changed files with 54 additions and 2 deletions

View File

@ -12790,6 +12790,23 @@ if HAS_GPU and not TEST_WITH_ASAN:
tmp1 = tl.load(in_ptr1 + (x3 + 262144*r0_2), r0_mask, eviction_policy='evict_first', other=0.0)""",
)
@config.patch("triton.skip_l1_cache", True)
def test_skip_l1_cache(self):
@torch.compile
def f(a, b):
return a + b
N = 512
inps = (torch.randn(N, device=GPU_TYPE), torch.randn(N, device=GPU_TYPE))
code = run_and_get_triton_code(f, *inps)
lines = [line for line in code.split("\n") if "tl.load" in line]
self.assertExpectedInline(
"\n".join(lines),
"""\
tmp0 = tl.load(in_ptr0 + (x0), xmask, cache_modifier='.cg')
tmp1 = tl.load(in_ptr1 + (x0), xmask, cache_modifier='.cg')""",
)
@config.patch("triton.use_block_ptr", True)
def test_evict_last_non_coalesced_loads_block_ptr(self):
@torch.compile

View File

@ -306,6 +306,7 @@ class CommonTemplate:
result, (triton_code,) = run_and_compare(self, foo, x, y)
@parametrize("prefer_nd_tiling", [False, True])
@config.patch("triton.skip_l1_cache", False)
def test_pointwise_broadcast_nonzero_strides(self, prefer_nd_tiling: bool):
"""
Test that we emit tl.broadcast_to instead of using strides of 0.

View File

@ -154,6 +154,18 @@ class SIMDKernelFeatures:
reduction_hint_val = ReductionHint.DEFAULT
return reduction_hint_val
@cache_on_self
def buffer_read_counts(self) -> Dict[str, int]:
"""Counts how many times each buffer is read within the kernel"""
read_counts: Dict[str, int] = collections.defaultdict(int)
for node in self.scheduler_nodes():
# node.read_writes.reads contains MemoryDep objects for each read
for read_dep in node.read_writes.reads:
read_counts[read_dep.name] += 1
return dict(read_counts) # Convert defaultdict to regular dict
def has_non_contiguous_pw_in_reduction_kernel(self) -> bool:
pointwise_nodes = [
n

View File

@ -2099,6 +2099,26 @@ class TritonKernel(SIMDKernel):
else:
other = ""
"""Check if the buffer we're about to load, has
more than one read dependency
NOTE: enabled with env variable TORCHINDUCTOR_SKIP_L1
"""
has_read_deps = True
if config.triton.skip_l1_cache:
buffer_read_counts = self.features.buffer_read_counts()
has_read_deps = buffer_read_counts[name] > 1
"""Skip L1 cache if we're (pretty?) sure the data is used only once
"""
skip_l1_cache = (
not self.is_broadcasted(original_index)
and not self.inside_reduction
and not has_read_deps
and is_coalesced # for indirect loads is_coalesced is False?
)
cachemod = ""
if skip_l1_cache:
cachemod = ", cache_modifier='.cg'"
append_broadcast = None
dtype = V.graph.get_dtype(name)
@ -2107,7 +2127,7 @@ class TritonKernel(SIMDKernel):
else:
if isinstance(indexing, BlockPtrOptions):
block_ptr, other = self.codegen_block_ptr(name, var, indexing, other)
line = f"tl.load({block_ptr}{other}{ep})"
line = f"tl.load({block_ptr}{other}{ep}{cachemod})"
line = indexing.codegen_broadcast_and_reshape(
line, indexing.block_shape, indexing.final_shape, True
)
@ -2116,7 +2136,7 @@ class TritonKernel(SIMDKernel):
line = f"tl.load({var} + ({original_index}))"
append_broadcast = indexing.expand_str
else:
line = f"tl.load({var} + ({indexing.index_str}), {indexing.mask_str}{ep}{other})"
line = f"tl.load({var} + ({indexing.index_str}), {indexing.mask_str}{ep}{other}{cachemod})"
if (
dtype in (torch.float16, torch.bfloat16)

View File

@ -1044,6 +1044,8 @@ class triton:
enable_persistent_tma_matmul = (
os.environ.get("ENABLE_PERSISTENT_TMA_MATMUL", "0") == "1"
)
# Skip L1 cache for buffers that are used only once. Disabled by default
skip_l1_cache = os.environ.get("TORCHINDUCTOR_SKIP_L1", "0") == "1"
class aot_inductor: