mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
swa avoid stream sync (#157705)
Summary:
When AveragedModel updates_parameters it calls self.n_averaged == 0 for each parameter, where n_averated is a buffer on GPU. Moving check before the cycle to call sync once
It improves update_parameter from 74ms to 57ms ~22% improvement
{F1980011097}
{F1980011111}
Test Plan:
CI
Rollback Plan:
Differential Revision: D77723025
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157705
Approved by: https://github.com/albanD, https://github.com/Skylion007, https://github.com/janeyx99
This commit is contained in:
parent
c2510fcd86
commit
2efa5eaa65
|
|
@ -259,11 +259,12 @@ class AveragedModel(Module):
|
|||
)
|
||||
self_param_detached: list[Optional[Tensor]] = []
|
||||
model_param_detached: list[Optional[Tensor]] = []
|
||||
copy_param = bool(self.n_averaged == 0)
|
||||
for p_averaged, p_model in zip(self_param, model_param):
|
||||
p_model_ = p_model.detach().to(p_averaged.device)
|
||||
self_param_detached.append(p_averaged.detach())
|
||||
model_param_detached.append(p_model_)
|
||||
if self.n_averaged == 0:
|
||||
if copy_param:
|
||||
p_averaged.detach().copy_(p_model_)
|
||||
|
||||
if self.n_averaged > 0:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user