mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Fixes https://github.com/pytorch/pytorch/issues/120441 We follow how triton_kernel_wrapper_functional gets re-inplaced: - If we see auto_functionalized, then first we compute what inputs we actually need to clone ("tensors_to_clone") and fixup the graph. This happens in `reinplace_and_refine_tensors_to_clone`, which I have refactored out of the triton_kernel_wrapper_functional reinplacing code. - Later on, after the reinplacing pass, we have a decomposition pass for auto_functionalized. In that decomposition pass, we make use of the "tensor_to_clone" info and only clone those inputs in the decomposition. - We shepherd "tensor_to_clone" from the first step to the second step by setting the .meta field on the auto_functionalized node. Test Plan: - existing tests - tested locally by reading the output of TORCH_LOGS="post_grad_graphs" - added assertExpectedInline tests for the post_grad_graphs Pull Request resolved: https://github.com/pytorch/pytorch/pull/120829 Approved by: https://github.com/oulgen |
||
|---|---|---|
| .. | ||
| _internal | ||
| __init__.py | ||
| _comparison.py | ||
| _creation.py | ||