pytorch/torch/nn/parameter.py
Samuel Marks e6779d4357 [*.py] Rename "Arguments:" to "Args:" (#49736)
Summary:
I've written custom parsers and emitters for everything from docstrings to classes and functions. However, I recently came across an issue when I was parsing/generating from the TensorFlow codebase: inconsistent use of `Args:` and `Arguments:` in its docstrings.

```sh
(pytorch#c348fae)$ for name in 'Args:' 'Arguments:'; do
    printf '%-10s %04d\n' "$name" "$(rg -IFtpy --count-matches "$name" | paste -s -d+ -- | bc)"; done
Args:      1095
Arguments: 0336
```

It is easy enough to extend my parsers to support both variants, however it looks like `Arguments:` is wrong anyway, as per:

  - https://google.github.io/styleguide/pyguide.html#doc-function-args @ [`ddccc0f`](https://github.com/google/styleguide/blob/ddccc0f/pyguide.md)

  - https://chromium.googlesource.com/chromiumos/docs/+/master/styleguide/python.md#describing-arguments-in-docstrings @ [`9fc0fc0`](https://chromium.googlesource.com/chromiumos/docs/+/9fc0fc0/styleguide/python.md)

  - https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html @ [`c0ae8e3`](https://github.com/sphinx-contrib/napoleon/blob/c0ae8e3/docs/source/example_google.rst)

Therefore, only `Args:` is valid. This PR replaces them throughout the codebase.

PS: For related PRs, see tensorflow/tensorflow/pull/45420

PPS: The trackbacks automatically appearing below are sending the same changes to other repositories in the [PyTorch](https://github.com/pytorch) organisation.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/49736

Reviewed By: albanD

Differential Revision: D25710534

Pulled By: soumith

fbshipit-source-id: 61e8ff01abb433e9f78185c2d1d0cbd7c22c1619
2020-12-28 09:34:47 -08:00

141 lines
5.6 KiB
Python

import torch
from torch._C import _disabled_torch_function_impl
from collections import OrderedDict
class Parameter(torch.Tensor):
r"""A kind of Tensor that is to be considered a module parameter.
Parameters are :class:`~torch.Tensor` subclasses, that have a
very special property when used with :class:`Module` s - when they're
assigned as Module attributes they are automatically added to the list of
its parameters, and will appear e.g. in :meth:`~Module.parameters` iterator.
Assigning a Tensor doesn't have such effect. This is because one might
want to cache some temporary state, like last hidden state of the RNN, in
the model. If there was no such class as :class:`Parameter`, these
temporaries would get registered too.
Args:
data (Tensor): parameter tensor.
requires_grad (bool, optional): if the parameter requires gradient. See
:ref:`excluding-subgraphs` for more details. Default: `True`
"""
def __new__(cls, data=None, requires_grad=True):
if data is None:
data = torch.Tensor()
return torch.Tensor._make_subclass(cls, data, requires_grad)
def __deepcopy__(self, memo):
if id(self) in memo:
return memo[id(self)]
else:
result = type(self)(self.data.clone(memory_format=torch.preserve_format), self.requires_grad)
memo[id(self)] = result
return result
def __repr__(self):
return 'Parameter containing:\n' + super(Parameter, self).__repr__()
def __reduce_ex__(self, proto):
# See Note [Don't serialize hooks]
return (
torch._utils._rebuild_parameter,
(self.data, self.requires_grad, OrderedDict())
)
__torch_function__ = _disabled_torch_function_impl
class UninitializedParameter(Parameter):
r"""A parameter that is not initialized.
Unitialized Parameters are a a special case of :class:`torch.nn.Parameter`
where the shape of the data is still unknown.
Unlikely a :class:`torch.nn.Parameter`, uninitialized parameters
hold no data and attempting to access some properties, like their shape,
will throw a runtime error. The only operations that can be performed on a uninitialized
parameter are changing its datatype, moving it to a different device and
converting it to a regular :class:`torch.nn.Parameter`.
"""
_allowed_methods = [
torch.Tensor.__hash__,
torch.Tensor.size,
torch.Tensor.copy_,
torch.Tensor.is_floating_point,
torch.Tensor.half,
torch.Tensor.float,
torch.Tensor.double,
torch.Tensor.char,
torch.Tensor.short,
torch.Tensor.int,
torch.Tensor.long,
torch.Tensor.cuda,
torch.Tensor.cpu,
torch.Tensor.to,
torch.Tensor.get_device,
torch._has_compatible_shallow_copy_type]
def __new__(cls, requires_grad=True):
data = torch.Tensor()
return torch.Tensor._make_subclass(cls, data, requires_grad)
def materialize(self, shape, device=None, dtype=None):
r"""Create a Parameter with the same properties of the uninitialized one.
Given a shape, it materializes a parameter in the same device
and with the same `dtype` as the current one or the specified ones in the
arguments.
Args:
shape : (tuple): the shape for the materialized tensor.
device (:class:`torch.device`): the desired device of the parameters
and buffers in this module. Optional.
dtype (:class:`torch.dtype`): the desired floating point type of
the floating point parameters and buffers in this module. Optional.
"""
if device is None:
device = self.data.device
if dtype is None:
dtype = self.data.dtype
self.data = torch.empty(shape, device=device, dtype=dtype)
self.__class__ = Parameter
@property
def shape(self):
raise RuntimeError(
'Can\'t access the shape of an uninitialized parameter. '
'This error usually happens in `load_state_dict` when trying to load '
'an uninitialized parameter into an initialized one. '
'Call `forward` to initialize the parameters before accessing their attributes.')
def share_memory_(self):
raise RuntimeError(
'Can\'t share memory on an uninitialized parameter. '
'Call `forward` to initialize the parameters before calling '
'`module.share_memory()`.')
def __repr__(self):
return 'Uninitialized parameter'
def __reduce_ex__(self, proto):
# See Note [Don't serialize hooks]
return (
UninitializedParameter,
(self.requires_grad,)
)
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
# method-wrapper is to detect access to Tensor properties that are
# wrapped in descriptors
if func in cls._allowed_methods or func.__class__.__name__ == 'method-wrapper':
if kwargs is None:
kwargs = {}
return super().__torch_function__(func, types, args, kwargs)
raise ValueError(
'Attempted to use an uninitialized parameter in {}. '
'This error happens when you are using a `LazyModule` or '
'explicitly manipulating `torch.nn.parameter.UninitializedParameter` '
'objects. When using LazyModules Call `forward` with a dummy batch '
'to initialize the parameters before calling torch functions'.format(func))