import torch import torch._C as _C from collections import OrderedDict from itertools import chain class Function(_C._FunctionBase): __call__ = _C._FunctionBase._do_forward def save_for_backward(self, *tensors): self.to_save = tensors def mark_dirty(self, *args): self.dirty_tensors = args def mark_shared_storage(self, *pairs): self.shared_pairs = pairs def mark_non_differentiable(self, *args): self.non_differentiable = args def register_hook(self, name, hook): self._backward_hooks = self._backward_hooks or OrderedDict() assert name not in self._backward_hooks, \ "Trying to register a second hook with name {}".format(name) self._backward_hooks[name] = hook def remove_hook(self, name): assert self._backward_hooks and name in self._backward_hooks, \ "Trying to remove an inexistent hook with name {}".format(name) del self._backward_hooks[name] def forward(self, *input): raise NotImplementedError def backward(self, *grad_output): raise NotImplementedError class InplaceFunction(Function): def __init__(self, inplace=False): super(InplaceFunction, self).__init__() self.inplace = inplace def _nested_map(condition, fn): def _map(obj): if condition(obj): return fn(obj) elif obj is None: return None elif isinstance(obj, (list, tuple)): return type(obj)(_map(x) for x in obj) else: raise ValueError("NestedIOFunction doesn't know how to process " "an input object of type " + torch.typename(obj)) return _map def _iter_filter(condition): def _iter(obj): if condition(obj): yield obj elif obj is None: return elif isinstance(obj, (list, tuple)): for o in obj: for var in _iter(o): yield var else: raise ValueError("NestedIOFunction doesn't know how to process " "an input object of type " + torch.typename(obj)) return _iter _iter_variables = _iter_filter(lambda o: isinstance(o, torch.autograd.Variable)) _iter_tensors = _iter_filter(torch.is_tensor) _iter_None_tensors = _iter_filter(lambda o: o is None or torch.is_tensor(o)) _map_variable_tensor = _nested_map(lambda o: isinstance(o, torch.autograd.Variable), lambda o: o.data) def _map_tensor_fromiter(itr): return _nested_map(lambda o: torch.is_tensor(o), lambda o: next(itr)) class NestedIOFunction(Function): def _do_forward(self, *input): self._nested_input = input flat_input = tuple(_iter_variables(input)) flat_output = super(NestedIOFunction, self)._do_forward(*flat_input) nested_output = self._nested_output nested_variables = _map_tensor_fromiter(iter(flat_output))(self._nested_output) return nested_variables def backward(self, *gradients): nested_gradients = _map_tensor_fromiter(iter(gradients))(self._nested_output) del self._nested_output result = self.backward_extended(*nested_gradients) del self._to_save_nested return tuple(_iter_None_tensors(result)) __call__ = _do_forward def forward(self, *args): nested_tensors = _map_variable_tensor(self._nested_input) result = self.forward_extended(*nested_tensors) del self._nested_input self._nested_output = result return tuple(_iter_tensors(result)) def save_for_backward(self, *args): self.to_save = tuple(_iter_tensors(args)) self._to_save_nested = args @property def saved_tensors(self): flat_tensors = super(NestedIOFunction, self).saved_tensors return _map_tensor_fromiter(iter(flat_tensors))(self._to_save_nested) def mark_dirty(self, *args, **kwargs): self.dirty_tensors = tuple(_iter_tensors((args, kwargs))) def mark_non_differentiable(self, *args, **kwargs): self.non_differentiable = tuple(_iter_tensors((args, kwargs))) def forward_extended(self, *input): raise NotImplementedError def backward_extended(self, *grad_output): raise NotImplementedError