mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
29 lines
672 B
Python
29 lines
672 B
Python
from copy import copy
|
|
from collections import defaultdict
|
|
|
|
class Optimizer(object):
|
|
|
|
def __init__(self, model):
|
|
self.model = model
|
|
self.state = defaultdict(dict)
|
|
self.parameters = list(self.model.parameters())
|
|
|
|
def __getstate__(self):
|
|
return {
|
|
'state': self.state,
|
|
'parameters': self.parameters,
|
|
}
|
|
|
|
def state_dict(self):
|
|
return self.__getstate__()
|
|
|
|
def _forward_backward(self, forward_closure):
|
|
self.model.zero_grad()
|
|
loss = forward_closure()
|
|
loss.backward()
|
|
return loss.data[0]
|
|
|
|
def step(self, forward_closure):
|
|
raise NotImplementedError
|
|
|