mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
[XLA:GPU] Enable generic triton emitter for all gemms, second attempt.
According to benchmarks we have reached the neutrality with the legacy emitter. Switching to the new emitter by default. Legacy emitter will be kept for some time but is considered depricated and should not be used. It will be deleted in the near future.
Reverts 85c99b1ecb
PiperOrigin-RevId: 823475406
This commit is contained in:
parent
1f1049a06b
commit
c8cc7f2fbb
4
third_party/xla/xla/debug_options_flags.cc
vendored
4
third_party/xla/xla/debug_options_flags.cc
vendored
|
|
@ -325,6 +325,10 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
|
|||
// default value of the command line flag in `MakeDebugOptionsFlags`.
|
||||
opts.add_xla_gpu_unsupported_generic_triton_emitter_features(
|
||||
DebugOptions::GENERIC_TRITON_EMITTER_ENABLE_NESTED_GEMM);
|
||||
opts.add_xla_gpu_unsupported_generic_triton_emitter_features(
|
||||
DebugOptions::GENERIC_TRITON_EMITTER_ALLOW_ALL_GEMM_SHAPES);
|
||||
opts.add_xla_gpu_unsupported_generic_triton_emitter_features(
|
||||
DebugOptions::GENERIC_TRITON_EMITTER_ALLOW_ALL_OPS_IN_GEMM_FUSION);
|
||||
opts.set_xla_gpu_unsupported_enable_triton_multi_output_fusion(true);
|
||||
opts.set_xla_gpu_enable_cudnn_int8x32_convolution_reordering(true);
|
||||
opts.set_xla_gpu_triton_gemm_any(true);
|
||||
|
|
|
|||
|
|
@ -390,10 +390,13 @@ TEST(ParseRepeatedEnumFlagsTest, GenericTritonEmitterFeatures) {
|
|||
const auto& enabled_features =
|
||||
debug_options.xla_gpu_unsupported_generic_triton_emitter_features();
|
||||
|
||||
// Check that the default setting is empty.
|
||||
// Check default setting.
|
||||
ASSERT_THAT(
|
||||
enabled_features,
|
||||
ElementsAre(DebugOptions::GENERIC_TRITON_EMITTER_ENABLE_NESTED_GEMM));
|
||||
testing::UnorderedElementsAre(
|
||||
DebugOptions::GENERIC_TRITON_EMITTER_ENABLE_NESTED_GEMM,
|
||||
DebugOptions::GENERIC_TRITON_EMITTER_ALLOW_ALL_GEMM_SHAPES,
|
||||
DebugOptions::GENERIC_TRITON_EMITTER_ALLOW_ALL_OPS_IN_GEMM_FUSION));
|
||||
|
||||
// Initialize the flag objects.
|
||||
std::vector<tsl::Flag> flag_objects;
|
||||
|
|
@ -401,24 +404,23 @@ TEST(ParseRepeatedEnumFlagsTest, GenericTritonEmitterFeatures) {
|
|||
|
||||
// Adding options.
|
||||
SetXlaFlagsEnvVar(
|
||||
"--xla_gpu_unsupported_generic_triton_emitter_features=+allow_all_gemm_"
|
||||
"shapes");
|
||||
"--xla_gpu_unsupported_generic_triton_emitter_features="
|
||||
"-allow_all_gemm_shapes");
|
||||
ParseFlagsFromEnvAndDieIfUnknown("XLA_FLAGS", flag_objects);
|
||||
EXPECT_EQ(enabled_features.size(), 2);
|
||||
EXPECT_THAT(
|
||||
enabled_features,
|
||||
ElementsAre(DebugOptions::GENERIC_TRITON_EMITTER_ENABLE_NESTED_GEMM,
|
||||
DebugOptions::GENERIC_TRITON_EMITTER_ALLOW_ALL_GEMM_SHAPES));
|
||||
testing::UnorderedElementsAre(
|
||||
DebugOptions::GENERIC_TRITON_EMITTER_ENABLE_NESTED_GEMM,
|
||||
DebugOptions::GENERIC_TRITON_EMITTER_ALLOW_ALL_OPS_IN_GEMM_FUSION));
|
||||
|
||||
// Overwriting options.
|
||||
SetXlaFlagsEnvVar(
|
||||
"--xla_gpu_unsupported_generic_triton_emitter_features=disable_legacy_"
|
||||
"gemm,allow_all_ops_in_gemm_fusion");
|
||||
ParseFlagsFromEnvAndDieIfUnknown("XLA_FLAGS", flag_objects);
|
||||
EXPECT_EQ(enabled_features.size(), 2);
|
||||
EXPECT_THAT(
|
||||
enabled_features,
|
||||
ElementsAre(
|
||||
testing::UnorderedElementsAre(
|
||||
DebugOptions::GENERIC_TRITON_EMITTER_DISABLE_LEGACY_GEMM,
|
||||
DebugOptions::GENERIC_TRITON_EMITTER_ALLOW_ALL_OPS_IN_GEMM_FUSION));
|
||||
|
||||
|
|
@ -427,10 +429,9 @@ TEST(ParseRepeatedEnumFlagsTest, GenericTritonEmitterFeatures) {
|
|||
"--xla_gpu_unsupported_generic_triton_emitter_features=-disable_legacy_"
|
||||
"gemm,-unspecified,+enable_nested_gemm,+allow_all_ops_in_gemm_fusion");
|
||||
ParseFlagsFromEnvAndDieIfUnknown("XLA_FLAGS", flag_objects);
|
||||
EXPECT_EQ(enabled_features.size(), 2);
|
||||
EXPECT_THAT(
|
||||
enabled_features,
|
||||
ElementsAre(
|
||||
testing::UnorderedElementsAre(
|
||||
DebugOptions::GENERIC_TRITON_EMITTER_ALLOW_ALL_OPS_IN_GEMM_FUSION,
|
||||
DebugOptions::GENERIC_TRITON_EMITTER_ENABLE_NESTED_GEMM));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1068,8 +1068,7 @@ GemmFusionAutotunerImpl::CompileAll(AutotunerCompileUtil& compile_util,
|
|||
if (code == absl::StatusCode::kInternal ||
|
||||
code == absl::StatusCode::kFailedPrecondition ||
|
||||
code == absl::StatusCode::kUnimplemented ||
|
||||
(debug_options_.xla_gpu_exhaustive_tiling_search() &&
|
||||
code == absl::StatusCode::kInvalidArgument)) {
|
||||
code == absl::StatusCode::kInvalidArgument) {
|
||||
VLOG(5) << "Compilation failed with status " << executable_or.status()
|
||||
<< " that is ignored";
|
||||
return nullptr;
|
||||
|
|
|
|||
|
|
@ -50,10 +50,6 @@ class DeterminismTest : public GpuCodegenTest {
|
|||
public:
|
||||
DeterminismTest() : debug_options_(HloTestBase::GetDebugOptionsForTest()) {
|
||||
debug_options_.set_xla_gpu_exclude_nondeterministic_ops(true);
|
||||
// TODO(b/393299275): remove when the flag is enabled by default.
|
||||
debug_options_.clear_xla_gpu_unsupported_generic_triton_emitter_features();
|
||||
debug_options_.add_xla_gpu_unsupported_generic_triton_emitter_features(
|
||||
DebugOptions::GENERIC_TRITON_EMITTER_ENABLE_NESTED_GEMM);
|
||||
}
|
||||
|
||||
se::CudaComputeCapability get_cuda_cc() const {
|
||||
|
|
|
|||
|
|
@ -192,7 +192,8 @@ ENTRY e {
|
|||
GmockMatch(match::Concatenate(match::Fusion(), match::Fusion())));
|
||||
}
|
||||
|
||||
TEST_F(NestGemmFusionTest, UnsupportedComputationsAreNotChanged) {
|
||||
// TODO(b/393299275): update test to use a unsupported operation.
|
||||
TEST_F(NestGemmFusionTest, DISABLED_UnsupportedComputationsAreNotChanged) {
|
||||
// Fusions other than kTritonNestedGemmFusionKind are not supported.
|
||||
// In this case pass should only change the supported fusions.
|
||||
absl::string_view hlo = R"(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user