[XLA:GPU] enable generic triton emitter for all gemms

According to benchmarks we have reached the neutrality with the legacy emitter. Switching to the new emitter by default.
Legacy emitter will be kept for some time but is considered depricated and should not be used. It will be deleted in the near future.

PiperOrigin-RevId: 822067921
This commit is contained in:
Mikhail Goncharov 2025-10-21 05:17:14 -07:00 committed by TensorFlower Gardener
parent bd257617f7
commit 2d4dd83773
5 changed files with 19 additions and 18 deletions

View File

@ -329,6 +329,10 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
// default value of the command line flag in `MakeDebugOptionsFlags`.
opts.add_xla_gpu_unsupported_generic_triton_emitter_features(
DebugOptions::GENERIC_TRITON_EMITTER_ENABLE_NESTED_GEMM);
opts.add_xla_gpu_unsupported_generic_triton_emitter_features(
DebugOptions::GENERIC_TRITON_EMITTER_ALLOW_ALL_GEMM_SHAPES);
opts.add_xla_gpu_unsupported_generic_triton_emitter_features(
DebugOptions::GENERIC_TRITON_EMITTER_ALLOW_ALL_OPS_IN_GEMM_FUSION);
opts.set_xla_gpu_unsupported_enable_triton_multi_output_fusion(true);
opts.set_xla_gpu_enable_cudnn_int8x32_convolution_reordering(true);
opts.set_xla_gpu_triton_gemm_any(true);

View File

@ -390,10 +390,13 @@ TEST(ParseRepeatedEnumFlagsTest, GenericTritonEmitterFeatures) {
const auto& enabled_features =
debug_options.xla_gpu_unsupported_generic_triton_emitter_features();
// Check that the default setting is empty.
// Check default setting.
ASSERT_THAT(
enabled_features,
ElementsAre(DebugOptions::GENERIC_TRITON_EMITTER_ENABLE_NESTED_GEMM));
testing::UnorderedElementsAre(
DebugOptions::GENERIC_TRITON_EMITTER_ENABLE_NESTED_GEMM,
DebugOptions::GENERIC_TRITON_EMITTER_ALLOW_ALL_GEMM_SHAPES,
DebugOptions::GENERIC_TRITON_EMITTER_ALLOW_ALL_OPS_IN_GEMM_FUSION));
// Initialize the flag objects.
std::vector<tsl::Flag> flag_objects;
@ -401,24 +404,23 @@ TEST(ParseRepeatedEnumFlagsTest, GenericTritonEmitterFeatures) {
// Adding options.
SetXlaFlagsEnvVar(
"--xla_gpu_unsupported_generic_triton_emitter_features=+allow_all_gemm_"
"shapes");
"--xla_gpu_unsupported_generic_triton_emitter_features="
"-allow_all_gemm_shapes");
ParseFlagsFromEnvAndDieIfUnknown("XLA_FLAGS", flag_objects);
EXPECT_EQ(enabled_features.size(), 2);
EXPECT_THAT(
enabled_features,
ElementsAre(DebugOptions::GENERIC_TRITON_EMITTER_ENABLE_NESTED_GEMM,
DebugOptions::GENERIC_TRITON_EMITTER_ALLOW_ALL_GEMM_SHAPES));
testing::UnorderedElementsAre(
DebugOptions::GENERIC_TRITON_EMITTER_ENABLE_NESTED_GEMM,
DebugOptions::GENERIC_TRITON_EMITTER_ALLOW_ALL_OPS_IN_GEMM_FUSION));
// Overwriting options.
SetXlaFlagsEnvVar(
"--xla_gpu_unsupported_generic_triton_emitter_features=disable_legacy_"
"gemm,allow_all_ops_in_gemm_fusion");
ParseFlagsFromEnvAndDieIfUnknown("XLA_FLAGS", flag_objects);
EXPECT_EQ(enabled_features.size(), 2);
EXPECT_THAT(
enabled_features,
ElementsAre(
testing::UnorderedElementsAre(
DebugOptions::GENERIC_TRITON_EMITTER_DISABLE_LEGACY_GEMM,
DebugOptions::GENERIC_TRITON_EMITTER_ALLOW_ALL_OPS_IN_GEMM_FUSION));
@ -427,10 +429,9 @@ TEST(ParseRepeatedEnumFlagsTest, GenericTritonEmitterFeatures) {
"--xla_gpu_unsupported_generic_triton_emitter_features=-disable_legacy_"
"gemm,-unspecified,+enable_nested_gemm,+allow_all_ops_in_gemm_fusion");
ParseFlagsFromEnvAndDieIfUnknown("XLA_FLAGS", flag_objects);
EXPECT_EQ(enabled_features.size(), 2);
EXPECT_THAT(
enabled_features,
ElementsAre(
testing::UnorderedElementsAre(
DebugOptions::GENERIC_TRITON_EMITTER_ALLOW_ALL_OPS_IN_GEMM_FUSION,
DebugOptions::GENERIC_TRITON_EMITTER_ENABLE_NESTED_GEMM));
}

View File

@ -1068,8 +1068,7 @@ GemmFusionAutotunerImpl::CompileAll(AutotunerCompileUtil& compile_util,
if (code == absl::StatusCode::kInternal ||
code == absl::StatusCode::kFailedPrecondition ||
code == absl::StatusCode::kUnimplemented ||
(debug_options_.xla_gpu_exhaustive_tiling_search() &&
code == absl::StatusCode::kInvalidArgument)) {
code == absl::StatusCode::kInvalidArgument) {
VLOG(5) << "Compilation failed with status " << executable_or.status()
<< " that is ignored";
return nullptr;

View File

@ -50,10 +50,6 @@ class DeterminismTest : public GpuCodegenTest {
public:
DeterminismTest() : debug_options_(HloTestBase::GetDebugOptionsForTest()) {
debug_options_.set_xla_gpu_exclude_nondeterministic_ops(true);
// TODO(b/393299275): remove when the flag is enabled by default.
debug_options_.clear_xla_gpu_unsupported_generic_triton_emitter_features();
debug_options_.add_xla_gpu_unsupported_generic_triton_emitter_features(
DebugOptions::GENERIC_TRITON_EMITTER_ENABLE_NESTED_GEMM);
}
se::CudaComputeCapability get_cuda_cc() const {

View File

@ -192,7 +192,8 @@ ENTRY e {
GmockMatch(match::Concatenate(match::Fusion(), match::Fusion())));
}
TEST_F(NestGemmFusionTest, UnsupportedComputationsAreNotChanged) {
// TODO(b/393299275): update test to use a unsupported operation.
TEST_F(NestGemmFusionTest, DISABLED_UnsupportedComputationsAreNotChanged) {
// Fusions other than kTritonNestedGemmFusionKind are not supported.
// In this case pass should only change the supported fusions.
absl::string_view hlo = R"(