mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Fixes #108629 1. Add the following to their modules' `__all__` so that pyright considers them to be publicly exported: * [`torch.autocast`](https://pytorch.org/docs/stable/amp.html#torch.autocast) * [`torch.cuda.amp.GradScaler`](https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler) * [`torch.cuda.amp.autocast`](https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast) * [`torch.cuda.amp.custom_fwd`](https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.custom_fwd) * [`torch.cuda.amp.custom_bwd`](https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.custom_bwd) 2. Add `overload`s for `torch.cuda.amp.GradScaler.scale` to differentiate when a `torch.Tensor` is returned vs. an `Iterable[torch.Tensor]` is returned based on the type of the `outputs` parameter. Pull Request resolved: https://github.com/pytorch/pytorch/pull/108630 Approved by: https://github.com/ezyang |
||
|---|---|---|
| .. | ||
| __init__.py | ||
| autocast_mode.py | ||
| common.py | ||
| grad_scaler.py | ||