pytorch/torch/autograd/function.py
Sam Gross 980300b381 Combine autograd.Leaf and autograd.Variable (#52)
Prior to this change, there was a circular reference between Leaf and
Variable. This means that the objects (and referenced Tensors) are not
collected as soon as they go out of scope, which lead to higher memory
usage and out-of-memory errors.
2016-09-25 20:21:14 -04:00

118 lines
4.3 KiB
Python

from collections import OrderedDict
from itertools import chain
from .variable import Variable
class Function(object):
def __init__(self):
self.previous_functions = None
self.output_ids = None
self.needs_input_grad = None
self.saved_variables = None
self.to_save = None
self.non_differentiable = None
self.backward_hooks = OrderedDict()
def __call__(self, *input):
return self._do_forward(*input)
def save_for_backward(self, *tensors):
self.to_save = tensors
def mark_dirty(self, *args):
dirty_set = set(args)
for var in self.input:
if var.data in dirty_set:
var.mark_dirty()
def mark_non_differentiable(self, *args):
self.non_differentiable = set(args)
@property
def saved_tensors(self):
return tuple(arg.data for arg in self.saved_variables)
def _do_forward(self, *input):
unpacked_input = tuple(arg.data for arg in input)
is_volatile = any(arg.volatile for arg in input)
# Save the input, so _save_for_backward can access it
self.input = input
if not is_volatile:
self.needs_input_grad = tuple(arg.requires_grad for arg in input)
self.requires_grad = any(self.needs_input_grad)
self.previous_functions = [(arg.creator or arg, id(arg)) for arg in input]
raw_output = self.forward(*unpacked_input)
if not isinstance(raw_output, tuple):
raw_output = (raw_output,)
if is_volatile:
output = tuple(Variable(tensor, volatile=True)
for tensor in raw_output)
else:
output = tuple(Variable(tensor, self, requires_grad=self.requires_grad)
for tensor in raw_output)
self.output_ids = {id(var): i for i, var in enumerate(output)}
if self.to_save:
# output has to be chained after input, so if the same tensor
# appears both in the input and output (happens for in-place
# function), we save the clean output variable.
#
# Some variables might have been changed in-place, so accessing
# their .data will throw. If they also occur in the output
# these references will be overwritten by clean variables,
# if now, they'll raise an error on backward.
t2var = {var._data: var for var in chain(input, output)}
self.saved_variables = tuple(t2var[t] for t in self.to_save)
del self.to_save
if self.non_differentiable is not None:
for var in output:
if var.data in self.non_differentiable:
var.requires_grad = False
del self.input # Remove unnecessary references to input
del self.non_differentiable # and output
if len(output) == 1:
output = output[0]
return output
def _do_backward(self, grad_output, retain_variables):
grad_input = self.backward(*grad_output)
if not isinstance(grad_input, tuple):
grad_input = (grad_input,)
assert len(grad_input) == len(self.previous_functions), \
self.__class__.__name__ + ' returned an invalid number of gradient tensors'
self._call_hooks(grad_input, grad_output)
if not retain_variables:
del self.saved_variables
return grad_input
def _call_hooks(self, grad_input, grad_output):
for hook in self.backward_hooks.values():
hook(grad_input, grad_output)
def register_hook(self, name, hook):
assert name not in self.backward_hooks, \
"Trying to register a second hook with name {}".format(name)
self.backward_hooks[name] = hook
def remove_hook(self, name):
assert name in self.backward_hooks, \
"Trying to remove an inexistent hook with name {}".format(name)
del self.backward_hooks[name]
def forward(self, *input):
raise NotImplementedError
def backward(self, *grad_output):
raise NotImplementedError
class InplaceFunction(Function):
def __init__(self, inplace=False):
super(InplaceFunction, self).__init__()
self.inplace = inplace