pytorch/test/dynamo/test_graph_region_tracker.py
Xinya Zhang e769026bcb [ROCm] Remove HIPBLASLT_ALLOW_TF32 from codebase (#162998)
A few UT failures are caused by `HIPBLASLT_ALLOW_TF32`

Fixes #157094
Fixes #157093
Fixes #157092
Fixes #157091
Fixes #157064
Fixes #157063
Fixes #157062
Fixes #157061
Fixes #157042
Fixes #157041
Fixes #157039
Fixes #157004

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162998
Approved by: https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
2025-09-18 13:53:48 +00:00

375 lines
12 KiB
Python

# Owner(s): ["module: dynamo"]
import contextlib
import torch
import torch.fx
from torch._dynamo.test_case import TestCase
from torch._dynamo.testing import extract_graph_and_tracker
from torch.utils._pytree import tree_map
class GraphRegionTrackerTests(TestCase):
def setUp(self):
self.exit_stack = contextlib.ExitStack()
self.exit_stack.enter_context(
torch._dynamo.config.patch("track_nodes_for_deduplication", True)
)
super().setUp()
def tearDown(self):
self.exit_stack.close()
super().tearDown()
def get_result(self, fn, *args, **kwargs):
graph, region_tracker = extract_graph_and_tracker(fn, *args, **kwargs)
region_groups = region_tracker.get_identical_regions(graph)
region_groups = tree_map(lambda n: n.name, region_groups)
return str(region_groups)
def get_mutation_tracking(self, fn, *args, **kwargs):
_, region_tracker = extract_graph_and_tracker(fn, *args, **kwargs)
return str(region_tracker.node_to_mutated_arg_positions)
def test_get_regions_single_region_group(self):
def inner_fn(x, y):
x0 = x + 1
y0 = y + 2
z = x0.sum() + y0.sum()
return z
def fn(x, y):
o0 = inner_fn(x, y)
o1 = torch.sin(y)
o2 = inner_fn(x, o1)
o3 = inner_fn(x, y)
o4 = o3 * o3
return o2 * o4 + o0
self.assertExpectedInline(
self.get_result(
fn,
torch.rand(10, 10),
torch.ones(10, 20),
),
"""[[['x0', 'y0', 'sum_1', 'sum_2', 'z'], ['x0_1', 'y0_1', 'sum_3', 'sum_4', 'z_1'],\
['x0_2', 'y0_2', 'sum_5', 'sum_6', 'z_2']]]""",
)
def test_get_regions_multiple_region_groups(self):
def inner_fn(x, y):
x1 = x + 1
y1 = y + 2
z = x1.sum() + y1.sum()
return z
def inner_fn2(a, b):
a += 2
b += 3
c = a * b.cos().sum()
return c
def fn(x, y):
x0 = torch.cos(x)
y0 = torch.sin(y)
o1 = inner_fn2(x0, y0)
o0 = inner_fn(x, y)
o1 = torch.sin(o0)
o2 = inner_fn(x, y0)
o2 = inner_fn2(x0, y0)
o3 = inner_fn(x, y)
return o1 * o2 + o3
self.assertExpectedInline(
self.get_result(
fn,
torch.rand(10, 10),
torch.ones(10, 20),
),
"""[[['x1', 'y1', 'sum_2', 'sum_3', 'z'], ['x1_1', 'y1_1', 'sum_4', 'sum_5', 'z_1'],\
['x1_2', 'y1_2', 'sum_7', 'sum_8', 'z_2']], [['a', 'b', 'cos_1', 'sum_1', 'c'], ['a_1', 'b_1', 'cos_2', 'sum_6', 'c_1']]]""",
)
def test_no_single_node_regions(self):
def inner_fn(x):
return x + 1
def fn(x):
o0 = inner_fn(x)
o1 = inner_fn(x)
o2 = inner_fn(x)
return o0 + o1 + o2
self.assertExpectedInline(self.get_result(fn, torch.ones(10, 10)), """[]""")
def test_mismatched_arg_shapes(self):
def inner_fn(x, y):
x1 = x + 1
y1 = y + 2
z = x1.sum() + y1.sum()
return z
def inner_fn2(a, b):
a += 2
b += 3
c = a * b.cos().sum()
return c
def fn(x, y):
x0 = torch.cos(x)
y0 = torch.sin(y)
o1 = inner_fn2(x0, y0)
o0 = inner_fn(x, o1)
o1 = torch.sin(o0)
o2 = inner_fn(x, y0)
o2 = inner_fn2(o2, y0)
o3 = inner_fn(x, y)
return o1 * o2 + o3
self.assertExpectedInline(
self.get_result(
fn,
torch.rand(10, 10),
torch.ones(10, 20),
),
"""[[['y1_1', 'sum_5'], ['y1_2', 'sum_8']], [['x1', 'sum_2', 'z'], ['x1_1', 'sum_4', 'z_1'], \
['x1_2', 'sum_7', 'z_2']], [['b', 'cos_1', 'sum_1'], ['b_1', 'cos_2', 'sum_6']]]""",
)
def test_mismatched_dtypes(self):
def inner_fn(x, y):
x1 = x * 1
y1 = y + 1
return x1 + y1.sum()
def fn(x, y):
x0 = torch.sin(x)
y0 = torch.cos(y)
o0 = inner_fn(x0, y0)
o2 = inner_fn(x0, y0)
o4 = inner_fn(x0, y0)
o5 = inner_fn(x0, y0)
o1 = inner_fn(x0.to(torch.bfloat16), y0.to(torch.bfloat16))
o3 = o1 + o2
return o3 * o0 + o4 + o5
self.assertExpectedInline(
self.get_result(
fn,
torch.rand(10, 10),
torch.ones(10, 20),
),
"""[[['x1', 'y1', 'sum_1', 'o0'], ['x1_1', 'y1_1', 'sum_2', 'o2'], \
['x1_2', 'y1_2', 'sum_3', 'o4'], ['x1_3', 'y1_3', 'sum_4', 'o5']]]""",
)
def test_nested_args(self):
def inner_fn(xs, ys):
out = torch._foreach_add(xs, ys)
return out[0] + out[1].sum()
def fn(x, y, z):
x0 = torch.sin(x)
y0 = torch.cos(y)
z0 = torch.sin(z)
o0 = inner_fn([x0, z0], [x0, y0])
o2 = inner_fn([x0, z0], [x0, y0])
o4 = inner_fn([x0, z0], [x0, y0])
o5 = inner_fn([x0, z0], [x0, y0])
o1 = inner_fn(
[x0.to(torch.bfloat16), z0.to(torch.bfloat16)],
[x0.to(torch.bfloat16), y0.to(torch.bfloat16)],
)
o3 = o1 + o2
return o3 * o0 + o4 + o5
self.assertExpectedInline(
self.get_result(
fn,
torch.rand(10, 10),
torch.rand(10, 20),
torch.ones(10, 20),
),
"""[[['_foreach_add', 'getitem', 'getitem_1', 'sum_1', 'o0'], ['_foreach_add_1', \
'getitem_2', 'getitem_3', 'sum_2', 'o2'], ['_foreach_add_2', 'getitem_4', 'getitem_5', 'sum_3', \
'o4'], ['_foreach_add_3', 'getitem_6', 'getitem_7', 'sum_4', 'o5']]]""",
)
def test_mismatched_global_state(self):
def inner_fn(x, y):
x1 = x * 1
y1 = y + 1
return x1 + y1.sum()
def fn(x, y, c):
x0 = torch.sin(x)
y0 = torch.cos(y)
o4 = inner_fn(x0, y0)
o5 = inner_fn(x0, y0)
if isinstance(c, tuple):
c[0]()
o0 = inner_fn(x0, y0)
o2 = inner_fn(x0, y0)
c[1]()
else:
with c():
o0 = inner_fn(x0, y0)
o2 = inner_fn(x0, y0)
return o0 + o2 + o4 + o5
def create_toggle_fns(property):
old_value = getattr(torch.backends.cuda.matmul, property)
def toggle_property():
setattr(torch.backends.cuda.matmul, property, not old_value)
def reset_property():
setattr(torch.backends.cuda.matmul, property, old_value)
return toggle_property, reset_property
old_dtype = torch.get_default_dtype()
def set_default_dtype_bfloat16():
torch.set_default_dtype(torch.bfloat16)
def reset_default_dtype():
torch.set_default_dtype(old_dtype)
for ctx in [
lambda: torch.set_grad_enabled(False),
torch.autograd.grad_mode.inference_mode,
lambda: torch.autograd.graph.disable_saved_tensors_hooks(
"This is not supported"
),
# lambda: torch.set_num_threads(2), : Unsupported
(set_default_dtype_bfloat16, reset_default_dtype),
(
lambda: torch.use_deterministic_algorithms(True),
lambda: torch.use_deterministic_algorithms(False),
),
# (lambda: torch.use_deterministic_algorithms(True, warn_only=True),
# lambda: torch.use_deterministic_algorithms(False)), : Unsupported
create_toggle_fns("allow_bf16_reduced_precision_reduction"),
create_toggle_fns("allow_fp16_reduced_precision_reduction"),
create_toggle_fns("allow_tf32"),
]:
self.assertExpectedInline(
self.get_result(fn, torch.rand(10, 10), torch.ones(10, 20), ctx),
"""[[['x1_2', 'y1_2', 'sum_3', 'o0'], ['x1_3', 'y1_3', 'sum_4', 'o2']], \
[['x1', 'y1', 'sum_1', 'o4'], ['x1_1', 'y1_1', 'sum_2', 'o5']]]""",
)
def test_mutation_tracking_simple(self):
def fn(x, y, z):
x0 = torch.sin(x)
y0 = torch.cos(y)
z.sin_()
y0.add_(z)
return x0.sum() + y0.sum()
self.assertExpectedInline(
self.get_mutation_tracking(
fn,
torch.rand(10, 10),
torch.rand(10, 20),
torch.ones(10, 20),
),
"""{sin_: OrderedSet([0]), add_: OrderedSet([0])}""",
)
def test_mutation_tracking_setitem(self):
def fn(x):
y = x + 1
y[0] = 3
return y
self.assertExpectedInline(
self.get_mutation_tracking(fn, torch.rand(10, 10)),
"""{setitem: OrderedSet([0])}""",
)
def test_mutation_tracking_allow_in_graph(self):
@torch._dynamo.allow_in_graph
def fn_mut(x, y):
x.add_(y)
return x.sum() + y.sum()
def fn(x, y):
z = x + y
o0 = fn_mut(z, y)
z.sin_()
return x + o0
self.assertExpectedInline(
self.get_mutation_tracking(
fn,
torch.rand(20, 10),
torch.rand(20, 10),
),
"""{o0: OrderedSet([0]), sin_: OrderedSet([0])}""",
)
def test_non_tensor_arg_hashing(self):
def inner(x, w, t):
y = x + x
return torch.conv2d(y, w, None, *t)
def fn(x, y):
o1 = inner(x, y, ((1, 1), (0, 0), (1, 1), 1))
o2 = inner(x, y, ((1, 1), (0, 0), (1, 1), 1))
o3 = inner(x, y, ((1, 1), (0, 0), (1, 1), 1))
o4 = inner(x, y, ((2, 2), (0, 0), (1, 1), 1))
return o1.sum() + o2.sum() + o3.sum() + o4.sum()
self.assertExpectedInline(
self.get_result(
fn,
torch.rand(32, 256, 56, 56),
torch.nn.Parameter(torch.rand(512, 256, 1, 1)),
),
"""[[['y', 'o1'], ['y_1', 'o2'], ['y_2', 'o3']]]""",
)
def test_region_sorting(self):
from torch._dynamo.graph_region_tracker import _sort_with_ref_region
index_to_rank = {0: 0, 2: 1, 1: 2}
regions = [[0, 1, 2], [1, 2, 0]]
_sort_with_ref_region(index_to_rank, regions)
self.assertExpectedInline(regions, """[[0, 2, 1], [1, 0, 2]]""")
def test_no_duplicate_tracking(self):
def inner_fn(x, y):
x0 = x + 1
y0 = y + 2
z = x0.sum() + y0.sum()
return z
def fn(x, y):
o0 = inner_fn(x, y)
o1 = torch.sin(y)
o2 = inner_fn(x, o1)
o3 = inner_fn(x, y)
o4 = o3 * o3
return o2 * o4 + o0
graph, tracker = extract_graph_and_tracker(
fn, torch.rand(10, 10), torch.ones(10, 20)
)
self.assertExpectedInline(
tracker.node_to_duplicates,
"""{l_x_: [l_x_], x0: [x0, x0_1, x0_2], l_y_: [l_y_], y0: [y0, y0_1, y0_2], sum_1: \
[sum_1, sum_3, sum_5], sum_2: [sum_2, sum_4, sum_6], z: [z, z_1, z_2], o1: [o1], x0_1: [x0, x0_1, x0_2], y0_1: [y0, y0_1, y0_2], \
sum_3: [sum_1, sum_3, sum_5], sum_4: [sum_2, sum_4, sum_6], \
z_1: [z, z_1, z_2], x0_2: [x0, x0_1, x0_2], y0_2: [y0, y0_1, y0_2], sum_5: [sum_1, sum_3, sum_5], sum_6: [sum_2, sum_4, sum_6], \
z_2: [z, z_1, z_2], o4: [o4], mul_1: [mul_1], add_9: [add_9]}""",
)
key = next(iter(tracker.node_to_duplicates.keys()))
tracker.track_node(None, key) # this will fail if the node is added again
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()