From 4911b80b8ebda4fc5053b95f22bae2232af52abb Mon Sep 17 00:00:00 2001 From: Adnan Akhundov Date: Mon, 10 Jul 2023 02:20:23 -0700 Subject: [PATCH] [inductor] addmm + ReLU / GELU fusion pass (#104132) Summary: Add a new path in `post_grad.py` for replacing addmm + ReLU / GELU activation with the corresponding `_addmm_activation` call (with `use_gelu=False` or `True`, respectively). The replacement is done only on `max_autotune_gemm=False` and when the activation is fusible. Test Plan: $ python test/inductor/test_pattern_matcher.py -k test_addmm_activation -v (__main__.TestPaternMatcher.test_addmm_activation) ... /data/users/aakhundov/pytorch/torch/_inductor/compile_fx.py:128: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance. warnings.warn( Using FallbackKernel: aten._addmm_activation.default Using FallbackKernel: aten._addmm_activation.default /data/users/aakhundov/pytorch/torch/_dynamo/eval_frame.py:373: UserWarning: changing options to `torch.compile()` may require calling `torch._dynamo.reset()` to take effect warnings.warn( frames [('total', 1), ('ok', 1)] stats [('calls_captured', 2), ('unique_graphs', 1)] aot_autograd [('total', 1), ('ok', 1)] inductor [] ok ---------------------------------------------------------------------- Ran 1 test in 13.415s OK Reviewers: @eellison Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/104132 Approved by: https://github.com/eellison, https://github.com/jansel --- ...DecompTest.test_aten_core_operators.expect | 2 + ...asDecompTest.test_has_decomposition.expect | 2 - test/inductor/test_pattern_matcher.py | 48 ++++++++++++ test/inductor/test_select_algorithm.py | 2 +- torch/_decomp/decompositions.py | 20 +++++ torch/_inductor/compile_fx.py | 4 +- torch/_inductor/fx_passes/fuse_attention.py | 7 -- torch/_inductor/fx_passes/pad_mm.py | 10 --- torch/_inductor/fx_passes/post_grad.py | 76 ++++++++++++++++++- torch/_inductor/lowering.py | 1 + torch/_inductor/pattern_matcher.py | 25 ++++-- 11 files changed, 164 insertions(+), 33 deletions(-) diff --git a/test/expect/HasDecompTest.test_aten_core_operators.expect b/test/expect/HasDecompTest.test_aten_core_operators.expect index 89e54a184a9..41ec24da9c5 100644 --- a/test/expect/HasDecompTest.test_aten_core_operators.expect +++ b/test/expect/HasDecompTest.test_aten_core_operators.expect @@ -4,6 +4,8 @@ aten::__irshift__.Scalar aten::__irshift__.Tensor aten::_adaptive_avg_pool2d aten::_adaptive_avg_pool2d.out +aten::_addmm_activation +aten::_addmm_activation.out aten::_euclidean_dist.out aten::_fused_dropout aten::_fused_dropout.out diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index dd32e1a9ef2..e0a0467c1d1 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -18,8 +18,6 @@ aten::_add_relu.Tensor aten::_add_relu.out aten::_add_relu_.Scalar aten::_add_relu_.Tensor -aten::_addmm_activation -aten::_addmm_activation.out aten::_aminmax aten::_aminmax.dim aten::_aminmax.dim_out diff --git a/test/inductor/test_pattern_matcher.py b/test/inductor/test_pattern_matcher.py index aef07ea3571..21e5fdfc4a2 100644 --- a/test/inductor/test_pattern_matcher.py +++ b/test/inductor/test_pattern_matcher.py @@ -85,6 +85,54 @@ class TestPaternMatcher(TestCase): self.assertEqual(counters["inductor"]["pattern_matcher_count"], 2) self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 4) + def test_addmm_activation(self): + def fn_addmm_relu(input, mat1, mat2): + return torch.nn.functional.relu(torch.addmm(input, mat1, mat2)) + + def fn_addmm_gelu(input, mat1, mat2): + return torch.nn.functional.gelu(torch.addmm(input, mat1, mat2)) + + args = [ + torch.randn(20, device="cuda"), # input + torch.randn(10, 15, device="cuda"), # mat1 + torch.randn(15, 20, device="cuda"), # mat2 + ] + + for fn, atol in ( + (fn_addmm_relu, 1e-8), + # higher tolerance due to the "tanh" approximation + # in fused GELU epilogue vs. "none" without fusion + (fn_addmm_gelu, 1e-3), + ): + expected = fn(*args) + actual, (code,) = run_and_get_code(torch.compile(fn), *args) + torch.testing.assert_close(actual, expected, atol=atol, rtol=0) + self.assertTrue("_addmm_activation" in code) + + for fn in (fn_addmm_relu, fn_addmm_gelu): + counters.clear() + torch.compile( + fn, + # replacement disabled on max_autotune_gemm + options={"max_autotune_gemm": True}, + )(*args) + self.assertEqual(counters["inductor"]["pattern_matcher_count"], 0) + self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 0) + + args_not_replaced = [ + # addmm + activation with a rank-2 input + # is not fusable, hence not replaced + torch.randn(10, 20, device="cuda"), # input + torch.randn(10, 15, device="cuda"), # mat1 + torch.randn(15, 20, device="cuda"), # mat2 + ] + + for fn in (fn_addmm_relu, fn_addmm_gelu): + counters.clear() + torch.compile(fn)(*args_not_replaced) + self.assertEqual(counters["inductor"]["pattern_matcher_count"], 0) + self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 0) + def test_cat_mm(self): def fn(a, b, c): return torch.cat( diff --git a/test/inductor/test_select_algorithm.py b/test/inductor/test_select_algorithm.py index 56faaf6630c..e66005c5a72 100644 --- a/test/inductor/test_select_algorithm.py +++ b/test/inductor/test_select_algorithm.py @@ -62,7 +62,7 @@ class TestSelectAlgorithm(TestCase): foo( torch.randn(64, 32, device="cuda"), torch.randn(16, 32, device="cuda"), - torch.randn(16, device="cuda"), + torch.randn(1, 16, device="cuda"), ) # Autotuning checks correctness of each version self.check_counter(counters["inductor"]["select_algorithm_autotune"], 1) diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index eeb38bcfad7..ad94269a92d 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -1189,6 +1189,26 @@ def addmm(self: Tensor, mat1: Tensor, mat2: Tensor, beta: int = 1, alpha: int = return out + beta * self +@register_decomposition(aten._addmm_activation) +@out_wrapper() +@pw_cast_for_opmath +def _addmm_activation( + self: Tensor, + mat1: Tensor, + mat2: Tensor, + beta: int = 1, + alpha: int = 1, + use_gelu: bool = False, +): + out = addmm(self, mat1, mat2, beta, alpha) + if use_gelu: + if self.is_cuda: + return aten.gelu(out, approximate="tanh") + else: + return aten.gelu(out) + return aten.relu(out) + + @register_decomposition(aten.addmv) @out_wrapper() @pw_cast_for_opmath diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index 7f495460743..a35b4e1bff9 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -479,9 +479,7 @@ def fx_codegen_and_compile( with V.set_fake_mode(fake_mode): # has some issues with memory in training - locality_reorder = is_inference and config.reordering - - post_grad_passes(gm, locality_reorder=locality_reorder) + post_grad_passes(gm, is_inference=is_inference) V.debug.fx_graph_transformed(gm, example_inputs) with V.set_fake_mode(fake_mode): diff --git a/torch/_inductor/fx_passes/fuse_attention.py b/torch/_inductor/fx_passes/fuse_attention.py index afa511315bf..05712e0d91b 100644 --- a/torch/_inductor/fx_passes/fuse_attention.py +++ b/torch/_inductor/fx_passes/fuse_attention.py @@ -301,9 +301,6 @@ def _sfdp_scale_factor_check(scale_factor_op): @functools.lru_cache(None) def _sfdp_init(): - from ..._dynamo.utils import counters - - counters_ref = counters["inductor"].copy() from .joint_graph import patterns if torch.cuda.is_available(): @@ -415,7 +412,3 @@ def _sfdp_init(): extra_check=extra_check, scalar_workaround=workaround, ) - - counters[ - "inductor" - ] = counters_ref # clear view matches encountered during sdpa tracing diff --git a/torch/_inductor/fx_passes/pad_mm.py b/torch/_inductor/fx_passes/pad_mm.py index a4611f0a087..9ad33a7c60c 100644 --- a/torch/_inductor/fx_passes/pad_mm.py +++ b/torch/_inductor/fx_passes/pad_mm.py @@ -388,10 +388,6 @@ def pad_bmm(mat1, mat2, m_padded_length, k_padded_length, n_padded_length): @functools.lru_cache(None) def _pad_mm_init(): - from ..._dynamo.utils import counters - - counters_ref = counters["inductor"].copy() - from .joint_graph import patterns if torch.cuda.is_available(): @@ -415,8 +411,6 @@ def _pad_mm_init(): # 0.113377 is a "magic" value that lets us recover the lost input arg relationship rep = {"beta": 0.213377, "alpha": 0.113377} - counters_ref = counters["inductor"].copy() - for pattern, replacement, args, workaround, extra_check in [ ( mm_pattern, @@ -459,7 +453,3 @@ def _pad_mm_init(): extra_check=extra_check, scalar_workaround=workaround, ) - - counters[ - "inductor" - ] = counters_ref # clear view matches encountered during mm tracing diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py index 2acd2c3e0ba..ac548b74939 100644 --- a/torch/_inductor/fx_passes/post_grad.py +++ b/torch/_inductor/fx_passes/post_grad.py @@ -5,6 +5,7 @@ import operator import torch import torch._inductor as inductor + from .. import config, ir, pattern_matcher from ..lowering import lowerings as L @@ -15,6 +16,7 @@ from ..pattern_matcher import ( filter_nodes, get_arg_value, Ignored, + inference_graph, init_once_fakemode, KeywordArg, ListOf, @@ -22,6 +24,7 @@ from ..pattern_matcher import ( MULTIPLE, PatternMatcherPass, register_graph_pattern, + register_replacement, stable_topological_sort, ) from ..virtualized import V @@ -37,9 +40,11 @@ pass_patterns = [ PatternMatcherPass(), PatternMatcherPass(), ] +# patterns applied only in inference +inference_patterns = PatternMatcherPass() -def post_grad_passes(gm: torch.fx.GraphModule, locality_reorder: bool): +def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool): """ Passes that run on after grad. This is called once on the forwards graph and once on the backwards graph. @@ -50,7 +55,7 @@ def post_grad_passes(gm: torch.fx.GraphModule, locality_reorder: bool): # has some issues with mutation in inference mode gm.graph.eliminate_dead_code() - if locality_reorder: + if is_inference and config.reordering: reorder_for_locality(gm.graph) if config.pattern_matcher: @@ -58,6 +63,8 @@ def post_grad_passes(gm: torch.fx.GraphModule, locality_reorder: bool): for patterns in pass_patterns: patterns.apply(gm.graph) + if is_inference: + inference_patterns.apply(gm.graph) stable_topological_sort(gm.graph) gm.recompile() @@ -74,6 +81,7 @@ def lazy_init(): from .quantization import register_quantization_lowerings register_quantization_lowerings() + register_addmm_activation_replacement() def reorder_for_locality(graph: torch.fx.Graph): @@ -344,6 +352,70 @@ def addmm(match, mat1, mat2, inp): return L[aten.add](inp, L[aten.mm](mat1, mat2)) +def addmm_relu_pattern(input, mat1, mat2): + output = aten.addmm(input, mat1, mat2) + return aten.relu(output) + + +def addmm_relu_replacement(input, mat1, mat2): + return aten._addmm_activation(input, mat1, mat2, use_gelu=False) + + +def addmm_gelu_pattern(input, mat1, mat2): + output = aten.addmm(input, mat1, mat2) + return aten.gelu(output) + + +def addmm_gelu_replacement(input, mat1, mat2): + return aten._addmm_activation(input, mat1, mat2, use_gelu=True) + + +def should_replace_addmm_activation(match): + if config.max_autotune_gemm: + # keep addmm for tuning + return False + + input = match.kwargs["input"].meta["val"] + # conditions of epilogue fusion in _addmm_activation + return input.is_cuda and input.dim() == 1 and input.is_contiguous() + + +def register_addmm_activation_replacement(): + if torch.cuda.is_available(): + # workaround https://github.com/pytorch/pytorch/issues/97894 + device = "cuda" + else: + device = "cpu" + + # sizes/values dont actually matter for initial trace + # once we get a possible match we re-trace with the actual values and verify the match still holds + + inp = functools.partial(torch.empty, (5,), device=device) + mat1 = functools.partial(torch.empty, (3, 4), device=device) + mat2 = functools.partial(torch.empty, (4, 5), device=device) + + for pattern, replacement, args in [ + ( + addmm_relu_pattern, + addmm_relu_replacement, + [inp(), mat1(), mat2()], + ), + ( + addmm_gelu_pattern, + addmm_gelu_replacement, + [inp(), mat1(), mat2()], + ), + ]: + register_replacement( + pattern, + replacement, + args, + inference_graph, + inference_patterns, + extra_check=should_replace_addmm_activation, + ) + + def is_valid_splitwithsizes_cat(match): split_nodes = filter_nodes(match.nodes, aten.split_with_sizes) cat_nodes = filter_nodes(match.nodes, aten.cat) diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 467861414a2..9dfcaa38d08 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -1670,6 +1670,7 @@ make_fallback(aten.adaptive_max_pool2d) make_fallback(aten.adaptive_max_pool3d) make_fallback(aten.addbmm) make_fallback(aten.addmv, warn=False) +make_fallback(aten._addmm_activation, warn=False) make_fallback(aten.avg_pool3d) make_fallback(aten.block_diag) make_fallback(aten._cdist_forward) diff --git a/torch/_inductor/pattern_matcher.py b/torch/_inductor/pattern_matcher.py index 2d0ed209272..409ad112c3f 100644 --- a/torch/_inductor/pattern_matcher.py +++ b/torch/_inductor/pattern_matcher.py @@ -17,6 +17,7 @@ from torch._prims_common import is_integer_dtype from torch.fx import Node from torch.fx.experimental.proxy_tensor import make_fx, maybe_disable_fake_tensor_mode from torch.fx.immutable_collections import immutable_dict, immutable_list + from .._functorch import config as functorch_config from .._functorch.aot_autograd import aot_function, make_boxed_func from .._functorch.partitioners import default_partition @@ -726,13 +727,14 @@ def register_replacement( if grad and is_integer_dtype(args[i].dtype): return False - args[i] = torch.empty_strided( - args[i].size(), - args[i].stride(), - dtype=args[i].dtype, - device=args[i].device, - requires_grad=grad, - ) + with torch._dynamo.utils.detect_fake_mode(args): + args[i] = torch.empty_strided( + args[i].size(), + args[i].stride(), + dtype=args[i].dtype, + device=args[i].device, + requires_grad=grad, + ) specific_graph = trace_fn(search_fn, args) specific_pattern = fx_to_pattern( specific_graph, argnames=argnames, exclusive_arg_names=exclusive_arg_names @@ -1070,10 +1072,17 @@ def init_once_fakemode(fn): @functools.lru_cache(None) @functools.wraps(fn) def lazy_init(): + counters_ref = counters["inductor"].copy() + with torch._guards.tracing( None ), maybe_disable_fake_tensor_mode(), FakeTensorMode(): - return fn() + result = fn() + + # clear view matches encountered during tracing + counters["inductor"] = counters_ref + + return result return lazy_init