mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
127 lines
4.8 KiB
Python
127 lines
4.8 KiB
Python
from collections import OrderedDict
|
|
from itertools import chain
|
|
from .variable import Variable
|
|
|
|
|
|
class Function(object):
|
|
|
|
def __init__(self):
|
|
self.previous_functions = None
|
|
self.output_ids = None
|
|
self.needs_input_grad = None
|
|
self.saved_variables = None
|
|
self.to_save = None
|
|
self.non_differentiable = None
|
|
self.backward_hooks = OrderedDict()
|
|
|
|
def __call__(self, *input):
|
|
return self._do_forward(*input)
|
|
|
|
def save_for_backward(self, *tensors):
|
|
self.to_save = tensors
|
|
|
|
def mark_dirty(self, *args):
|
|
dirty_set = set(args)
|
|
for var in self.input:
|
|
if var.data in dirty_set:
|
|
var.mark_dirty()
|
|
|
|
def mark_non_differentiable(self, *args):
|
|
self.non_differentiable = set(args)
|
|
|
|
@property
|
|
def saved_tensors(self):
|
|
return tuple(arg.data for arg in self.saved_variables)
|
|
|
|
def _do_forward(self, *input):
|
|
for i in input:
|
|
if not isinstance(i, Variable):
|
|
raise RuntimeError("expected a Variable argument, but got " +
|
|
type(i).__name__)
|
|
unpacked_input = tuple(arg.data for arg in input)
|
|
is_volatile = any(arg.volatile for arg in input)
|
|
# Save the input, so _save_for_backward can access it
|
|
self.input = input
|
|
if not is_volatile:
|
|
self.needs_input_grad = tuple(arg.requires_grad for arg in input)
|
|
self.requires_grad = any(self.needs_input_grad)
|
|
self.previous_functions = [(arg.creator or arg, id(arg)) for arg in input]
|
|
|
|
raw_output = self.forward(*unpacked_input)
|
|
if not isinstance(raw_output, tuple):
|
|
raw_output = (raw_output,)
|
|
|
|
if is_volatile:
|
|
output = tuple(Variable(tensor, volatile=True)
|
|
for tensor in raw_output)
|
|
else:
|
|
output = tuple(Variable(tensor, self, requires_grad=self.requires_grad)
|
|
for tensor in raw_output)
|
|
self.output_ids = {id(var): i for i, var in enumerate(output)}
|
|
if self.to_save:
|
|
# output has to be chained after input, so if the same tensor
|
|
# appears both in the input and output (happens for in-place
|
|
# function), we save the clean output variable.
|
|
#
|
|
# Some variables might have been changed in-place, so accessing
|
|
# their .data will throw. If they also occur in the output
|
|
# these references will be overwritten by clean variables,
|
|
# if now, they'll raise an error on backward.
|
|
t2var = {var._data: var for var in chain(input, output)}
|
|
self.saved_variables = tuple(t2var[t] for t in self.to_save)
|
|
del self.to_save
|
|
if self.non_differentiable is not None:
|
|
for var in output:
|
|
if var.data in self.non_differentiable:
|
|
var.requires_grad = False
|
|
|
|
del self.input # Remove unnecessary references to input
|
|
del self.non_differentiable # and output
|
|
if len(output) == 1:
|
|
output = output[0]
|
|
return output
|
|
|
|
def _do_backward(self, grad_output, retain_variables):
|
|
if not hasattr(self, 'saved_variables'):
|
|
raise RuntimeError("Trying to backward through the graph second "
|
|
"time, but the buffers have already been freed. Please "
|
|
"specify retain_variables=True when calling backward for "
|
|
"the first time.")
|
|
grad_input = self.backward(*grad_output)
|
|
if not isinstance(grad_input, tuple):
|
|
grad_input = (grad_input,)
|
|
assert len(grad_input) == len(self.previous_functions), \
|
|
self.__class__.__name__ + ' returned an invalid number of gradient tensors'
|
|
|
|
self._call_hooks(grad_input, grad_output)
|
|
if not retain_variables:
|
|
del self.saved_variables
|
|
return grad_input
|
|
|
|
def _call_hooks(self, grad_input, grad_output):
|
|
for hook in self.backward_hooks.values():
|
|
hook(grad_input, grad_output)
|
|
|
|
def register_hook(self, name, hook):
|
|
assert name not in self.backward_hooks, \
|
|
"Trying to register a second hook with name {}".format(name)
|
|
self.backward_hooks[name] = hook
|
|
|
|
def remove_hook(self, name):
|
|
assert name in self.backward_hooks, \
|
|
"Trying to remove an inexistent hook with name {}".format(name)
|
|
del self.backward_hooks[name]
|
|
|
|
def forward(self, *input):
|
|
raise NotImplementedError
|
|
|
|
def backward(self, *grad_output):
|
|
raise NotImplementedError
|
|
|
|
|
|
class InplaceFunction(Function):
|
|
|
|
def __init__(self, inplace=False):
|
|
super(InplaceFunction, self).__init__()
|
|
self.inplace = inplace
|