mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/44405 Test Plan: Imported from OSS Reviewed By: agolynski Differential Revision: D23783987 Pulled By: albanD fbshipit-source-id: 5018b0d381cb09301d2f88a98a910854f740ace1
648 lines
23 KiB
Python
648 lines
23 KiB
Python
import warnings
|
|
from collections import OrderedDict
|
|
from torch._six import container_abcs
|
|
from itertools import islice
|
|
import operator
|
|
|
|
import torch
|
|
from .module import Module
|
|
from torch._jit_internal import _copy_to_script_wrapper
|
|
|
|
from typing import Any, Iterable, Iterator, Mapping, Optional, overload, Tuple, TypeVar, Union
|
|
|
|
|
|
T = TypeVar('T')
|
|
|
|
|
|
class Container(Module):
|
|
|
|
def __init__(self, **kwargs: Any) -> None:
|
|
super(Container, self).__init__()
|
|
# DeprecationWarning is ignored by default <sigh>
|
|
warnings.warn("nn.Container is deprecated. All of it's functionality "
|
|
"is now implemented in nn.Module. Subclass that instead.")
|
|
for key, value in kwargs.items():
|
|
self.add_module(key, value)
|
|
|
|
|
|
class Sequential(Module):
|
|
r"""A sequential container.
|
|
Modules will be added to it in the order they are passed in the constructor.
|
|
Alternatively, an ordered dict of modules can also be passed in.
|
|
|
|
To make it easier to understand, here is a small example::
|
|
|
|
# Example of using Sequential
|
|
model = nn.Sequential(
|
|
nn.Conv2d(1,20,5),
|
|
nn.ReLU(),
|
|
nn.Conv2d(20,64,5),
|
|
nn.ReLU()
|
|
)
|
|
|
|
# Example of using Sequential with OrderedDict
|
|
model = nn.Sequential(OrderedDict([
|
|
('conv1', nn.Conv2d(1,20,5)),
|
|
('relu1', nn.ReLU()),
|
|
('conv2', nn.Conv2d(20,64,5)),
|
|
('relu2', nn.ReLU())
|
|
]))
|
|
"""
|
|
|
|
@overload
|
|
def __init__(self, *args: Module) -> None:
|
|
...
|
|
|
|
@overload
|
|
def __init__(self, arg: 'OrderedDict[str, Module]') -> None:
|
|
...
|
|
|
|
def __init__(self, *args: Any):
|
|
super(Sequential, self).__init__()
|
|
if len(args) == 1 and isinstance(args[0], OrderedDict):
|
|
for key, module in args[0].items():
|
|
self.add_module(key, module)
|
|
else:
|
|
for idx, module in enumerate(args):
|
|
self.add_module(str(idx), module)
|
|
|
|
def _get_item_by_idx(self, iterator, idx):
|
|
"""Get the idx-th item of the iterator"""
|
|
size = len(self)
|
|
idx = operator.index(idx)
|
|
if not -size <= idx < size:
|
|
raise IndexError('index {} is out of range'.format(idx))
|
|
idx %= size
|
|
return next(islice(iterator, idx, None))
|
|
|
|
@_copy_to_script_wrapper
|
|
def __getitem__(self: T, idx) -> T:
|
|
if isinstance(idx, slice):
|
|
return self.__class__(OrderedDict(list(self._modules.items())[idx]))
|
|
else:
|
|
return self._get_item_by_idx(self._modules.values(), idx)
|
|
|
|
def __setitem__(self, idx: int, module: Module) -> None:
|
|
key = self._get_item_by_idx(self._modules.keys(), idx)
|
|
return setattr(self, key, module)
|
|
|
|
def __delitem__(self, idx: Union[slice, int]) -> None:
|
|
if isinstance(idx, slice):
|
|
for key in list(self._modules.keys())[idx]:
|
|
delattr(self, key)
|
|
else:
|
|
key = self._get_item_by_idx(self._modules.keys(), idx)
|
|
delattr(self, key)
|
|
|
|
@_copy_to_script_wrapper
|
|
def __len__(self) -> int:
|
|
return len(self._modules)
|
|
|
|
@_copy_to_script_wrapper
|
|
def __dir__(self):
|
|
keys = super(Sequential, self).__dir__()
|
|
keys = [key for key in keys if not key.isdigit()]
|
|
return keys
|
|
|
|
@_copy_to_script_wrapper
|
|
def __iter__(self) -> Iterator[Module]:
|
|
return iter(self._modules.values())
|
|
|
|
# NB: We can't really type check this function as the type of input
|
|
# may change dynamically (as is tested in
|
|
# TestScript.test_sequential_intermediary_types). Cannot annotate
|
|
# with Any as TorchScript expects a more precise type
|
|
def forward(self, input):
|
|
for module in self:
|
|
input = module(input)
|
|
return input
|
|
|
|
|
|
class ModuleList(Module):
|
|
r"""Holds submodules in a list.
|
|
|
|
:class:`~torch.nn.ModuleList` can be indexed like a regular Python list, but
|
|
modules it contains are properly registered, and will be visible by all
|
|
:class:`~torch.nn.Module` methods.
|
|
|
|
Arguments:
|
|
modules (iterable, optional): an iterable of modules to add
|
|
|
|
Example::
|
|
|
|
class MyModule(nn.Module):
|
|
def __init__(self):
|
|
super(MyModule, self).__init__()
|
|
self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])
|
|
|
|
def forward(self, x):
|
|
# ModuleList can act as an iterable, or be indexed using ints
|
|
for i, l in enumerate(self.linears):
|
|
x = self.linears[i // 2](x) + l(x)
|
|
return x
|
|
"""
|
|
|
|
def __init__(self, modules: Optional[Iterable[Module]] = None) -> None:
|
|
super(ModuleList, self).__init__()
|
|
if modules is not None:
|
|
self += modules
|
|
|
|
def _get_abs_string_index(self, idx):
|
|
"""Get the absolute index for the list of modules"""
|
|
idx = operator.index(idx)
|
|
if not (-len(self) <= idx < len(self)):
|
|
raise IndexError('index {} is out of range'.format(idx))
|
|
if idx < 0:
|
|
idx += len(self)
|
|
return str(idx)
|
|
|
|
@_copy_to_script_wrapper
|
|
def __getitem__(self, idx: int) -> Module:
|
|
if isinstance(idx, slice):
|
|
return self.__class__(list(self._modules.values())[idx])
|
|
else:
|
|
return self._modules[self._get_abs_string_index(idx)]
|
|
|
|
def __setitem__(self, idx: int, module: Module) -> None:
|
|
idx = self._get_abs_string_index(idx)
|
|
return setattr(self, str(idx), module)
|
|
|
|
def __delitem__(self, idx: Union[int, slice]) -> None:
|
|
if isinstance(idx, slice):
|
|
for k in range(len(self._modules))[idx]:
|
|
delattr(self, str(k))
|
|
else:
|
|
delattr(self, self._get_abs_string_index(idx))
|
|
# To preserve numbering, self._modules is being reconstructed with modules after deletion
|
|
str_indices = [str(i) for i in range(len(self._modules))]
|
|
self._modules = OrderedDict(list(zip(str_indices, self._modules.values())))
|
|
|
|
@_copy_to_script_wrapper
|
|
def __len__(self) -> int:
|
|
return len(self._modules)
|
|
|
|
@_copy_to_script_wrapper
|
|
def __iter__(self) -> Iterator[Module]:
|
|
return iter(self._modules.values())
|
|
|
|
def __iadd__(self: T, modules: Iterable[Module]) -> T:
|
|
return self.extend(modules)
|
|
|
|
@_copy_to_script_wrapper
|
|
def __dir__(self):
|
|
keys = super(ModuleList, self).__dir__()
|
|
keys = [key for key in keys if not key.isdigit()]
|
|
return keys
|
|
|
|
def insert(self, index: int, module: Module) -> None:
|
|
r"""Insert a given module before a given index in the list.
|
|
|
|
Arguments:
|
|
index (int): index to insert.
|
|
module (nn.Module): module to insert
|
|
"""
|
|
for i in range(len(self._modules), index, -1):
|
|
self._modules[str(i)] = self._modules[str(i - 1)]
|
|
self._modules[str(index)] = module
|
|
|
|
def append(self: T, module: Module) -> T:
|
|
r"""Appends a given module to the end of the list.
|
|
|
|
Arguments:
|
|
module (nn.Module): module to append
|
|
"""
|
|
self.add_module(str(len(self)), module)
|
|
return self
|
|
|
|
def extend(self: T, modules: Iterable[Module]) -> T:
|
|
r"""Appends modules from a Python iterable to the end of the list.
|
|
|
|
Arguments:
|
|
modules (iterable): iterable of modules to append
|
|
"""
|
|
if not isinstance(modules, container_abcs.Iterable):
|
|
raise TypeError("ModuleList.extend should be called with an "
|
|
"iterable, but got " + type(modules).__name__)
|
|
offset = len(self)
|
|
for i, module in enumerate(modules):
|
|
self.add_module(str(offset + i), module)
|
|
return self
|
|
|
|
def forward(self):
|
|
raise NotImplementedError()
|
|
|
|
|
|
class ModuleDict(Module):
|
|
r"""Holds submodules in a dictionary.
|
|
|
|
:class:`~torch.nn.ModuleDict` can be indexed like a regular Python dictionary,
|
|
but modules it contains are properly registered, and will be visible by all
|
|
:class:`~torch.nn.Module` methods.
|
|
|
|
:class:`~torch.nn.ModuleDict` is an **ordered** dictionary that respects
|
|
|
|
* the order of insertion, and
|
|
|
|
* in :meth:`~torch.nn.ModuleDict.update`, the order of the merged
|
|
``OrderedDict``, ``dict`` (started from Python 3.6) or another
|
|
:class:`~torch.nn.ModuleDict` (the argument to
|
|
:meth:`~torch.nn.ModuleDict.update`).
|
|
|
|
Note that :meth:`~torch.nn.ModuleDict.update` with other unordered mapping
|
|
types (e.g., Python's plain ``dict`` before Python version 3.6) does not
|
|
preserve the order of the merged mapping.
|
|
|
|
Arguments:
|
|
modules (iterable, optional): a mapping (dictionary) of (string: module)
|
|
or an iterable of key-value pairs of type (string, module)
|
|
|
|
Example::
|
|
|
|
class MyModule(nn.Module):
|
|
def __init__(self):
|
|
super(MyModule, self).__init__()
|
|
self.choices = nn.ModuleDict({
|
|
'conv': nn.Conv2d(10, 10, 3),
|
|
'pool': nn.MaxPool2d(3)
|
|
})
|
|
self.activations = nn.ModuleDict([
|
|
['lrelu', nn.LeakyReLU()],
|
|
['prelu', nn.PReLU()]
|
|
])
|
|
|
|
def forward(self, x, choice, act):
|
|
x = self.choices[choice](x)
|
|
x = self.activations[act](x)
|
|
return x
|
|
"""
|
|
|
|
def __init__(self, modules: Optional[Mapping[str, Module]] = None) -> None:
|
|
super(ModuleDict, self).__init__()
|
|
if modules is not None:
|
|
self.update(modules)
|
|
|
|
@_copy_to_script_wrapper
|
|
def __getitem__(self, key: str) -> Module:
|
|
return self._modules[key]
|
|
|
|
def __setitem__(self, key: str, module: Module) -> None:
|
|
self.add_module(key, module)
|
|
|
|
def __delitem__(self, key: str) -> None:
|
|
del self._modules[key]
|
|
|
|
@_copy_to_script_wrapper
|
|
def __len__(self) -> int:
|
|
return len(self._modules)
|
|
|
|
@_copy_to_script_wrapper
|
|
def __iter__(self) -> Iterator[str]:
|
|
return iter(self._modules)
|
|
|
|
@_copy_to_script_wrapper
|
|
def __contains__(self, key: str) -> bool:
|
|
return key in self._modules
|
|
|
|
def clear(self) -> None:
|
|
"""Remove all items from the ModuleDict.
|
|
"""
|
|
self._modules.clear()
|
|
|
|
def pop(self, key: str) -> Module:
|
|
r"""Remove key from the ModuleDict and return its module.
|
|
|
|
Arguments:
|
|
key (string): key to pop from the ModuleDict
|
|
"""
|
|
v = self[key]
|
|
del self[key]
|
|
return v
|
|
|
|
@_copy_to_script_wrapper
|
|
def keys(self) -> Iterable[str]:
|
|
r"""Return an iterable of the ModuleDict keys.
|
|
"""
|
|
return self._modules.keys()
|
|
|
|
@_copy_to_script_wrapper
|
|
def items(self) -> Iterable[Tuple[str, Module]]:
|
|
r"""Return an iterable of the ModuleDict key/value pairs.
|
|
"""
|
|
return self._modules.items()
|
|
|
|
@_copy_to_script_wrapper
|
|
def values(self) -> Iterable[Module]:
|
|
r"""Return an iterable of the ModuleDict values.
|
|
"""
|
|
return self._modules.values()
|
|
|
|
def update(self, modules: Mapping[str, Module]) -> None:
|
|
r"""Update the :class:`~torch.nn.ModuleDict` with the key-value pairs from a
|
|
mapping or an iterable, overwriting existing keys.
|
|
|
|
.. note::
|
|
If :attr:`modules` is an ``OrderedDict``, a :class:`~torch.nn.ModuleDict`, or
|
|
an iterable of key-value pairs, the order of new elements in it is preserved.
|
|
|
|
Arguments:
|
|
modules (iterable): a mapping (dictionary) from string to :class:`~torch.nn.Module`,
|
|
or an iterable of key-value pairs of type (string, :class:`~torch.nn.Module`)
|
|
"""
|
|
if not isinstance(modules, container_abcs.Iterable):
|
|
raise TypeError("ModuleDict.update should be called with an "
|
|
"iterable of key/value pairs, but got " +
|
|
type(modules).__name__)
|
|
|
|
if isinstance(modules, (OrderedDict, ModuleDict, container_abcs.Mapping)):
|
|
for key, module in modules.items():
|
|
self[key] = module
|
|
else:
|
|
for j, m in enumerate(modules):
|
|
if not isinstance(m, container_abcs.Iterable):
|
|
raise TypeError("ModuleDict update sequence element "
|
|
"#" + str(j) + " should be Iterable; is" +
|
|
type(m).__name__)
|
|
if not len(m) == 2:
|
|
raise ValueError("ModuleDict update sequence element "
|
|
"#" + str(j) + " has length " + str(len(m)) +
|
|
"; 2 is required")
|
|
self[m[0]] = m[1]
|
|
|
|
def forward(self):
|
|
raise NotImplementedError()
|
|
|
|
|
|
class ParameterList(Module):
|
|
r"""Holds parameters in a list.
|
|
|
|
:class:`~torch.nn.ParameterList` can be indexed like a regular Python
|
|
list, but parameters it contains are properly registered, and will be
|
|
visible by all :class:`~torch.nn.Module` methods.
|
|
|
|
Arguments:
|
|
parameters (iterable, optional): an iterable of :class:`~torch.nn.Parameter` to add
|
|
|
|
Example::
|
|
|
|
class MyModule(nn.Module):
|
|
def __init__(self):
|
|
super(MyModule, self).__init__()
|
|
self.params = nn.ParameterList([nn.Parameter(torch.randn(10, 10)) for i in range(10)])
|
|
|
|
def forward(self, x):
|
|
# ParameterList can act as an iterable, or be indexed using ints
|
|
for i, p in enumerate(self.params):
|
|
x = self.params[i // 2].mm(x) + p.mm(x)
|
|
return x
|
|
"""
|
|
|
|
def __init__(self, parameters: Optional[Iterable['Parameter']] = None) -> None:
|
|
super(ParameterList, self).__init__()
|
|
if parameters is not None:
|
|
self += parameters
|
|
|
|
def _get_abs_string_index(self, idx):
|
|
"""Get the absolute index for the list of modules"""
|
|
idx = operator.index(idx)
|
|
if not (-len(self) <= idx < len(self)):
|
|
raise IndexError('index {} is out of range'.format(idx))
|
|
if idx < 0:
|
|
idx += len(self)
|
|
return str(idx)
|
|
|
|
@overload
|
|
def __getitem__(self, idx: int) -> 'Parameter':
|
|
...
|
|
|
|
@overload
|
|
def __getitem__(self: T, idx: slice) -> T:
|
|
...
|
|
|
|
def __getitem__(self, idx):
|
|
if isinstance(idx, slice):
|
|
return self.__class__(list(self._parameters.values())[idx])
|
|
else:
|
|
idx = self._get_abs_string_index(idx)
|
|
return self._parameters[str(idx)]
|
|
|
|
def __setitem__(self, idx: int, param: 'Parameter') -> None:
|
|
idx = self._get_abs_string_index(idx)
|
|
return self.register_parameter(str(idx), param)
|
|
|
|
def __setattr__(self, key: Any, value: Any) -> None:
|
|
if not isinstance(value, torch.nn.Parameter):
|
|
warnings.warn("Setting attributes on ParameterList is not supported.")
|
|
super(ParameterList, self).__setattr__(key, value)
|
|
|
|
def __len__(self) -> int:
|
|
return len(self._parameters)
|
|
|
|
def __iter__(self) -> Iterator['Parameter']:
|
|
return iter(self._parameters.values())
|
|
|
|
def __iadd__(self: T, parameters: Iterable['Parameter']) -> T:
|
|
return self.extend(parameters)
|
|
|
|
def __dir__(self):
|
|
keys = super(ParameterList, self).__dir__()
|
|
keys = [key for key in keys if not key.isdigit()]
|
|
return keys
|
|
|
|
def append(self: T, parameter: 'Parameter') -> T:
|
|
"""Appends a given parameter at the end of the list.
|
|
|
|
Arguments:
|
|
parameter (nn.Parameter): parameter to append
|
|
"""
|
|
self.register_parameter(str(len(self)), parameter)
|
|
return self
|
|
|
|
def extend(self: T, parameters: Iterable['Parameter']) -> T:
|
|
"""Appends parameters from a Python iterable to the end of the list.
|
|
|
|
Arguments:
|
|
parameters (iterable): iterable of parameters to append
|
|
"""
|
|
if not isinstance(parameters, container_abcs.Iterable):
|
|
raise TypeError("ParameterList.extend should be called with an "
|
|
"iterable, but got " + type(parameters).__name__)
|
|
offset = len(self)
|
|
for i, param in enumerate(parameters):
|
|
self.register_parameter(str(offset + i), param)
|
|
return self
|
|
|
|
def extra_repr(self) -> str:
|
|
child_lines = []
|
|
for k, p in self._parameters.items():
|
|
size_str = 'x'.join(str(size) for size in p.size())
|
|
device_str = '' if not p.is_cuda else ' (GPU {})'.format(p.get_device())
|
|
parastr = 'Parameter containing: [{} of size {}{}]'.format(
|
|
torch.typename(p), size_str, device_str)
|
|
child_lines.append(' (' + str(k) + '): ' + parastr)
|
|
tmpstr = '\n'.join(child_lines)
|
|
return tmpstr
|
|
|
|
def __call__(self, input):
|
|
raise RuntimeError('ParameterList should not be called.')
|
|
|
|
def _replicate_for_data_parallel(self):
|
|
warnings.warn("nn.ParameterList is being used with DataParallel but this is not "
|
|
"supported. This list will appear empty for the models replicated "
|
|
"on each GPU except the original one.")
|
|
|
|
return super(ParameterList, self)._replicate_for_data_parallel()
|
|
|
|
|
|
class ParameterDict(Module):
|
|
r"""Holds parameters in a dictionary.
|
|
|
|
ParameterDict can be indexed like a regular Python dictionary, but parameters it
|
|
contains are properly registered, and will be visible by all Module methods.
|
|
|
|
:class:`~torch.nn.ParameterDict` is an **ordered** dictionary that respects
|
|
|
|
* the order of insertion, and
|
|
|
|
* in :meth:`~torch.nn.ParameterDict.update`, the order of the merged ``OrderedDict``
|
|
or another :class:`~torch.nn.ParameterDict` (the argument to
|
|
:meth:`~torch.nn.ParameterDict.update`).
|
|
|
|
Note that :meth:`~torch.nn.ParameterDict.update` with other unordered mapping
|
|
types (e.g., Python's plain ``dict``) does not preserve the order of the
|
|
merged mapping.
|
|
|
|
Arguments:
|
|
parameters (iterable, optional): a mapping (dictionary) of
|
|
(string : :class:`~torch.nn.Parameter`) or an iterable of key-value pairs
|
|
of type (string, :class:`~torch.nn.Parameter`)
|
|
|
|
Example::
|
|
|
|
class MyModule(nn.Module):
|
|
def __init__(self):
|
|
super(MyModule, self).__init__()
|
|
self.params = nn.ParameterDict({
|
|
'left': nn.Parameter(torch.randn(5, 10)),
|
|
'right': nn.Parameter(torch.randn(5, 10))
|
|
})
|
|
|
|
def forward(self, x, choice):
|
|
x = self.params[choice].mm(x)
|
|
return x
|
|
"""
|
|
|
|
def __init__(self, parameters: Optional[Mapping[str, 'Parameter']] = None) -> None:
|
|
super(ParameterDict, self).__init__()
|
|
if parameters is not None:
|
|
self.update(parameters)
|
|
|
|
def __getitem__(self, key: str) -> 'Parameter':
|
|
return self._parameters[key]
|
|
|
|
def __setitem__(self, key: str, parameter: 'Parameter') -> None:
|
|
self.register_parameter(key, parameter)
|
|
|
|
def __delitem__(self, key: str) -> None:
|
|
del self._parameters[key]
|
|
|
|
def __setattr__(self, key: Any, value: Any) -> None:
|
|
if not isinstance(value, torch.nn.Parameter):
|
|
warnings.warn("Setting attributes on ParameterDict is not supported.")
|
|
super(ParameterDict, self).__setattr__(key, value)
|
|
|
|
def __len__(self) -> int:
|
|
return len(self._parameters)
|
|
|
|
def __iter__(self) -> Iterator[str]:
|
|
return iter(self._parameters.keys())
|
|
|
|
def __contains__(self, key: str) -> bool:
|
|
return key in self._parameters
|
|
|
|
def clear(self) -> None:
|
|
"""Remove all items from the ParameterDict.
|
|
"""
|
|
self._parameters.clear()
|
|
|
|
def pop(self, key: str) -> 'Parameter':
|
|
r"""Remove key from the ParameterDict and return its parameter.
|
|
|
|
Arguments:
|
|
key (string): key to pop from the ParameterDict
|
|
"""
|
|
v = self[key]
|
|
del self[key]
|
|
return v
|
|
|
|
def keys(self) -> Iterable[str]:
|
|
r"""Return an iterable of the ParameterDict keys.
|
|
"""
|
|
return self._parameters.keys()
|
|
|
|
def items(self) -> Iterable[Tuple[str, 'Parameter']]:
|
|
r"""Return an iterable of the ParameterDict key/value pairs.
|
|
"""
|
|
return self._parameters.items()
|
|
|
|
def values(self) -> Iterable['Parameter']:
|
|
r"""Return an iterable of the ParameterDict values.
|
|
"""
|
|
return self._parameters.values()
|
|
|
|
def update(self, parameters: Mapping[str, 'Parameter']) -> None:
|
|
r"""Update the :class:`~torch.nn.ParameterDict` with the key-value pairs from a
|
|
mapping or an iterable, overwriting existing keys.
|
|
|
|
.. note::
|
|
If :attr:`parameters` is an ``OrderedDict``, a :class:`~torch.nn.ParameterDict`, or
|
|
an iterable of key-value pairs, the order of new elements in it is preserved.
|
|
|
|
Arguments:
|
|
parameters (iterable): a mapping (dictionary) from string to
|
|
:class:`~torch.nn.Parameter`, or an iterable of
|
|
key-value pairs of type (string, :class:`~torch.nn.Parameter`)
|
|
"""
|
|
if not isinstance(parameters, container_abcs.Iterable):
|
|
raise TypeError("ParametersDict.update should be called with an "
|
|
"iterable of key/value pairs, but got " +
|
|
type(parameters).__name__)
|
|
|
|
if isinstance(parameters, (OrderedDict, ParameterDict)):
|
|
for key, parameter in parameters.items():
|
|
self[key] = parameter
|
|
elif isinstance(parameters, container_abcs.Mapping):
|
|
for key, parameter in sorted(parameters.items()):
|
|
self[key] = parameter
|
|
else:
|
|
for j, p in enumerate(parameters):
|
|
if not isinstance(p, container_abcs.Iterable):
|
|
raise TypeError("ParameterDict update sequence element "
|
|
"#" + str(j) + " should be Iterable; is" +
|
|
type(p).__name__)
|
|
if not len(p) == 2:
|
|
raise ValueError("ParameterDict update sequence element "
|
|
"#" + str(j) + " has length " + str(len(p)) +
|
|
"; 2 is required")
|
|
self[p[0]] = p[1]
|
|
|
|
def extra_repr(self) -> str:
|
|
child_lines = []
|
|
for k, p in self._parameters.items():
|
|
size_str = 'x'.join(str(size) for size in p.size())
|
|
device_str = '' if not p.is_cuda else ' (GPU {})'.format(p.get_device())
|
|
parastr = 'Parameter containing: [{} of size {}{}]'.format(
|
|
torch.typename(p), size_str, device_str)
|
|
child_lines.append(' (' + k + '): ' + parastr)
|
|
tmpstr = '\n'.join(child_lines)
|
|
return tmpstr
|
|
|
|
def __call__(self, input):
|
|
raise RuntimeError('ParameterDict should not be called.')
|
|
|
|
def _replicate_for_data_parallel(self):
|
|
warnings.warn("nn.ParameterDict is being used with DataParallel but this is not "
|
|
"supported. This dict will appear empty for the models replicated "
|
|
"on each GPU except the original one.")
|
|
|
|
return super(ParameterDict, self)._replicate_for_data_parallel()
|