mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
This PR re-lands - [Typing] Fix PEP 484 Violation (#105022) - Update mypy to 1.4.1 (#91983) That were reverted due to the conflict with internal source repo. Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional) Plus few real fixes: - Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi` - Add missing return statement to `torch._export. deserialize_graph` - Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights` - Add assert it `torch/optim/optimizer.py` that Optional list is not None TODO (in followup PR): - Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py` Unrelated, to bypass CI failures due to the gcc9 dependency update in Ubuntu-18.04: - Add hack to squash older libstdc++ from conda environment in favor one from OS to `.ci/docker/install_conda.sh` - Update bazel cuda builds to focal, as with libstdc++-6.0.32 bazel builds loose the ability to catch exceptions (probably because they link with cupti statically, but I could not found where it is done) Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227 Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
37 lines
1.2 KiB
Python
37 lines
1.2 KiB
Python
from torch import nn
|
|
from typing import List, Optional
|
|
|
|
__all__ = ["partition_model"]
|
|
|
|
def partition_model(
|
|
module: nn.Sequential,
|
|
balance: List[int],
|
|
devices: Optional[List[int]] = None):
|
|
"""
|
|
Given an :class:`nn.Sequential <torch.nn.Sequential>` module, partitions
|
|
the model across multiple GPU devices according the provided ``balance``
|
|
and ``devices``.
|
|
|
|
Args:
|
|
module (:class:`nn.Sequential <torch.nn.Sequential>`):
|
|
Sequential model representing the pipe.
|
|
balance (List[int]):
|
|
List indicating the number of layers in each partition.
|
|
devices (List[int], optional):
|
|
List indicating the device to use for each partition. Defaults to
|
|
``range(len(balance))``
|
|
"""
|
|
device_idx = 0
|
|
pipe_idx = 0
|
|
balanced_pipe = []
|
|
for num_layers in balance:
|
|
layers = []
|
|
for i in range(num_layers):
|
|
layers.append(module[pipe_idx])
|
|
pipe_idx += 1
|
|
device = device_idx if devices is None else devices[device_idx]
|
|
balanced_pipe.append(nn.Sequential(*layers).to(device))
|
|
device_idx += 1
|
|
|
|
return nn.Sequential(*balanced_pipe)
|