pytorch/torch/_logging
charlifu 02724b5f64 [Bugfix][Inductor] Fix dependency list merged incorrectly for a custom op with multiple mutated inputs and None return type. (#157133)
This is an attempt to fix a memory allocation issue when using `torch.compile` with a custom layernorm kernel in vllm:
```C++
  // In-place fused Add and RMS Normalization.
  ops.def(
      "fused_add_rms_norm(Tensor! input, Tensor! residual, Tensor weight, "
      "float epsilon) -> ()");
  ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm);
```
We observed abnormal extra memory allocations with this op enabled using `torch.compile`:
<img width="738" alt="{374E9FCF-FB46-4750-8B60-D31E3ADCE00A}" src="https://github.com/user-attachments/assets/6c45e1aa-ccde-4c56-99dc-bf4776d699d5" />
and without this op:
<img width="738" alt="{9BB08EFE-FFE3-4D06-82C0-C70BBE6ADD56}" src="https://github.com/user-attachments/assets/56e2ee43-ab87-492d-834c-69e9cafbb0df" />

After investigation, we found that this is because the compiler considers the two buffers for the two mutated inputs `Tensor input` and `Tensor residual` should share a same dependency list, which makes it can not reuse the buffer of `Tensor input`.
```
buf1.users = [
        NodeUser(node=ExternKernelSchedulerNode(name='op2'), can_inplace=False, is_weak=False),
        NodeUser(node=ExternKernelSchedulerNode(name='op9'), can_inplace=False, is_weak=False),
        NodeUser(node=ExternKernelSchedulerNode(name='op13'), can_inplace=False, is_weak=False),
        NodeUser(node=ExternKernelSchedulerNode(name='op20'), can_inplace=False, is_weak=False),
        NodeUser(node=ExternKernelSchedulerNode(name='op24'), can_inplace=False, is_weak=False),
        NodeUser(node=ExternKernelSchedulerNode(name='op31'), can_inplace=False, is_weak=False),
        NodeUser(node=ExternKernelSchedulerNode(name='op35'), can_inplace=False, is_weak=False),
        NodeUser(node=ExternKernelSchedulerNode(name='op42'), can_inplace=False, is_weak=False),
        NodeUser(node=ExternKernelSchedulerNode(name='op46'), can_inplace=False, is_weak=False),
        NodeUser(node=ExternKernelSchedulerNode(name='op53'), can_inplace=False, is_weak=False),
    ]
buf16.users = [
        NodeUser(node=ExternKernelSchedulerNode(name='op2'), can_inplace=False, is_weak=False),
        NodeUser(node=ExternKernelSchedulerNode(name='op9'), can_inplace=False, is_weak=False),
        NodeUser(node=ExternKernelSchedulerNode(name='op13'), can_inplace=False, is_weak=False),
        NodeUser(node=ExternKernelSchedulerNode(name='op20'), can_inplace=False, is_weak=False),
        NodeUser(node=ExternKernelSchedulerNode(name='op24'), can_inplace=False, is_weak=False),
        NodeUser(node=ExternKernelSchedulerNode(name='op31'), can_inplace=False, is_weak=False),
        NodeUser(node=ExternKernelSchedulerNode(name='op35'), can_inplace=False, is_weak=False),
        NodeUser(node=ExternKernelSchedulerNode(name='op42'), can_inplace=False, is_weak=False),
        NodeUser(node=ExternKernelSchedulerNode(name='op46'), can_inplace=False, is_weak=False),
        NodeUser(node=ExternKernelSchedulerNode(name='op53'), can_inplace=False, is_weak=False),
    ]
```
```
op13: ExternKernelSchedulerNode(FallbackKernel)
op13.writes =
    [   StarDep(name='buf17', mode=None),
        StarDep(name='buf18', mode=None),
        StarDep(name='buf19', mode=None)]
op13.unmet_dependencies =
    [   StarDep(name='buf13', mode=None),
        StarDep(name='buf16', mode=None),
        WeakDep(name='buf11', mutating_buf='buf18'),
        WeakDep(name='buf12', mutating_buf='buf18'),
        WeakDep(name='buf13', mutating_buf='buf18'),
        WeakDep(name='buf2', mutating_buf='buf18'),
        WeakDep(name='buf3', mutating_buf='buf18')]
op13.met_dependencies = [StarDep(name='arg11_1', mode=None)]
op13.outputs = [
    buf17: FallbackKernel
    buf17.layout = NoneLayout(device=device(type='cuda', index=0), size=[0], stride=[0])
    buf17.aliases = ['buf16', 'buf1']
    buf17.users = [
        NodeUser(node=ExternKernelSchedulerNode(name='op2'), can_inplace=False, is_weak=False),
        NodeUser(node=ExternKernelSchedulerNode(name='op9'), can_inplace=False, is_weak=False),
        NodeUser(node=ExternKernelSchedulerNode(name='op13'), can_inplace=False, is_weak=False),
        NodeUser(node=ExternKernelSchedulerNode(name='op20'), can_inplace=False, is_weak=False),
        NodeUser(node=ExternKernelSchedulerNode(name='op24'), can_inplace=False, is_weak=False),
        NodeUser(node=ExternKernelSchedulerNode(name='op31'), can_inplace=False, is_weak=False),
        NodeUser(node=ExternKernelSchedulerNode(name='op35'), can_inplace=False, is_weak=False),
        NodeUser(node=ExternKernelSchedulerNode(name='op42'), can_inplace=False, is_weak=False),
        NodeUser(node=ExternKernelSchedulerNode(name='op46'), can_inplace=False, is_weak=False),
        NodeUser(node=ExternKernelSchedulerNode(name='op53'), can_inplace=False, is_weak=False),
    ]
    buf18: MutationOutput
    buf18.layout = NoneLayout(device=device(type='cuda', index=0), size=[0], stride=[0])
    buf18.mutations = ['buf16']
    buf18.users = [
        NodeUser(node=ExternKernelSchedulerNode(name='op14'), can_inplace=False, is_weak=False),
        NodeUser(node=ExternKernelSchedulerNode(name='op20'), can_inplace=False, is_weak=True),
        NodeUser(node=ExternKernelSchedulerNode(name='op24'), can_inplace=False, is_weak=True),
        NodeUser(node=ExternKernelSchedulerNode(name='op31'), can_inplace=False, is_weak=True),
        NodeUser(node=ExternKernelSchedulerNode(name='op35'), can_inplace=False, is_weak=True),
        NodeUser(node=ExternKernelSchedulerNode(name='op42'), can_inplace=False, is_weak=True),
        NodeUser(node=ExternKernelSchedulerNode(name='op46'), can_inplace=False, is_weak=True),
        NodeUser(node=ExternKernelSchedulerNode(name='op53'), can_inplace=False, is_weak=True),
    ]
    buf19: MutationOutput
    buf19.layout = NoneLayout(device=device(type='cuda', index=0), size=[0], stride=[0])
    buf19.mutations = ['buf1']
    buf19.users = [NodeUser(node=ExternKernelSchedulerNode(name='op20'), can_inplace=False, is_weak=False)]
]
op13.node.kernel = torch.ops._C.fused_add_rms_norm.default
```
Here we can see `buf16` shares the same dependency list with `buf1` because `buf16` and `buf1` are in the aliases list of `buf17`. This is incorrect since those two are two separate tensors. And this makes the compiler could not reuse `buf16` for subsequent ops.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157133
Approved by: https://github.com/jansel
2025-07-11 09:06:31 +00:00
..
__init__.py [draft-export] Include guards for constraint violation errors (#138748) 2024-10-30 00:24:17 +00:00
_internal.py [Bugfix][Inductor] Fix dependency list merged incorrectly for a custom op with multiple mutated inputs and None return type. (#157133) 2025-07-11 09:06:31 +00:00
_registrations.py [Bugfix][Inductor] Fix dependency list merged incorrectly for a custom op with multiple mutated inputs and None return type. (#157133) 2025-07-11 09:06:31 +00:00
scribe.py PEP585 update - mostly toplevels (#145178) 2025-01-22 02:21:14 +00:00
structured.py [BE][PYFMT] migrate PYFMT for torch/_[a-h]*/ to ruff format (#144551) 2025-06-25 06:16:06 +00:00