mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[inductor] addmm + ReLU / GELU fusion pass (#104132)
Summary:
Add a new path in `post_grad.py` for replacing addmm + ReLU / GELU activation with the corresponding `_addmm_activation` call (with `use_gelu=False` or `True`, respectively). The replacement is done only on `max_autotune_gemm=False` and when the activation is fusible.
Test Plan:
$ python test/inductor/test_pattern_matcher.py -k test_addmm_activation -v
(__main__.TestPaternMatcher.test_addmm_activation) ... /data/users/aakhundov/pytorch/torch/_inductor/compile_fx.py:128: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
warnings.warn(
Using FallbackKernel: aten._addmm_activation.default
Using FallbackKernel: aten._addmm_activation.default
/data/users/aakhundov/pytorch/torch/_dynamo/eval_frame.py:373: UserWarning: changing options to `torch.compile()` may require calling `torch._dynamo.reset()` to take effect
warnings.warn(
frames [('total', 1), ('ok', 1)]
stats [('calls_captured', 2), ('unique_graphs', 1)]
aot_autograd [('total', 1), ('ok', 1)]
inductor []
ok
----------------------------------------------------------------------
Ran 1 test in 13.415s
OK
Reviewers: @eellison
Subscribers:
Tasks:
Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104132
Approved by: https://github.com/eellison, https://github.com/jansel
This commit is contained in:
parent
7166df8094
commit
4911b80b8e
|
|
@ -4,6 +4,8 @@ aten::__irshift__.Scalar
|
|||
aten::__irshift__.Tensor
|
||||
aten::_adaptive_avg_pool2d
|
||||
aten::_adaptive_avg_pool2d.out
|
||||
aten::_addmm_activation
|
||||
aten::_addmm_activation.out
|
||||
aten::_euclidean_dist.out
|
||||
aten::_fused_dropout
|
||||
aten::_fused_dropout.out
|
||||
|
|
|
|||
|
|
@ -18,8 +18,6 @@ aten::_add_relu.Tensor
|
|||
aten::_add_relu.out
|
||||
aten::_add_relu_.Scalar
|
||||
aten::_add_relu_.Tensor
|
||||
aten::_addmm_activation
|
||||
aten::_addmm_activation.out
|
||||
aten::_aminmax
|
||||
aten::_aminmax.dim
|
||||
aten::_aminmax.dim_out
|
||||
|
|
|
|||
|
|
@ -85,6 +85,54 @@ class TestPaternMatcher(TestCase):
|
|||
self.assertEqual(counters["inductor"]["pattern_matcher_count"], 2)
|
||||
self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 4)
|
||||
|
||||
def test_addmm_activation(self):
|
||||
def fn_addmm_relu(input, mat1, mat2):
|
||||
return torch.nn.functional.relu(torch.addmm(input, mat1, mat2))
|
||||
|
||||
def fn_addmm_gelu(input, mat1, mat2):
|
||||
return torch.nn.functional.gelu(torch.addmm(input, mat1, mat2))
|
||||
|
||||
args = [
|
||||
torch.randn(20, device="cuda"), # input
|
||||
torch.randn(10, 15, device="cuda"), # mat1
|
||||
torch.randn(15, 20, device="cuda"), # mat2
|
||||
]
|
||||
|
||||
for fn, atol in (
|
||||
(fn_addmm_relu, 1e-8),
|
||||
# higher tolerance due to the "tanh" approximation
|
||||
# in fused GELU epilogue vs. "none" without fusion
|
||||
(fn_addmm_gelu, 1e-3),
|
||||
):
|
||||
expected = fn(*args)
|
||||
actual, (code,) = run_and_get_code(torch.compile(fn), *args)
|
||||
torch.testing.assert_close(actual, expected, atol=atol, rtol=0)
|
||||
self.assertTrue("_addmm_activation" in code)
|
||||
|
||||
for fn in (fn_addmm_relu, fn_addmm_gelu):
|
||||
counters.clear()
|
||||
torch.compile(
|
||||
fn,
|
||||
# replacement disabled on max_autotune_gemm
|
||||
options={"max_autotune_gemm": True},
|
||||
)(*args)
|
||||
self.assertEqual(counters["inductor"]["pattern_matcher_count"], 0)
|
||||
self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 0)
|
||||
|
||||
args_not_replaced = [
|
||||
# addmm + activation with a rank-2 input
|
||||
# is not fusable, hence not replaced
|
||||
torch.randn(10, 20, device="cuda"), # input
|
||||
torch.randn(10, 15, device="cuda"), # mat1
|
||||
torch.randn(15, 20, device="cuda"), # mat2
|
||||
]
|
||||
|
||||
for fn in (fn_addmm_relu, fn_addmm_gelu):
|
||||
counters.clear()
|
||||
torch.compile(fn)(*args_not_replaced)
|
||||
self.assertEqual(counters["inductor"]["pattern_matcher_count"], 0)
|
||||
self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 0)
|
||||
|
||||
def test_cat_mm(self):
|
||||
def fn(a, b, c):
|
||||
return torch.cat(
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ class TestSelectAlgorithm(TestCase):
|
|||
foo(
|
||||
torch.randn(64, 32, device="cuda"),
|
||||
torch.randn(16, 32, device="cuda"),
|
||||
torch.randn(16, device="cuda"),
|
||||
torch.randn(1, 16, device="cuda"),
|
||||
)
|
||||
# Autotuning checks correctness of each version
|
||||
self.check_counter(counters["inductor"]["select_algorithm_autotune"], 1)
|
||||
|
|
|
|||
|
|
@ -1189,6 +1189,26 @@ def addmm(self: Tensor, mat1: Tensor, mat2: Tensor, beta: int = 1, alpha: int =
|
|||
return out + beta * self
|
||||
|
||||
|
||||
@register_decomposition(aten._addmm_activation)
|
||||
@out_wrapper()
|
||||
@pw_cast_for_opmath
|
||||
def _addmm_activation(
|
||||
self: Tensor,
|
||||
mat1: Tensor,
|
||||
mat2: Tensor,
|
||||
beta: int = 1,
|
||||
alpha: int = 1,
|
||||
use_gelu: bool = False,
|
||||
):
|
||||
out = addmm(self, mat1, mat2, beta, alpha)
|
||||
if use_gelu:
|
||||
if self.is_cuda:
|
||||
return aten.gelu(out, approximate="tanh")
|
||||
else:
|
||||
return aten.gelu(out)
|
||||
return aten.relu(out)
|
||||
|
||||
|
||||
@register_decomposition(aten.addmv)
|
||||
@out_wrapper()
|
||||
@pw_cast_for_opmath
|
||||
|
|
|
|||
|
|
@ -479,9 +479,7 @@ def fx_codegen_and_compile(
|
|||
|
||||
with V.set_fake_mode(fake_mode):
|
||||
# has some issues with memory in training
|
||||
locality_reorder = is_inference and config.reordering
|
||||
|
||||
post_grad_passes(gm, locality_reorder=locality_reorder)
|
||||
post_grad_passes(gm, is_inference=is_inference)
|
||||
V.debug.fx_graph_transformed(gm, example_inputs)
|
||||
|
||||
with V.set_fake_mode(fake_mode):
|
||||
|
|
|
|||
|
|
@ -301,9 +301,6 @@ def _sfdp_scale_factor_check(scale_factor_op):
|
|||
|
||||
@functools.lru_cache(None)
|
||||
def _sfdp_init():
|
||||
from ..._dynamo.utils import counters
|
||||
|
||||
counters_ref = counters["inductor"].copy()
|
||||
from .joint_graph import patterns
|
||||
|
||||
if torch.cuda.is_available():
|
||||
|
|
@ -415,7 +412,3 @@ def _sfdp_init():
|
|||
extra_check=extra_check,
|
||||
scalar_workaround=workaround,
|
||||
)
|
||||
|
||||
counters[
|
||||
"inductor"
|
||||
] = counters_ref # clear view matches encountered during sdpa tracing
|
||||
|
|
|
|||
|
|
@ -388,10 +388,6 @@ def pad_bmm(mat1, mat2, m_padded_length, k_padded_length, n_padded_length):
|
|||
|
||||
@functools.lru_cache(None)
|
||||
def _pad_mm_init():
|
||||
from ..._dynamo.utils import counters
|
||||
|
||||
counters_ref = counters["inductor"].copy()
|
||||
|
||||
from .joint_graph import patterns
|
||||
|
||||
if torch.cuda.is_available():
|
||||
|
|
@ -415,8 +411,6 @@ def _pad_mm_init():
|
|||
# 0.113377 is a "magic" value that lets us recover the lost input arg relationship
|
||||
rep = {"beta": 0.213377, "alpha": 0.113377}
|
||||
|
||||
counters_ref = counters["inductor"].copy()
|
||||
|
||||
for pattern, replacement, args, workaround, extra_check in [
|
||||
(
|
||||
mm_pattern,
|
||||
|
|
@ -459,7 +453,3 @@ def _pad_mm_init():
|
|||
extra_check=extra_check,
|
||||
scalar_workaround=workaround,
|
||||
)
|
||||
|
||||
counters[
|
||||
"inductor"
|
||||
] = counters_ref # clear view matches encountered during mm tracing
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import operator
|
|||
|
||||
import torch
|
||||
import torch._inductor as inductor
|
||||
|
||||
from .. import config, ir, pattern_matcher
|
||||
|
||||
from ..lowering import lowerings as L
|
||||
|
|
@ -15,6 +16,7 @@ from ..pattern_matcher import (
|
|||
filter_nodes,
|
||||
get_arg_value,
|
||||
Ignored,
|
||||
inference_graph,
|
||||
init_once_fakemode,
|
||||
KeywordArg,
|
||||
ListOf,
|
||||
|
|
@ -22,6 +24,7 @@ from ..pattern_matcher import (
|
|||
MULTIPLE,
|
||||
PatternMatcherPass,
|
||||
register_graph_pattern,
|
||||
register_replacement,
|
||||
stable_topological_sort,
|
||||
)
|
||||
from ..virtualized import V
|
||||
|
|
@ -37,9 +40,11 @@ pass_patterns = [
|
|||
PatternMatcherPass(),
|
||||
PatternMatcherPass(),
|
||||
]
|
||||
# patterns applied only in inference
|
||||
inference_patterns = PatternMatcherPass()
|
||||
|
||||
|
||||
def post_grad_passes(gm: torch.fx.GraphModule, locality_reorder: bool):
|
||||
def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
|
||||
"""
|
||||
Passes that run on after grad. This is called once on the forwards
|
||||
graph and once on the backwards graph.
|
||||
|
|
@ -50,7 +55,7 @@ def post_grad_passes(gm: torch.fx.GraphModule, locality_reorder: bool):
|
|||
# has some issues with mutation in inference mode
|
||||
gm.graph.eliminate_dead_code()
|
||||
|
||||
if locality_reorder:
|
||||
if is_inference and config.reordering:
|
||||
reorder_for_locality(gm.graph)
|
||||
|
||||
if config.pattern_matcher:
|
||||
|
|
@ -58,6 +63,8 @@ def post_grad_passes(gm: torch.fx.GraphModule, locality_reorder: bool):
|
|||
|
||||
for patterns in pass_patterns:
|
||||
patterns.apply(gm.graph)
|
||||
if is_inference:
|
||||
inference_patterns.apply(gm.graph)
|
||||
|
||||
stable_topological_sort(gm.graph)
|
||||
gm.recompile()
|
||||
|
|
@ -74,6 +81,7 @@ def lazy_init():
|
|||
from .quantization import register_quantization_lowerings
|
||||
|
||||
register_quantization_lowerings()
|
||||
register_addmm_activation_replacement()
|
||||
|
||||
|
||||
def reorder_for_locality(graph: torch.fx.Graph):
|
||||
|
|
@ -344,6 +352,70 @@ def addmm(match, mat1, mat2, inp):
|
|||
return L[aten.add](inp, L[aten.mm](mat1, mat2))
|
||||
|
||||
|
||||
def addmm_relu_pattern(input, mat1, mat2):
|
||||
output = aten.addmm(input, mat1, mat2)
|
||||
return aten.relu(output)
|
||||
|
||||
|
||||
def addmm_relu_replacement(input, mat1, mat2):
|
||||
return aten._addmm_activation(input, mat1, mat2, use_gelu=False)
|
||||
|
||||
|
||||
def addmm_gelu_pattern(input, mat1, mat2):
|
||||
output = aten.addmm(input, mat1, mat2)
|
||||
return aten.gelu(output)
|
||||
|
||||
|
||||
def addmm_gelu_replacement(input, mat1, mat2):
|
||||
return aten._addmm_activation(input, mat1, mat2, use_gelu=True)
|
||||
|
||||
|
||||
def should_replace_addmm_activation(match):
|
||||
if config.max_autotune_gemm:
|
||||
# keep addmm for tuning
|
||||
return False
|
||||
|
||||
input = match.kwargs["input"].meta["val"]
|
||||
# conditions of epilogue fusion in _addmm_activation
|
||||
return input.is_cuda and input.dim() == 1 and input.is_contiguous()
|
||||
|
||||
|
||||
def register_addmm_activation_replacement():
|
||||
if torch.cuda.is_available():
|
||||
# workaround https://github.com/pytorch/pytorch/issues/97894
|
||||
device = "cuda"
|
||||
else:
|
||||
device = "cpu"
|
||||
|
||||
# sizes/values dont actually matter for initial trace
|
||||
# once we get a possible match we re-trace with the actual values and verify the match still holds
|
||||
|
||||
inp = functools.partial(torch.empty, (5,), device=device)
|
||||
mat1 = functools.partial(torch.empty, (3, 4), device=device)
|
||||
mat2 = functools.partial(torch.empty, (4, 5), device=device)
|
||||
|
||||
for pattern, replacement, args in [
|
||||
(
|
||||
addmm_relu_pattern,
|
||||
addmm_relu_replacement,
|
||||
[inp(), mat1(), mat2()],
|
||||
),
|
||||
(
|
||||
addmm_gelu_pattern,
|
||||
addmm_gelu_replacement,
|
||||
[inp(), mat1(), mat2()],
|
||||
),
|
||||
]:
|
||||
register_replacement(
|
||||
pattern,
|
||||
replacement,
|
||||
args,
|
||||
inference_graph,
|
||||
inference_patterns,
|
||||
extra_check=should_replace_addmm_activation,
|
||||
)
|
||||
|
||||
|
||||
def is_valid_splitwithsizes_cat(match):
|
||||
split_nodes = filter_nodes(match.nodes, aten.split_with_sizes)
|
||||
cat_nodes = filter_nodes(match.nodes, aten.cat)
|
||||
|
|
|
|||
|
|
@ -1670,6 +1670,7 @@ make_fallback(aten.adaptive_max_pool2d)
|
|||
make_fallback(aten.adaptive_max_pool3d)
|
||||
make_fallback(aten.addbmm)
|
||||
make_fallback(aten.addmv, warn=False)
|
||||
make_fallback(aten._addmm_activation, warn=False)
|
||||
make_fallback(aten.avg_pool3d)
|
||||
make_fallback(aten.block_diag)
|
||||
make_fallback(aten._cdist_forward)
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ from torch._prims_common import is_integer_dtype
|
|||
from torch.fx import Node
|
||||
from torch.fx.experimental.proxy_tensor import make_fx, maybe_disable_fake_tensor_mode
|
||||
from torch.fx.immutable_collections import immutable_dict, immutable_list
|
||||
|
||||
from .._functorch import config as functorch_config
|
||||
from .._functorch.aot_autograd import aot_function, make_boxed_func
|
||||
from .._functorch.partitioners import default_partition
|
||||
|
|
@ -726,6 +727,7 @@ def register_replacement(
|
|||
if grad and is_integer_dtype(args[i].dtype):
|
||||
return False
|
||||
|
||||
with torch._dynamo.utils.detect_fake_mode(args):
|
||||
args[i] = torch.empty_strided(
|
||||
args[i].size(),
|
||||
args[i].stride(),
|
||||
|
|
@ -1070,10 +1072,17 @@ def init_once_fakemode(fn):
|
|||
@functools.lru_cache(None)
|
||||
@functools.wraps(fn)
|
||||
def lazy_init():
|
||||
counters_ref = counters["inductor"].copy()
|
||||
|
||||
with torch._guards.tracing(
|
||||
None
|
||||
), maybe_disable_fake_tensor_mode(), FakeTensorMode():
|
||||
return fn()
|
||||
result = fn()
|
||||
|
||||
# clear view matches encountered during tracing
|
||||
counters["inductor"] = counters_ref
|
||||
|
||||
return result
|
||||
|
||||
return lazy_init
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user