[XLA:GPU] Enable dots with block_n=8 in triton and autotuner

This change utilizes recently added Triton support for smaller block sizes.

Skipping occupancy optimization for some configs is essentially a workaround for incompatible split_k values. The impact of these configs is limited however because they are only present in non-exhaustive mode, so they mostly get filtered out anyway.

PiperOrigin-RevId: 820617352
This commit is contained in:
Nikita Putikhin 2025-10-17 03:23:29 -07:00 committed by TensorFlower Gardener
parent abc19d2d20
commit cc58fb18fd
7 changed files with 105 additions and 16 deletions

View File

@ -1889,6 +1889,39 @@ ENTRY e {
.status());
}
class RhsLayoutParameterizedTritonGemmTest
: public TritonGemmTest,
public ::testing::WithParamInterface<absl::string_view> {};
TEST_P(RhsLayoutParameterizedTritonGemmTest,
BF16WithSmallRHSOuterDimDoesNotCrash) {
std::string hlo_text = absl::Substitute(R"(
triton_dot {
p0 = bf16[64,32] parameter(0)
p1 = bf16[32,8]$0 parameter(1)
ROOT dot = f32[64,8] dot(p0, p1),
lhs_contracting_dims={1},
rhs_contracting_dims={0}
}
ENTRY e {
p0 = bf16[64,32] parameter(0)
p1 = bf16[32,8]$0 parameter(1)
ROOT _ = f32[64,8] fusion(p0, p1), kind=kCustom, calls=triton_dot,
backend_config={"fusion_backend_config": {kind: "__triton_gemm", triton_gemm_config:
{"block_m":64,"block_n":8,"block_k":32,
"split_k":1,"num_stages":1,"num_warps":4,
"num_ctas":1}}}
})",
GetParam());
EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-1, /*arel=*/1e-2}));
}
INSTANTIATE_TEST_SUITE_P(RhsLayoutParameterizedTritonGemmTestSuite,
RhsLayoutParameterizedTritonGemmTest,
::testing::Values("", "{0, 1}", "{1, 0}"));
TEST_F(TritonGemmTest, BinaryOperationWithSmallInputsIsFused) {
constexpr absl::string_view kHloText = R"(
HloModule m

View File

@ -4148,6 +4148,55 @@ ENTRY entry {
kHloText, ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-6}));
}
TEST_F(TritonEmitterTest, BF16WithSmallRHSOuterDimDoesNotCrash) {
const std::string kHloText = R"(
flhs {
ROOT flhs.p0 = bf16[64,32] parameter(0)
}
frhs {
ROOT frhs.p0 = bf16[32,8] parameter(0)
}
fdot {
fdot.p0 = bf16[64,32] parameter(0)
fdot.p1 = bf16[32,8] parameter(1)
fdot.lhs = bf16[64,32] fusion(fdot.p0), kind=kCustom, calls=flhs, backend_config={
"fusion_backend_config":{
"kind":"__triton_nested_gemm_fusion", "block_level_fusion_config":{
"output_tiles":[{"sizes":["64", "32"]}]
}
}
}
fdot.rhs = bf16[32,8]{1,0} fusion(fdot.p1), kind=kCustom, calls=frhs, backend_config={
"fusion_backend_config":{
"kind":"__triton_nested_gemm_fusion", "block_level_fusion_config":{
"output_tiles":[{"sizes":["32", "8"]}]
}
}
}
ROOT fdot.root = bf16[64,8]{1,0} dot(fdot.lhs, fdot.rhs),
lhs_contracting_dims={1}, rhs_contracting_dims={0}
}
ENTRY entry {
entry.p0 = bf16[64,32] parameter(0)
entry.p1 = bf16[32,8] parameter(1)
ROOT fusion = bf16[64,8] fusion(entry.p0, entry.p1),
kind=kCustom, calls=fdot, backend_config={
"fusion_backend_config":{
"kind":"__triton_nested_gemm_fusion",
"block_level_fusion_config":{
"output_tiles":[{"sizes":["64","8"]}],
"num_warps":"4",
"num_ctas":"1",
"num_stages":"1"}}}
})";
EXPECT_TRUE(RunAndCompareNoHloPasses(
kHloText, ErrorSpec{/*aabs=*/1e-1, /*arel=*/1e-2}));
}
struct ScaleDotTestParams {
std::string lhs_type;
std::string rhs_type;

View File

@ -812,7 +812,7 @@ absl::Status ValidateMatMulConfig(const TritonGemmConfig& config,
TF_RET_CHECK(config.split_k >= 1);
TF_RET_CHECK(config.block_m >= 16);
TF_RET_CHECK(config.block_k >= 16);
TF_RET_CHECK(config.block_n >= 16);
TF_RET_CHECK(config.block_n >= 8);
const auto& dims = dot.dot_dimension_numbers();
int num_batch_dims =

View File

@ -329,12 +329,11 @@ bool TritonDotFusionSearchSpace::ShouldOptimizeForOccupancy() const {
TritonDotFusionSearchSpace::OutputTile
TritonDotFusionSearchSpace::GetMinOutputTile() const {
// Triton currently doesn't support tiles smaller than 16x16.
// TODO: b/395572776 - Lift this restriction, and calculate a smaller tile
// based on the requested algorithm (e.g., if we want to use wgmma vs mma
// vs fma, the minimal reasonable tile size is different).
constexpr OutputTile kMinSupportedTile = {16, 16};
constexpr OutputTile kMinWgmmaTile = {64, 16};
// TODO: b/395572776 - Calculate tile sizes based on the requested algorithm
// (e.g., if we want to use wgmma vs mma vs fma, the minimal reasonable tile
// size is different).
constexpr OutputTile kMinSupportedTile = {16, 8};
constexpr OutputTile kMinWgmmaTile = {64, 8};
if (device_description_.cuda_compute_capability().IsAtLeastHopper() &&
!should_optimize_for_occupancy_) {
VLOG(5) << "Computing output_tile: Want to use wgmma, so output_tile >= "
@ -656,6 +655,14 @@ void TritonDotFusionSearchSpace::EliminateLowOccupancyConfigs(
ConfigWithNotes last_config = configs.back(); // Largest split.
auto has_too_few_tiles = [](const ConfigWithNotes& config) {
// Small dots frequently lead to large split_k values that are not
// compatible with codegen. We skip occupancy optimization for them to be
// able to consider smaller splits in non-exhaustive mode.
// The value of 4 was found by running exhaustive autotuning and noting that
// the majority of optimal configs with block_n == 8 had split_k <= 4.
if (config.config.block_n == 8 && config.config.split_k <= 4) {
return false;
}
if (config.not_enough_tiles) {
VLOG(10) << "Skipping due to fewer tiles than cores, config = "
<< config.ToString();

View File

@ -45,6 +45,7 @@ void PrintTo(const TritonGemmConfig& config, std::ostream* os) {
namespace {
using ::testing::AllOf;
using ::testing::Contains;
using ::testing::ElementsAre;
using ::testing::ElementsAreArray;
using ::testing::Eq;
@ -195,7 +196,7 @@ TEST_F(DotSearchSpaceTest, SerializesSearchSpace) {
EXPECT_EQ(search_space.ToString(),
"problem_size_BxMxNxKxE: 1x1024x1024x1024x(16->16) "
"tile_range_SxMxNxK: [1-64]x[16-256]x[16-512]x[16-?] "
"tile_range_SxMxNxK: [1-64]x[16-256]x[8-512]x[16-?] "
"desired_total_warps: 2640 occupancy_optimization: 1 "
"warps_per_cta: [2-?]");
}
@ -306,16 +307,15 @@ TEST_F(DotSearchSpaceTest, FindsGoodDataReuseTilesForLowOccupancyProblem) {
Contains(AllOf(BlockMIs(Ge(32)), SplitKIs(Ge(2)))));
}
TEST_F(DotSearchSpaceTest,
FindsUniqueOccupancyMaximizingTilingForSmallProblem) {
TEST_F(DotSearchSpaceTest, FindsOccupancyMaximizingTilingForSmallProblem) {
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<VerifiedHloModule> module,
GetDefaultDotModule(/*lhs_parallel_dim=*/64, /*rhs_parallel_dim=*/64,
/*contracting_dim=*/64));
TritonDotFusionSearchSpace search_space = MakeSearchSpace(module.get());
EXPECT_THAT(search_space.GenerateConfigs(),
AllOf(SizeIs(1), Each(AllOf(BlockMIs(Eq(16)), BlockNIs(Eq(16)),
SplitKIs(Eq(4))))));
EXPECT_THAT(
search_space.GenerateConfigs(),
Contains(AllOf(BlockMIs(Eq(16)), BlockNIs(Eq(8)), SplitKIs(Eq(4)))));
}
TEST_F(DotSearchSpaceTest, FindsGoodDataReuseTilesForForcedHugeSplit) {
@ -348,7 +348,7 @@ TEST_F(DotSearchSpaceTest, HonorsMinimumOutputTileSizeForTinyProblem) {
EXPECT_THAT(
search_space.GenerateConfigs(),
AllOf(Not(IsEmpty()), Each(BlockMIs(Ge(16))), Each(BlockNIs(Ge(16)))));
AllOf(Not(IsEmpty()), Each(BlockMIs(Ge(16))), Each(BlockNIs(Ge(8)))));
}
TEST_F(DotSearchSpaceTest, AssignsEnoughWarpsPerScheduler) {

View File

@ -582,7 +582,7 @@ ENTRY e {
MatchOptimizedHlo(kHloText, R"(
; CHECK: reduce
; CHECK: ENTRY
; CHECK: f32[16,7,18]{2,1,0} fusion({{.*}})
; CHECK: f32[{{.*}},7,18]{2,1,0} fusion({{.*}})
; CHECK: ROOT {{.*}} f16[7,18]{1,0} fusion({{.*}})
)");

View File

@ -65,7 +65,7 @@ static const std::vector<TritonGemmConfig>* const kHopperAmpereConfigs =
Config(128, 16, 32, 8, 4, 2), Config(128, 16, 64, 16, 3, 2),
Config(128, 16, 64, 16, 1, 4), Config(128, 32, 32, 8, 4, 2),
Config(128, 128, 32, 8, 4, 8), Config(128, 256, 32, 1, 4, 8),
Config(128, 256, 64, 1, 4, 8)});
Config(128, 256, 64, 1, 4, 8), Config(64, 8, 128, 2, 3, 4, 1)});
static const std::vector<TritonGemmConfig>* const kDefaultCudaConfigs =
new std::vector<TritonGemmConfig>(