mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Reopen the PR: https://github.com/pytorch/pytorch/pull/45837 This PR add a new feature for Partitioner() class called size_based_partition. Given a list of devices with the same memory size, this function could distribute graph nodes into different devices. To implement this feature, several help functions are created in Partitioner.py and GraphManipulation.py. An unit test is also added in test/test_fx_experimental.py Pull Request resolved: https://github.com/pytorch/pytorch/pull/46282 Reviewed By: gcatron Differential Revision: D24288470 Pulled By: scottxu0730 fbshipit-source-id: e81b1e0c56e34f61e497d868882126216eba7538
96 lines
3.1 KiB
Python
96 lines
3.1 KiB
Python
import torch
|
|
from torch.fx.symbolic_trace import symbolic_trace
|
|
from torch.fx.experimental import GraphManipulation
|
|
from torch.fx.experimental.Partitioner import Partitioner, Device
|
|
from torch.testing._internal.common_utils import run_tests
|
|
from torch.testing._internal.jit_utils import JitTestCase
|
|
|
|
class TestFXExperimental(JitTestCase):
|
|
def test_find_single_partition(self):
|
|
class TestModule(torch.nn.Module):
|
|
def forward(self, a, b):
|
|
return a + b
|
|
m = TestModule()
|
|
traced = symbolic_trace(m)
|
|
a = torch.rand(1)
|
|
b = torch.rand(1)
|
|
GraphManipulation.get_size_of_all_nodes(
|
|
traced,
|
|
[a, b]
|
|
)
|
|
partitioner = Partitioner()
|
|
devices = [
|
|
Device('dev_0', 125),
|
|
Device('dev_1', 125),
|
|
Device('dev_2', 125)
|
|
]
|
|
ret = partitioner.partition_graph(traced, m, devices)
|
|
module_with_submodules = ret.module_with_submodules
|
|
self.assertEqual(traced(a, b), module_with_submodules(a, b))
|
|
|
|
def test_size_based_partition(self):
|
|
class TestModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(4, 4)
|
|
|
|
def forward(self, a, b):
|
|
add_1 = a + b
|
|
linear = self.linear(add_1)
|
|
e = torch.rand(4)
|
|
add_2 = linear + e
|
|
return add_2
|
|
|
|
m = TestModule()
|
|
traced = symbolic_trace(m)
|
|
a = torch.rand(4)
|
|
b = torch.rand(4)
|
|
GraphManipulation.get_size_of_all_nodes(
|
|
traced,
|
|
[a, b]
|
|
)
|
|
partitioner = Partitioner()
|
|
devices = [
|
|
Device('dev_0', 125),
|
|
Device('dev_1', 125),
|
|
Device('dev_2', 125)
|
|
]
|
|
ret = partitioner.partition_graph(traced, m, devices)
|
|
module_with_submodules = ret.module_with_submodules
|
|
self.assertEqual(traced(a, b), module_with_submodules(a, b))
|
|
assert len(module_with_submodules.graph.nodes) == 7
|
|
|
|
def test_partition_combining(self):
|
|
class TestModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear_0 = torch.nn.Linear(4, 4)
|
|
|
|
def forward(self, a, b):
|
|
add_1 = a + b
|
|
c = self.linear_0(a)
|
|
add_2 = c + add_1
|
|
return add_2
|
|
|
|
m = TestModule()
|
|
traced = symbolic_trace(m)
|
|
a = torch.rand(4)
|
|
b = torch.rand(4)
|
|
GraphManipulation.get_size_of_all_nodes(
|
|
traced,
|
|
[a, b]
|
|
)
|
|
partitioner = Partitioner()
|
|
devices = [
|
|
Device('dev_0', 125),
|
|
Device('dev_1', 125),
|
|
Device('dev_2', 125)
|
|
]
|
|
ret = partitioner.partition_graph(traced, m, devices)
|
|
module_with_submodules = ret.module_with_submodules
|
|
self.assertEqual(traced(a, b), module_with_submodules(a, b))
|
|
assert len(module_with_submodules.graph.nodes) == 5
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|