pytorch/torch/nn/parallel/replicate.py
2016-09-27 15:45:45 -07:00

35 lines
1.1 KiB
Python

from copy import copy
from collections import OrderedDict
from .functions import Broadcast
from ..modules.container import Container
def _replicate_module(module, gpu, param_remap):
if module is None:
return module
replica = copy(module)
replica._parameters = OrderedDict()
for key, param in module._parameters.items():
replica._parameters[key] = param_remap[param]
if isinstance(replica, Container):
replica.modules = OrderedDict()
for name, child in module.modules.items():
replica.modules[name] = _replicate_module(child, gpu, param_remap)
return replica
def replicate(module, device_ids):
seen_params = set()
param_remap = [{} for dev_id in device_ids]
for param in module.parameters():
if param in seen_params:
continue
seen_params.add(param)
param_copies = Broadcast(device_ids)(param)
for copy, remap in zip(param_copies, param_remap):
remap[param] = copy
return [_replicate_module(module, device_id, remap)
for device_id, remap in zip(device_ids, param_remap)]