mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
35 lines
1.4 KiB
Python
35 lines
1.4 KiB
Python
from torch.autograd import Variable
|
|
|
|
|
|
class Parameter(Variable):
|
|
"""A kind of Variable that is to be considered a module parameter.
|
|
|
|
Parameters are :class:`~torch.autograd.Variable` subclasses, that have a
|
|
very special property when used with :class:`Module` s - when they're
|
|
assigned as Module attributes they are automatically added to the list of
|
|
its parameters, and will appear e.g. in :meth:`~Module.parameters` iterator.
|
|
Assigning a Variable doesn't have such effect. This is because one might
|
|
want to cache some temporary state, like last hidden state of the RNN, in
|
|
the model. If there was no such class as :class:`Parameter`, these
|
|
temporaries would get registered too.
|
|
|
|
Another difference is that parameters can't be volatile and that they
|
|
require gradient by default.
|
|
|
|
Arguments:
|
|
data (Tensor): parameter tensor.
|
|
requires_grad (bool, optional): if the parameter requires gradient. See
|
|
:ref:`excluding-subgraphs` for more details.
|
|
"""
|
|
|
|
def __init__(self, data, requires_grad=True):
|
|
super(Parameter, self).__init__(data, requires_grad=requires_grad)
|
|
|
|
def __deepcopy__(self, memo):
|
|
result = type(self)(self.data.clone(), self.requires_grad)
|
|
memo[id(self)] = result
|
|
return result
|
|
|
|
def __repr__(self):
|
|
return 'Parameter containing:' + self.data.__repr__()
|