mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[PyTorch/d2go] fix optim _multi_tensor (#73215)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73215
Fixing an issue in optimizers from _multi_tensor, for `sgd_mt` introduced in 2cb03e926f
Reviewed By: mikaylagawarecki
Differential Revision: D34389034
fbshipit-source-id: ede153d52dca15909c6c022853589707f18dc8d1
This commit is contained in:
parent
a43ea6d9fc
commit
cc8a58e584
|
|
@ -4,17 +4,25 @@ Most commonly used methods are already supported, and the interface is general
|
|||
enough, so that more sophisticated ones can be also easily integrated in the
|
||||
future.
|
||||
"""
|
||||
from functools import partial
|
||||
from functools import partialmethod
|
||||
from torch import optim
|
||||
|
||||
Adam = partial(optim.Adam, foreach=True)
|
||||
AdamW = partial(optim.AdamW, foreach=True)
|
||||
NAdam = partial(optim.NAdam, foreach=True)
|
||||
SGD = partial(optim.SGD, foreach=True)
|
||||
RAdam = partial(optim.RAdam, foreach=True)
|
||||
RMSprop = partial(optim.RMSprop, foreach=True)
|
||||
Rprop = partial(optim.Rprop, foreach=True)
|
||||
ASGD = partial(optim.ASGD, foreach=True)
|
||||
Adamax = partial(optim.Adamax, foreach=True)
|
||||
Adadelta = partial(optim.Adadelta, foreach=True)
|
||||
Adagrad = partial(optim.Adagrad, foreach=True)
|
||||
def partialclass(cls, *args, **kwargs):
|
||||
|
||||
class NewCls(cls):
|
||||
__init__ = partialmethod(cls.__init__, *args, **kwargs)
|
||||
|
||||
return NewCls
|
||||
|
||||
|
||||
Adam = partialclass(optim.Adam, foreach=True)
|
||||
AdamW = partialclass(optim.AdamW, foreach=True)
|
||||
NAdam = partialclass(optim.NAdam, foreach=True)
|
||||
SGD = partialclass(optim.SGD, foreach=True)
|
||||
RAdam = partialclass(optim.RAdam, foreach=True)
|
||||
RMSprop = partialclass(optim.RMSprop, foreach=True)
|
||||
Rprop = partialclass(optim.Rprop, foreach=True)
|
||||
ASGD = partialclass(optim.ASGD, foreach=True)
|
||||
Adamax = partialclass(optim.Adamax, foreach=True)
|
||||
Adadelta = partialclass(optim.Adadelta, foreach=True)
|
||||
Adagrad = partialclass(optim.Adagrad, foreach=True)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user