mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Prior to this change, there was a circular reference between Leaf and Variable. This means that the objects (and referenced Tensors) are not collected as soon as they go out of scope, which lead to higher memory usage and out-of-memory errors.
118 lines
4.3 KiB
Python
118 lines
4.3 KiB
Python
from collections import OrderedDict
|
|
from itertools import chain
|
|
from .variable import Variable
|
|
|
|
|
|
class Function(object):
|
|
|
|
def __init__(self):
|
|
self.previous_functions = None
|
|
self.output_ids = None
|
|
self.needs_input_grad = None
|
|
self.saved_variables = None
|
|
self.to_save = None
|
|
self.non_differentiable = None
|
|
self.backward_hooks = OrderedDict()
|
|
|
|
def __call__(self, *input):
|
|
return self._do_forward(*input)
|
|
|
|
def save_for_backward(self, *tensors):
|
|
self.to_save = tensors
|
|
|
|
def mark_dirty(self, *args):
|
|
dirty_set = set(args)
|
|
for var in self.input:
|
|
if var.data in dirty_set:
|
|
var.mark_dirty()
|
|
|
|
def mark_non_differentiable(self, *args):
|
|
self.non_differentiable = set(args)
|
|
|
|
@property
|
|
def saved_tensors(self):
|
|
return tuple(arg.data for arg in self.saved_variables)
|
|
|
|
def _do_forward(self, *input):
|
|
unpacked_input = tuple(arg.data for arg in input)
|
|
is_volatile = any(arg.volatile for arg in input)
|
|
# Save the input, so _save_for_backward can access it
|
|
self.input = input
|
|
if not is_volatile:
|
|
self.needs_input_grad = tuple(arg.requires_grad for arg in input)
|
|
self.requires_grad = any(self.needs_input_grad)
|
|
self.previous_functions = [(arg.creator or arg, id(arg)) for arg in input]
|
|
|
|
raw_output = self.forward(*unpacked_input)
|
|
if not isinstance(raw_output, tuple):
|
|
raw_output = (raw_output,)
|
|
|
|
if is_volatile:
|
|
output = tuple(Variable(tensor, volatile=True)
|
|
for tensor in raw_output)
|
|
else:
|
|
output = tuple(Variable(tensor, self, requires_grad=self.requires_grad)
|
|
for tensor in raw_output)
|
|
self.output_ids = {id(var): i for i, var in enumerate(output)}
|
|
if self.to_save:
|
|
# output has to be chained after input, so if the same tensor
|
|
# appears both in the input and output (happens for in-place
|
|
# function), we save the clean output variable.
|
|
#
|
|
# Some variables might have been changed in-place, so accessing
|
|
# their .data will throw. If they also occur in the output
|
|
# these references will be overwritten by clean variables,
|
|
# if now, they'll raise an error on backward.
|
|
t2var = {var._data: var for var in chain(input, output)}
|
|
self.saved_variables = tuple(t2var[t] for t in self.to_save)
|
|
del self.to_save
|
|
if self.non_differentiable is not None:
|
|
for var in output:
|
|
if var.data in self.non_differentiable:
|
|
var.requires_grad = False
|
|
|
|
del self.input # Remove unnecessary references to input
|
|
del self.non_differentiable # and output
|
|
if len(output) == 1:
|
|
output = output[0]
|
|
return output
|
|
|
|
def _do_backward(self, grad_output, retain_variables):
|
|
grad_input = self.backward(*grad_output)
|
|
if not isinstance(grad_input, tuple):
|
|
grad_input = (grad_input,)
|
|
assert len(grad_input) == len(self.previous_functions), \
|
|
self.__class__.__name__ + ' returned an invalid number of gradient tensors'
|
|
|
|
self._call_hooks(grad_input, grad_output)
|
|
if not retain_variables:
|
|
del self.saved_variables
|
|
return grad_input
|
|
|
|
def _call_hooks(self, grad_input, grad_output):
|
|
for hook in self.backward_hooks.values():
|
|
hook(grad_input, grad_output)
|
|
|
|
def register_hook(self, name, hook):
|
|
assert name not in self.backward_hooks, \
|
|
"Trying to register a second hook with name {}".format(name)
|
|
self.backward_hooks[name] = hook
|
|
|
|
def remove_hook(self, name):
|
|
assert name in self.backward_hooks, \
|
|
"Trying to remove an inexistent hook with name {}".format(name)
|
|
del self.backward_hooks[name]
|
|
|
|
def forward(self, *input):
|
|
raise NotImplementedError
|
|
|
|
def backward(self, *grad_output):
|
|
raise NotImplementedError
|
|
|
|
|
|
class InplaceFunction(Function):
|
|
|
|
def __init__(self, inplace=False):
|
|
super(InplaceFunction, self).__init__()
|
|
self.inplace = inplace
|