mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-08 07:39:33 +01:00
Summary: According to pytorch/rfcs#3 From the goals in the RFC: 1. Support subclassing `torch.Tensor` in Python (done here) 2. Preserve `torch.Tensor` subclasses when calling `torch` functions on them (done here) 3. Use the PyTorch API with `torch.Tensor`-like objects that are _not_ `torch.Tensor` subclasses (done in https://github.com/pytorch/pytorch/issues/30730) 4. Preserve `torch.Tensor` subclasses when calling `torch.Tensor` methods. (done here) 5. Propagating subclass instances correctly also with operators, using views/slices/indexing/etc. (done here) 6. Preserve subclass attributes when using methods or views/slices/indexing. (done here) 7. A way to insert code that operates on both functions and methods uniformly (so we can write a single function that overrides all operators). (done here) 8. The ability to give external libraries a way to also define functions/methods that follow the `__torch_function__` protocol. (will be addressed in a separate PR) This PR makes the following changes: 1. Adds the `self` argument to the arg parser. 2. Dispatches on `self` as well if `self` is not `nullptr`. 3. Adds a `torch._C.DisableTorchFunction` context manager to disable `__torch_function__`. 4. Adds a `torch::torch_function_enabled()` and `torch._C._torch_function_enabled()` to check the state of `__torch_function__`. 5. Dispatches all `torch._C.TensorBase` and `torch.Tensor` methods via `__torch_function__`. TODO: - [x] Sequence Methods - [x] Docs - [x] Tests Closes https://github.com/pytorch/pytorch/issues/28361 Benchmarks in https://github.com/pytorch/pytorch/pull/37091#issuecomment-633657778 Pull Request resolved: https://github.com/pytorch/pytorch/pull/37091 Reviewed By: ngimel Differential Revision: D22765678 Pulled By: ezyang fbshipit-source-id: 53f8aa17ddb8b1108c0997f6a7aa13cb5be73de0
47 lines
1.8 KiB
Python
47 lines
1.8 KiB
Python
import torch
|
|
from torch._C import _disabled_torch_function_impl
|
|
from collections import OrderedDict
|
|
|
|
|
|
class Parameter(torch.Tensor):
|
|
r"""A kind of Tensor that is to be considered a module parameter.
|
|
|
|
Parameters are :class:`~torch.Tensor` subclasses, that have a
|
|
very special property when used with :class:`Module` s - when they're
|
|
assigned as Module attributes they are automatically added to the list of
|
|
its parameters, and will appear e.g. in :meth:`~Module.parameters` iterator.
|
|
Assigning a Tensor doesn't have such effect. This is because one might
|
|
want to cache some temporary state, like last hidden state of the RNN, in
|
|
the model. If there was no such class as :class:`Parameter`, these
|
|
temporaries would get registered too.
|
|
|
|
Arguments:
|
|
data (Tensor): parameter tensor.
|
|
requires_grad (bool, optional): if the parameter requires gradient. See
|
|
:ref:`excluding-subgraphs` for more details. Default: `True`
|
|
"""
|
|
def __new__(cls, data=None, requires_grad=True):
|
|
if data is None:
|
|
data = torch.Tensor()
|
|
return torch.Tensor._make_subclass(cls, data, requires_grad)
|
|
|
|
def __deepcopy__(self, memo):
|
|
if id(self) in memo:
|
|
return memo[id(self)]
|
|
else:
|
|
result = type(self)(self.data.clone(memory_format=torch.preserve_format), self.requires_grad)
|
|
memo[id(self)] = result
|
|
return result
|
|
|
|
def __repr__(self):
|
|
return 'Parameter containing:\n' + super(Parameter, self).__repr__()
|
|
|
|
def __reduce_ex__(self, proto):
|
|
# See Note [Don't serialize hooks]
|
|
return (
|
|
torch._utils._rebuild_parameter,
|
|
(self.data, self.requires_grad, OrderedDict())
|
|
)
|
|
|
|
__torch_function__ = _disabled_torch_function_impl
|