pytorch/torch
soulitzer dca73982c5 Support setting grad_dtype on leaf tensors (#162815)
`grad_dtype` is a new attribute on Tensor to control gradient dtype:
- Access/setting is leaf-only.
- grad_dtype is respected when (1) when assigning to .grad, and (2) in the engine after the previous node produces incoming gradients for AccumulateGrad. (See table below for details)
- Not setting grad_dtype preserves the current behavior. Accessing it returns `t.dtype`
- `grad_dtype` cannot be set when there is already a `.grad` present and the dtypes conflict.

| `grad_dtype` setting | Setting `.grad` manually | Incoming gradient from autograd engine |
|-----------------------|--------------------------|-----------------------------------------|
| **Default (tensor’s dtype)** | `.grad` must match tensor’s dtype | Engine casts incoming grad to tensor’s dtype |
| **Set to specific dtype** | `.grad` must match that dtype | Engine casts incoming grad to the specified dtype |
| **Set to `None`** | `.grad` may be any dtype | Engine does not cast; accepts incoming grad dtype as-is |

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162815
Approved by: https://github.com/albanD
2025-10-02 23:09:07 +00:00
..
_awaits
_C Support setting grad_dtype on leaf tensors (#162815) 2025-10-02 23:09:07 +00:00
_C_flatbuffer
_custom_op Add initial suppressions for pyrefly (#164177) 2025-10-02 20:57:41 +00:00
_decomp [5/N] Apply ruff UP035 rule (#164423) 2025-10-02 07:31:11 +00:00
_dispatch Add initial suppressions for pyrefly (#164177) 2025-10-02 20:57:41 +00:00
_dynamo Revert "C++-accessible Placements via pybind11 (#163030)" 2025-10-02 18:25:24 +00:00
_export [RELAND v2] Close some sources of fake tensors (#164372) 2025-10-02 18:58:52 +00:00
_functorch [RELAND v2] Close some sources of fake tensors (#164372) 2025-10-02 18:58:52 +00:00
_higher_order_ops [5/N] Apply ruff UP035 rule (#164423) 2025-10-02 07:31:11 +00:00
_inductor Improved support for autotuning in wrapper_fxir (#164132) 2025-10-02 22:54:22 +00:00
_lazy Add initial suppressions for pyrefly (#164177) 2025-10-02 20:57:41 +00:00
_library [5/N] Apply ruff UP035 rule (#164423) 2025-10-02 07:31:11 +00:00
_logging [5/N] Apply ruff UP035 rule (#164423) 2025-10-02 07:31:11 +00:00
_numpy Remove unnecessary list comprehensions (#164103) 2025-09-30 03:56:54 +00:00
_prims [5/N] Apply ruff UP035 rule (#164423) 2025-10-02 07:31:11 +00:00
_prims_common [5/N] Apply ruff UP035 rule (#164423) 2025-10-02 07:31:11 +00:00
_refs [5/N] Apply ruff UP035 rule (#164423) 2025-10-02 07:31:11 +00:00
_strobelight Add initial suppressions for pyrefly (#164177) 2025-10-02 20:57:41 +00:00
_subclasses [5/N] Apply ruff UP035 rule (#164423) 2025-10-02 07:31:11 +00:00
_vendor
accelerator
amp Add initial suppressions for pyrefly (#164177) 2025-10-02 20:57:41 +00:00
ao [4/N] Apply ruff UP035 rule to python code (#164206) 2025-10-01 19:05:53 +00:00
autograd [4/N] Apply ruff UP035 rule to python code (#164206) 2025-10-01 19:05:53 +00:00
backends Add SVE128 ISA (#158932) 2025-09-29 14:49:19 +00:00
compiler [5/N] Apply ruff UP035 rule (#164423) 2025-10-02 07:31:11 +00:00
contrib
cpu Add initial suppressions for pyrefly (#164177) 2025-10-02 20:57:41 +00:00
csrc Support setting grad_dtype on leaf tensors (#162815) 2025-10-02 23:09:07 +00:00
cuda [4/N] Apply ruff UP035 rule to python code (#164206) 2025-10-01 19:05:53 +00:00
distributed Revert "[DeviceMesh] Simplifying internal bookkeeping with CuTe layout (#163213)" 2025-10-02 22:22:26 +00:00
distributions [5/N] Apply ruff UP035 rule (#164423) 2025-10-02 07:31:11 +00:00
export [RELAND v2] Close some sources of fake tensors (#164372) 2025-10-02 18:58:52 +00:00
fft
func
futures [5/N] Apply ruff UP035 rule (#164423) 2025-10-02 07:31:11 +00:00
fx add tensor subclass printing support in fx/graph.py (#164403) 2025-10-02 20:06:12 +00:00
headeronly Migrate DeviceType to torch/headeronly (#163999) 2025-09-30 23:13:27 +00:00
jit [5/N] Apply ruff UP035 rule (#164423) 2025-10-02 07:31:11 +00:00
legacy
lib
linalg Add initial suppressions for pyrefly (#164177) 2025-10-02 20:57:41 +00:00
masked [5/N] Apply ruff UP035 rule (#164423) 2025-10-02 07:31:11 +00:00
monitor
mps Add type annotations to MPS profiler utilities (#163486) 2025-09-27 23:00:53 +00:00
mtia Add initial suppressions for pyrefly (#164177) 2025-10-02 20:57:41 +00:00
multiprocessing Add initial suppressions for pyrefly (#164177) 2025-10-02 20:57:41 +00:00
nativert Add SVE128 ISA (#158932) 2025-09-29 14:49:19 +00:00
nested Add initial suppressions for pyrefly (#164177) 2025-10-02 20:57:41 +00:00
nn [1/N] Fix ruff warnings (#164333) 2025-10-01 16:48:32 +00:00
numa Add initial suppressions for pyrefly (#164177) 2025-10-02 20:57:41 +00:00
onnx [5/N] Apply ruff UP035 rule (#164423) 2025-10-02 07:31:11 +00:00
optim [5/N] Apply ruff UP035 rule (#164423) 2025-10-02 07:31:11 +00:00
package Add initial suppressions for pyrefly (#164177) 2025-10-02 20:57:41 +00:00
profiler [5/N] Apply ruff UP035 rule (#164423) 2025-10-02 07:31:11 +00:00
quantization Add initial suppressions for pyrefly (#164177) 2025-10-02 20:57:41 +00:00
signal Add initial suppressions for pyrefly (#164177) 2025-10-02 20:57:41 +00:00
sparse [5/N] Apply ruff UP035 rule (#164423) 2025-10-02 07:31:11 +00:00
special
testing Stop parsing command line arguments every time common_utils is imported. (#156703) 2025-10-02 22:22:04 +00:00
utils Fix FloorDiv should not generate non integer rationals (due to sympy bug) (#164398) 2025-10-02 22:51:03 +00:00
xpu Add initial suppressions for pyrefly (#164177) 2025-10-02 20:57:41 +00:00
__config__.py
__future__.py
__init__.py Add initial suppressions for pyrefly (#164177) 2025-10-02 20:57:41 +00:00
_appdirs.py
_classes.py
_compile.py [4/N] Apply ruff UP035 rule to python code (#164206) 2025-10-01 19:05:53 +00:00
_custom_ops.py
_environment.py
_guards.py [4/N] Apply ruff UP035 rule to python code (#164206) 2025-10-01 19:05:53 +00:00
_jit_internal.py Add initial suppressions for pyrefly (#164177) 2025-10-02 20:57:41 +00:00
_linalg_utils.py
_lobpcg.py Add initial suppressions for pyrefly (#164177) 2025-10-02 20:57:41 +00:00
_lowrank.py
_meta_registrations.py [4/N] Apply ruff UP035 rule to python code (#164206) 2025-10-01 19:05:53 +00:00
_namedtensor_internals.py
_ops.py Add initial suppressions for pyrefly (#164177) 2025-10-02 20:57:41 +00:00
_python_dispatcher.py
_size_docs.py
_sources.py
_storage_docs.py
_streambase.py
_tensor_docs.py Support setting grad_dtype on leaf tensors (#162815) 2025-10-02 23:09:07 +00:00
_tensor_str.py [1/N] Fix ruff warnings (#164333) 2025-10-01 16:48:32 +00:00
_tensor.py Add initial suppressions for pyrefly (#164177) 2025-10-02 20:57:41 +00:00
_thread_safe_fork.py
_torch_docs.py Update documentation for torch.index_select (#163616) 2025-09-25 18:29:17 +00:00
_utils_internal.py Add initial suppressions for pyrefly (#164177) 2025-10-02 20:57:41 +00:00
_utils.py Add initial suppressions for pyrefly (#164177) 2025-10-02 20:57:41 +00:00
_VF.py
_vmap_internals.py [4/N] Apply ruff UP035 rule to python code (#164206) 2025-10-01 19:05:53 +00:00
_weights_only_unpickler.py [4/N] Apply ruff UP035 rule to python code (#164206) 2025-10-01 19:05:53 +00:00
CMakeLists.txt Revert "[RELAND] Always build USE_DISTRIBUTED (#160449) and Make distributed modules importable even when backend not built (#159889) (#162594)" 2025-09-25 13:47:46 +00:00
custom_class_detail.h
custom_class.h
extension.h
functional.py Add initial suppressions for pyrefly (#164177) 2025-10-02 20:57:41 +00:00
header_only_apis.txt Migrate DeviceType to torch/headeronly (#163999) 2025-09-30 23:13:27 +00:00
hub.py Add initial suppressions for pyrefly (#164177) 2025-10-02 20:57:41 +00:00
library.h
library.py Add initial suppressions for pyrefly (#164177) 2025-10-02 20:57:41 +00:00
overrides.py Support setting grad_dtype on leaf tensors (#162815) 2025-10-02 23:09:07 +00:00
py.typed
quasirandom.py
random.py
return_types.py
script.h
serialization.py Add initial suppressions for pyrefly (#164177) 2025-10-02 20:57:41 +00:00
storage.py Add initial suppressions for pyrefly (#164177) 2025-10-02 20:57:41 +00:00
torch_version.py
types.py [4/N] Apply ruff UP035 rule to python code (#164206) 2025-10-01 19:05:53 +00:00
version.py.tpl