pytorch/torch/testing
rzou d534a49767 Reinplace auto_functionalized (#120829)
Fixes https://github.com/pytorch/pytorch/issues/120441

We follow how triton_kernel_wrapper_functional gets re-inplaced:
- If we see auto_functionalized, then first we compute what inputs we
  actually need to clone ("tensors_to_clone") and fixup the graph. This happens in
  `reinplace_and_refine_tensors_to_clone`, which I have refactored out
  of the triton_kernel_wrapper_functional reinplacing code.
- Later on, after the reinplacing pass, we have a decomposition pass for
  auto_functionalized. In that decomposition pass, we make use of the
  "tensor_to_clone" info and only clone those inputs in the
  decomposition.
- We shepherd "tensor_to_clone" from the first step to the second step
  by setting the .meta field on the auto_functionalized node.

Test Plan:
- existing tests
- tested locally by reading the output of TORCH_LOGS="post_grad_graphs"
- added assertExpectedInline tests for the post_grad_graphs

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120829
Approved by: https://github.com/oulgen
2024-03-01 00:55:19 +00:00
..
_internal Reinplace auto_functionalized (#120829) 2024-03-01 00:55:19 +00:00
__init__.py
_comparison.py [BE]: Apply RUF025 dict.fromkeys preview rule (#118637) 2024-01-30 20:46:54 +00:00
_creation.py additional support for float8_e4m3fnuz and _e5m2fnuz (#115214) 2024-01-22 18:33:41 +00:00