mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Improves autograd performance by more than 2x and fixes a couple of bugs. All core functions have been moved to C.
46 lines
1.2 KiB
Python
46 lines
1.2 KiB
Python
import torch._C as _C
|
|
from collections import OrderedDict
|
|
from itertools import chain
|
|
|
|
|
|
class Function(_C._FunctionBase):
|
|
|
|
__call__ = _C._FunctionBase._do_forward
|
|
|
|
def save_for_backward(self, *tensors):
|
|
self.to_save = tensors
|
|
|
|
def mark_dirty(self, *args):
|
|
self.dirty_tensors = args
|
|
|
|
def mark_shared_storage(self, *pairs):
|
|
self.shared_pairs = pairs
|
|
|
|
def mark_non_differentiable(self, *args):
|
|
self.non_differentiable = args
|
|
|
|
def register_hook(self, name, hook):
|
|
self.backward_hooks = self.backward_hooks or OrderedDict()
|
|
assert name not in self.backward_hooks, \
|
|
"Trying to register a second hook with name {}".format(name)
|
|
self.backward_hooks[name] = hook
|
|
|
|
def remove_hook(self, name):
|
|
assert self.backward_hooks and name in self.backward_hooks, \
|
|
"Trying to remove an inexistent hook with name {}".format(name)
|
|
del self.backward_hooks[name]
|
|
|
|
def forward(self, *input):
|
|
raise NotImplementedError
|
|
|
|
def backward(self, *grad_output):
|
|
raise NotImplementedError
|
|
|
|
|
|
class InplaceFunction(Function):
|
|
|
|
def __init__(self, inplace=False):
|
|
super(InplaceFunction, self).__init__()
|
|
self.inplace = inplace
|
|
|