mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Fixes #108629 1. Add the following to their modules' `__all__` so that pyright considers them to be publicly exported: * [`torch.autocast`](https://pytorch.org/docs/stable/amp.html#torch.autocast) * [`torch.cuda.amp.GradScaler`](https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler) * [`torch.cuda.amp.autocast`](https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast) * [`torch.cuda.amp.custom_fwd`](https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.custom_fwd) * [`torch.cuda.amp.custom_bwd`](https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.custom_bwd) 2. Add `overload`s for `torch.cuda.amp.GradScaler.scale` to differentiate when a `torch.Tensor` is returned vs. an `Iterable[torch.Tensor]` is returned based on the type of the `outputs` parameter. Pull Request resolved: https://github.com/pytorch/pytorch/pull/108630 Approved by: https://github.com/ezyang |
||
|---|---|---|
| .. | ||
| amp | ||
| __init__.py | ||
| _memory_viz.py | ||
| _sanitizer.py | ||
| _utils.py | ||
| comm.py | ||
| error.py | ||
| graphs.py | ||
| jiterator.py | ||
| memory.py | ||
| nccl.py | ||
| nvtx.py | ||
| profiler.py | ||
| random.py | ||
| sparse.py | ||
| streams.py | ||