mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
35 lines
1.1 KiB
Python
35 lines
1.1 KiB
Python
from copy import copy
|
|
from collections import OrderedDict
|
|
|
|
from .functions import Broadcast
|
|
from ..modules.container import Container
|
|
|
|
|
|
def _replicate_module(module, gpu, param_remap):
|
|
if module is None:
|
|
return module
|
|
replica = copy(module)
|
|
replica._parameters = OrderedDict()
|
|
for key, param in module._parameters.items():
|
|
replica._parameters[key] = param_remap[param]
|
|
if isinstance(replica, Container):
|
|
replica.modules = OrderedDict()
|
|
for name, child in module.modules.items():
|
|
replica.modules[name] = _replicate_module(child, gpu, param_remap)
|
|
return replica
|
|
|
|
|
|
def replicate(module, device_ids):
|
|
seen_params = set()
|
|
param_remap = [{} for dev_id in device_ids]
|
|
for param in module.parameters():
|
|
if param in seen_params:
|
|
continue
|
|
seen_params.add(param)
|
|
param_copies = Broadcast(device_ids)(param)
|
|
for copy, remap in zip(param_copies, param_remap):
|
|
remap[param] = copy
|
|
return [_replicate_module(module, device_id, remap)
|
|
for device_id, remap in zip(device_ids, param_remap)]
|
|
|