mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
131 lines
4.2 KiB
Python
131 lines
4.2 KiB
Python
import torch
|
|
import torch._C as _C
|
|
from collections import OrderedDict
|
|
from itertools import chain
|
|
|
|
|
|
class Function(_C._FunctionBase):
|
|
|
|
__call__ = _C._FunctionBase._do_forward
|
|
|
|
def save_for_backward(self, *tensors):
|
|
self.to_save = tensors
|
|
|
|
def mark_dirty(self, *args):
|
|
self.dirty_tensors = args
|
|
|
|
def mark_shared_storage(self, *pairs):
|
|
self.shared_pairs = pairs
|
|
|
|
def mark_non_differentiable(self, *args):
|
|
self.non_differentiable = args
|
|
|
|
def register_hook(self, name, hook):
|
|
self._backward_hooks = self._backward_hooks or OrderedDict()
|
|
assert name not in self._backward_hooks, \
|
|
"Trying to register a second hook with name {}".format(name)
|
|
self._backward_hooks[name] = hook
|
|
|
|
def remove_hook(self, name):
|
|
assert self._backward_hooks and name in self._backward_hooks, \
|
|
"Trying to remove an inexistent hook with name {}".format(name)
|
|
del self._backward_hooks[name]
|
|
|
|
def forward(self, *input):
|
|
raise NotImplementedError
|
|
|
|
def backward(self, *grad_output):
|
|
raise NotImplementedError
|
|
|
|
|
|
class InplaceFunction(Function):
|
|
|
|
def __init__(self, inplace=False):
|
|
super(InplaceFunction, self).__init__()
|
|
self.inplace = inplace
|
|
|
|
|
|
def _nested_map(condition, fn):
|
|
def _map(obj):
|
|
if condition(obj):
|
|
return fn(obj)
|
|
elif obj is None:
|
|
return None
|
|
elif isinstance(obj, (list, tuple)):
|
|
return type(obj)(_map(x) for x in obj)
|
|
else:
|
|
raise ValueError("NestedIOFunction doesn't know how to process "
|
|
"an input object of type " + torch.typename(obj))
|
|
return _map
|
|
|
|
def _iter_filter(condition):
|
|
def _iter(obj):
|
|
if condition(obj):
|
|
yield obj
|
|
elif obj is None:
|
|
return
|
|
elif isinstance(obj, (list, tuple)):
|
|
for o in obj:
|
|
for var in _iter(o):
|
|
yield var
|
|
else:
|
|
raise ValueError("NestedIOFunction doesn't know how to process "
|
|
"an input object of type " + torch.typename(obj))
|
|
return _iter
|
|
|
|
|
|
_iter_variables = _iter_filter(lambda o: isinstance(o, torch.autograd.Variable))
|
|
_iter_tensors = _iter_filter(torch.is_tensor)
|
|
_iter_None_tensors = _iter_filter(lambda o: o is None or torch.is_tensor(o))
|
|
_map_variable_tensor = _nested_map(lambda o: isinstance(o, torch.autograd.Variable), lambda o: o.data)
|
|
|
|
def _map_tensor_fromiter(itr):
|
|
return _nested_map(lambda o: torch.is_tensor(o), lambda o: next(itr))
|
|
|
|
class NestedIOFunction(Function):
|
|
|
|
def _do_forward(self, *input):
|
|
self._nested_input = input
|
|
flat_input = tuple(_iter_variables(input))
|
|
flat_output = super(NestedIOFunction, self)._do_forward(*flat_input)
|
|
nested_output = self._nested_output
|
|
nested_variables = _map_tensor_fromiter(iter(flat_output))(self._nested_output)
|
|
return nested_variables
|
|
|
|
def backward(self, *gradients):
|
|
nested_gradients = _map_tensor_fromiter(iter(gradients))(self._nested_output)
|
|
del self._nested_output
|
|
result = self.backward_extended(*nested_gradients)
|
|
del self._to_save_nested
|
|
return tuple(_iter_None_tensors(result))
|
|
|
|
__call__ = _do_forward
|
|
|
|
def forward(self, *args):
|
|
nested_tensors = _map_variable_tensor(self._nested_input)
|
|
result = self.forward_extended(*nested_tensors)
|
|
del self._nested_input
|
|
self._nested_output = result
|
|
return tuple(_iter_tensors(result))
|
|
|
|
def save_for_backward(self, *args):
|
|
self.to_save = tuple(_iter_tensors(args))
|
|
self._to_save_nested = args
|
|
|
|
@property
|
|
def saved_tensors(self):
|
|
flat_tensors = super(NestedIOFunction, self).saved_tensors
|
|
return _map_tensor_fromiter(iter(flat_tensors))(self._to_save_nested)
|
|
|
|
def mark_dirty(self, *args, **kwargs):
|
|
self.dirty_tensors = tuple(_iter_tensors((args, kwargs)))
|
|
|
|
def mark_non_differentiable(self, *args, **kwargs):
|
|
self.non_differentiable = tuple(_iter_tensors((args, kwargs)))
|
|
|
|
def forward_extended(self, *input):
|
|
raise NotImplementedError
|
|
|
|
def backward_extended(self, *grad_output):
|
|
raise NotImplementedError
|