[PyTorch/d2go] fix optim _multi_tensor (#73215)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73215

Fixing an issue in optimizers from _multi_tensor, for `sgd_mt` introduced in 2cb03e926f

Reviewed By: mikaylagawarecki

Differential Revision: D34389034

fbshipit-source-id: ede153d52dca15909c6c022853589707f18dc8d1
This commit is contained in:
Jan Zikes 2022-02-23 02:07:29 -08:00 committed by Facebook GitHub Bot
parent a43ea6d9fc
commit cc8a58e584

View File

@ -4,17 +4,25 @@ Most commonly used methods are already supported, and the interface is general
enough, so that more sophisticated ones can be also easily integrated in the
future.
"""
from functools import partial
from functools import partialmethod
from torch import optim
Adam = partial(optim.Adam, foreach=True)
AdamW = partial(optim.AdamW, foreach=True)
NAdam = partial(optim.NAdam, foreach=True)
SGD = partial(optim.SGD, foreach=True)
RAdam = partial(optim.RAdam, foreach=True)
RMSprop = partial(optim.RMSprop, foreach=True)
Rprop = partial(optim.Rprop, foreach=True)
ASGD = partial(optim.ASGD, foreach=True)
Adamax = partial(optim.Adamax, foreach=True)
Adadelta = partial(optim.Adadelta, foreach=True)
Adagrad = partial(optim.Adagrad, foreach=True)
def partialclass(cls, *args, **kwargs):
class NewCls(cls):
__init__ = partialmethod(cls.__init__, *args, **kwargs)
return NewCls
Adam = partialclass(optim.Adam, foreach=True)
AdamW = partialclass(optim.AdamW, foreach=True)
NAdam = partialclass(optim.NAdam, foreach=True)
SGD = partialclass(optim.SGD, foreach=True)
RAdam = partialclass(optim.RAdam, foreach=True)
RMSprop = partialclass(optim.RMSprop, foreach=True)
Rprop = partialclass(optim.Rprop, foreach=True)
ASGD = partialclass(optim.ASGD, foreach=True)
Adamax = partialclass(optim.Adamax, foreach=True)
Adadelta = partialclass(optim.Adadelta, foreach=True)
Adagrad = partialclass(optim.Adagrad, foreach=True)