[inductor] addmm + ReLU / GELU fusion pass (#104132)

Summary:

Add a new path in `post_grad.py` for replacing addmm + ReLU / GELU activation with the corresponding `_addmm_activation` call (with `use_gelu=False` or `True`, respectively). The replacement is done only on `max_autotune_gemm=False` and when the activation is fusible.

Test Plan:

$ python test/inductor/test_pattern_matcher.py -k test_addmm_activation -v

(__main__.TestPaternMatcher.test_addmm_activation) ... /data/users/aakhundov/pytorch/torch/_inductor/compile_fx.py:128: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
  warnings.warn(
Using FallbackKernel: aten._addmm_activation.default
Using FallbackKernel: aten._addmm_activation.default
/data/users/aakhundov/pytorch/torch/_dynamo/eval_frame.py:373: UserWarning: changing options to `torch.compile()` may require calling `torch._dynamo.reset()` to take effect
  warnings.warn(
frames [('total', 1), ('ok', 1)]
stats [('calls_captured', 2), ('unique_graphs', 1)]
aot_autograd [('total', 1), ('ok', 1)]
inductor []
ok

----------------------------------------------------------------------
Ran 1 test in 13.415s

OK

Reviewers: @eellison

Subscribers:

Tasks:

Tags:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104132
Approved by: https://github.com/eellison, https://github.com/jansel
This commit is contained in:
Adnan Akhundov 2023-07-10 02:20:23 -07:00 committed by PyTorch MergeBot
parent 7166df8094
commit 4911b80b8e
11 changed files with 164 additions and 33 deletions

View File

@ -4,6 +4,8 @@ aten::__irshift__.Scalar
aten::__irshift__.Tensor aten::__irshift__.Tensor
aten::_adaptive_avg_pool2d aten::_adaptive_avg_pool2d
aten::_adaptive_avg_pool2d.out aten::_adaptive_avg_pool2d.out
aten::_addmm_activation
aten::_addmm_activation.out
aten::_euclidean_dist.out aten::_euclidean_dist.out
aten::_fused_dropout aten::_fused_dropout
aten::_fused_dropout.out aten::_fused_dropout.out

View File

@ -18,8 +18,6 @@ aten::_add_relu.Tensor
aten::_add_relu.out aten::_add_relu.out
aten::_add_relu_.Scalar aten::_add_relu_.Scalar
aten::_add_relu_.Tensor aten::_add_relu_.Tensor
aten::_addmm_activation
aten::_addmm_activation.out
aten::_aminmax aten::_aminmax
aten::_aminmax.dim aten::_aminmax.dim
aten::_aminmax.dim_out aten::_aminmax.dim_out

View File

@ -85,6 +85,54 @@ class TestPaternMatcher(TestCase):
self.assertEqual(counters["inductor"]["pattern_matcher_count"], 2) self.assertEqual(counters["inductor"]["pattern_matcher_count"], 2)
self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 4) self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 4)
def test_addmm_activation(self):
def fn_addmm_relu(input, mat1, mat2):
return torch.nn.functional.relu(torch.addmm(input, mat1, mat2))
def fn_addmm_gelu(input, mat1, mat2):
return torch.nn.functional.gelu(torch.addmm(input, mat1, mat2))
args = [
torch.randn(20, device="cuda"), # input
torch.randn(10, 15, device="cuda"), # mat1
torch.randn(15, 20, device="cuda"), # mat2
]
for fn, atol in (
(fn_addmm_relu, 1e-8),
# higher tolerance due to the "tanh" approximation
# in fused GELU epilogue vs. "none" without fusion
(fn_addmm_gelu, 1e-3),
):
expected = fn(*args)
actual, (code,) = run_and_get_code(torch.compile(fn), *args)
torch.testing.assert_close(actual, expected, atol=atol, rtol=0)
self.assertTrue("_addmm_activation" in code)
for fn in (fn_addmm_relu, fn_addmm_gelu):
counters.clear()
torch.compile(
fn,
# replacement disabled on max_autotune_gemm
options={"max_autotune_gemm": True},
)(*args)
self.assertEqual(counters["inductor"]["pattern_matcher_count"], 0)
self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 0)
args_not_replaced = [
# addmm + activation with a rank-2 input
# is not fusable, hence not replaced
torch.randn(10, 20, device="cuda"), # input
torch.randn(10, 15, device="cuda"), # mat1
torch.randn(15, 20, device="cuda"), # mat2
]
for fn in (fn_addmm_relu, fn_addmm_gelu):
counters.clear()
torch.compile(fn)(*args_not_replaced)
self.assertEqual(counters["inductor"]["pattern_matcher_count"], 0)
self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 0)
def test_cat_mm(self): def test_cat_mm(self):
def fn(a, b, c): def fn(a, b, c):
return torch.cat( return torch.cat(

View File

@ -62,7 +62,7 @@ class TestSelectAlgorithm(TestCase):
foo( foo(
torch.randn(64, 32, device="cuda"), torch.randn(64, 32, device="cuda"),
torch.randn(16, 32, device="cuda"), torch.randn(16, 32, device="cuda"),
torch.randn(16, device="cuda"), torch.randn(1, 16, device="cuda"),
) )
# Autotuning checks correctness of each version # Autotuning checks correctness of each version
self.check_counter(counters["inductor"]["select_algorithm_autotune"], 1) self.check_counter(counters["inductor"]["select_algorithm_autotune"], 1)

View File

@ -1189,6 +1189,26 @@ def addmm(self: Tensor, mat1: Tensor, mat2: Tensor, beta: int = 1, alpha: int =
return out + beta * self return out + beta * self
@register_decomposition(aten._addmm_activation)
@out_wrapper()
@pw_cast_for_opmath
def _addmm_activation(
self: Tensor,
mat1: Tensor,
mat2: Tensor,
beta: int = 1,
alpha: int = 1,
use_gelu: bool = False,
):
out = addmm(self, mat1, mat2, beta, alpha)
if use_gelu:
if self.is_cuda:
return aten.gelu(out, approximate="tanh")
else:
return aten.gelu(out)
return aten.relu(out)
@register_decomposition(aten.addmv) @register_decomposition(aten.addmv)
@out_wrapper() @out_wrapper()
@pw_cast_for_opmath @pw_cast_for_opmath

View File

@ -479,9 +479,7 @@ def fx_codegen_and_compile(
with V.set_fake_mode(fake_mode): with V.set_fake_mode(fake_mode):
# has some issues with memory in training # has some issues with memory in training
locality_reorder = is_inference and config.reordering post_grad_passes(gm, is_inference=is_inference)
post_grad_passes(gm, locality_reorder=locality_reorder)
V.debug.fx_graph_transformed(gm, example_inputs) V.debug.fx_graph_transformed(gm, example_inputs)
with V.set_fake_mode(fake_mode): with V.set_fake_mode(fake_mode):

View File

@ -301,9 +301,6 @@ def _sfdp_scale_factor_check(scale_factor_op):
@functools.lru_cache(None) @functools.lru_cache(None)
def _sfdp_init(): def _sfdp_init():
from ..._dynamo.utils import counters
counters_ref = counters["inductor"].copy()
from .joint_graph import patterns from .joint_graph import patterns
if torch.cuda.is_available(): if torch.cuda.is_available():
@ -415,7 +412,3 @@ def _sfdp_init():
extra_check=extra_check, extra_check=extra_check,
scalar_workaround=workaround, scalar_workaround=workaround,
) )
counters[
"inductor"
] = counters_ref # clear view matches encountered during sdpa tracing

View File

@ -388,10 +388,6 @@ def pad_bmm(mat1, mat2, m_padded_length, k_padded_length, n_padded_length):
@functools.lru_cache(None) @functools.lru_cache(None)
def _pad_mm_init(): def _pad_mm_init():
from ..._dynamo.utils import counters
counters_ref = counters["inductor"].copy()
from .joint_graph import patterns from .joint_graph import patterns
if torch.cuda.is_available(): if torch.cuda.is_available():
@ -415,8 +411,6 @@ def _pad_mm_init():
# 0.113377 is a "magic" value that lets us recover the lost input arg relationship # 0.113377 is a "magic" value that lets us recover the lost input arg relationship
rep = {"beta": 0.213377, "alpha": 0.113377} rep = {"beta": 0.213377, "alpha": 0.113377}
counters_ref = counters["inductor"].copy()
for pattern, replacement, args, workaround, extra_check in [ for pattern, replacement, args, workaround, extra_check in [
( (
mm_pattern, mm_pattern,
@ -459,7 +453,3 @@ def _pad_mm_init():
extra_check=extra_check, extra_check=extra_check,
scalar_workaround=workaround, scalar_workaround=workaround,
) )
counters[
"inductor"
] = counters_ref # clear view matches encountered during mm tracing

View File

@ -5,6 +5,7 @@ import operator
import torch import torch
import torch._inductor as inductor import torch._inductor as inductor
from .. import config, ir, pattern_matcher from .. import config, ir, pattern_matcher
from ..lowering import lowerings as L from ..lowering import lowerings as L
@ -15,6 +16,7 @@ from ..pattern_matcher import (
filter_nodes, filter_nodes,
get_arg_value, get_arg_value,
Ignored, Ignored,
inference_graph,
init_once_fakemode, init_once_fakemode,
KeywordArg, KeywordArg,
ListOf, ListOf,
@ -22,6 +24,7 @@ from ..pattern_matcher import (
MULTIPLE, MULTIPLE,
PatternMatcherPass, PatternMatcherPass,
register_graph_pattern, register_graph_pattern,
register_replacement,
stable_topological_sort, stable_topological_sort,
) )
from ..virtualized import V from ..virtualized import V
@ -37,9 +40,11 @@ pass_patterns = [
PatternMatcherPass(), PatternMatcherPass(),
PatternMatcherPass(), PatternMatcherPass(),
] ]
# patterns applied only in inference
inference_patterns = PatternMatcherPass()
def post_grad_passes(gm: torch.fx.GraphModule, locality_reorder: bool): def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
""" """
Passes that run on after grad. This is called once on the forwards Passes that run on after grad. This is called once on the forwards
graph and once on the backwards graph. graph and once on the backwards graph.
@ -50,7 +55,7 @@ def post_grad_passes(gm: torch.fx.GraphModule, locality_reorder: bool):
# has some issues with mutation in inference mode # has some issues with mutation in inference mode
gm.graph.eliminate_dead_code() gm.graph.eliminate_dead_code()
if locality_reorder: if is_inference and config.reordering:
reorder_for_locality(gm.graph) reorder_for_locality(gm.graph)
if config.pattern_matcher: if config.pattern_matcher:
@ -58,6 +63,8 @@ def post_grad_passes(gm: torch.fx.GraphModule, locality_reorder: bool):
for patterns in pass_patterns: for patterns in pass_patterns:
patterns.apply(gm.graph) patterns.apply(gm.graph)
if is_inference:
inference_patterns.apply(gm.graph)
stable_topological_sort(gm.graph) stable_topological_sort(gm.graph)
gm.recompile() gm.recompile()
@ -74,6 +81,7 @@ def lazy_init():
from .quantization import register_quantization_lowerings from .quantization import register_quantization_lowerings
register_quantization_lowerings() register_quantization_lowerings()
register_addmm_activation_replacement()
def reorder_for_locality(graph: torch.fx.Graph): def reorder_for_locality(graph: torch.fx.Graph):
@ -344,6 +352,70 @@ def addmm(match, mat1, mat2, inp):
return L[aten.add](inp, L[aten.mm](mat1, mat2)) return L[aten.add](inp, L[aten.mm](mat1, mat2))
def addmm_relu_pattern(input, mat1, mat2):
output = aten.addmm(input, mat1, mat2)
return aten.relu(output)
def addmm_relu_replacement(input, mat1, mat2):
return aten._addmm_activation(input, mat1, mat2, use_gelu=False)
def addmm_gelu_pattern(input, mat1, mat2):
output = aten.addmm(input, mat1, mat2)
return aten.gelu(output)
def addmm_gelu_replacement(input, mat1, mat2):
return aten._addmm_activation(input, mat1, mat2, use_gelu=True)
def should_replace_addmm_activation(match):
if config.max_autotune_gemm:
# keep addmm for tuning
return False
input = match.kwargs["input"].meta["val"]
# conditions of epilogue fusion in _addmm_activation
return input.is_cuda and input.dim() == 1 and input.is_contiguous()
def register_addmm_activation_replacement():
if torch.cuda.is_available():
# workaround https://github.com/pytorch/pytorch/issues/97894
device = "cuda"
else:
device = "cpu"
# sizes/values dont actually matter for initial trace
# once we get a possible match we re-trace with the actual values and verify the match still holds
inp = functools.partial(torch.empty, (5,), device=device)
mat1 = functools.partial(torch.empty, (3, 4), device=device)
mat2 = functools.partial(torch.empty, (4, 5), device=device)
for pattern, replacement, args in [
(
addmm_relu_pattern,
addmm_relu_replacement,
[inp(), mat1(), mat2()],
),
(
addmm_gelu_pattern,
addmm_gelu_replacement,
[inp(), mat1(), mat2()],
),
]:
register_replacement(
pattern,
replacement,
args,
inference_graph,
inference_patterns,
extra_check=should_replace_addmm_activation,
)
def is_valid_splitwithsizes_cat(match): def is_valid_splitwithsizes_cat(match):
split_nodes = filter_nodes(match.nodes, aten.split_with_sizes) split_nodes = filter_nodes(match.nodes, aten.split_with_sizes)
cat_nodes = filter_nodes(match.nodes, aten.cat) cat_nodes = filter_nodes(match.nodes, aten.cat)

View File

@ -1670,6 +1670,7 @@ make_fallback(aten.adaptive_max_pool2d)
make_fallback(aten.adaptive_max_pool3d) make_fallback(aten.adaptive_max_pool3d)
make_fallback(aten.addbmm) make_fallback(aten.addbmm)
make_fallback(aten.addmv, warn=False) make_fallback(aten.addmv, warn=False)
make_fallback(aten._addmm_activation, warn=False)
make_fallback(aten.avg_pool3d) make_fallback(aten.avg_pool3d)
make_fallback(aten.block_diag) make_fallback(aten.block_diag)
make_fallback(aten._cdist_forward) make_fallback(aten._cdist_forward)

View File

@ -17,6 +17,7 @@ from torch._prims_common import is_integer_dtype
from torch.fx import Node from torch.fx import Node
from torch.fx.experimental.proxy_tensor import make_fx, maybe_disable_fake_tensor_mode from torch.fx.experimental.proxy_tensor import make_fx, maybe_disable_fake_tensor_mode
from torch.fx.immutable_collections import immutable_dict, immutable_list from torch.fx.immutable_collections import immutable_dict, immutable_list
from .._functorch import config as functorch_config from .._functorch import config as functorch_config
from .._functorch.aot_autograd import aot_function, make_boxed_func from .._functorch.aot_autograd import aot_function, make_boxed_func
from .._functorch.partitioners import default_partition from .._functorch.partitioners import default_partition
@ -726,6 +727,7 @@ def register_replacement(
if grad and is_integer_dtype(args[i].dtype): if grad and is_integer_dtype(args[i].dtype):
return False return False
with torch._dynamo.utils.detect_fake_mode(args):
args[i] = torch.empty_strided( args[i] = torch.empty_strided(
args[i].size(), args[i].size(),
args[i].stride(), args[i].stride(),
@ -1070,10 +1072,17 @@ def init_once_fakemode(fn):
@functools.lru_cache(None) @functools.lru_cache(None)
@functools.wraps(fn) @functools.wraps(fn)
def lazy_init(): def lazy_init():
counters_ref = counters["inductor"].copy()
with torch._guards.tracing( with torch._guards.tracing(
None None
), maybe_disable_fake_tensor_mode(), FakeTensorMode(): ), maybe_disable_fake_tensor_mode(), FakeTensorMode():
return fn() result = fn()
# clear view matches encountered during tracing
counters["inductor"] = counters_ref
return result
return lazy_init return lazy_init