pytorch/test/inductor/test_memory_planning.py
bobrenjc93 f649ee73ce Use source hashing to generate consistent symbolic ids (#149665)
This PR was inspired by internal models that were cache missing due to PGO. At a high level the problem looks as follows

Run 1, Invocation 1: We do static compile, save some example values in PGO/automatic dynamic

Run 1, Invocation 2: We detect varying inputs, do dynamic compile, get a dynamic graph and save to PGO. Crucially what we save to PGO is actually a superset of what is actually dynamic. If we notice an input was varying, we mark it as dynamic in PGO even if later on that value gets specialized. When a value gets specialized, we actually remove the symbol from the graph. This results in an interesting conundrum where although we are producing the same isomorphic graph, PGO makes the second run cache miss. Let's see how....

Run 2, Invocation 1: We fetch the PGO, over-mark things as dynamic, get a fx graph, look it up in the cache and... whoops! cache miss! This is because of the aforementioned behavior where the PGO profile will cause us to over-allocate symbols. In practice this means we end up saving a graph in cache with symbols x:s1, y:s3 and on second attempt we cache miss with x:s1, y:s6 where symbols s3,s4,s5 were all optimistically marked dynamic by PGO and subsequently specialized.

We solve this problem by hashing the source names. This ensures somewhat stable assignment. To prevent catastrophic symbol collisions, we use linear probing to ensure no collisions.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149665
Approved by: https://github.com/Mingming-Ding, https://github.com/laithsakka
2025-03-28 05:36:32 +00:00

124 lines
4.1 KiB
Python

# Owner(s): ["module: inductor"]
import sys
import unittest
from torch.testing._internal.common_utils import IS_CI, IS_WINDOWS, skipIfXpu
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU, requires_gpu
if IS_WINDOWS and IS_CI:
sys.stderr.write(
"Windows CI does not have necessary dependencies for test_memory_planning yet\n"
)
if __name__ == "__main__":
sys.exit(0)
raise unittest.SkipTest("requires sympy/functorch/filelock") # noqa: F821
import torch
from torch._C import FileCheck
from torch._dynamo.utils import same
from torch._inductor import config
from torch._inductor.test_case import run_tests, TestCase
from torch._inductor.utils import run_and_get_cpp_code
from torch.export import Dim
@requires_gpu()
@config.patch(memory_planning=True)
class TestMemoryPlanning(TestCase):
device = GPU_TYPE
def _generate(self, *, device):
"""
Generate a simple test case that has multiple simultaneously-live intermediate tensors.
"""
class Foo(torch.nn.Module):
def forward(self, x, y, z):
t0 = x.matmul(y)
t1 = x.matmul(z)
t0 = x.transpose(0, 1).matmul(t1)
t1 = x.matmul(t0)
return t0.sum() + t1.sum()
x = torch.randn((3, 2), device=device)
y = torch.randn((2, 4), device=device)
z = torch.randn((2, 3), device=device)
return (Foo(), (x, y, z))
def test_python_wrapper(self):
f, args = self._generate(device=GPU_TYPE)
compiled = torch.compile(f, dynamic=True)
result, code = run_and_get_cpp_code(compiled, *args)
FileCheck().check(
"pool1 = empty_strided_"
+ GPU_TYPE
+ "((4*s27*s77 + align(4*s77*s77), ), (1, )"
).check_next(
"buf0 = alloc_from_pool(pool1, 0, torch.float32, (s77, s77), (s77, 1))"
).check(
"buf1 = alloc_from_pool(pool1, align(4*s77*s77),"
).run(
code
)
self.assertTrue(same(f(*args), result))
def test_cpp_wrapper(self):
f, args = self._generate(device=GPU_TYPE)
compiled = torch.compile(f, dynamic=True)
with config.patch({"cpp_wrapper": True}):
result, code = run_and_get_cpp_code(compiled, *args)
FileCheck().check(
"aoti_torch__alloc_from_pool(pool1, 0, cached_torch_dtype_float32, 2, int_array_4, int_array_5, &tmp_tensor_handle_0)"
).check_next("auto buf0 = RAIIAtenTensorHandle(tmp_tensor_handle_0);").check(
"auto buf1 = RAIIAtenTensorHandle(tmp_tensor_handle_1);"
).run(
code
)
self.assertTrue(same(f(*args), result))
@skipIfXpu(msg="aoti doesn't work on XPU")
def test_aoti(self):
try:
from .test_aot_inductor import AOTIRunnerUtil
except ImportError:
from test_aot_inductor import ( # @manual=fbcode//caffe2/test/inductor:test_aot_inductor-library
AOTIRunnerUtil,
)
f, args = self._generate(device=GPU_TYPE)
dim0_x = Dim("dim0_x", min=1, max=2048)
dynamic_shapes = ({0: dim0_x}, None, None)
result, code = run_and_get_cpp_code(
lambda: AOTIRunnerUtil.run(f, args, dynamic_shapes=dynamic_shapes)
)
FileCheck().check(
"int64_t int_array_2[] = {24L + align(12L*s77), };"
).check_next("int64_t int_array_3[] = {1L, };").check_next(
"AtenTensorHandle pool1_handle;"
).check_next(
"aoti_torch_empty_strided(1, int_array_2, int_array_3,"
).check_next(
"RAIIAtenTensorHandle pool1(pool1_handle);"
).check_next(
"int64_t int_array_4[] = {s77, 3L};"
).check_next(
"int64_t int_array_5[] = {3L, 1L};"
).check_next(
"AtenTensorHandle tmp_tensor_handle_0;"
).check_next(
"aoti_torch__alloc_from_pool(pool1, 0"
).run(
code
)
self.assertTrue(same(f(*args), result))
if __name__ == "__main__":
if HAS_GPU:
run_tests()