[XLA:GPU] Enable generic triton emitter for all gemms, second attempt.

According to benchmarks we have reached the neutrality with the legacy emitter. Switching to the new emitter by default. Legacy emitter will be kept for some time but is considered depricated and should not be used. It will be deleted in the near future.

Reverts 85c99b1ecb

PiperOrigin-RevId: 823475406
This commit is contained in:
Christian Sigg 2025-10-24 04:38:22 -07:00 committed by TensorFlower Gardener
parent 1f1049a06b
commit c8cc7f2fbb
5 changed files with 19 additions and 18 deletions

View File

@ -325,6 +325,10 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
// default value of the command line flag in `MakeDebugOptionsFlags`.
opts.add_xla_gpu_unsupported_generic_triton_emitter_features(
DebugOptions::GENERIC_TRITON_EMITTER_ENABLE_NESTED_GEMM);
opts.add_xla_gpu_unsupported_generic_triton_emitter_features(
DebugOptions::GENERIC_TRITON_EMITTER_ALLOW_ALL_GEMM_SHAPES);
opts.add_xla_gpu_unsupported_generic_triton_emitter_features(
DebugOptions::GENERIC_TRITON_EMITTER_ALLOW_ALL_OPS_IN_GEMM_FUSION);
opts.set_xla_gpu_unsupported_enable_triton_multi_output_fusion(true);
opts.set_xla_gpu_enable_cudnn_int8x32_convolution_reordering(true);
opts.set_xla_gpu_triton_gemm_any(true);

View File

@ -390,10 +390,13 @@ TEST(ParseRepeatedEnumFlagsTest, GenericTritonEmitterFeatures) {
const auto& enabled_features =
debug_options.xla_gpu_unsupported_generic_triton_emitter_features();
// Check that the default setting is empty.
// Check default setting.
ASSERT_THAT(
enabled_features,
ElementsAre(DebugOptions::GENERIC_TRITON_EMITTER_ENABLE_NESTED_GEMM));
testing::UnorderedElementsAre(
DebugOptions::GENERIC_TRITON_EMITTER_ENABLE_NESTED_GEMM,
DebugOptions::GENERIC_TRITON_EMITTER_ALLOW_ALL_GEMM_SHAPES,
DebugOptions::GENERIC_TRITON_EMITTER_ALLOW_ALL_OPS_IN_GEMM_FUSION));
// Initialize the flag objects.
std::vector<tsl::Flag> flag_objects;
@ -401,24 +404,23 @@ TEST(ParseRepeatedEnumFlagsTest, GenericTritonEmitterFeatures) {
// Adding options.
SetXlaFlagsEnvVar(
"--xla_gpu_unsupported_generic_triton_emitter_features=+allow_all_gemm_"
"shapes");
"--xla_gpu_unsupported_generic_triton_emitter_features="
"-allow_all_gemm_shapes");
ParseFlagsFromEnvAndDieIfUnknown("XLA_FLAGS", flag_objects);
EXPECT_EQ(enabled_features.size(), 2);
EXPECT_THAT(
enabled_features,
ElementsAre(DebugOptions::GENERIC_TRITON_EMITTER_ENABLE_NESTED_GEMM,
DebugOptions::GENERIC_TRITON_EMITTER_ALLOW_ALL_GEMM_SHAPES));
testing::UnorderedElementsAre(
DebugOptions::GENERIC_TRITON_EMITTER_ENABLE_NESTED_GEMM,
DebugOptions::GENERIC_TRITON_EMITTER_ALLOW_ALL_OPS_IN_GEMM_FUSION));
// Overwriting options.
SetXlaFlagsEnvVar(
"--xla_gpu_unsupported_generic_triton_emitter_features=disable_legacy_"
"gemm,allow_all_ops_in_gemm_fusion");
ParseFlagsFromEnvAndDieIfUnknown("XLA_FLAGS", flag_objects);
EXPECT_EQ(enabled_features.size(), 2);
EXPECT_THAT(
enabled_features,
ElementsAre(
testing::UnorderedElementsAre(
DebugOptions::GENERIC_TRITON_EMITTER_DISABLE_LEGACY_GEMM,
DebugOptions::GENERIC_TRITON_EMITTER_ALLOW_ALL_OPS_IN_GEMM_FUSION));
@ -427,10 +429,9 @@ TEST(ParseRepeatedEnumFlagsTest, GenericTritonEmitterFeatures) {
"--xla_gpu_unsupported_generic_triton_emitter_features=-disable_legacy_"
"gemm,-unspecified,+enable_nested_gemm,+allow_all_ops_in_gemm_fusion");
ParseFlagsFromEnvAndDieIfUnknown("XLA_FLAGS", flag_objects);
EXPECT_EQ(enabled_features.size(), 2);
EXPECT_THAT(
enabled_features,
ElementsAre(
testing::UnorderedElementsAre(
DebugOptions::GENERIC_TRITON_EMITTER_ALLOW_ALL_OPS_IN_GEMM_FUSION,
DebugOptions::GENERIC_TRITON_EMITTER_ENABLE_NESTED_GEMM));
}

View File

@ -1068,8 +1068,7 @@ GemmFusionAutotunerImpl::CompileAll(AutotunerCompileUtil& compile_util,
if (code == absl::StatusCode::kInternal ||
code == absl::StatusCode::kFailedPrecondition ||
code == absl::StatusCode::kUnimplemented ||
(debug_options_.xla_gpu_exhaustive_tiling_search() &&
code == absl::StatusCode::kInvalidArgument)) {
code == absl::StatusCode::kInvalidArgument) {
VLOG(5) << "Compilation failed with status " << executable_or.status()
<< " that is ignored";
return nullptr;

View File

@ -50,10 +50,6 @@ class DeterminismTest : public GpuCodegenTest {
public:
DeterminismTest() : debug_options_(HloTestBase::GetDebugOptionsForTest()) {
debug_options_.set_xla_gpu_exclude_nondeterministic_ops(true);
// TODO(b/393299275): remove when the flag is enabled by default.
debug_options_.clear_xla_gpu_unsupported_generic_triton_emitter_features();
debug_options_.add_xla_gpu_unsupported_generic_triton_emitter_features(
DebugOptions::GENERIC_TRITON_EMITTER_ENABLE_NESTED_GEMM);
}
se::CudaComputeCapability get_cuda_cc() const {

View File

@ -192,7 +192,8 @@ ENTRY e {
GmockMatch(match::Concatenate(match::Fusion(), match::Fusion())));
}
TEST_F(NestGemmFusionTest, UnsupportedComputationsAreNotChanged) {
// TODO(b/393299275): update test to use a unsupported operation.
TEST_F(NestGemmFusionTest, DISABLED_UnsupportedComputationsAreNotChanged) {
// Fusions other than kTritonNestedGemmFusionKind are not supported.
// In this case pass should only change the supported fusions.
absl::string_view hlo = R"(