[inductor] addmm + ReLU / GELU fusion pass (#104132)

Summary:

Add a new path in `post_grad.py` for replacing addmm + ReLU / GELU activation with the corresponding `_addmm_activation` call (with `use_gelu=False` or `True`, respectively). The replacement is done only on `max_autotune_gemm=False` and when the activation is fusible.

Test Plan:

$ python test/inductor/test_pattern_matcher.py -k test_addmm_activation -v

(__main__.TestPaternMatcher.test_addmm_activation) ... /data/users/aakhundov/pytorch/torch/_inductor/compile_fx.py:128: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
  warnings.warn(
Using FallbackKernel: aten._addmm_activation.default
Using FallbackKernel: aten._addmm_activation.default
/data/users/aakhundov/pytorch/torch/_dynamo/eval_frame.py:373: UserWarning: changing options to `torch.compile()` may require calling `torch._dynamo.reset()` to take effect
  warnings.warn(
frames [('total', 1), ('ok', 1)]
stats [('calls_captured', 2), ('unique_graphs', 1)]
aot_autograd [('total', 1), ('ok', 1)]
inductor []
ok

----------------------------------------------------------------------
Ran 1 test in 13.415s

OK

Reviewers: @eellison

Subscribers:

Tasks:

Tags:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104132
Approved by: https://github.com/eellison, https://github.com/jansel
This commit is contained in:
Adnan Akhundov 2023-07-10 02:20:23 -07:00 committed by PyTorch MergeBot
parent 7166df8094
commit 4911b80b8e
11 changed files with 164 additions and 33 deletions

View File

@ -4,6 +4,8 @@ aten::__irshift__.Scalar
aten::__irshift__.Tensor
aten::_adaptive_avg_pool2d
aten::_adaptive_avg_pool2d.out
aten::_addmm_activation
aten::_addmm_activation.out
aten::_euclidean_dist.out
aten::_fused_dropout
aten::_fused_dropout.out

View File

@ -18,8 +18,6 @@ aten::_add_relu.Tensor
aten::_add_relu.out
aten::_add_relu_.Scalar
aten::_add_relu_.Tensor
aten::_addmm_activation
aten::_addmm_activation.out
aten::_aminmax
aten::_aminmax.dim
aten::_aminmax.dim_out

View File

@ -85,6 +85,54 @@ class TestPaternMatcher(TestCase):
self.assertEqual(counters["inductor"]["pattern_matcher_count"], 2)
self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 4)
def test_addmm_activation(self):
def fn_addmm_relu(input, mat1, mat2):
return torch.nn.functional.relu(torch.addmm(input, mat1, mat2))
def fn_addmm_gelu(input, mat1, mat2):
return torch.nn.functional.gelu(torch.addmm(input, mat1, mat2))
args = [
torch.randn(20, device="cuda"), # input
torch.randn(10, 15, device="cuda"), # mat1
torch.randn(15, 20, device="cuda"), # mat2
]
for fn, atol in (
(fn_addmm_relu, 1e-8),
# higher tolerance due to the "tanh" approximation
# in fused GELU epilogue vs. "none" without fusion
(fn_addmm_gelu, 1e-3),
):
expected = fn(*args)
actual, (code,) = run_and_get_code(torch.compile(fn), *args)
torch.testing.assert_close(actual, expected, atol=atol, rtol=0)
self.assertTrue("_addmm_activation" in code)
for fn in (fn_addmm_relu, fn_addmm_gelu):
counters.clear()
torch.compile(
fn,
# replacement disabled on max_autotune_gemm
options={"max_autotune_gemm": True},
)(*args)
self.assertEqual(counters["inductor"]["pattern_matcher_count"], 0)
self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 0)
args_not_replaced = [
# addmm + activation with a rank-2 input
# is not fusable, hence not replaced
torch.randn(10, 20, device="cuda"), # input
torch.randn(10, 15, device="cuda"), # mat1
torch.randn(15, 20, device="cuda"), # mat2
]
for fn in (fn_addmm_relu, fn_addmm_gelu):
counters.clear()
torch.compile(fn)(*args_not_replaced)
self.assertEqual(counters["inductor"]["pattern_matcher_count"], 0)
self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 0)
def test_cat_mm(self):
def fn(a, b, c):
return torch.cat(

View File

@ -62,7 +62,7 @@ class TestSelectAlgorithm(TestCase):
foo(
torch.randn(64, 32, device="cuda"),
torch.randn(16, 32, device="cuda"),
torch.randn(16, device="cuda"),
torch.randn(1, 16, device="cuda"),
)
# Autotuning checks correctness of each version
self.check_counter(counters["inductor"]["select_algorithm_autotune"], 1)

View File

@ -1189,6 +1189,26 @@ def addmm(self: Tensor, mat1: Tensor, mat2: Tensor, beta: int = 1, alpha: int =
return out + beta * self
@register_decomposition(aten._addmm_activation)
@out_wrapper()
@pw_cast_for_opmath
def _addmm_activation(
self: Tensor,
mat1: Tensor,
mat2: Tensor,
beta: int = 1,
alpha: int = 1,
use_gelu: bool = False,
):
out = addmm(self, mat1, mat2, beta, alpha)
if use_gelu:
if self.is_cuda:
return aten.gelu(out, approximate="tanh")
else:
return aten.gelu(out)
return aten.relu(out)
@register_decomposition(aten.addmv)
@out_wrapper()
@pw_cast_for_opmath

View File

@ -479,9 +479,7 @@ def fx_codegen_and_compile(
with V.set_fake_mode(fake_mode):
# has some issues with memory in training
locality_reorder = is_inference and config.reordering
post_grad_passes(gm, locality_reorder=locality_reorder)
post_grad_passes(gm, is_inference=is_inference)
V.debug.fx_graph_transformed(gm, example_inputs)
with V.set_fake_mode(fake_mode):

View File

@ -301,9 +301,6 @@ def _sfdp_scale_factor_check(scale_factor_op):
@functools.lru_cache(None)
def _sfdp_init():
from ..._dynamo.utils import counters
counters_ref = counters["inductor"].copy()
from .joint_graph import patterns
if torch.cuda.is_available():
@ -415,7 +412,3 @@ def _sfdp_init():
extra_check=extra_check,
scalar_workaround=workaround,
)
counters[
"inductor"
] = counters_ref # clear view matches encountered during sdpa tracing

View File

@ -388,10 +388,6 @@ def pad_bmm(mat1, mat2, m_padded_length, k_padded_length, n_padded_length):
@functools.lru_cache(None)
def _pad_mm_init():
from ..._dynamo.utils import counters
counters_ref = counters["inductor"].copy()
from .joint_graph import patterns
if torch.cuda.is_available():
@ -415,8 +411,6 @@ def _pad_mm_init():
# 0.113377 is a "magic" value that lets us recover the lost input arg relationship
rep = {"beta": 0.213377, "alpha": 0.113377}
counters_ref = counters["inductor"].copy()
for pattern, replacement, args, workaround, extra_check in [
(
mm_pattern,
@ -459,7 +453,3 @@ def _pad_mm_init():
extra_check=extra_check,
scalar_workaround=workaround,
)
counters[
"inductor"
] = counters_ref # clear view matches encountered during mm tracing

View File

@ -5,6 +5,7 @@ import operator
import torch
import torch._inductor as inductor
from .. import config, ir, pattern_matcher
from ..lowering import lowerings as L
@ -15,6 +16,7 @@ from ..pattern_matcher import (
filter_nodes,
get_arg_value,
Ignored,
inference_graph,
init_once_fakemode,
KeywordArg,
ListOf,
@ -22,6 +24,7 @@ from ..pattern_matcher import (
MULTIPLE,
PatternMatcherPass,
register_graph_pattern,
register_replacement,
stable_topological_sort,
)
from ..virtualized import V
@ -37,9 +40,11 @@ pass_patterns = [
PatternMatcherPass(),
PatternMatcherPass(),
]
# patterns applied only in inference
inference_patterns = PatternMatcherPass()
def post_grad_passes(gm: torch.fx.GraphModule, locality_reorder: bool):
def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
"""
Passes that run on after grad. This is called once on the forwards
graph and once on the backwards graph.
@ -50,7 +55,7 @@ def post_grad_passes(gm: torch.fx.GraphModule, locality_reorder: bool):
# has some issues with mutation in inference mode
gm.graph.eliminate_dead_code()
if locality_reorder:
if is_inference and config.reordering:
reorder_for_locality(gm.graph)
if config.pattern_matcher:
@ -58,6 +63,8 @@ def post_grad_passes(gm: torch.fx.GraphModule, locality_reorder: bool):
for patterns in pass_patterns:
patterns.apply(gm.graph)
if is_inference:
inference_patterns.apply(gm.graph)
stable_topological_sort(gm.graph)
gm.recompile()
@ -74,6 +81,7 @@ def lazy_init():
from .quantization import register_quantization_lowerings
register_quantization_lowerings()
register_addmm_activation_replacement()
def reorder_for_locality(graph: torch.fx.Graph):
@ -344,6 +352,70 @@ def addmm(match, mat1, mat2, inp):
return L[aten.add](inp, L[aten.mm](mat1, mat2))
def addmm_relu_pattern(input, mat1, mat2):
output = aten.addmm(input, mat1, mat2)
return aten.relu(output)
def addmm_relu_replacement(input, mat1, mat2):
return aten._addmm_activation(input, mat1, mat2, use_gelu=False)
def addmm_gelu_pattern(input, mat1, mat2):
output = aten.addmm(input, mat1, mat2)
return aten.gelu(output)
def addmm_gelu_replacement(input, mat1, mat2):
return aten._addmm_activation(input, mat1, mat2, use_gelu=True)
def should_replace_addmm_activation(match):
if config.max_autotune_gemm:
# keep addmm for tuning
return False
input = match.kwargs["input"].meta["val"]
# conditions of epilogue fusion in _addmm_activation
return input.is_cuda and input.dim() == 1 and input.is_contiguous()
def register_addmm_activation_replacement():
if torch.cuda.is_available():
# workaround https://github.com/pytorch/pytorch/issues/97894
device = "cuda"
else:
device = "cpu"
# sizes/values dont actually matter for initial trace
# once we get a possible match we re-trace with the actual values and verify the match still holds
inp = functools.partial(torch.empty, (5,), device=device)
mat1 = functools.partial(torch.empty, (3, 4), device=device)
mat2 = functools.partial(torch.empty, (4, 5), device=device)
for pattern, replacement, args in [
(
addmm_relu_pattern,
addmm_relu_replacement,
[inp(), mat1(), mat2()],
),
(
addmm_gelu_pattern,
addmm_gelu_replacement,
[inp(), mat1(), mat2()],
),
]:
register_replacement(
pattern,
replacement,
args,
inference_graph,
inference_patterns,
extra_check=should_replace_addmm_activation,
)
def is_valid_splitwithsizes_cat(match):
split_nodes = filter_nodes(match.nodes, aten.split_with_sizes)
cat_nodes = filter_nodes(match.nodes, aten.cat)

View File

@ -1670,6 +1670,7 @@ make_fallback(aten.adaptive_max_pool2d)
make_fallback(aten.adaptive_max_pool3d)
make_fallback(aten.addbmm)
make_fallback(aten.addmv, warn=False)
make_fallback(aten._addmm_activation, warn=False)
make_fallback(aten.avg_pool3d)
make_fallback(aten.block_diag)
make_fallback(aten._cdist_forward)

View File

@ -17,6 +17,7 @@ from torch._prims_common import is_integer_dtype
from torch.fx import Node
from torch.fx.experimental.proxy_tensor import make_fx, maybe_disable_fake_tensor_mode
from torch.fx.immutable_collections import immutable_dict, immutable_list
from .._functorch import config as functorch_config
from .._functorch.aot_autograd import aot_function, make_boxed_func
from .._functorch.partitioners import default_partition
@ -726,6 +727,7 @@ def register_replacement(
if grad and is_integer_dtype(args[i].dtype):
return False
with torch._dynamo.utils.detect_fake_mode(args):
args[i] = torch.empty_strided(
args[i].size(),
args[i].stride(),
@ -1070,10 +1072,17 @@ def init_once_fakemode(fn):
@functools.lru_cache(None)
@functools.wraps(fn)
def lazy_init():
counters_ref = counters["inductor"].copy()
with torch._guards.tracing(
None
), maybe_disable_fake_tensor_mode(), FakeTensorMode():
return fn()
result = fn()
# clear view matches encountered during tracing
counters["inductor"] = counters_ref
return result
return lazy_init