pytorch/test/dynamo/test_reconstruct.py
Guilherme Leobas 34127fc688 Only reconstruct dict if needed (#141606)
Fixes #141452

This is a follow-up of PR #134876, which optimized dict reconstruct to codegen only if any value changed. In this PR we cover the general case and do not codegen any instruction if the dictionary remains the same.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141606
Approved by: https://github.com/zou3519
2024-12-03 02:22:34 +00:00

309 lines
9.2 KiB
Python

# Owner(s): ["module: dynamo"]
import contextlib
import dis
import unittest
from typing import List
import torch
import torch._dynamo.test_case
from torch.testing._internal.common_utils import IS_FBCODE
def _filter_instructions(instructions, opname):
return list(filter(lambda x: x.opname == opname, instructions))
class ReconstructTest(torch._dynamo.test_case.TestCase):
@contextlib.contextmanager
def register_bytecode_hook(self, fn):
def hook(code, out_code):
fn(list(dis.get_instructions(out_code)))
return code
torch._dynamo.reset()
handle = torch._dynamo.convert_frame.register_bytecode_hook(hook)
try:
yield
finally:
handle.remove()
def test_ConstDict_optimize_reconstruct(self):
"""
Emit code to reconstruct only the key that changed
"""
def hook(instructions: List[dis.Instruction]):
build_map = _filter_instructions(instructions, "BUILD_MAP")
self.assertEqual(len(build_map), 1)
# reconstruct only d[40]
self.assertEqual(build_map[0].argval, 1)
def f(d, t):
d[40] = t + 1
t = torch.randn(3, 4)
d = {1: t}
d_opt = d.copy()
f(d, t)
with self.register_bytecode_hook(hook):
opt_f = torch.compile(f, backend="eager", fullgraph=True)
opt_f(d_opt, t)
self.assertEqual(d, d_opt)
def test_ConstDict_pop_reconstruct(self):
"""
If something is pop'ed from the dict, we reconstruct everything
"""
def hook(instructions: List[dis.Instruction]):
build_map = _filter_instructions(instructions, "BUILD_MAP")
self.assertEqual(len(build_map), 1)
# reconstruct everything
self.assertEqual(build_map[0].argval, 2)
def f(d, t):
d.pop(2)
d[40] = t + 1
t = torch.randn(3, 4)
d = {1: t, 2: t + 1}
d_opt = d.copy()
f(d, t)
with self.register_bytecode_hook(hook):
opt_f = torch.compile(f, backend="eager", fullgraph=True)
opt_f(d_opt, t)
self.assertEqual(d, d_opt)
@unittest.expectedFailure
def test_ConstDict_popitem_reconstruct(self):
"""
If something is pop'ed from the dict, we reconstruct everything
"""
def hook(instructions: List[dis.Instruction]):
build_map = _filter_instructions(instructions, "BUILD_MAP")
self.assertEqual(len(build_map), 1)
# reconstruct everything
self.assertEqual(build_map[0].argval, 1)
def f(d, t):
d.popitem()
t = torch.randn(3, 4)
d = {1: t, 2: t + 1}
d_opt = d.copy()
f(d, t)
with self.register_bytecode_hook(hook):
opt_f = torch.compile(f, backend="eager", fullgraph=True)
opt_f(d_opt, t)
self.assertEqual(d, d_opt)
def test_ConstDict_popitem_reconstruct_graph_break(self):
"""
If something is pop'ed from the dict, we reconstruct everything.
Calling dict.popitem will graph break.
"""
def f(d, t):
d.popitem()
t = torch.randn(3, 4)
d = {1: t, 2: t + 1}
d_opt = d.copy()
f(d, t)
opt_f = torch.compile(backend="eager")(f)
opt_f(d_opt, t)
self.assertEqual(d, d_opt)
def test_ConstDict_del_reconstruct(self):
"""
If something is deleted from the dict, we reconstruct everything
"""
def hook(instructions: List[dis.Instruction]):
build_map = _filter_instructions(instructions, "BUILD_MAP")
self.assertEqual(len(build_map), 1)
# reconstruct everything
self.assertEqual(build_map[0].argval, 2)
def f(d, t):
del d[2]
d[40] = t + 1
t = torch.randn(3, 4)
d = {1: t, 2: t + 1}
d_opt = d.copy()
f(d, t)
with self.register_bytecode_hook(hook):
opt_f = torch.compile(f, backend="eager", fullgraph=True)
opt_f(d_opt, t)
self.assertEqual(d, d_opt)
def test_ConstDict_get_reconstruct(self):
"""
dict.get shouldn't affect anything
"""
def hook(instructions: List[dis.Instruction]):
build_map = _filter_instructions(instructions, "BUILD_MAP")
self.assertEqual(len(build_map), 1)
self.assertEqual(build_map[0].argval, 1)
load_const = _filter_instructions(instructions, "LOAD_CONST")
self.assertNotIn(123, load_const)
def f(d, t):
d[456] = d.get(456) + t
t = torch.randn(3, 4)
d = {123: t, 456: t + 1}
d_opt = d.copy()
f(d, t)
with self.register_bytecode_hook(hook):
opt_f = torch.compile(f, backend="eager", fullgraph=True)
opt_f(d_opt, t)
self.assertEqual(d, d_opt)
def test_ConstDict_clear_reconstruct(self):
"""
If dict.clear() is used, we reconstruct everything
"""
def hook(instructions: List[dis.Instruction]):
build_map = _filter_instructions(instructions, "BUILD_MAP")
self.assertEqual(len(build_map), 1)
# reconstruct everything
self.assertEqual(build_map[0].argval, 1)
def f(d, t):
d.clear()
d[3] = t + 3
t = torch.randn(3, 4)
d = {1: t, 2: t + 1}
d_opt = d.copy()
f(d, t)
with self.register_bytecode_hook(hook):
opt_f = torch.compile(f, backend="eager", fullgraph=True)
opt_f(d_opt, t)
self.assertEqual(d, d_opt)
def test_create_dict_reconstruct(self):
"""
If dict is created inside a function, everything needs to be reconstructed
"""
def hook(instructions: List[dis.Instruction]):
build_map = _filter_instructions(instructions, "BUILD_MAP")
self.assertEqual(len(build_map), 1)
# reconstruct everything
self.assertEqual(build_map[0].argval, 2)
def f(t):
return {1: t, 2: t + 1}
t = torch.randn(3, 4)
d = f(t)
with self.register_bytecode_hook(hook):
opt_f = torch.compile(f, backend="eager", fullgraph=True)
d_opt = opt_f(t)
self.assertEqual(d, d_opt)
@unittest.skipIf(
IS_FBCODE, "capturing functional_call is not enabled by default in FB_CODE"
)
def test_functional_call_reconstruct(self):
"""
PyTorch shouldn't codegen any key/value when functional_call is used
"""
def hook(instructions: List[dis.Instruction]):
build_map = _filter_instructions(instructions, "BUILD_MAP")
# don't reconstruct anything
self.assertEqual(len(build_map), 0)
m = torch.nn.Linear(3, 3)
new_bias = torch.randn(3)
new_weight = torch.randn(3, 3)
def fn(new_weight, new_bias, x):
return torch.func.functional_call(
m, {"weight": new_weight, "bias": new_bias}, x
)
x = torch.randn(2, 3)
expected = torch.nn.functional.linear(x, new_weight, new_bias)
with self.register_bytecode_hook(hook):
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
got = opt_fn(new_weight, new_bias, x)
self.assertEqual(expected, got)
@unittest.skipIf(
IS_FBCODE, "capturing functional_call is not enabled by default in FB_CODE"
)
def test_functional_call_reconstruct_2(self):
"""
PyTorch shouldn't codegen any key/value when functional_call is used
"""
def hook(instructions: List[dis.Instruction]):
build_map = _filter_instructions(instructions, "BUILD_MAP")
# don't reconstruct anything
self.assertEqual(len(build_map), 0)
class DummyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.a = torch.nn.ModuleDict(
{
"b": torch.nn.ModuleDict(
{
"c": torch.nn.ModuleDict(
{
"d": torch.nn.ModuleDict(
{"e": torch.nn.Linear(10, 10, bias=False)}
)
}
)
}
)
}
)
def forward(self, x):
return self.a.b.c.d.e(x)
model = DummyModule()
def fn(model, states, x):
return torch.func.functional_call(model, states, x)
x = torch.randn(2, 3)
states = model.state_dict()
x = torch.randn(10, 10)
expected = fn(model, states, x)
with self.register_bytecode_hook(hook):
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
got = opt_fn(model, states, x)
self.assertEqual(expected, got)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()