mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
62 lines
2.3 KiB
Python
62 lines
2.3 KiB
Python
from collections import OrderedDict
|
|
from .variable import Variable
|
|
|
|
class Function(object):
|
|
|
|
def __init__(self):
|
|
self.previous_functions = None
|
|
self.output_ids = None
|
|
self.needs_input_grad = None
|
|
self.backward_hooks = OrderedDict()
|
|
|
|
def __call__(self, *input):
|
|
return self._do_forward(*input)
|
|
|
|
def _do_forward(self, *input):
|
|
unpacked_input = tuple(arg.data for arg in input)
|
|
is_volatile = any(arg.volatile for arg in input)
|
|
raw_output = self.forward(*unpacked_input)
|
|
if not isinstance(raw_output, tuple):
|
|
raw_output = (raw_output,)
|
|
|
|
if is_volatile:
|
|
output = tuple(Variable(tensor, volatile=True) for tensor in raw_output)
|
|
else:
|
|
self.needs_input_grad = tuple(arg.creator.requires_grad for arg in input)
|
|
self.requires_grad = any(self.needs_input_grad)
|
|
self.previous_functions = [(arg.creator, id(arg)) for arg in input]
|
|
output = tuple(Variable(tensor, self) for tensor in raw_output)
|
|
self.output_ids = {id(var): i for i, var in enumerate(output)}
|
|
|
|
return output
|
|
|
|
def _do_backward(self, *grad_output):
|
|
grad_input = self.backward(*grad_output)
|
|
if not isinstance(grad_input, tuple):
|
|
grad_input = (grad_input,)
|
|
assert len(grad_input) == len(self.previous_functions), \
|
|
self.__class__.__name__ + ' returned an invalid number of gradient tensors'
|
|
|
|
for hook, idx in self.backward_hooks.values():
|
|
gi = grad_input if idx is None else grad_input[idx]
|
|
hook(grad_input, grad_output)
|
|
|
|
return grad_input
|
|
|
|
def register_hook(self, name, hook, variable=None):
|
|
assert name not in self.backward_hooks, \
|
|
"Trying to register a second hook with name {}".format(name)
|
|
variable_idx = self.output_ids[id(variable)] if variable else None
|
|
self.backward_hooks[name] = (hook, variable_idx)
|
|
|
|
def remove_hook(self, name):
|
|
assert name in self.backward_hooks, \
|
|
"Trying to remove an inexistent hook with name {}".format(name)
|
|
del self.backward_hooks[name]
|
|
|
|
def forward(self, *input):
|
|
raise NotImplementedError
|
|
|
|
def backward(self, *grad_output):
|
|
raise NotImplementedError
|