[fx-acc] Saturate host by replicating partitions onto idle devices (#60064)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/60064

This implements a host saturation optimization to maximize the utilization of the available devices.
It uses a greedy heuristic to replicate all partitions on the used devices to another set of idle devices with enough memory.

The added unittest shows an example as follows:

```
partition_0: 192 bytes; partition_1: 48 bytes
dev_0: 200 bytes, [partition_0]
dev_1: 200 bytes, [partition_1]
dev_2: 100 bytes,
dev_3: 100 bytes,
dev_4: 200 bytes,
dev_5: 100 bytes
```

Before host saturation, `partition_0` is assigned to dev_0 and `partition_1` is assigned to dev_1.
After host saturation, `partition_0` is replicated to dev_4 simply because it's the only device that can hold all partitions on dev_0. `partition_1` is replicated to dev_2 because it has minimal but large enough memory to hold all partitions on dev_1.

Test Plan:
```
buck test mode/opt //caffe2/test:test_fx_experimental -- --exact 'caffe2/test:test_fx_experimental - test_saturate_host (test_fx_experimental.TestFXExperimental)'

Started reporting to test run: https://www.internalfb.com/intern/testinfra/testrun/8444249343103429
    ✓ ListingSuccess: caffe2/test:test_fx_experimental - main (1.322)
    ✓ Pass: caffe2/test:test_fx_experimental - test_saturate_host (test_fx_experimental.TestFXExperimental) (1.322)
Summary
  Pass: 1
  ListingSuccess: 1
```

An e2e test will be added to `test_fx_glow.py` in a followup diff.

Reviewed By: gcatron

Differential Revision: D29039998

fbshipit-source-id: 57518aadf668f7f05abd6ff73224c16b5d2a12ac
This commit is contained in:
Hangchen Yu 2021-06-15 23:03:33 -07:00 committed by Facebook GitHub Bot
parent a344b09db2
commit 9fbbab88da
3 changed files with 171 additions and 22 deletions

View File

@ -609,6 +609,47 @@ class TestFXExperimental(JitTestCase):
)
assert (input1 * input2) == traced(input1, input2)
def test_saturate_host(self):
class TestModule(torch.nn.Module):
def __init__(self):
super(TestModule, self).__init__()
self.linear = torch.nn.Linear(4, 4)
def forward(self, a):
add_1 = a + torch.rand(4)
add_2 = add_1 + torch.rand(4)
linear_1 = self.linear(add_1)
add_3 = add_2 + linear_1
add_4 = add_2 + add_3
return add_4
m = TestModule()
traced = symbolic_trace(m)
a = torch.rand(4)
graph_manipulation.get_size_of_all_nodes(traced, [a])
devices = [
Device("dev_0", 200, 0),
Device("dev_1", 200, 1),
Device("dev_2", 100, 2),
Device("dev_3", 100, 3),
Device("dev_4", 200, 4),
Device("dev_5", 100, 5),
]
partitioner = Partitioner()
# Without host saturation, the model will be split into two partitions.
# dev_0 holds partition 0 of 192 bytes and dev_1 holds partition 1 of 48 bytes.
partitioner_config = PartitionerConfig(devices, saturate_host=True)
ret = partitioner.partition_graph(traced, m, partitioner_config)
module_with_submodules = ret.module_with_submodules
self.assertEqual(traced(a), module_with_submodules(a))
partitions = partitioner.partitions
self.assertEqual(len(partitions), 2)
# With host saturation, partition 1 will be replicated to dev_4, and partition 2
# will be replicated to dev_2.
self.assertEqual(partitions[0].logical_device_ids, [0, 4])
self.assertEqual(partitions[1].logical_device_ids, [1, 2])
@skipIfNoTorchVision
def test_conv_bn_fusion(self):
rn18 = resnet18().eval()

View File

@ -165,6 +165,51 @@ def get_node_to_partition_mapping(partitions: List[Partition]) -> Dict[Node, int
return node_to_partition
def get_logical_id_to_device(devices: List[Device]) -> Dict[int, Device]:
"""Get a mapping from device logical ID to Device object."""
logical_id_to_device: Dict[int, Device] = {}
for d in devices:
logical_id_to_device[d.logical_id] = d
return logical_id_to_device
def get_device_partition_stats(
partitions: List[Partition], devices: List[Device]
) -> Tuple[Dict[Device, List[Partition]], Dict[Device, int], List[Partition]]:
"""Given a list of partitions and a list of devices, returns:
1. A mapping from device to partitions on it;
2. A mapping from device to its remaining memory size;
3. A list of partitions that do not have a device.
"""
# logical id to device
logical_id_to_device = get_logical_id_to_device(devices)
# Track partitions on device
device_to_partitions: Dict[Device, List[Partition]] = {}
# Track device's left mem size
device_to_left_mem_bytes: Dict[Device, int] = {}
for d in devices:
device_to_partitions[d] = []
device_to_left_mem_bytes[d] = d.available_mem_bytes
# Deal with the partitions that already have a device
# and also collect all partitions without a device (no_device_partitions)
no_device_partitions = []
for partition in partitions:
if partition.logical_device_ids != []:
for logical_id in partition.logical_device_ids:
device = logical_id_to_device[logical_id]
device_to_partitions[device].append(partition)
device_to_left_mem_bytes[device] -= partition.used_mem_bytes
else:
no_device_partitions.append(partition)
return (
device_to_partitions,
device_to_left_mem_bytes,
no_device_partitions,
)
def get_device_to_partitions_mapping(
partitions: List[Partition], devices: List[Device]
):
@ -204,27 +249,12 @@ def get_device_to_partitions_mapping(
return True
return False
# logical id to device
logical_id_to_device: Dict[int, Device] = {}
# Track partitions on device
device_to_partitions: Dict[Device, List[Partition]] = {}
# Track device's left mem size
device_to_left_mem_bytes: Dict[Device, int] = {}
for d in devices:
logical_id_to_device[d.logical_id] = d
device_to_partitions[d] = []
device_to_left_mem_bytes[d] = d.available_mem_bytes
# Deal with the partitions that already have a device
# and also collect all partitions without a device (no_device_partitions)
no_device_partitions = []
for partition in partitions:
if partition.logical_device_ids != []:
logical_id = partition.logical_device_ids[0]
device = logical_id_to_device[logical_id]
device_to_partitions[device] = [partition]
device_to_left_mem_bytes[device] -= partition.used_mem_bytes
else:
no_device_partitions.append(partition)
(
device_to_partitions,
device_to_left_mem_bytes,
no_device_partitions,
) = get_device_partition_stats(partitions, devices)
# Find devices for all the partitions without a device
found_device = True
for partition in no_device_partitions:
@ -341,7 +371,14 @@ class Partitioner:
)
else:
self.size_based_partition()
# Saturate host if possible.
if partitioner_config.saturate_host:
self.saturate_host()
# Partition the graph module based on the partition assignment.
module_with_submodules = self.do_partition()
# The DAG contains DAGNodes with info of each partition's input nodes, output nodes
# and how partitions are connected.
dag = self.dump_dag(module_with_submodules)
@ -459,6 +496,75 @@ class Partitioner:
raise RuntimeError("Cannot Get a Valid Partition to Logical Device Mapping")
return
def saturate_host(self) -> None:
"""Saturate host by assigning replicates to unused devices with enough memory.
It uses a greedy approach to find a next available set of devices to place all split
partitions: For each used device, it searches for an idle device with minimal memory
size that can hold all the partition located on that device; If the search is successful
for all used devices, it then assigns the new devices' logical ID to the corresponding
partition.
"""
(
device_to_partitions,
device_to_left_mem_bytes,
no_device_partitions,
) = get_device_partition_stats(self.partitions, self.devices)
assert (
len(no_device_partitions) == 0
), f"Expect no_device_partitions has 0 device, but get {len(no_device_partitions)}"
# Devices that hold partitions
used_devices = [d for d in self.devices if len(device_to_partitions[d]) > 0]
# Track replicates of the assigned devices
replicated_device_to_used_device: Dict[Device, Device] = {}
while len(used_devices) * 2 + len(replicated_device_to_used_device) <= len(
self.devices
):
# Success flag for this round
success = True
# Devices that have not been assigned
idle_devices = [
d
for d in self.devices
if d not in used_devices and d not in replicated_device_to_used_device
]
# Temporary mapping from replicated device to original device
temp_replicate_mapping = {}
# Find a new device to replicate all partitions on an used device
for used_device in used_devices:
# Idle devices that have enough memory
available_devices = [
d
for d in idle_devices
if d.available_mem_bytes
>= used_device.available_mem_bytes
- device_to_left_mem_bytes[used_device]
]
if len(available_devices) == 0:
success = False
break
new_device = min(available_devices, key=lambda d: d.available_mem_bytes)
idle_devices.remove(new_device)
temp_replicate_mapping[new_device] = used_device
if not success:
break
replicated_device_to_used_device.update(temp_replicate_mapping)
# Update logical device IDs assigned to the partitions
for (
replicate_device,
original_device,
) in replicated_device_to_used_device.items():
logical_id = replicate_device.logical_id
for partition in device_to_partitions[original_device]:
partition.logical_device_ids.append(logical_id)
for p in self.partitions:
print(p.logical_device_ids)
def do_partition(self) -> GraphModule:
"""Return a new fx module with submodule nodes (partitions)."""
module_with_submodules = split_module(
@ -469,7 +575,7 @@ class Partitioner:
return module_with_submodules
def dump_dag(self, module_with_submodules: GraphModule) -> DAG:
"""Return the dag structure and the new fx module with submodules"""
"""Return the dag structure and the new fx module with submodules."""
dag = DAG()
for node in module_with_submodules.graph.nodes:
if node.op == "output":

View File

@ -93,6 +93,8 @@ class PartitionerConfig(NamedTuple):
node_to_latency_mapping: Dict[Node, NodeLatency] = {}
node_to_partition_mapping: Dict[Node, int] = {}
partition_to_logical_device_mapping: Dict[int, List[int]] = {}
# Saturate host by replicating partitions to the remaining idle devices.
saturate_host: bool = False
def get_extra_size_of(node: Node, nodes: Set[Node]) -> int: