mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[fx-acc] Saturate host by replicating partitions onto idle devices (#60064)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/60064 This implements a host saturation optimization to maximize the utilization of the available devices. It uses a greedy heuristic to replicate all partitions on the used devices to another set of idle devices with enough memory. The added unittest shows an example as follows: ``` partition_0: 192 bytes; partition_1: 48 bytes dev_0: 200 bytes, [partition_0] dev_1: 200 bytes, [partition_1] dev_2: 100 bytes, dev_3: 100 bytes, dev_4: 200 bytes, dev_5: 100 bytes ``` Before host saturation, `partition_0` is assigned to dev_0 and `partition_1` is assigned to dev_1. After host saturation, `partition_0` is replicated to dev_4 simply because it's the only device that can hold all partitions on dev_0. `partition_1` is replicated to dev_2 because it has minimal but large enough memory to hold all partitions on dev_1. Test Plan: ``` buck test mode/opt //caffe2/test:test_fx_experimental -- --exact 'caffe2/test:test_fx_experimental - test_saturate_host (test_fx_experimental.TestFXExperimental)' Started reporting to test run: https://www.internalfb.com/intern/testinfra/testrun/8444249343103429 ✓ ListingSuccess: caffe2/test:test_fx_experimental - main (1.322) ✓ Pass: caffe2/test:test_fx_experimental - test_saturate_host (test_fx_experimental.TestFXExperimental) (1.322) Summary Pass: 1 ListingSuccess: 1 ``` An e2e test will be added to `test_fx_glow.py` in a followup diff. Reviewed By: gcatron Differential Revision: D29039998 fbshipit-source-id: 57518aadf668f7f05abd6ff73224c16b5d2a12ac
This commit is contained in:
parent
a344b09db2
commit
9fbbab88da
|
|
@ -609,6 +609,47 @@ class TestFXExperimental(JitTestCase):
|
|||
)
|
||||
assert (input1 * input2) == traced(input1, input2)
|
||||
|
||||
def test_saturate_host(self):
|
||||
class TestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(TestModule, self).__init__()
|
||||
self.linear = torch.nn.Linear(4, 4)
|
||||
|
||||
def forward(self, a):
|
||||
add_1 = a + torch.rand(4)
|
||||
add_2 = add_1 + torch.rand(4)
|
||||
linear_1 = self.linear(add_1)
|
||||
add_3 = add_2 + linear_1
|
||||
add_4 = add_2 + add_3
|
||||
return add_4
|
||||
|
||||
m = TestModule()
|
||||
traced = symbolic_trace(m)
|
||||
a = torch.rand(4)
|
||||
graph_manipulation.get_size_of_all_nodes(traced, [a])
|
||||
devices = [
|
||||
Device("dev_0", 200, 0),
|
||||
Device("dev_1", 200, 1),
|
||||
Device("dev_2", 100, 2),
|
||||
Device("dev_3", 100, 3),
|
||||
Device("dev_4", 200, 4),
|
||||
Device("dev_5", 100, 5),
|
||||
]
|
||||
partitioner = Partitioner()
|
||||
# Without host saturation, the model will be split into two partitions.
|
||||
# dev_0 holds partition 0 of 192 bytes and dev_1 holds partition 1 of 48 bytes.
|
||||
partitioner_config = PartitionerConfig(devices, saturate_host=True)
|
||||
ret = partitioner.partition_graph(traced, m, partitioner_config)
|
||||
module_with_submodules = ret.module_with_submodules
|
||||
self.assertEqual(traced(a), module_with_submodules(a))
|
||||
|
||||
partitions = partitioner.partitions
|
||||
self.assertEqual(len(partitions), 2)
|
||||
# With host saturation, partition 1 will be replicated to dev_4, and partition 2
|
||||
# will be replicated to dev_2.
|
||||
self.assertEqual(partitions[0].logical_device_ids, [0, 4])
|
||||
self.assertEqual(partitions[1].logical_device_ids, [1, 2])
|
||||
|
||||
@skipIfNoTorchVision
|
||||
def test_conv_bn_fusion(self):
|
||||
rn18 = resnet18().eval()
|
||||
|
|
|
|||
|
|
@ -165,6 +165,51 @@ def get_node_to_partition_mapping(partitions: List[Partition]) -> Dict[Node, int
|
|||
return node_to_partition
|
||||
|
||||
|
||||
def get_logical_id_to_device(devices: List[Device]) -> Dict[int, Device]:
|
||||
"""Get a mapping from device logical ID to Device object."""
|
||||
logical_id_to_device: Dict[int, Device] = {}
|
||||
for d in devices:
|
||||
logical_id_to_device[d.logical_id] = d
|
||||
return logical_id_to_device
|
||||
|
||||
|
||||
def get_device_partition_stats(
|
||||
partitions: List[Partition], devices: List[Device]
|
||||
) -> Tuple[Dict[Device, List[Partition]], Dict[Device, int], List[Partition]]:
|
||||
"""Given a list of partitions and a list of devices, returns:
|
||||
1. A mapping from device to partitions on it;
|
||||
2. A mapping from device to its remaining memory size;
|
||||
3. A list of partitions that do not have a device.
|
||||
"""
|
||||
# logical id to device
|
||||
logical_id_to_device = get_logical_id_to_device(devices)
|
||||
# Track partitions on device
|
||||
device_to_partitions: Dict[Device, List[Partition]] = {}
|
||||
# Track device's left mem size
|
||||
device_to_left_mem_bytes: Dict[Device, int] = {}
|
||||
for d in devices:
|
||||
device_to_partitions[d] = []
|
||||
device_to_left_mem_bytes[d] = d.available_mem_bytes
|
||||
|
||||
# Deal with the partitions that already have a device
|
||||
# and also collect all partitions without a device (no_device_partitions)
|
||||
no_device_partitions = []
|
||||
for partition in partitions:
|
||||
if partition.logical_device_ids != []:
|
||||
for logical_id in partition.logical_device_ids:
|
||||
device = logical_id_to_device[logical_id]
|
||||
device_to_partitions[device].append(partition)
|
||||
device_to_left_mem_bytes[device] -= partition.used_mem_bytes
|
||||
else:
|
||||
no_device_partitions.append(partition)
|
||||
|
||||
return (
|
||||
device_to_partitions,
|
||||
device_to_left_mem_bytes,
|
||||
no_device_partitions,
|
||||
)
|
||||
|
||||
|
||||
def get_device_to_partitions_mapping(
|
||||
partitions: List[Partition], devices: List[Device]
|
||||
):
|
||||
|
|
@ -204,27 +249,12 @@ def get_device_to_partitions_mapping(
|
|||
return True
|
||||
return False
|
||||
|
||||
# logical id to device
|
||||
logical_id_to_device: Dict[int, Device] = {}
|
||||
# Track partitions on device
|
||||
device_to_partitions: Dict[Device, List[Partition]] = {}
|
||||
# Track device's left mem size
|
||||
device_to_left_mem_bytes: Dict[Device, int] = {}
|
||||
for d in devices:
|
||||
logical_id_to_device[d.logical_id] = d
|
||||
device_to_partitions[d] = []
|
||||
device_to_left_mem_bytes[d] = d.available_mem_bytes
|
||||
# Deal with the partitions that already have a device
|
||||
# and also collect all partitions without a device (no_device_partitions)
|
||||
no_device_partitions = []
|
||||
for partition in partitions:
|
||||
if partition.logical_device_ids != []:
|
||||
logical_id = partition.logical_device_ids[0]
|
||||
device = logical_id_to_device[logical_id]
|
||||
device_to_partitions[device] = [partition]
|
||||
device_to_left_mem_bytes[device] -= partition.used_mem_bytes
|
||||
else:
|
||||
no_device_partitions.append(partition)
|
||||
(
|
||||
device_to_partitions,
|
||||
device_to_left_mem_bytes,
|
||||
no_device_partitions,
|
||||
) = get_device_partition_stats(partitions, devices)
|
||||
|
||||
# Find devices for all the partitions without a device
|
||||
found_device = True
|
||||
for partition in no_device_partitions:
|
||||
|
|
@ -341,7 +371,14 @@ class Partitioner:
|
|||
)
|
||||
else:
|
||||
self.size_based_partition()
|
||||
|
||||
# Saturate host if possible.
|
||||
if partitioner_config.saturate_host:
|
||||
self.saturate_host()
|
||||
|
||||
# Partition the graph module based on the partition assignment.
|
||||
module_with_submodules = self.do_partition()
|
||||
|
||||
# The DAG contains DAGNodes with info of each partition's input nodes, output nodes
|
||||
# and how partitions are connected.
|
||||
dag = self.dump_dag(module_with_submodules)
|
||||
|
|
@ -459,6 +496,75 @@ class Partitioner:
|
|||
raise RuntimeError("Cannot Get a Valid Partition to Logical Device Mapping")
|
||||
return
|
||||
|
||||
def saturate_host(self) -> None:
|
||||
"""Saturate host by assigning replicates to unused devices with enough memory.
|
||||
It uses a greedy approach to find a next available set of devices to place all split
|
||||
partitions: For each used device, it searches for an idle device with minimal memory
|
||||
size that can hold all the partition located on that device; If the search is successful
|
||||
for all used devices, it then assigns the new devices' logical ID to the corresponding
|
||||
partition.
|
||||
"""
|
||||
(
|
||||
device_to_partitions,
|
||||
device_to_left_mem_bytes,
|
||||
no_device_partitions,
|
||||
) = get_device_partition_stats(self.partitions, self.devices)
|
||||
|
||||
assert (
|
||||
len(no_device_partitions) == 0
|
||||
), f"Expect no_device_partitions has 0 device, but get {len(no_device_partitions)}"
|
||||
|
||||
# Devices that hold partitions
|
||||
used_devices = [d for d in self.devices if len(device_to_partitions[d]) > 0]
|
||||
# Track replicates of the assigned devices
|
||||
replicated_device_to_used_device: Dict[Device, Device] = {}
|
||||
|
||||
while len(used_devices) * 2 + len(replicated_device_to_used_device) <= len(
|
||||
self.devices
|
||||
):
|
||||
# Success flag for this round
|
||||
success = True
|
||||
# Devices that have not been assigned
|
||||
idle_devices = [
|
||||
d
|
||||
for d in self.devices
|
||||
if d not in used_devices and d not in replicated_device_to_used_device
|
||||
]
|
||||
# Temporary mapping from replicated device to original device
|
||||
temp_replicate_mapping = {}
|
||||
|
||||
# Find a new device to replicate all partitions on an used device
|
||||
for used_device in used_devices:
|
||||
# Idle devices that have enough memory
|
||||
available_devices = [
|
||||
d
|
||||
for d in idle_devices
|
||||
if d.available_mem_bytes
|
||||
>= used_device.available_mem_bytes
|
||||
- device_to_left_mem_bytes[used_device]
|
||||
]
|
||||
if len(available_devices) == 0:
|
||||
success = False
|
||||
break
|
||||
new_device = min(available_devices, key=lambda d: d.available_mem_bytes)
|
||||
idle_devices.remove(new_device)
|
||||
temp_replicate_mapping[new_device] = used_device
|
||||
|
||||
if not success:
|
||||
break
|
||||
replicated_device_to_used_device.update(temp_replicate_mapping)
|
||||
|
||||
# Update logical device IDs assigned to the partitions
|
||||
for (
|
||||
replicate_device,
|
||||
original_device,
|
||||
) in replicated_device_to_used_device.items():
|
||||
logical_id = replicate_device.logical_id
|
||||
for partition in device_to_partitions[original_device]:
|
||||
partition.logical_device_ids.append(logical_id)
|
||||
for p in self.partitions:
|
||||
print(p.logical_device_ids)
|
||||
|
||||
def do_partition(self) -> GraphModule:
|
||||
"""Return a new fx module with submodule nodes (partitions)."""
|
||||
module_with_submodules = split_module(
|
||||
|
|
@ -469,7 +575,7 @@ class Partitioner:
|
|||
return module_with_submodules
|
||||
|
||||
def dump_dag(self, module_with_submodules: GraphModule) -> DAG:
|
||||
"""Return the dag structure and the new fx module with submodules"""
|
||||
"""Return the dag structure and the new fx module with submodules."""
|
||||
dag = DAG()
|
||||
for node in module_with_submodules.graph.nodes:
|
||||
if node.op == "output":
|
||||
|
|
|
|||
|
|
@ -93,6 +93,8 @@ class PartitionerConfig(NamedTuple):
|
|||
node_to_latency_mapping: Dict[Node, NodeLatency] = {}
|
||||
node_to_partition_mapping: Dict[Node, int] = {}
|
||||
partition_to_logical_device_mapping: Dict[int, List[int]] = {}
|
||||
# Saturate host by replicating partitions to the remaining idle devices.
|
||||
saturate_host: bool = False
|
||||
|
||||
|
||||
def get_extra_size_of(node: Node, nodes: Set[Node]) -> int:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user