mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Skip L1 cache for single-use buffers (#143115)
### 1. Synopsis
Adds `cache_modifier='.cg'` optional argument into `tl.load` instructions in the inductor-generated triton code for selected buffers.
It makes the `tl.load` instruction to skip the L1 cache for short-lived / non-reused data.
### 2. Using the feature
This feature is experimental and disabled by default. It can be enabled by setting the environmental variable `TORCHINDUCTOR_SKIP_L1` equal to `1`.
### 3. Results
For a simple pointwise addition kernel:
```python
@torch.compile
def add_dummy(x: torch.Tensor, y: torch.Tensor):
return x+y
```
we get (bandwith performance is in GB/s):
(a) feature DISABLED:

(b) feature ENABLED:

### 4. Caveats
The feature boost is only available when using
```python
torch._dynamo.config.cache_size_limit = 64 # or any other sufficiently big number..
torch._dynamo.config.automatic_dynamic_shapes = False # use static shapes
```
When using (the default) dynamic shapes, only 1-2 triton kernels are generated with non-optimal block-sizes for
*all* the cases (vector sizes), hiding any perf benefit from skipping the L1 cache.
In the static case, as an optimal block size is generated for each vector size, the perf benefit of skipping the L1 cache becomes visible.
This block-size optimization issue is a larger problem in pytorch inductor and is outside the scope of this feature.
### 5. References
- [tl.load](https://triton-lang.org/main/python-api/generated/triton.language.load.html#triton.language.load)
- [cache operators](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#cache-operators)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143115
Approved by: https://github.com/jansel
This commit is contained in:
parent
355b0bc7e3
commit
0aa74d0ab9
|
|
@ -12790,6 +12790,23 @@ if HAS_GPU and not TEST_WITH_ASAN:
|
|||
tmp1 = tl.load(in_ptr1 + (x3 + 262144*r0_2), r0_mask, eviction_policy='evict_first', other=0.0)""",
|
||||
)
|
||||
|
||||
@config.patch("triton.skip_l1_cache", True)
|
||||
def test_skip_l1_cache(self):
|
||||
@torch.compile
|
||||
def f(a, b):
|
||||
return a + b
|
||||
|
||||
N = 512
|
||||
inps = (torch.randn(N, device=GPU_TYPE), torch.randn(N, device=GPU_TYPE))
|
||||
code = run_and_get_triton_code(f, *inps)
|
||||
lines = [line for line in code.split("\n") if "tl.load" in line]
|
||||
self.assertExpectedInline(
|
||||
"\n".join(lines),
|
||||
"""\
|
||||
tmp0 = tl.load(in_ptr0 + (x0), xmask, cache_modifier='.cg')
|
||||
tmp1 = tl.load(in_ptr1 + (x0), xmask, cache_modifier='.cg')""",
|
||||
)
|
||||
|
||||
@config.patch("triton.use_block_ptr", True)
|
||||
def test_evict_last_non_coalesced_loads_block_ptr(self):
|
||||
@torch.compile
|
||||
|
|
|
|||
|
|
@ -306,6 +306,7 @@ class CommonTemplate:
|
|||
result, (triton_code,) = run_and_compare(self, foo, x, y)
|
||||
|
||||
@parametrize("prefer_nd_tiling", [False, True])
|
||||
@config.patch("triton.skip_l1_cache", False)
|
||||
def test_pointwise_broadcast_nonzero_strides(self, prefer_nd_tiling: bool):
|
||||
"""
|
||||
Test that we emit tl.broadcast_to instead of using strides of 0.
|
||||
|
|
|
|||
|
|
@ -154,6 +154,18 @@ class SIMDKernelFeatures:
|
|||
reduction_hint_val = ReductionHint.DEFAULT
|
||||
return reduction_hint_val
|
||||
|
||||
@cache_on_self
|
||||
def buffer_read_counts(self) -> Dict[str, int]:
|
||||
"""Counts how many times each buffer is read within the kernel"""
|
||||
read_counts: Dict[str, int] = collections.defaultdict(int)
|
||||
|
||||
for node in self.scheduler_nodes():
|
||||
# node.read_writes.reads contains MemoryDep objects for each read
|
||||
for read_dep in node.read_writes.reads:
|
||||
read_counts[read_dep.name] += 1
|
||||
|
||||
return dict(read_counts) # Convert defaultdict to regular dict
|
||||
|
||||
def has_non_contiguous_pw_in_reduction_kernel(self) -> bool:
|
||||
pointwise_nodes = [
|
||||
n
|
||||
|
|
|
|||
|
|
@ -2099,6 +2099,26 @@ class TritonKernel(SIMDKernel):
|
|||
else:
|
||||
other = ""
|
||||
|
||||
"""Check if the buffer we're about to load, has
|
||||
more than one read dependency
|
||||
NOTE: enabled with env variable TORCHINDUCTOR_SKIP_L1
|
||||
"""
|
||||
has_read_deps = True
|
||||
if config.triton.skip_l1_cache:
|
||||
buffer_read_counts = self.features.buffer_read_counts()
|
||||
has_read_deps = buffer_read_counts[name] > 1
|
||||
"""Skip L1 cache if we're (pretty?) sure the data is used only once
|
||||
"""
|
||||
skip_l1_cache = (
|
||||
not self.is_broadcasted(original_index)
|
||||
and not self.inside_reduction
|
||||
and not has_read_deps
|
||||
and is_coalesced # for indirect loads is_coalesced is False?
|
||||
)
|
||||
cachemod = ""
|
||||
if skip_l1_cache:
|
||||
cachemod = ", cache_modifier='.cg'"
|
||||
|
||||
append_broadcast = None
|
||||
dtype = V.graph.get_dtype(name)
|
||||
|
||||
|
|
@ -2107,7 +2127,7 @@ class TritonKernel(SIMDKernel):
|
|||
else:
|
||||
if isinstance(indexing, BlockPtrOptions):
|
||||
block_ptr, other = self.codegen_block_ptr(name, var, indexing, other)
|
||||
line = f"tl.load({block_ptr}{other}{ep})"
|
||||
line = f"tl.load({block_ptr}{other}{ep}{cachemod})"
|
||||
line = indexing.codegen_broadcast_and_reshape(
|
||||
line, indexing.block_shape, indexing.final_shape, True
|
||||
)
|
||||
|
|
@ -2116,7 +2136,7 @@ class TritonKernel(SIMDKernel):
|
|||
line = f"tl.load({var} + ({original_index}))"
|
||||
append_broadcast = indexing.expand_str
|
||||
else:
|
||||
line = f"tl.load({var} + ({indexing.index_str}), {indexing.mask_str}{ep}{other})"
|
||||
line = f"tl.load({var} + ({indexing.index_str}), {indexing.mask_str}{ep}{other}{cachemod})"
|
||||
|
||||
if (
|
||||
dtype in (torch.float16, torch.bfloat16)
|
||||
|
|
|
|||
|
|
@ -1044,6 +1044,8 @@ class triton:
|
|||
enable_persistent_tma_matmul = (
|
||||
os.environ.get("ENABLE_PERSISTENT_TMA_MATMUL", "0") == "1"
|
||||
)
|
||||
# Skip L1 cache for buffers that are used only once. Disabled by default
|
||||
skip_l1_cache = os.environ.get("TORCHINDUCTOR_SKIP_L1", "0") == "1"
|
||||
|
||||
|
||||
class aot_inductor:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user