mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[cutlass backend] Forward fix for less aligned gemm shapes (#148521)
Differential Revision: [D70600093](https://our.internmc.facebook.com/intern/diff/D70600093/) 1. Check if config name filtering still works. Tested, it works 2. do we get C++ compile error Yes, potentially we need to filter them out manually. Here we get this. ``` static_assert(threads_minor == 0 || (TileSizeK % threads_minor == 0)); ``` We need to move some assertions to gemm_template.py Pull Request resolved: https://github.com/pytorch/pytorch/pull/148521 Approved by: https://github.com/ColinPeppler
This commit is contained in:
parent
aac230a511
commit
b47d81682d
|
|
@ -898,7 +898,9 @@ class TestCutlassBackend(TestCase):
|
|||
"torch._inductor.kernel.mm.autotune_select_algorithm",
|
||||
wraps=select_no_algorithm,
|
||||
) as sa:
|
||||
with self.assertRaises(InductorError):
|
||||
with self.assertRaisesRegex(
|
||||
InductorError, r".*NoValidChoicesError.*"
|
||||
):
|
||||
torch.compile(my_addmm, dynamic=False)(x, a, b, 1.0, 2.0)
|
||||
args, _ = sa.call_args
|
||||
op_name, choices, _, __ = args
|
||||
|
|
@ -944,7 +946,9 @@ class TestCutlassBackend(TestCase):
|
|||
"torch._inductor.kernel.mm.autotune_select_algorithm",
|
||||
wraps=select_no_algorithm,
|
||||
) as sa:
|
||||
with self.assertRaises(InductorError):
|
||||
with self.assertRaisesRegex(
|
||||
InductorError, r".*NoValidChoicesError.*"
|
||||
):
|
||||
torch.compile(addmm, dynamic=False)(x, a, b, 1.0, 1.0)
|
||||
args, _ = sa.call_args
|
||||
op_name, choices, _, __ = args
|
||||
|
|
@ -961,6 +965,80 @@ class TestCutlassBackend(TestCase):
|
|||
cuda_template_count += 1
|
||||
assert cuda_template_count > 0, "No CUDATemplateCaller choices"
|
||||
|
||||
@unittest.skipIf(not SM90OrLater, "need sm_90")
|
||||
@mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
|
||||
def test_cutlass_backend_shape_coverage_mm(
|
||||
self,
|
||||
):
|
||||
"""
|
||||
Checks if cutlass backend produces some ops for a variety of shapes.
|
||||
|
||||
This test doesn't compile and check the correctness of the ops.
|
||||
|
||||
NOTE: K has to be even.
|
||||
"""
|
||||
|
||||
inputs = [
|
||||
(torch.randn(128, 500).cuda().half(), torch.randn(500, 576).cuda().half()),
|
||||
(
|
||||
torch.randn(500, 128).cuda().half(),
|
||||
torch.randn(128, 576).cuda().half(),
|
||||
),
|
||||
(torch.randn(128, 250).cuda().half(), torch.randn(250, 576).cuda().half()),
|
||||
(
|
||||
torch.randn(250, 128).cuda().half(),
|
||||
torch.randn(128, 576).cuda().half(),
|
||||
),
|
||||
(
|
||||
torch.randn(125, 128).cuda().half(),
|
||||
torch.randn(128, 576).cuda().half(),
|
||||
),
|
||||
]
|
||||
|
||||
def select_no_algorithm(*args, **kwargs):
|
||||
raise NoValidChoicesError
|
||||
|
||||
with fresh_inductor_cache(), config.patch(
|
||||
{
|
||||
"max_autotune": True,
|
||||
"max_autotune_gemm_backends": "CUTLASS",
|
||||
"cuda.cutlass_max_profiling_configs": 2,
|
||||
"autotune_fallback_to_aten": False,
|
||||
}
|
||||
), mock.patch(
|
||||
"torch._inductor.kernel.mm.autotune_select_algorithm",
|
||||
wraps=select_no_algorithm,
|
||||
) as sa:
|
||||
for input in inputs:
|
||||
A, B = input
|
||||
M, K = A.shape
|
||||
_, N = B.shape
|
||||
|
||||
with self.assertRaisesRegex(InductorError, r".*NoValidChoicesError.*"):
|
||||
torch.compile(torch.mm, dynamic=False)(*input)
|
||||
|
||||
self.assertTrue(
|
||||
sa.called,
|
||||
f"autotune_select_algorithm was not called with shape M={M}, N={N}, K={K}",
|
||||
)
|
||||
args, _ = sa.call_args
|
||||
op_name, choices, _, __ = args
|
||||
assert op_name == "mm"
|
||||
cuda_template_count = 0
|
||||
for choice in choices:
|
||||
if isinstance(choice, CUDATemplateCaller):
|
||||
choice_info = choice.info_dict()
|
||||
op_conf_name = choice_info.get("op_conf_name", "")
|
||||
assert isinstance(op_conf_name, str)
|
||||
cuda_template_count += 1
|
||||
|
||||
self.assertGreater(
|
||||
cuda_template_count,
|
||||
0,
|
||||
"No CUDATemplateCaller choices found for matmul with shape "
|
||||
f"M={M}, N={N}, K={K}",
|
||||
)
|
||||
|
||||
@unittest.skipIf(not SM80OrLater, "need sm_80")
|
||||
@mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
|
||||
def test_get_max_alignment(self):
|
||||
|
|
|
|||
|
|
@ -841,6 +841,16 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
|
|||
# Set epilogue.
|
||||
# TODO: update epilogue functor according to epilogues.
|
||||
op.element_epilogue = op.accumulator_type()
|
||||
|
||||
# Set bias layout and alignment.
|
||||
status = self._set_bias_layout_and_alignment(op)
|
||||
if not status:
|
||||
log.debug(
|
||||
"Skipping due to bias layout and alignment setting failure. op: %s", op
|
||||
)
|
||||
return None
|
||||
|
||||
# Apply regex filters at the end when configuration name doesn't change anymore
|
||||
if inductor_cuda_config.cutlass_op_allowlist_regex is not None:
|
||||
if not re.search(
|
||||
inductor_cuda_config.cutlass_op_allowlist_regex, op.configuration_name()
|
||||
|
|
@ -852,14 +862,6 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
|
|||
):
|
||||
return None
|
||||
|
||||
# Set bias layout and alignment.
|
||||
status = self._set_bias_layout_and_alignment(op)
|
||||
if not status:
|
||||
log.debug(
|
||||
"Skipping due to bias layout and alignment setting failure. op: %s", op
|
||||
)
|
||||
return None
|
||||
|
||||
return op
|
||||
|
||||
def gen_ops(self) -> "list[tuple[str, cutlass_gemm_op.GemmOperation]]": # type: ignore[name-defined] # noqa: F821
|
||||
|
|
@ -1212,46 +1214,29 @@ class CUTLASS3xGemmTemplate(CUTLASSGemmTemplate):
|
|||
self,
|
||||
op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821
|
||||
) -> bool:
|
||||
import cutlass_library.library as cutlass_lib
|
||||
|
||||
has_bias = len(self.input_nodes) >= 3 and self.input_nodes[2] is not None
|
||||
if has_bias:
|
||||
bias = self.input_nodes[2]
|
||||
bias_layout = CUTLASSGemmTemplate.cutlass_layout(bias.get_layout())
|
||||
Bias = self.input_nodes[2]
|
||||
# bias dtype
|
||||
op.C.element = cutlass_utils.torch_dtype_to_cutlass_type(
|
||||
Bias.get_layout().dtype
|
||||
)
|
||||
assert op.C.element == op.D.element, (
|
||||
f"Expect C and D to have the same dtype, found {op.C.element} and {op.D.element}"
|
||||
)
|
||||
|
||||
# Bias layout
|
||||
bias_layout = CUTLASSGemmTemplate.cutlass_layout(Bias.get_layout())
|
||||
op.C.layout = bias_layout
|
||||
status = self.set_alignment(bias.get_layout(), op.C)
|
||||
|
||||
# Bias alignment
|
||||
status = self.set_alignment(Bias.get_layout(), op.C)
|
||||
if not status:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _dtype_match(
|
||||
self,
|
||||
op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821
|
||||
) -> bool:
|
||||
"""
|
||||
Checking dtypes of C (i.e. bias) here, since that is the one not checked in the base class.
|
||||
"""
|
||||
|
||||
if not super()._dtype_match(op):
|
||||
return False
|
||||
|
||||
assert cutlass_utils.try_import_cutlass()
|
||||
from cutlass_library.library import DataType # type: ignore[import]
|
||||
|
||||
has_bias = len(self.input_nodes) >= 3 and self.input_nodes[2] is not None
|
||||
|
||||
if op.C.element == DataType.void:
|
||||
if has_bias:
|
||||
# op expects no bias, but bias exists
|
||||
return False
|
||||
else:
|
||||
# op expects bias. Needs to check if bias exists and is of the right dtype
|
||||
if not (
|
||||
has_bias
|
||||
and cutlass_utils.dtype_match(
|
||||
self.input_nodes[2].get_dtype(), op.C.element
|
||||
)
|
||||
):
|
||||
return False
|
||||
|
||||
op.C.element = cutlass_lib.DataType.void
|
||||
return True
|
||||
|
||||
def _define_gemm_instance(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user