mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
A few things to note: 1. Customers like vllm use a custom backend (e.g. VllmBackend), split the graph, and call standalone_compile for each split. If we let the bisector override the backend, we won't bisect thru the custom backend. `test_configs.bisect_keep_custom_backend_for_inductor` is used to keep the custom backend if we are bisecting for inductor. 2. pre_grad_graph bisecting and lowering bisecting so far does not compose well with each other since an issue may be just captured by the first one we try. `test_configs.bisect_pre_grad_graph` is used to enable the 'pre_grad_graph' bisecting. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166344 Approved by: https://github.com/eellison
336 lines
11 KiB
Python
336 lines
11 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
|
|
from contextlib import contextmanager
|
|
from importlib import import_module
|
|
|
|
import torch
|
|
import torch._prims_common as utils
|
|
from torch._dynamo.utils import preserve_rng_state
|
|
from torch._inductor import config
|
|
from torch._inductor.compiler_bisector import CompilerBisector
|
|
from torch._inductor.test_case import TestCase
|
|
from torch.library import _scoped_library, Library
|
|
from torch.testing._internal.triton_utils import requires_cuda_and_triton
|
|
|
|
|
|
aten = torch.ops.aten
|
|
|
|
|
|
f32 = torch.float32
|
|
i64 = torch.int64
|
|
i32 = torch.int32
|
|
|
|
|
|
@requires_cuda_and_triton
|
|
class TestCompilerBisector(TestCase):
|
|
test_ns = "_test_bisector"
|
|
|
|
def tearDown(self):
|
|
if hasattr(torch.ops, self.test_ns):
|
|
delattr(torch.ops, self.test_ns)
|
|
if hasattr(self, "lib"):
|
|
del self.lib.m
|
|
del self.lib
|
|
|
|
def get_op(self, name):
|
|
return getattr(getattr(torch.ops, self.test_ns), name).default
|
|
|
|
def get_lib(self):
|
|
lib = Library(self.test_ns, "FRAGMENT") # noqa: TOR901
|
|
self.lib = lib
|
|
return lib
|
|
|
|
def test_bad_decomp(self):
|
|
import_module("torch._inductor.compile_fx")
|
|
|
|
def bad_exp_decomp(self, rate=1, generator=None):
|
|
assert generator is None
|
|
torch._check(
|
|
not utils.is_complex_dtype(self.dtype)
|
|
and not utils.is_integer_dtype(self.dtype)
|
|
and not utils.is_boolean_dtype(self.dtype),
|
|
lambda: f"Exponential distribution is a continuous probability distribution. \
|
|
dtype must be a floating point but you specified {self.dtype}",
|
|
)
|
|
torch._check(
|
|
rate > 0.0,
|
|
lambda: f"exponential_ expects lambda > 0.0, but found lambda={rate}",
|
|
)
|
|
return torch.rand_like(self) * float("nan")
|
|
|
|
@contextmanager
|
|
def patch_exp_decomp():
|
|
from torch._inductor.compile_fx import select_decomp_table as old_decomp
|
|
|
|
def get_decomp():
|
|
out = old_decomp()
|
|
out = out.copy()
|
|
out[aten.exponential.default] = bad_exp_decomp
|
|
return out
|
|
|
|
torch._inductor.compile_fx.select_decomp_table = get_decomp
|
|
try:
|
|
yield
|
|
|
|
finally:
|
|
torch._inductor.compile_fx.select_decomp_table = old_decomp
|
|
|
|
def vq(x):
|
|
return (x + 3).exponential_() * 10.5
|
|
|
|
def test_fn():
|
|
torch._dynamo.reset()
|
|
with patch_exp_decomp():
|
|
vq_compiled = torch.compile(vq)
|
|
x = torch.randn(4, 400, 256).cuda()
|
|
with torch._dynamo.utils.preserve_rng_state():
|
|
vq(x)
|
|
out_compiled = vq_compiled(x)
|
|
|
|
return not out_compiled.isnan().any()
|
|
|
|
out = CompilerBisector.do_bisect(test_fn)
|
|
self.assertEqual(out.backend, "aot_eager_decomp_partition")
|
|
self.assertEqual(out.subsystem, "decomposition")
|
|
self.assertEqual(out.bisect_number, 1)
|
|
self.assertTrue("aten.exponential" in out.debug_info)
|
|
|
|
def test_pre_grad(self):
|
|
import operator
|
|
|
|
from torch._inductor import config
|
|
|
|
# similar setup to test_joint_graph (see below)
|
|
def pass_fn(graph: torch.fx.Graph):
|
|
nodes = graph.find_nodes(op="call_function", target=operator.add)
|
|
assert len(nodes) == 1
|
|
args = list(nodes[0].args)
|
|
args[1] = 2
|
|
nodes[0].args = tuple(args)
|
|
|
|
config.pre_grad_custom_pass = pass_fn
|
|
|
|
def foo(x):
|
|
return x + 1
|
|
|
|
def test_fn():
|
|
torch._dynamo.reset()
|
|
|
|
inp = torch.rand([10])
|
|
|
|
out = foo(inp)
|
|
out_c = torch.compile(foo)(inp)
|
|
|
|
return torch.allclose(out, out_c)
|
|
|
|
out = CompilerBisector.do_bisect(test_fn)
|
|
self.assertEqual(out.backend, "inductor")
|
|
self.assertEqual(out.subsystem, "pre_grad_passes")
|
|
self.assertEqual(out.bisect_number, 0)
|
|
self.assertTrue("pre_grad_custom_pass" in out.debug_info)
|
|
|
|
def test_joint_graph(self):
|
|
from torch._inductor import config
|
|
|
|
def pass_fn(graph: torch.fx.Graph):
|
|
nodes = graph.find_nodes(
|
|
op="call_function", target=torch.ops.aten.add.Tensor
|
|
)
|
|
assert len(nodes) == 1
|
|
args = list(nodes[0].args)
|
|
args[1] = 2
|
|
nodes[0].args = tuple(args)
|
|
|
|
config.joint_custom_post_pass = pass_fn
|
|
|
|
def foo(x):
|
|
return x + 1
|
|
|
|
def test_fn():
|
|
torch._dynamo.reset()
|
|
|
|
inp = torch.rand([10], device="cuda")
|
|
|
|
out = foo(inp)
|
|
out_c = torch.compile(foo)(inp)
|
|
|
|
return torch.allclose(out, out_c)
|
|
|
|
out = CompilerBisector.do_bisect(test_fn)
|
|
self.assertEqual(out.backend, "inductor")
|
|
self.assertEqual(out.subsystem, "joint_graph_passes")
|
|
self.assertEqual(out.bisect_number, 4)
|
|
self.assertTrue("joint_custom_post_pass" in out.debug_info)
|
|
|
|
def test_rng(self):
|
|
def foo():
|
|
return torch.rand([10], device="cuda") + 1
|
|
|
|
def test_fn():
|
|
torch._dynamo.reset()
|
|
|
|
with preserve_rng_state():
|
|
out = foo()
|
|
with preserve_rng_state():
|
|
out_c = torch.compile(foo)()
|
|
|
|
return torch.allclose(out, out_c)
|
|
|
|
out = CompilerBisector.do_bisect(test_fn)
|
|
self.assertEqual(out.backend, "inductor")
|
|
self.assertEqual(out.subsystem, "inductor_fallback_random")
|
|
self.assertTrue("inductor_fallback_random" in out.debug_info)
|
|
|
|
def test_crossref(self):
|
|
with _scoped_library(self.test_ns, "FRAGMENT") as lib:
|
|
lib.define("foo(Tensor x) -> Tensor")
|
|
op = self.get_op("foo")
|
|
|
|
class Foo(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x):
|
|
# Emulate AutoDispatchBelowADInplaceOrView, which is not bound into python
|
|
with torch._C._AutoDispatchBelowAutograd():
|
|
with torch._C._ExcludeDispatchKeyGuard(
|
|
torch._C.DispatchKeySet(
|
|
torch._C.DispatchKey.ADInplaceOrView
|
|
)
|
|
):
|
|
return op(x)
|
|
|
|
@staticmethod
|
|
def backward(ctx, gx):
|
|
return gx
|
|
|
|
def foo_impl(x):
|
|
return x.view_as(x).clone()
|
|
|
|
def foo_meta(x):
|
|
return x.view_as(x)
|
|
|
|
lib.impl("foo", Foo.apply, "Autograd")
|
|
lib.impl("foo", foo_impl, "CPU")
|
|
lib.impl("foo", foo_meta, "Meta")
|
|
|
|
x = torch.tensor(3.14159 / 3, requires_grad=True)
|
|
|
|
def test_fn():
|
|
torch._dynamo.reset()
|
|
|
|
try:
|
|
torch.testing.assert_close(torch.compile(op)(x), op(x))
|
|
except Exception:
|
|
return False
|
|
return True
|
|
|
|
out = CompilerBisector.do_bisect(test_fn)
|
|
self.assertEqual(out.backend, "aot_eager_decomp_partition_crossref")
|
|
|
|
def test_emulate_precision_casts(self):
|
|
def test_fn():
|
|
torch._dynamo.reset()
|
|
|
|
def calculate_scale(inp):
|
|
amax = torch.abs(torch.max(inp))
|
|
scale = 448.0 / torch.clamp(amax, min=1e-12)
|
|
scale = scale.to(torch.float32)
|
|
return scale
|
|
|
|
dtype = torch.bfloat16
|
|
torch.manual_seed(0)
|
|
inp = torch.randn(16, 16, 768, dtype=dtype, device="cuda")
|
|
eager_scale = calculate_scale(inp)
|
|
compile_scale = torch.compile(calculate_scale)(inp)
|
|
|
|
return torch.equal(eager_scale, compile_scale)
|
|
|
|
out = CompilerBisector.do_bisect(test_fn)
|
|
self.assertEqual(out.backend, "inductor")
|
|
self.assertEqual(out.subsystem, "inductor_emulate_precision_casts")
|
|
|
|
def test_bad_lowering(self):
|
|
def test_fn():
|
|
torch._dynamo.reset()
|
|
with config.patch("triton.inject_relu_bug_TESTING_ONLY", "accuracy"):
|
|
|
|
def my_func(x):
|
|
return ((x * -1) - 0.01).relu()
|
|
|
|
inp = torch.rand([100], device="cuda")
|
|
|
|
return torch.allclose(torch.compile(my_func)(inp), my_func(inp))
|
|
|
|
out = CompilerBisector.do_bisect(test_fn)
|
|
self.assertEqual(out.backend, "inductor")
|
|
self.assertEqual(out.subsystem, "lowerings")
|
|
self.assertEqual(out.bisect_number, 2)
|
|
self.assertTrue("relu" in out.debug_info)
|
|
|
|
def test_eager_backend(self):
|
|
# should indicate problem with first backend
|
|
def test_fn():
|
|
return False
|
|
|
|
out = CompilerBisector.do_bisect(test_fn)
|
|
self.assertEqual(out.backend, "eager")
|
|
self.assertEqual(out.subsystem, None)
|
|
|
|
@config.patch(
|
|
{
|
|
"test_configs.bisect_pre_grad_graph": True,
|
|
"test_configs.bisect_keep_custom_backend_for_inductor": True,
|
|
}
|
|
)
|
|
def test_bisect_pre_grad_graph(self):
|
|
def f(x):
|
|
for i in range(5):
|
|
x = x + 1
|
|
return x.relu()
|
|
|
|
class MyBackend:
|
|
def __call__(self, gm, example_inputs):
|
|
node_idx = 0
|
|
|
|
def node_to_graph_id(node):
|
|
nonlocal node_idx
|
|
out = 0 if node_idx < 3 else 1
|
|
node_idx += 1
|
|
return out
|
|
|
|
split_gm = torch.fx.passes.split_module.split_module(
|
|
gm, None, node_to_graph_id, keep_original_order=True
|
|
)
|
|
|
|
for name, submod in split_gm.named_modules():
|
|
if "submod_" in name:
|
|
# the test case is simple enough that using
|
|
# the original example_inputs works for sub
|
|
# moule
|
|
submod.forward = torch._inductor.standalone_compile(
|
|
submod,
|
|
example_inputs,
|
|
dynamic_shapes="from_example_inputs",
|
|
options={},
|
|
)
|
|
|
|
return split_gm
|
|
|
|
def test_fn():
|
|
torch._dynamo.reset()
|
|
|
|
x = torch.randn(1024, device="cuda")
|
|
with config.patch("triton.inject_relu_bug_TESTING_ONLY", "accuracy"):
|
|
opt_f = torch.compile(f, backend=MyBackend())
|
|
return torch.allclose(opt_f(x), f(x))
|
|
|
|
out = CompilerBisector.do_bisect(test_fn)
|
|
self.assertEqual(out.backend, "inductor")
|
|
self.assertEqual(out.subsystem, "pre_grad_graph")
|
|
self.assertEqual(out.bisect_number, 1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|