mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
This is an attempt to fix a memory allocation issue when using `torch.compile` with a custom layernorm kernel in vllm:
```C++
// In-place fused Add and RMS Normalization.
ops.def(
"fused_add_rms_norm(Tensor! input, Tensor! residual, Tensor weight, "
"float epsilon) -> ()");
ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm);
```
We observed abnormal extra memory allocations with this op enabled using `torch.compile`:
<img width="738" alt="{374E9FCF-FB46-4750-8B60-D31E3ADCE00A}" src="https://github.com/user-attachments/assets/6c45e1aa-ccde-4c56-99dc-bf4776d699d5" />
and without this op:
<img width="738" alt="{9BB08EFE-FFE3-4D06-82C0-C70BBE6ADD56}" src="https://github.com/user-attachments/assets/56e2ee43-ab87-492d-834c-69e9cafbb0df" />
After investigation, we found that this is because the compiler considers the two buffers for the two mutated inputs `Tensor input` and `Tensor residual` should share a same dependency list, which makes it can not reuse the buffer of `Tensor input`.
```
buf1.users = [
NodeUser(node=ExternKernelSchedulerNode(name='op2'), can_inplace=False, is_weak=False),
NodeUser(node=ExternKernelSchedulerNode(name='op9'), can_inplace=False, is_weak=False),
NodeUser(node=ExternKernelSchedulerNode(name='op13'), can_inplace=False, is_weak=False),
NodeUser(node=ExternKernelSchedulerNode(name='op20'), can_inplace=False, is_weak=False),
NodeUser(node=ExternKernelSchedulerNode(name='op24'), can_inplace=False, is_weak=False),
NodeUser(node=ExternKernelSchedulerNode(name='op31'), can_inplace=False, is_weak=False),
NodeUser(node=ExternKernelSchedulerNode(name='op35'), can_inplace=False, is_weak=False),
NodeUser(node=ExternKernelSchedulerNode(name='op42'), can_inplace=False, is_weak=False),
NodeUser(node=ExternKernelSchedulerNode(name='op46'), can_inplace=False, is_weak=False),
NodeUser(node=ExternKernelSchedulerNode(name='op53'), can_inplace=False, is_weak=False),
]
buf16.users = [
NodeUser(node=ExternKernelSchedulerNode(name='op2'), can_inplace=False, is_weak=False),
NodeUser(node=ExternKernelSchedulerNode(name='op9'), can_inplace=False, is_weak=False),
NodeUser(node=ExternKernelSchedulerNode(name='op13'), can_inplace=False, is_weak=False),
NodeUser(node=ExternKernelSchedulerNode(name='op20'), can_inplace=False, is_weak=False),
NodeUser(node=ExternKernelSchedulerNode(name='op24'), can_inplace=False, is_weak=False),
NodeUser(node=ExternKernelSchedulerNode(name='op31'), can_inplace=False, is_weak=False),
NodeUser(node=ExternKernelSchedulerNode(name='op35'), can_inplace=False, is_weak=False),
NodeUser(node=ExternKernelSchedulerNode(name='op42'), can_inplace=False, is_weak=False),
NodeUser(node=ExternKernelSchedulerNode(name='op46'), can_inplace=False, is_weak=False),
NodeUser(node=ExternKernelSchedulerNode(name='op53'), can_inplace=False, is_weak=False),
]
```
```
op13: ExternKernelSchedulerNode(FallbackKernel)
op13.writes =
[ StarDep(name='buf17', mode=None),
StarDep(name='buf18', mode=None),
StarDep(name='buf19', mode=None)]
op13.unmet_dependencies =
[ StarDep(name='buf13', mode=None),
StarDep(name='buf16', mode=None),
WeakDep(name='buf11', mutating_buf='buf18'),
WeakDep(name='buf12', mutating_buf='buf18'),
WeakDep(name='buf13', mutating_buf='buf18'),
WeakDep(name='buf2', mutating_buf='buf18'),
WeakDep(name='buf3', mutating_buf='buf18')]
op13.met_dependencies = [StarDep(name='arg11_1', mode=None)]
op13.outputs = [
buf17: FallbackKernel
buf17.layout = NoneLayout(device=device(type='cuda', index=0), size=[0], stride=[0])
buf17.aliases = ['buf16', 'buf1']
buf17.users = [
NodeUser(node=ExternKernelSchedulerNode(name='op2'), can_inplace=False, is_weak=False),
NodeUser(node=ExternKernelSchedulerNode(name='op9'), can_inplace=False, is_weak=False),
NodeUser(node=ExternKernelSchedulerNode(name='op13'), can_inplace=False, is_weak=False),
NodeUser(node=ExternKernelSchedulerNode(name='op20'), can_inplace=False, is_weak=False),
NodeUser(node=ExternKernelSchedulerNode(name='op24'), can_inplace=False, is_weak=False),
NodeUser(node=ExternKernelSchedulerNode(name='op31'), can_inplace=False, is_weak=False),
NodeUser(node=ExternKernelSchedulerNode(name='op35'), can_inplace=False, is_weak=False),
NodeUser(node=ExternKernelSchedulerNode(name='op42'), can_inplace=False, is_weak=False),
NodeUser(node=ExternKernelSchedulerNode(name='op46'), can_inplace=False, is_weak=False),
NodeUser(node=ExternKernelSchedulerNode(name='op53'), can_inplace=False, is_weak=False),
]
buf18: MutationOutput
buf18.layout = NoneLayout(device=device(type='cuda', index=0), size=[0], stride=[0])
buf18.mutations = ['buf16']
buf18.users = [
NodeUser(node=ExternKernelSchedulerNode(name='op14'), can_inplace=False, is_weak=False),
NodeUser(node=ExternKernelSchedulerNode(name='op20'), can_inplace=False, is_weak=True),
NodeUser(node=ExternKernelSchedulerNode(name='op24'), can_inplace=False, is_weak=True),
NodeUser(node=ExternKernelSchedulerNode(name='op31'), can_inplace=False, is_weak=True),
NodeUser(node=ExternKernelSchedulerNode(name='op35'), can_inplace=False, is_weak=True),
NodeUser(node=ExternKernelSchedulerNode(name='op42'), can_inplace=False, is_weak=True),
NodeUser(node=ExternKernelSchedulerNode(name='op46'), can_inplace=False, is_weak=True),
NodeUser(node=ExternKernelSchedulerNode(name='op53'), can_inplace=False, is_weak=True),
]
buf19: MutationOutput
buf19.layout = NoneLayout(device=device(type='cuda', index=0), size=[0], stride=[0])
buf19.mutations = ['buf1']
buf19.users = [NodeUser(node=ExternKernelSchedulerNode(name='op20'), can_inplace=False, is_weak=False)]
]
op13.node.kernel = torch.ops._C.fused_add_rms_norm.default
```
Here we can see `buf16` shares the same dependency list with `buf1` because `buf16` and `buf1` are in the aliases list of `buf17`. This is incorrect since those two are two separate tensors. And this makes the compiler could not reuse `buf16` for subsequent ops.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157133
Approved by: https://github.com/jansel
|
||
|---|---|---|
| .. | ||
| __init__.py | ||
| _internal.py | ||
| _registrations.py | ||
| scribe.py | ||
| structured.py | ||