mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
[XLA:GPU] Enable dots with block_n=8 in triton and autotuner
This change utilizes recently added Triton support for smaller block sizes. Skipping occupancy optimization for some configs is essentially a workaround for incompatible split_k values. The impact of these configs is limited however because they are only present in non-exhaustive mode, so they mostly get filtered out anyway. PiperOrigin-RevId: 820617352
This commit is contained in:
parent
abc19d2d20
commit
cc58fb18fd
|
|
@ -1889,6 +1889,39 @@ ENTRY e {
|
||||||
.status());
|
.status());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class RhsLayoutParameterizedTritonGemmTest
|
||||||
|
: public TritonGemmTest,
|
||||||
|
public ::testing::WithParamInterface<absl::string_view> {};
|
||||||
|
|
||||||
|
TEST_P(RhsLayoutParameterizedTritonGemmTest,
|
||||||
|
BF16WithSmallRHSOuterDimDoesNotCrash) {
|
||||||
|
std::string hlo_text = absl::Substitute(R"(
|
||||||
|
triton_dot {
|
||||||
|
p0 = bf16[64,32] parameter(0)
|
||||||
|
p1 = bf16[32,8]$0 parameter(1)
|
||||||
|
ROOT dot = f32[64,8] dot(p0, p1),
|
||||||
|
lhs_contracting_dims={1},
|
||||||
|
rhs_contracting_dims={0}
|
||||||
|
}
|
||||||
|
|
||||||
|
ENTRY e {
|
||||||
|
p0 = bf16[64,32] parameter(0)
|
||||||
|
p1 = bf16[32,8]$0 parameter(1)
|
||||||
|
ROOT _ = f32[64,8] fusion(p0, p1), kind=kCustom, calls=triton_dot,
|
||||||
|
backend_config={"fusion_backend_config": {kind: "__triton_gemm", triton_gemm_config:
|
||||||
|
{"block_m":64,"block_n":8,"block_k":32,
|
||||||
|
"split_k":1,"num_stages":1,"num_warps":4,
|
||||||
|
"num_ctas":1}}}
|
||||||
|
})",
|
||||||
|
GetParam());
|
||||||
|
|
||||||
|
EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-1, /*arel=*/1e-2}));
|
||||||
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(RhsLayoutParameterizedTritonGemmTestSuite,
|
||||||
|
RhsLayoutParameterizedTritonGemmTest,
|
||||||
|
::testing::Values("", "{0, 1}", "{1, 0}"));
|
||||||
|
|
||||||
TEST_F(TritonGemmTest, BinaryOperationWithSmallInputsIsFused) {
|
TEST_F(TritonGemmTest, BinaryOperationWithSmallInputsIsFused) {
|
||||||
constexpr absl::string_view kHloText = R"(
|
constexpr absl::string_view kHloText = R"(
|
||||||
HloModule m
|
HloModule m
|
||||||
|
|
|
||||||
|
|
@ -4148,6 +4148,55 @@ ENTRY entry {
|
||||||
kHloText, ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-6}));
|
kHloText, ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-6}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(TritonEmitterTest, BF16WithSmallRHSOuterDimDoesNotCrash) {
|
||||||
|
const std::string kHloText = R"(
|
||||||
|
flhs {
|
||||||
|
ROOT flhs.p0 = bf16[64,32] parameter(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
frhs {
|
||||||
|
ROOT frhs.p0 = bf16[32,8] parameter(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
fdot {
|
||||||
|
fdot.p0 = bf16[64,32] parameter(0)
|
||||||
|
fdot.p1 = bf16[32,8] parameter(1)
|
||||||
|
fdot.lhs = bf16[64,32] fusion(fdot.p0), kind=kCustom, calls=flhs, backend_config={
|
||||||
|
"fusion_backend_config":{
|
||||||
|
"kind":"__triton_nested_gemm_fusion", "block_level_fusion_config":{
|
||||||
|
"output_tiles":[{"sizes":["64", "32"]}]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fdot.rhs = bf16[32,8]{1,0} fusion(fdot.p1), kind=kCustom, calls=frhs, backend_config={
|
||||||
|
"fusion_backend_config":{
|
||||||
|
"kind":"__triton_nested_gemm_fusion", "block_level_fusion_config":{
|
||||||
|
"output_tiles":[{"sizes":["32", "8"]}]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ROOT fdot.root = bf16[64,8]{1,0} dot(fdot.lhs, fdot.rhs),
|
||||||
|
lhs_contracting_dims={1}, rhs_contracting_dims={0}
|
||||||
|
}
|
||||||
|
|
||||||
|
ENTRY entry {
|
||||||
|
entry.p0 = bf16[64,32] parameter(0)
|
||||||
|
entry.p1 = bf16[32,8] parameter(1)
|
||||||
|
ROOT fusion = bf16[64,8] fusion(entry.p0, entry.p1),
|
||||||
|
kind=kCustom, calls=fdot, backend_config={
|
||||||
|
"fusion_backend_config":{
|
||||||
|
"kind":"__triton_nested_gemm_fusion",
|
||||||
|
"block_level_fusion_config":{
|
||||||
|
"output_tiles":[{"sizes":["64","8"]}],
|
||||||
|
"num_warps":"4",
|
||||||
|
"num_ctas":"1",
|
||||||
|
"num_stages":"1"}}}
|
||||||
|
})";
|
||||||
|
|
||||||
|
EXPECT_TRUE(RunAndCompareNoHloPasses(
|
||||||
|
kHloText, ErrorSpec{/*aabs=*/1e-1, /*arel=*/1e-2}));
|
||||||
|
}
|
||||||
|
|
||||||
struct ScaleDotTestParams {
|
struct ScaleDotTestParams {
|
||||||
std::string lhs_type;
|
std::string lhs_type;
|
||||||
std::string rhs_type;
|
std::string rhs_type;
|
||||||
|
|
|
||||||
|
|
@ -812,7 +812,7 @@ absl::Status ValidateMatMulConfig(const TritonGemmConfig& config,
|
||||||
TF_RET_CHECK(config.split_k >= 1);
|
TF_RET_CHECK(config.split_k >= 1);
|
||||||
TF_RET_CHECK(config.block_m >= 16);
|
TF_RET_CHECK(config.block_m >= 16);
|
||||||
TF_RET_CHECK(config.block_k >= 16);
|
TF_RET_CHECK(config.block_k >= 16);
|
||||||
TF_RET_CHECK(config.block_n >= 16);
|
TF_RET_CHECK(config.block_n >= 8);
|
||||||
|
|
||||||
const auto& dims = dot.dot_dimension_numbers();
|
const auto& dims = dot.dot_dimension_numbers();
|
||||||
int num_batch_dims =
|
int num_batch_dims =
|
||||||
|
|
|
||||||
|
|
@ -329,12 +329,11 @@ bool TritonDotFusionSearchSpace::ShouldOptimizeForOccupancy() const {
|
||||||
|
|
||||||
TritonDotFusionSearchSpace::OutputTile
|
TritonDotFusionSearchSpace::OutputTile
|
||||||
TritonDotFusionSearchSpace::GetMinOutputTile() const {
|
TritonDotFusionSearchSpace::GetMinOutputTile() const {
|
||||||
// Triton currently doesn't support tiles smaller than 16x16.
|
// TODO: b/395572776 - Calculate tile sizes based on the requested algorithm
|
||||||
// TODO: b/395572776 - Lift this restriction, and calculate a smaller tile
|
// (e.g., if we want to use wgmma vs mma vs fma, the minimal reasonable tile
|
||||||
// based on the requested algorithm (e.g., if we want to use wgmma vs mma
|
// size is different).
|
||||||
// vs fma, the minimal reasonable tile size is different).
|
constexpr OutputTile kMinSupportedTile = {16, 8};
|
||||||
constexpr OutputTile kMinSupportedTile = {16, 16};
|
constexpr OutputTile kMinWgmmaTile = {64, 8};
|
||||||
constexpr OutputTile kMinWgmmaTile = {64, 16};
|
|
||||||
if (device_description_.cuda_compute_capability().IsAtLeastHopper() &&
|
if (device_description_.cuda_compute_capability().IsAtLeastHopper() &&
|
||||||
!should_optimize_for_occupancy_) {
|
!should_optimize_for_occupancy_) {
|
||||||
VLOG(5) << "Computing output_tile: Want to use wgmma, so output_tile >= "
|
VLOG(5) << "Computing output_tile: Want to use wgmma, so output_tile >= "
|
||||||
|
|
@ -656,6 +655,14 @@ void TritonDotFusionSearchSpace::EliminateLowOccupancyConfigs(
|
||||||
|
|
||||||
ConfigWithNotes last_config = configs.back(); // Largest split.
|
ConfigWithNotes last_config = configs.back(); // Largest split.
|
||||||
auto has_too_few_tiles = [](const ConfigWithNotes& config) {
|
auto has_too_few_tiles = [](const ConfigWithNotes& config) {
|
||||||
|
// Small dots frequently lead to large split_k values that are not
|
||||||
|
// compatible with codegen. We skip occupancy optimization for them to be
|
||||||
|
// able to consider smaller splits in non-exhaustive mode.
|
||||||
|
// The value of 4 was found by running exhaustive autotuning and noting that
|
||||||
|
// the majority of optimal configs with block_n == 8 had split_k <= 4.
|
||||||
|
if (config.config.block_n == 8 && config.config.split_k <= 4) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
if (config.not_enough_tiles) {
|
if (config.not_enough_tiles) {
|
||||||
VLOG(10) << "Skipping due to fewer tiles than cores, config = "
|
VLOG(10) << "Skipping due to fewer tiles than cores, config = "
|
||||||
<< config.ToString();
|
<< config.ToString();
|
||||||
|
|
|
||||||
|
|
@ -45,6 +45,7 @@ void PrintTo(const TritonGemmConfig& config, std::ostream* os) {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
using ::testing::AllOf;
|
using ::testing::AllOf;
|
||||||
|
using ::testing::Contains;
|
||||||
using ::testing::ElementsAre;
|
using ::testing::ElementsAre;
|
||||||
using ::testing::ElementsAreArray;
|
using ::testing::ElementsAreArray;
|
||||||
using ::testing::Eq;
|
using ::testing::Eq;
|
||||||
|
|
@ -195,7 +196,7 @@ TEST_F(DotSearchSpaceTest, SerializesSearchSpace) {
|
||||||
|
|
||||||
EXPECT_EQ(search_space.ToString(),
|
EXPECT_EQ(search_space.ToString(),
|
||||||
"problem_size_BxMxNxKxE: 1x1024x1024x1024x(16->16) "
|
"problem_size_BxMxNxKxE: 1x1024x1024x1024x(16->16) "
|
||||||
"tile_range_SxMxNxK: [1-64]x[16-256]x[16-512]x[16-?] "
|
"tile_range_SxMxNxK: [1-64]x[16-256]x[8-512]x[16-?] "
|
||||||
"desired_total_warps: 2640 occupancy_optimization: 1 "
|
"desired_total_warps: 2640 occupancy_optimization: 1 "
|
||||||
"warps_per_cta: [2-?]");
|
"warps_per_cta: [2-?]");
|
||||||
}
|
}
|
||||||
|
|
@ -306,16 +307,15 @@ TEST_F(DotSearchSpaceTest, FindsGoodDataReuseTilesForLowOccupancyProblem) {
|
||||||
Contains(AllOf(BlockMIs(Ge(32)), SplitKIs(Ge(2)))));
|
Contains(AllOf(BlockMIs(Ge(32)), SplitKIs(Ge(2)))));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DotSearchSpaceTest,
|
TEST_F(DotSearchSpaceTest, FindsOccupancyMaximizingTilingForSmallProblem) {
|
||||||
FindsUniqueOccupancyMaximizingTilingForSmallProblem) {
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
std::unique_ptr<VerifiedHloModule> module,
|
std::unique_ptr<VerifiedHloModule> module,
|
||||||
GetDefaultDotModule(/*lhs_parallel_dim=*/64, /*rhs_parallel_dim=*/64,
|
GetDefaultDotModule(/*lhs_parallel_dim=*/64, /*rhs_parallel_dim=*/64,
|
||||||
/*contracting_dim=*/64));
|
/*contracting_dim=*/64));
|
||||||
TritonDotFusionSearchSpace search_space = MakeSearchSpace(module.get());
|
TritonDotFusionSearchSpace search_space = MakeSearchSpace(module.get());
|
||||||
EXPECT_THAT(search_space.GenerateConfigs(),
|
EXPECT_THAT(
|
||||||
AllOf(SizeIs(1), Each(AllOf(BlockMIs(Eq(16)), BlockNIs(Eq(16)),
|
search_space.GenerateConfigs(),
|
||||||
SplitKIs(Eq(4))))));
|
Contains(AllOf(BlockMIs(Eq(16)), BlockNIs(Eq(8)), SplitKIs(Eq(4)))));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DotSearchSpaceTest, FindsGoodDataReuseTilesForForcedHugeSplit) {
|
TEST_F(DotSearchSpaceTest, FindsGoodDataReuseTilesForForcedHugeSplit) {
|
||||||
|
|
@ -348,7 +348,7 @@ TEST_F(DotSearchSpaceTest, HonorsMinimumOutputTileSizeForTinyProblem) {
|
||||||
|
|
||||||
EXPECT_THAT(
|
EXPECT_THAT(
|
||||||
search_space.GenerateConfigs(),
|
search_space.GenerateConfigs(),
|
||||||
AllOf(Not(IsEmpty()), Each(BlockMIs(Ge(16))), Each(BlockNIs(Ge(16)))));
|
AllOf(Not(IsEmpty()), Each(BlockMIs(Ge(16))), Each(BlockNIs(Ge(8)))));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DotSearchSpaceTest, AssignsEnoughWarpsPerScheduler) {
|
TEST_F(DotSearchSpaceTest, AssignsEnoughWarpsPerScheduler) {
|
||||||
|
|
|
||||||
|
|
@ -582,7 +582,7 @@ ENTRY e {
|
||||||
MatchOptimizedHlo(kHloText, R"(
|
MatchOptimizedHlo(kHloText, R"(
|
||||||
; CHECK: reduce
|
; CHECK: reduce
|
||||||
; CHECK: ENTRY
|
; CHECK: ENTRY
|
||||||
; CHECK: f32[16,7,18]{2,1,0} fusion({{.*}})
|
; CHECK: f32[{{.*}},7,18]{2,1,0} fusion({{.*}})
|
||||||
; CHECK: ROOT {{.*}} f16[7,18]{1,0} fusion({{.*}})
|
; CHECK: ROOT {{.*}} f16[7,18]{1,0} fusion({{.*}})
|
||||||
)");
|
)");
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -65,7 +65,7 @@ static const std::vector<TritonGemmConfig>* const kHopperAmpereConfigs =
|
||||||
Config(128, 16, 32, 8, 4, 2), Config(128, 16, 64, 16, 3, 2),
|
Config(128, 16, 32, 8, 4, 2), Config(128, 16, 64, 16, 3, 2),
|
||||||
Config(128, 16, 64, 16, 1, 4), Config(128, 32, 32, 8, 4, 2),
|
Config(128, 16, 64, 16, 1, 4), Config(128, 32, 32, 8, 4, 2),
|
||||||
Config(128, 128, 32, 8, 4, 8), Config(128, 256, 32, 1, 4, 8),
|
Config(128, 128, 32, 8, 4, 8), Config(128, 256, 32, 1, 4, 8),
|
||||||
Config(128, 256, 64, 1, 4, 8)});
|
Config(128, 256, 64, 1, 4, 8), Config(64, 8, 128, 2, 3, 4, 1)});
|
||||||
|
|
||||||
static const std::vector<TritonGemmConfig>* const kDefaultCudaConfigs =
|
static const std::vector<TritonGemmConfig>* const kDefaultCudaConfigs =
|
||||||
new std::vector<TritonGemmConfig>(
|
new std::vector<TritonGemmConfig>(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user