mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
The `non_blocking` arg here is useless if the values are all eagerly consumed, so revert the change. Pull Request resolved: https://github.com/pytorch/pytorch/pull/124013 Approved by: https://github.com/ezyang
This commit is contained in:
parent
fea1b99d89
commit
9c4fc5fa34
|
|
@ -426,8 +426,10 @@ class GradScaler:
|
|||
found_inf = cast(
|
||||
torch.Tensor,
|
||||
sum(
|
||||
t.to(scaler.device, non_blocking=True)
|
||||
for t in optimizer_state["found_inf_per_device"].values()
|
||||
[ # noqa: C419
|
||||
t.to(scaler.device, non_blocking=True)
|
||||
for t in optimizer_state["found_inf_per_device"].values()
|
||||
]
|
||||
),
|
||||
)
|
||||
optimizer.grad_scale = ( # type: ignore[attr-defined]
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user