mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
* Codemod to update our codebase to 0.4 standard * Update some of the test scri[ts * remove Variable in test_clip_grad_value * fix _symbolic_override_wrapper_maker
68 lines
2.0 KiB
Python
68 lines
2.0 KiB
Python
import threading
|
|
import torch
|
|
|
|
|
|
def get_a_var(obj):
|
|
if isinstance(obj, torch.Tensor):
|
|
return obj
|
|
|
|
if isinstance(obj, list) or isinstance(obj, tuple):
|
|
for result in map(get_a_var, obj):
|
|
if isinstance(result, torch.Tensor):
|
|
return result
|
|
if isinstance(obj, dict):
|
|
for result in map(get_a_var, obj.items()):
|
|
if isinstance(result, torch.Tensor):
|
|
return result
|
|
return None
|
|
|
|
|
|
def parallel_apply(modules, inputs, kwargs_tup=None, devices=None):
|
|
assert len(modules) == len(inputs)
|
|
if kwargs_tup is not None:
|
|
assert len(modules) == len(kwargs_tup)
|
|
else:
|
|
kwargs_tup = ({},) * len(modules)
|
|
if devices is not None:
|
|
assert len(modules) == len(devices)
|
|
else:
|
|
devices = [None] * len(modules)
|
|
|
|
lock = threading.Lock()
|
|
results = {}
|
|
grad_enabled = torch.is_grad_enabled()
|
|
|
|
def _worker(i, module, input, kwargs, device=None):
|
|
torch.set_grad_enabled(grad_enabled)
|
|
if device is None:
|
|
device = get_a_var(input).get_device()
|
|
try:
|
|
with torch.cuda.device(device):
|
|
output = module(*input, **kwargs)
|
|
with lock:
|
|
results[i] = output
|
|
except Exception as e:
|
|
with lock:
|
|
results[i] = e
|
|
|
|
if len(modules) > 1:
|
|
threads = [threading.Thread(target=_worker,
|
|
args=(i, module, input, kwargs, device))
|
|
for i, (module, input, kwargs, device) in
|
|
enumerate(zip(modules, inputs, kwargs_tup, devices))]
|
|
|
|
for thread in threads:
|
|
thread.start()
|
|
for thread in threads:
|
|
thread.join()
|
|
else:
|
|
_worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0])
|
|
|
|
outputs = []
|
|
for i in range(len(inputs)):
|
|
output = results[i]
|
|
if isinstance(output, Exception):
|
|
raise output
|
|
outputs.append(output)
|
|
return outputs
|