mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
85 lines
2.9 KiB
Python
85 lines
2.9 KiB
Python
import torch
|
|
from torch.autograd import Variable
|
|
|
|
|
|
def parameters_to_vector(parameters):
|
|
r"""Convert parameters to one vector
|
|
|
|
Arguments:
|
|
parameters (Iterable[Variable]): an iterator of Variables that are the
|
|
parameters of a model.
|
|
|
|
Returns:
|
|
The parameters represented by a single vector
|
|
"""
|
|
# Flag for the device where the parameter is located
|
|
param_device = None
|
|
|
|
vec = []
|
|
for param in parameters:
|
|
# Ensure the parameters are located in the same device
|
|
param_device = _check_param_device(param, param_device)
|
|
|
|
vec.append(param.view(-1))
|
|
return torch.cat(vec)
|
|
|
|
|
|
def vector_to_parameters(vec, parameters):
|
|
r"""Convert one vector to the parameters
|
|
|
|
Arguments:
|
|
vec (Variable): a single vector represents the parameters of a model.
|
|
parameters (Iterable[Variable]): an iterator of Variables that are the
|
|
parameters of a model.
|
|
"""
|
|
# Ensure vec of type Variable
|
|
if not isinstance(vec, Variable):
|
|
raise TypeError('expected torch.autograd.Variable, but got: {}'
|
|
.format(torch.typename(vec)))
|
|
# Flag for the device where the parameter is located
|
|
param_device = None
|
|
|
|
# Pointer for slicing the vector for each parameter
|
|
pointer = 0
|
|
for param in parameters:
|
|
# Ensure the parameters are located in the same device
|
|
param_device = _check_param_device(param, param_device)
|
|
|
|
# The length of the parameter
|
|
num_param = torch.prod(torch.LongTensor(list(param.size())))
|
|
# Slice the vector, reshape it, and replace the old data of the parameter
|
|
param.data = vec[pointer:pointer + num_param].view(param.size()).data
|
|
|
|
# Increment the pointer
|
|
pointer += num_param
|
|
|
|
|
|
def _check_param_device(param, old_param_device):
|
|
r"""This helper function is to check if the parameters are located
|
|
in the same device. Currently, the conversion between model parameters
|
|
and single vector form is not supported for multiple allocations,
|
|
e.g. parameters in different GPUs, or mixture of CPU/GPU.
|
|
|
|
Arguments:
|
|
param ([Variable]): a Variable of a parameter of a model
|
|
old_param_device (int): the device where the first parameter of a
|
|
model is allocated.
|
|
|
|
Returns:
|
|
old_param_device (int): report device for the first time
|
|
"""
|
|
|
|
# Meet the first parameter
|
|
if old_param_device is None:
|
|
old_param_device = param.get_device() if param.is_cuda else -1
|
|
else:
|
|
warn = False
|
|
if param.is_cuda: # Check if in same GPU
|
|
warn = (param.get_device() != old_param_device)
|
|
else: # Check if in CPU
|
|
warn = (old_param_device != -1)
|
|
if warn:
|
|
raise TypeError('Found two parameters on different devices, '
|
|
'this is currently not supported.')
|
|
return old_param_device
|