mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-08 07:39:33 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/50151 **Summary** This commit adds a graph transformation pass that merges several matrix multiplications that use the same RHS operand into one large matrix multiplication. The LHS operands from all of the smaller matrix multiplications are concatenated together and used as an input in the large matrix multiply, and the result is split in order to obtain the same products as the original set of matrix multiplications. **Test Plan** This commit adds a simple unit test with two matrix multiplications that share the same RHS operand. `python test/test_fx_experimental.py -k merge_matmul -v` Test Plan: Imported from OSS Reviewed By: ngimel Differential Revision: D25809409 Pulled By: SplitInfinity fbshipit-source-id: fb55c044a54dea9f07b71aa60d44b7a8f3966ed0
973 lines
37 KiB
Python
973 lines
37 KiB
Python
import torch
|
|
import unittest
|
|
import sys
|
|
from typing import Callable, Dict, Union, List
|
|
from torch.fx.symbolic_trace import symbolic_trace
|
|
from torch.fx.graph_module import GraphModule
|
|
from torch.fx.node import Node
|
|
from torch.fx.experimental import graph_manipulation
|
|
from torch.fx.experimental.accelerator_partitioner import Partitioner
|
|
from torch.fx.experimental.rewriter import RewritingTracer
|
|
from torch.fx.experimental.param_fetch import lift_lowering_attrs_to_nodes
|
|
from torch.testing._internal.common_utils import run_tests
|
|
from torch.testing._internal.jit_utils import JitTestCase
|
|
from torch.fx.experimental.subgraph_creation_example import split_module
|
|
from torch.fx.experimental.partitioner_utils import (
|
|
NodeLatency,
|
|
get_partition_to_latency_mapping,
|
|
get_latency_of_partitioned_graph,
|
|
Device,
|
|
PartitionerConfig,
|
|
PartitionMode
|
|
)
|
|
from torch.fx.experimental.fuser import fuse
|
|
from torch.fx.experimental import merge_matmul
|
|
|
|
try:
|
|
from torchvision.models import resnet18
|
|
HAS_TORCHVISION = True
|
|
except ImportError:
|
|
HAS_TORCHVISION = False
|
|
skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
|
|
|
|
|
|
def symbolic_trace_with_rewrite(root: Union[torch.nn.Module, Callable]) -> GraphModule:
|
|
return GraphModule(
|
|
root if isinstance(root, torch.nn.Module) else torch.nn.Module(),
|
|
RewritingTracer().trace(root),
|
|
)
|
|
|
|
|
|
class TestFXExperimental(JitTestCase):
|
|
def test_serialize_graph(self):
|
|
class TestModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(4, 4)
|
|
self.e = torch.rand(4)
|
|
self.conv = torch.nn.Conv2d(3, 3, 2, bias=False)
|
|
|
|
def forward(self, a, b, c):
|
|
add_1 = a + b
|
|
conv1 = self.conv(c)
|
|
linear = self.linear(add_1 + conv1)
|
|
add_2 = linear + self.e
|
|
return add_2
|
|
|
|
m = TestModule()
|
|
traced = symbolic_trace(m)
|
|
a = torch.rand(4)
|
|
b = torch.rand(4)
|
|
c = torch.rand(3, 3, 2, 2)
|
|
graph_manipulation.get_size_of_all_nodes(traced, [a, b, c])
|
|
|
|
partitioner = Partitioner()
|
|
devices = [Device("dev_0", 5000, 0), Device("dev_1", 125, 1)]
|
|
partitioner_config = PartitionerConfig(devices, PartitionMode.sparse_nn)
|
|
ret = partitioner.partition_graph(traced, m, partitioner_config)
|
|
module_with_submodules = ret.module_with_submodules
|
|
# Fix for now to add type/shape to output
|
|
for node in traced.graph.nodes:
|
|
if node.op == "output":
|
|
node.shape = a.shape
|
|
node.dtype = a.dtype
|
|
for mod in module_with_submodules.modules():
|
|
if isinstance(mod, GraphModule):
|
|
for node in mod.graph.nodes:
|
|
node.shape = a.shape
|
|
node.dtype = a.dtype
|
|
for node in module_with_submodules.graph.nodes:
|
|
node.shape = a.shape
|
|
node.dtype = a.dtype
|
|
|
|
agm1 = graph_manipulation.AcceleratedGraphModule(traced)
|
|
agm2 = graph_manipulation.AcceleratedGraphModule(module_with_submodules)
|
|
assert len(agm1.weights) == 4
|
|
assert len(agm2.weights) == 4
|
|
assert len(agm1.serialized_graph["nodes"]) == 10
|
|
assert len(agm1.serialized_graph["weights"]) == 4
|
|
assert len(agm1.serialized_graph["modules"]) == 0
|
|
assert len(agm2.serialized_graph["nodes"]) == 6
|
|
assert len(agm2.serialized_graph["weights"]) == 4
|
|
assert len(agm2.serialized_graph["modules"]) == 1
|
|
assert agm1.serialized_graph["weights"]["linear.weight"]["shape"] == "[4, 4]"
|
|
assert (
|
|
agm1.serialized_graph["weights"]["linear.weight"]["dtype"]
|
|
== "torch.float32"
|
|
)
|
|
assert (
|
|
agm1.serialized_graph["weights"]["linear.weight"]["is_quantized"] is False
|
|
)
|
|
assert agm1.serialized_graph["nodes"][0]["shape"] == "[4]"
|
|
assert agm1.serialized_graph["nodes"][0]["dtype"] == "torch.float32"
|
|
assert agm1.serialized_graph["nodes"][0]["target"] == "a"
|
|
assert agm1.serialized_graph["nodes"][0]["op_code"] == "placeholder"
|
|
assert agm1.serialized_graph["nodes"][0]["name"] == "a"
|
|
assert agm1.serialized_graph["nodes"][6]["args"][0]["name"] == "add_2"
|
|
assert agm1.serialized_graph["nodes"][6]["args"][0]["is_node"] is True
|
|
|
|
# Test quantization info serialization.
|
|
x = torch.tensor([[-1.0, 0.0], [1.0, 2.0]])
|
|
q_tensor = torch.quantize_per_tensor(x, 1, 0, torch.qint32)
|
|
q_tensor_channel = torch.quantize_per_channel(
|
|
x, torch.tensor([0.1, 0.01]), torch.tensor([10, 0]), 0, torch.quint8
|
|
)
|
|
result = graph_manipulation.serialize_tensor_quantization(q_tensor)
|
|
result2 = graph_manipulation.serialize_tensor_quantization(q_tensor_channel)
|
|
assert result["q_scheme"] == "torch.per_tensor_affine"
|
|
assert result["q_scale"] == 1.0
|
|
assert result2["q_scheme"] == "torch.per_channel_affine"
|
|
assert len(result2["q_per_channel_scales"]) == 2
|
|
|
|
def test_find_single_partition(self):
|
|
class TestModule(torch.nn.Module):
|
|
def forward(self, a, b):
|
|
return a + b
|
|
|
|
m = TestModule()
|
|
traced = symbolic_trace(m)
|
|
a = torch.rand(1)
|
|
b = torch.rand(1)
|
|
graph_manipulation.get_size_of_all_nodes(traced, [a, b])
|
|
partitioner = Partitioner()
|
|
devices = [
|
|
Device("dev_0", 125, 0),
|
|
Device("dev_1", 125, 1),
|
|
Device("dev_2", 125, 2)
|
|
]
|
|
partitioner_config = PartitionerConfig(devices)
|
|
ret = partitioner.partition_graph(traced, m, partitioner_config)
|
|
module_with_submodules = ret.module_with_submodules
|
|
dag = ret.dag
|
|
self.assertEqual(traced(a, b), module_with_submodules(a, b))
|
|
assert dag.nodes[0].logical_device_ids == [0]
|
|
|
|
def test_lack_of_devices(self):
|
|
class TestModule(torch.nn.Module):
|
|
def forward(self, a, b):
|
|
return a + b
|
|
|
|
m = TestModule()
|
|
traced = symbolic_trace(m)
|
|
a = torch.rand(4)
|
|
b = torch.rand(4)
|
|
graph_manipulation.get_size_of_all_nodes(traced, [a, b])
|
|
partitioner = Partitioner()
|
|
devices = [Device("dev_0", 4, 0), Device("dev_1", 4, 1)]
|
|
partitioner_config = PartitionerConfig(devices, PartitionMode.size_based)
|
|
catch_runtime_error = False
|
|
try:
|
|
ret = partitioner.partition_graph(traced, m, partitioner_config)
|
|
except RuntimeError:
|
|
catch_runtime_error = True
|
|
assert catch_runtime_error
|
|
|
|
def test_large_node_error(self):
|
|
class TestModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(4, 4)
|
|
|
|
def forward(self, a):
|
|
linear = self.linear(a)
|
|
add = linear + a
|
|
return add
|
|
|
|
m = TestModule()
|
|
traced = symbolic_trace(m)
|
|
a = torch.rand(4)
|
|
graph_manipulation.get_size_of_all_nodes(traced, [a])
|
|
partitioner = Partitioner()
|
|
devices = [
|
|
Device("dev_0", 40, 0),
|
|
Device("dev_1", 40, 0),
|
|
Device("dev_2", 40, 0),
|
|
Device("dev_3", 40, 0),
|
|
Device("dev_4", 40, 0)
|
|
]
|
|
partitioner_config = PartitionerConfig(devices, PartitionMode.size_based)
|
|
catch_runtime_error = False
|
|
try:
|
|
ret = partitioner.partition_graph(traced, m, partitioner_config)
|
|
except RuntimeError:
|
|
catch_runtime_error = True
|
|
assert catch_runtime_error
|
|
|
|
def test_partition_node_manipulation(self):
|
|
class TestModule(torch.nn.Module):
|
|
def forward(self, a, b):
|
|
add_1 = a + b
|
|
add_2 = add_1 + torch.rand(4)
|
|
add_3 = add_2 + torch.rand(4)
|
|
return add_3
|
|
|
|
m = TestModule()
|
|
traced = symbolic_trace(m)
|
|
a, b = torch.rand(4), torch.rand(4)
|
|
graph_manipulation.get_size_of_all_nodes(traced, [a, b])
|
|
partitioner = Partitioner()
|
|
devices = [Device('dev_0', 1000, 0)]
|
|
partitioner_config = PartitionerConfig(devices)
|
|
ret = partitioner.partition_graph(traced, m, partitioner_config)
|
|
partition = partitioner.partitions[0]
|
|
assert partition.used_mem_bytes == 112
|
|
# Select add_3 node to remove
|
|
selected_node = None
|
|
for node in partition.nodes:
|
|
if node.name == 'add_3':
|
|
selected_node = node
|
|
partition.remove_node(selected_node)
|
|
assert(partition.used_mem_bytes == 80)
|
|
|
|
def test_size_based_partition(self):
|
|
class TestModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(4, 4)
|
|
self.c = torch.rand(4)
|
|
|
|
def forward(self, a, b):
|
|
add_1 = a + b
|
|
linear = self.linear(add_1)
|
|
add_2 = linear + self.c
|
|
return add_2
|
|
|
|
m = TestModule()
|
|
traced = symbolic_trace(m)
|
|
a = torch.rand(4)
|
|
b = torch.rand(4)
|
|
graph_manipulation.get_size_of_all_nodes(traced, [a, b])
|
|
partitioner = Partitioner()
|
|
devices = [
|
|
Device("dev_0", 125, 0),
|
|
Device("dev_1", 125, 1),
|
|
Device("dev_2", 125, 2)
|
|
]
|
|
partitioner_config = PartitionerConfig(devices, PartitionMode.size_based)
|
|
ret = partitioner.partition_graph(traced, m, partitioner_config)
|
|
module_with_submodules = ret.module_with_submodules
|
|
dag = ret.dag
|
|
self.assertEqual(traced(a, b), module_with_submodules(a, b))
|
|
for i, node in enumerate(dag.nodes):
|
|
assert node.logical_device_ids == [i]
|
|
|
|
def test_partition_device_mapping(self):
|
|
class TestModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(4, 4)
|
|
|
|
def forward(self, a):
|
|
b = torch.rand(4)
|
|
add_1 = a + b
|
|
linear_1 = self.linear(add_1)
|
|
add_2 = torch.rand(4) + a
|
|
add_3 = add_2 + linear_1
|
|
return add_3
|
|
|
|
m = TestModule()
|
|
traced = symbolic_trace(m)
|
|
a = torch.rand(4)
|
|
graph_manipulation.get_size_of_all_nodes(traced, [a])
|
|
partitioner = Partitioner()
|
|
devices = [Device("dev_0", 120, 0), Device("dev_1", 160, 1)]
|
|
partitioner_config = PartitionerConfig(devices, PartitionMode.size_based)
|
|
ret = partitioner.partition_graph(traced, m, partitioner_config)
|
|
module_with_submodules = ret.module_with_submodules
|
|
dag = ret.dag
|
|
self.assertEqual(traced(a), module_with_submodules(a))
|
|
for i, node in enumerate(dag.nodes):
|
|
if i == 1:
|
|
assert node.logical_device_ids == [1]
|
|
else:
|
|
assert node.logical_device_ids == [0]
|
|
|
|
def test_sparse_nn_partition(self):
|
|
class MyRecommendationModule(torch.nn.Module):
|
|
def create_mlp(self, num_of_layers: int, input_size: int, output_size: int):
|
|
layers = torch.nn.ModuleList()
|
|
for _ in range(num_of_layers):
|
|
ll = torch.nn.Linear(input_size, output_size)
|
|
layers.append(ll)
|
|
layers.append(torch.nn.ReLU())
|
|
return layers
|
|
|
|
def __init__(self):
|
|
super(MyRecommendationModule, self).__init__()
|
|
layers = self.create_mlp(4, 4, 4)
|
|
self.bottom_layers = torch.nn.Sequential(*layers)
|
|
layers = self.create_mlp(3, 24, 24)
|
|
self.top_layers = torch.nn.Sequential(*layers)
|
|
self.embedding_layers = torch.nn.ModuleList()
|
|
el = torch.nn.EmbeddingBag(500000, 4, mode="sum", sparse=True)
|
|
self.embedding_layers.append(el)
|
|
for i in range(3):
|
|
el = torch.nn.EmbeddingBag(1000000, 4, mode="sum", sparse=True)
|
|
self.embedding_layers.append(el)
|
|
el = torch.nn.EmbeddingBag(500000, 4, mode="sum", sparse=True)
|
|
self.embedding_layers.append(el)
|
|
|
|
def forward(self, a, b, offset):
|
|
x = self.bottom_layers(a)
|
|
y = []
|
|
c = []
|
|
for i in range(len(self.embedding_layers)):
|
|
temp = torch.randint(10, (8,))
|
|
c.append(temp + b)
|
|
for i in range(len(self.embedding_layers)):
|
|
if i % 2 == 0:
|
|
y.append(self.embedding_layers[i](c[i], offset))
|
|
else:
|
|
y.append(
|
|
self.embedding_layers[i](torch.randint(10, (8,)), offset)
|
|
)
|
|
z = torch.cat([x] + y, dim=1)
|
|
p = self.top_layers(z)
|
|
return p
|
|
|
|
m = MyRecommendationModule()
|
|
a = torch.rand(2, 4)
|
|
b = torch.randint(10, (8,))
|
|
offset = torch.randint(1, (2,))
|
|
traced = symbolic_trace(m)
|
|
graph_manipulation.get_size_of_all_nodes(traced, [a, b, offset])
|
|
devices = [
|
|
Device("dev_0", 33000000, 0),
|
|
Device("dev_1", 33000000, 1),
|
|
Device("dev_2", 33000000, 2)
|
|
]
|
|
partitioner_config = PartitionerConfig(devices, PartitionMode.sparse_nn)
|
|
partitioner = Partitioner()
|
|
ret = partitioner.partition_graph(traced, m, partitioner_config)
|
|
module_with_submodules = ret.module_with_submodules
|
|
dag = ret.dag
|
|
self.assertEqual(traced(a, b, offset), module_with_submodules(a, b, offset))
|
|
assert len(module_with_submodules.graph.nodes) == 24
|
|
|
|
def test_partition_latency(self):
|
|
class TestModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(TestModule, self).__init__()
|
|
self.linear = torch.nn.Linear(4, 4)
|
|
|
|
def forward(self, a):
|
|
add_1 = a + torch.rand(4)
|
|
add_2 = add_1 + torch.rand(4)
|
|
linear_1 = self.linear(add_1)
|
|
add_3 = add_2 + linear_1
|
|
add_4 = add_2 + add_3
|
|
return add_4
|
|
|
|
def get_node_to_latency_mapping(fx_module: GraphModule):
|
|
"""Given a fx module, generate node latency for each node
|
|
based on the size of each node
|
|
"""
|
|
node_to_latency_mapping: Dict[Node, NodeLatency] = {}
|
|
for node in fx_module.graph.nodes:
|
|
if node.op not in {"output", "placeholder", "get_attr"}:
|
|
if node.size_bytes.total_size == node.size_bytes.output_size:
|
|
node_to_latency_mapping[node] = NodeLatency(
|
|
node.size_bytes.total_size, 2.0 * node.size_bytes.total_size
|
|
)
|
|
else:
|
|
node_to_latency_mapping[node] = NodeLatency(
|
|
node.size_bytes.total_size, node.size_bytes.output_size
|
|
)
|
|
return node_to_latency_mapping
|
|
|
|
m = TestModule()
|
|
traced = symbolic_trace(m)
|
|
a = torch.rand(4)
|
|
graph_manipulation.get_size_of_all_nodes(traced, [a])
|
|
node_to_latency_mapping = get_node_to_latency_mapping(traced)
|
|
devices = [Device("dev_0", 200, 0), Device("dev_1", 200, 1)]
|
|
partitioner = Partitioner()
|
|
partitioner_config = PartitionerConfig(devices)
|
|
ret = partitioner.partition_graph(traced, m, partitioner_config)
|
|
module_with_submodules = ret.module_with_submodules
|
|
self.assertEqual(traced(a), module_with_submodules(a))
|
|
partitions = partitioner.partitions
|
|
partition_to_latency_mapping = get_partition_to_latency_mapping(
|
|
partitions, node_to_latency_mapping
|
|
)
|
|
for p in partition_to_latency_mapping:
|
|
if p.partition_id == 0:
|
|
assert partition_to_latency_mapping[p] == (128.0, 80.0, 160.0)
|
|
else:
|
|
assert partition_to_latency_mapping[p] == (16.0, 32.0, 32.0)
|
|
transfer_rate_bytes_per_sec = 2
|
|
critical_path_latency_sec = get_latency_of_partitioned_graph(
|
|
partitions, partition_to_latency_mapping, transfer_rate_bytes_per_sec
|
|
)
|
|
assert critical_path_latency_sec == 208.0
|
|
|
|
def test_cost_aware_partition(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(4, 4)
|
|
|
|
def forward(self, a):
|
|
add_1 = a + torch.rand(4)
|
|
add_2 = add_1 + torch.rand(4)
|
|
linear_1 = self.linear(add_1)
|
|
add_3 = add_2 + torch.rand(4)
|
|
add_4 = add_2 + linear_1
|
|
add_5 = add_3 + add_4
|
|
return add_5
|
|
|
|
def get_node_to_latency_mapping(fx_module: GraphModule):
|
|
node_to_latency_mapping: Dict[Node, Nodelatency] = {}
|
|
for node in fx_module.graph.nodes:
|
|
if node.op not in {'output', 'placeholder', 'get_attr'}:
|
|
if node.size_bytes.total_size == node.size_bytes.output_size:
|
|
node_to_latency_mapping[node] = NodeLatency(node.size_bytes.total_size, 1)
|
|
else:
|
|
node_to_latency_mapping[node] = NodeLatency(node.size_bytes.total_size, node.size_bytes.output_size)
|
|
return node_to_latency_mapping
|
|
|
|
m = MyModule()
|
|
traced = symbolic_trace(m)
|
|
a = torch.rand(4)
|
|
graph_manipulation.get_size_of_all_nodes(traced, [a])
|
|
devices = [
|
|
Device('dev_0', 125, 0),
|
|
Device('dev_1', 125, 1),
|
|
Device('dev_2', 125, 2),
|
|
Device('dev_3', 125, 3)
|
|
]
|
|
node_to_latency_mapping = get_node_to_latency_mapping(traced)
|
|
partitioner_config = PartitionerConfig(
|
|
devices,
|
|
mode=PartitionMode.cost_aware,
|
|
transfer_rate_bytes_per_sec=2,
|
|
node_to_latency_mapping=node_to_latency_mapping
|
|
)
|
|
partitioner = Partitioner()
|
|
ret = partitioner.partition_graph(traced, m, partitioner_config)
|
|
module_with_submodules = ret.module_with_submodules
|
|
dag = ret.dag
|
|
self.assertEqual(traced(a), module_with_submodules(a))
|
|
partitions = partitioner.partitions
|
|
partition_to_latency_mapping = get_partition_to_latency_mapping(partitions, node_to_latency_mapping)
|
|
critical_path_latency_sec = get_latency_of_partitioned_graph(
|
|
partitions,
|
|
partition_to_latency_mapping,
|
|
partitioner_config.transfer_rate_bytes_per_sec
|
|
)
|
|
assert critical_path_latency_sec == 160.
|
|
|
|
def test_kl_based_partition(self):
|
|
class TestModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(TestModule, self).__init__()
|
|
self.linear = torch.nn.Linear(4, 4)
|
|
self.b = torch.rand(4)
|
|
self.c = torch.rand(4)
|
|
self.d = torch.rand(4)
|
|
|
|
def forward(self, a):
|
|
add_1 = a + self.b
|
|
add_2 = add_1 + self.c
|
|
linear_1 = self.linear(add_1)
|
|
add_3 = add_2 + linear_1
|
|
add_4 = add_2 + self.d
|
|
add_5 = add_3 + add_4
|
|
return add_4
|
|
m = TestModule()
|
|
traced = symbolic_trace(m)
|
|
a = torch.rand(4)
|
|
graph_manipulation.get_size_of_all_nodes(traced, [a])
|
|
node_to_latency_mapping = get_node_to_latency_mapping(traced)
|
|
transfer_rate_bytes_per_sec = 2
|
|
devices = [
|
|
Device('dev_0', 200, 0),
|
|
Device('dev_1', 200, 1),
|
|
Device('dev_2', 200, 2),
|
|
Device('dev_3', 200, 3)
|
|
]
|
|
partitioner = Partitioner()
|
|
partitioner_config = PartitionerConfig(
|
|
devices,
|
|
mode=PartitionMode.kl_based,
|
|
transfer_rate_bytes_per_sec=transfer_rate_bytes_per_sec,
|
|
node_to_latency_mapping=node_to_latency_mapping
|
|
)
|
|
ret = partitioner.partition_graph(traced, m, partitioner_config)
|
|
module_with_submodules = ret.module_with_submodules
|
|
self.assertEqual(traced(a), module_with_submodules(a))
|
|
dag = ret.dag
|
|
assert dag.nodes[0] == 176
|
|
assert dag.nodes[1] == 112
|
|
partition_to_latency_mapping = get_partition_to_latency_mapping(
|
|
partitioner.partitions,
|
|
node_to_latency_mapping
|
|
)
|
|
cost = get_latency_of_partitioned_graph(
|
|
partitioner.partitions,
|
|
partition_to_latency_mapping,
|
|
transfer_rate_bytes_per_sec
|
|
)
|
|
assert cost == 208.
|
|
|
|
def test_aot_based_partition(self):
|
|
class TestModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(TestModule, self).__init__()
|
|
self.b = torch.rand(4)
|
|
self.c = torch.rand(4)
|
|
|
|
def forward(self, a):
|
|
add_1 = a + self.b
|
|
add_2 = self.c + add_1
|
|
return add_2
|
|
m = TestModule()
|
|
traced = symbolic_trace(m)
|
|
a = torch.rand(4)
|
|
node_to_partition_id = {}
|
|
partition_to_logical_devices = {}
|
|
count = 0
|
|
GraphManipulation.get_size_of_all_nodes(traced, [a])
|
|
for node in traced.graph.nodes:
|
|
if node.op not in {'placeholder', 'get_attr', 'output'}:
|
|
node_to_partition_id[node] = count
|
|
partition_to_logical_devices[count] = [0]
|
|
count += 1
|
|
devices = [Device('dev_0', 200, 0)]
|
|
partitioner_config = PartitionerConfig(
|
|
devices=devices,
|
|
mode=PartitionMode.aot_based,
|
|
node_to_partition_mapping=node_to_partition_id,
|
|
partition_to_logical_device_mapping=partition_to_logical_devices
|
|
)
|
|
partitioner = Partitioner()
|
|
ret = partitioner.partition_graph(traced, m, partitioner_config)
|
|
module_with_submodules = ret.module_with_submodules
|
|
dag = ret.dag
|
|
self.assertEqual(module_with_submodules(a), traced(a))
|
|
for node in dag.nodes:
|
|
assert node.size_bytes == 48
|
|
assert node.logical_device_ids == [0]
|
|
|
|
def test_replace_target_nodes_with(self):
|
|
class testModule(torch.nn.Module):
|
|
def forward(self, a, b):
|
|
return a + b
|
|
m = testModule()
|
|
traced = symbolic_trace(m)
|
|
input1 = torch.randn(1)
|
|
input2 = torch.randn(1)
|
|
assert (input1 + input2) == traced(input1, input2)
|
|
graph_manipulation.replace_target_nodes_with(
|
|
fx_module=traced,
|
|
old_op="call_function",
|
|
old_target=operator.add,
|
|
new_op="call_function",
|
|
new_target=operator.mul,
|
|
)
|
|
assert (input1 * input2) == traced(input1, input2)
|
|
|
|
@skipIfNoTorchVision
|
|
def test_conv_bn_fusion(self):
|
|
rn18 = resnet18().eval()
|
|
traced = symbolic_trace(rn18)
|
|
fused = fuse(traced)
|
|
|
|
self.assertTrue(all(not isinstance(m, torch.nn.BatchNorm2d) for m in fused.modules()))
|
|
|
|
N, C, H, W = 20, 3, 224, 224
|
|
inp = torch.randn(N, C, H, W)
|
|
|
|
self.assertEqual(fused(inp), rn18(inp))
|
|
|
|
def test_call_to_assert_no_msg(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, a, b):
|
|
assert a == b
|
|
return a + b
|
|
|
|
m = M()
|
|
traced = symbolic_trace_with_rewrite(m)
|
|
|
|
# Make sure the graph is well-formed
|
|
traced.graph.lint(traced)
|
|
|
|
# Check the IR to make sure there's a call_function node with target == "Assert"
|
|
self.assertTrue(
|
|
any(
|
|
node.op == "call_function" and node.target == torch._assert
|
|
for node in traced.graph.nodes
|
|
)
|
|
)
|
|
|
|
# Ensure that the assert throws when it's supposed to and doesn't throw when it's not supposed to
|
|
traced(3, 3)
|
|
with self.assertRaisesRegex(AssertionError, ""):
|
|
traced(3, 5)
|
|
|
|
# Confirm that the output is correct
|
|
self.assertEqual(traced(3, 3), m(3, 3))
|
|
|
|
def test_call_to_assert_with_msg(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, a, b):
|
|
assert a == b, "test message"
|
|
return a + b
|
|
|
|
m = M()
|
|
traced = symbolic_trace_with_rewrite(m)
|
|
|
|
# Make sure the graph is well-formed
|
|
traced.graph.lint(traced)
|
|
|
|
# Check the IR to make sure there's a call_function node with target == "Assert"
|
|
self.assertTrue(
|
|
any(
|
|
node.op == "call_function" and node.target == torch._assert
|
|
for node in traced.graph.nodes
|
|
)
|
|
)
|
|
|
|
# Ensure that the assert throws when it's supposed to and doesn't throw when it's not supposed to
|
|
traced(3, 3)
|
|
with self.assertRaisesRegex(AssertionError, "test message"):
|
|
traced(3, 5)
|
|
|
|
# Confirm that the output is correct
|
|
self.assertEqual(traced(3, 3), m(3, 3))
|
|
|
|
def test_call_to_assert_with_empty_msg(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, a, b):
|
|
assert a == b, ""
|
|
return a + b
|
|
|
|
m = M()
|
|
traced = symbolic_trace_with_rewrite(m)
|
|
|
|
# Make sure the graph is well-formed
|
|
traced.graph.lint(traced)
|
|
|
|
# Check the IR to make sure there's a call_function node with target == "Assert"
|
|
self.assertTrue(
|
|
any(
|
|
node.op == "call_function" and node.target == torch._assert
|
|
for node in traced.graph.nodes
|
|
)
|
|
)
|
|
|
|
# Ensure that the assert throws when it's supposed to and doesn't throw when it's not supposed to
|
|
traced(3, 3)
|
|
with self.assertRaisesRegex(AssertionError, ""):
|
|
traced(3, 5)
|
|
|
|
# Confirm that the output is correct
|
|
self.assertEqual(traced(3, 3), m(3, 3))
|
|
|
|
def test_call_to_assert_with_multiline_message(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, a, b):
|
|
error_msg = """
|
|
An error message with
|
|
terrible spacing
|
|
"""
|
|
assert a == b, error_msg
|
|
return a + b
|
|
|
|
m = M()
|
|
traced = symbolic_trace_with_rewrite(m)
|
|
|
|
# Make sure the graph is well-formed
|
|
traced.graph.lint(traced)
|
|
|
|
# Check the IR to make sure there's a call_function node with target == "Assert"
|
|
self.assertTrue(
|
|
any(
|
|
node.op == "call_function" and node.target == torch._assert
|
|
for node in traced.graph.nodes
|
|
)
|
|
)
|
|
|
|
# Ensure that the assert throws when it's supposed to and doesn't throw when it's not supposed to
|
|
error_msg = """
|
|
An error message with
|
|
terrible spacing
|
|
"""
|
|
traced(3, 3)
|
|
with self.assertRaisesRegex(AssertionError, error_msg):
|
|
traced(3, 5)
|
|
|
|
# Confirm that the output is correct
|
|
self.assertEqual(traced(3, 3), m(3, 3))
|
|
|
|
def test_subgraph_creation(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.param = torch.nn.Parameter(torch.rand(3, 4))
|
|
self.linear = torch.nn.Linear(4, 5)
|
|
|
|
def forward(self, x, y):
|
|
z = self.linear(x + self.param).clamp(min=0.0, max=1.0)
|
|
w = self.linear(y).clamp(min=0.0, max=1.0)
|
|
return z + w
|
|
|
|
# symbolically trace model
|
|
my_module = MyModule()
|
|
my_module_traced = symbolic_trace(my_module)
|
|
|
|
# random mod partitioning
|
|
partition_counter = 0
|
|
NPARTITIONS = 3
|
|
|
|
def mod_partition(node: Node):
|
|
nonlocal partition_counter
|
|
partition = partition_counter % NPARTITIONS
|
|
partition_counter = (partition_counter + 1) % NPARTITIONS
|
|
return partition
|
|
|
|
# split module in module with submodules
|
|
module_with_submodules = split_module(my_module_traced, my_module, mod_partition)
|
|
|
|
x = torch.rand(3, 4)
|
|
y = torch.rand(3, 4)
|
|
|
|
orig_out = my_module_traced(x, y)
|
|
submodules_out = module_with_submodules(x, y)
|
|
|
|
self.assertEqual(orig_out, submodules_out)
|
|
|
|
@skipIfNoTorchVision
|
|
def test_subgraph_trivial_resnet(self):
|
|
# Smoke test trivially splitting resnet into 1 partition works
|
|
# There was an issue before causing submodule names to be aliased
|
|
m = resnet18()
|
|
traced = symbolic_trace(m)
|
|
a = torch.rand(64, 3, 7, 7)
|
|
module_with_submodules = split_module(traced, m, lambda node: 0)
|
|
module_with_submodules(a)
|
|
|
|
def test_subgraph_uniquename(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(4, 4)
|
|
|
|
def forward(self, a, b, c, d):
|
|
add_1 = a + b
|
|
add_2 = add_1 + c
|
|
linear_1 = self.linear(add_1)
|
|
add_3 = add_2 + d
|
|
add_4 = add_2 + linear_1
|
|
add_5 = add_3 + add_4
|
|
return add_5
|
|
|
|
a, b, c, d = torch.ones(4), torch.ones(4), torch.ones(4), torch.ones(4)
|
|
mm = MyModule()
|
|
traced = symbolic_trace(mm)
|
|
|
|
def split_cb(node : torch.fx.Node):
|
|
if node.name == 'a' or node.name == 'b' or node.name == 'add':
|
|
return 0
|
|
else:
|
|
return 1
|
|
module_with_submodule = split_module(traced, mm, split_cb)
|
|
self.assertEqual(module_with_submodule(a, b, c, d), traced(a, b, c, d))
|
|
|
|
def test_traceable_function_with_nonstandard_name(self):
|
|
def foo(x):
|
|
return torch.relu(x)
|
|
|
|
traced = symbolic_trace_with_rewrite(foo)
|
|
|
|
def test_to_folder(self):
|
|
class Test(torch.nn.Module):
|
|
def __init__(self):
|
|
super(Test, self).__init__()
|
|
self.W = torch.nn.Parameter(torch.randn(2))
|
|
self.seq = torch.nn.Sequential(torch.nn.BatchNorm1d(2, 2))
|
|
self.linear = torch.nn.Linear(2, 2)
|
|
self.attr = torch.randn(2)
|
|
self.register_buffer('attr2', torch.randn(2))
|
|
|
|
def forward(self, x):
|
|
return self.linear(self.seq(self.W + self.attr + self.attr2 + x))
|
|
|
|
mod = symbolic_trace(Test())
|
|
module_name = 'Foo'
|
|
import tempfile
|
|
from pathlib import Path
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
tmp_dir = Path(tmp_dir)
|
|
mod.to_folder(tmp_dir, module_name)
|
|
# Recipe taken from here:
|
|
# https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly
|
|
import importlib.util
|
|
spec = importlib.util.spec_from_file_location(module_name, tmp_dir / '__init__.py')
|
|
module = importlib.util.module_from_spec(spec)
|
|
sys.modules[module_name] = module
|
|
spec.loader.exec_module(module)
|
|
t = torch.randn(2, 2)
|
|
self.assertEqual(module.Foo()(t), mod(t))
|
|
|
|
def test_fetch(self):
|
|
attrs_for_lowering: Dict[str, List[str]] = {
|
|
"torch.nn.modules.conv.Conv2d": [
|
|
"weight", "bias", "kernel_size", "stride", "padding", "dilation", "groups", "padding_mode"
|
|
],
|
|
"torch.nn.modules.batchnorm.BatchNorm2d": [
|
|
"weight", "bias", "running_mean", "running_var", "eps"
|
|
],
|
|
}
|
|
|
|
class TestModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(3, 3, 2)
|
|
self.bn = torch.nn.BatchNorm2d(3)
|
|
|
|
def forward(self, a):
|
|
a = self.conv(a)
|
|
a += a
|
|
return self.bn(a)
|
|
|
|
mod = TestModule()
|
|
traced = symbolic_trace(mod)
|
|
lift_lowering_attrs_to_nodes(traced)
|
|
|
|
for node in traced.graph.nodes:
|
|
if node.op == "call_module":
|
|
assert hasattr(node, "attrs_for_lowering")
|
|
para_list = attrs_for_lowering[node.attrs_for_lowering["name"]]
|
|
|
|
# node.attrs_for_lowering has an addition field of class name
|
|
assert len(para_list) + 1 == len(node.attrs_for_lowering)
|
|
for p_name in para_list:
|
|
assert p_name in node.attrs_for_lowering
|
|
|
|
def test_merge_matmuls(self):
|
|
"""
|
|
A collection of test cases for torch.fx.experimental.merge_matmul,
|
|
a graph transformation that merges matrix multiplication operations.
|
|
"""
|
|
# Utility function for counting matmuls for test assertions.
|
|
def _count_matmuls(mod):
|
|
gm = torch.fx.symbolic_trace(mod)
|
|
|
|
num_matmuls = 0
|
|
for node in gm.graph.nodes:
|
|
if node.target == torch.matmul:
|
|
num_matmuls += 1
|
|
|
|
return num_matmuls
|
|
|
|
# Simple test case in which there are two matmuls of the same size to merge.
|
|
class SimpleMergeMatmulModule(torch.nn.Module):
|
|
def __init__(self, rhs):
|
|
super().__init__()
|
|
self.rhs = rhs
|
|
|
|
def forward(self, x, y):
|
|
a = torch.matmul(x, self.rhs)
|
|
b = torch.matmul(y, self.rhs)
|
|
return a + b
|
|
|
|
# Initialize inputs.
|
|
a = torch.randn(3, 3)
|
|
b = torch.randn(3, 3)
|
|
|
|
# Initialize RHS for matmuls.
|
|
rhs = torch.randn(3, 4)
|
|
|
|
# Construct SimpleMergeMatmulModule and call merge_matmul on it.
|
|
module = SimpleMergeMatmulModule(rhs)
|
|
opt_module = merge_matmul.merge_matmul(module)
|
|
|
|
# Numerical correctness check.
|
|
before = module(a, b)
|
|
after = opt_module(a, b)
|
|
before.allclose(after)
|
|
|
|
# Basic graph structure check; original module should have 2 matmuls
|
|
# and optimized module should have 1.
|
|
self.assertEqual(_count_matmuls(module), 2)
|
|
self.assertEqual(_count_matmuls(opt_module), 1)
|
|
|
|
# Test case in which there are multiple matmuls of different sizes to merge.
|
|
class FiveMergeMatmulModule(torch.nn.Module):
|
|
def __init__(self, rhs):
|
|
super().__init__()
|
|
self.rhs = rhs
|
|
|
|
def forward(self, a, b, c, d, e):
|
|
s = torch.Tensor((0))
|
|
matmuls = []
|
|
|
|
# For some reason using a list comprehension or for-loop for this
|
|
# doesn't work.
|
|
matmuls.append(torch.matmul(a, self.rhs))
|
|
matmuls.append(torch.matmul(b, self.rhs))
|
|
matmuls.append(torch.matmul(c, self.rhs))
|
|
matmuls.append(torch.matmul(d, self.rhs))
|
|
matmuls.append(torch.matmul(e, self.rhs))
|
|
|
|
for m in matmuls:
|
|
s += torch.sum(m)
|
|
|
|
return s
|
|
|
|
# Initialize inputs.
|
|
inputs = [torch.randn(2 * i + 1, 5) for i in range(5)]
|
|
|
|
# Initialize RHS.
|
|
rhs = torch.randn(5, 4)
|
|
|
|
# Construct FiveMergeMatmulModule and call merge_matmul on it.
|
|
module = FiveMergeMatmulModule(rhs)
|
|
opt_module = merge_matmul.merge_matmul(module)
|
|
|
|
# Numerical correctness check.
|
|
before = module(*inputs)
|
|
after = opt_module(*inputs)
|
|
before.allclose(after)
|
|
|
|
# Basic graph structure check; original module should have len(inputs) matmuls
|
|
# and optimized module should have 1.
|
|
self.assertEqual(_count_matmuls(module), len(inputs))
|
|
self.assertEqual(_count_matmuls(opt_module), 1)
|
|
|
|
# Simple test case in which two matmuls cannot be merged due to a data dependency between
|
|
# the LHS operands.
|
|
class UnmergeableMatmulModule(torch.nn.Module):
|
|
def __init__(self, rhs):
|
|
super().__init__()
|
|
self.rhs = rhs
|
|
|
|
def forward(self, x):
|
|
a = torch.matmul(x, self.rhs)
|
|
a_abs = torch.abs(a)
|
|
b = torch.matmul(a_abs.transpose(1, 0), self.rhs)
|
|
return b
|
|
|
|
# Initialize inputs.
|
|
a = torch.randn(3, 3)
|
|
|
|
# Initialize RHS for matmuls.
|
|
rhs = torch.randn(3, 4)
|
|
|
|
# Construct UnmergeableMatmulModule and call merge_matmul on it.
|
|
module = UnmergeableMatmulModule(rhs)
|
|
opt_module = merge_matmul.merge_matmul(module)
|
|
|
|
# Numerical correctness check.
|
|
before = module(a)
|
|
after = opt_module(a)
|
|
before.allclose(after)
|
|
|
|
# Basic graph structure check; the number of matrix multiplcations should not have changed.
|
|
self.assertEqual(_count_matmuls(module), 2)
|
|
self.assertEqual(_count_matmuls(opt_module), 2)
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|