mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: This PR introduces a helper function named `torch.nn.utils.skip_init()` that accepts a module class object + `args` / `kwargs` and instantiates the module while skipping initialization of parameter / buffer values. See discussion at https://github.com/pytorch/pytorch/issues/29523 for more context. Example usage: ```python import torch m = torch.nn.utils.skip_init(torch.nn.Linear, 5, 1) print(m.weight) m2 = torch.nn.utils.skip_init(torch.nn.Linear, 5, 1, device='cuda') print(m2.weight) m3 = torch.nn.utils.skip_init(torch.nn.Linear, in_features=5, out_features=1) print(m3.weight) ``` ``` Parameter containing: tensor([[-3.3011e+28, 4.5915e-41, -3.3009e+28, 4.5915e-41, 0.0000e+00]], requires_grad=True) Parameter containing: tensor([[-2.5339e+27, 4.5915e-41, -2.5367e+27, 4.5915e-41, 0.0000e+00]], device='cuda:0', requires_grad=True) Parameter containing: tensor([[1.4013e-45, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00]], requires_grad=True) ``` Bikeshedding on the name / namespace is welcome, as well as comments on the design itself - just wanted to get something out there for discussion. Pull Request resolved: https://github.com/pytorch/pytorch/pull/57555 Reviewed By: zou3519 Differential Revision: D28640613 Pulled By: jbschlosser fbshipit-source-id: 5654f2e5af5530425ab7a9e357b6ba0d807e967f
52 lines
2.1 KiB
Python
52 lines
2.1 KiB
Python
import inspect
|
|
import torch
|
|
|
|
|
|
def skip_init(module_cls, *args, **kwargs):
|
|
r"""
|
|
Given a module class object and args / kwargs, instantiates the module without initializing
|
|
parameters / buffers. This can be useful if initialization is slow or if custom initialization will
|
|
be performed, making the default initialization unnecessary. There are some caveats to this, due to
|
|
the way this function is implemented:
|
|
|
|
1. The module must accept a `device` arg in its constructor that is passed to any parameters
|
|
or buffers created during construction.
|
|
|
|
2. The module must not perform any computation on parameters in its constructor except
|
|
initialization (i.e. functions from :mod:`torch.nn.init`).
|
|
|
|
If these conditions are satisfied, the module can be instantiated with parameter / buffer values
|
|
uninitialized, as if having been created using :func:`torch.empty`.
|
|
|
|
Args:
|
|
module_cls: Class object; should be a subclass of :class:`torch.nn.Module`
|
|
args: args to pass to the module's constructor
|
|
kwargs: kwargs to pass to the module's constructor
|
|
|
|
Returns:
|
|
Instantiated module with uninitialized parameters / buffers
|
|
|
|
Example::
|
|
|
|
>>> import torch
|
|
>>> m = torch.nn.utils.skip_init(torch.nn.Linear, 5, 1)
|
|
>>> m.weight
|
|
Parameter containing:
|
|
tensor([[0.0000e+00, 1.5846e+29, 7.8307e+00, 2.5250e-29, 1.1210e-44]],
|
|
requires_grad=True)
|
|
>>> m2 = torch.nn.utils.skip_init(torch.nn.Linear, in_features=6, out_features=1)
|
|
>>> m2.weight
|
|
Parameter containing:
|
|
tensor([[-1.4677e+24, 4.5915e-41, 1.4013e-45, 0.0000e+00, -1.4677e+24,
|
|
4.5915e-41]], requires_grad=True)
|
|
|
|
"""
|
|
if not issubclass(module_cls, torch.nn.Module):
|
|
raise RuntimeError('Expected a Module; got {}'.format(module_cls))
|
|
if 'device' not in inspect.signature(module_cls).parameters:
|
|
raise RuntimeError('Module must support a \'device\' arg to skip initialization')
|
|
|
|
final_device = kwargs.pop('device', 'cpu')
|
|
kwargs['device'] = 'meta'
|
|
return module_cls(*args, **kwargs).to_empty(device=final_device)
|