[XLA:GPU/WS] Adding test coverage for auto warp specialization via Triton.

PiperOrigin-RevId: 820637611
This commit is contained in:
Mohammed Anany 2025-10-17 04:41:33 -07:00 committed by TensorFlower Gardener
parent cc58fb18fd
commit 097f587e4e
3 changed files with 179 additions and 7 deletions

View File

@ -499,6 +499,7 @@ xla_cc_test(
"//xla/hlo/testlib:hlo_hardware_independent_test_base",
"//xla/hlo/testlib:verified_hlo_module",
"//xla/service/gpu:gpu_device_info_for_tests",
"//xla/service/gpu/model:block_level_parameters",
"//xla/service/gpu/model/experimental:symbolic_expr",
"//xla/stream_executor:device_description",
"//xla/stream_executor/cuda:cuda_compute_capability",
@ -857,10 +858,10 @@ xla_test(
"//xla:xla_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/hlo/testlib:filecheck",
"//xla/hlo/testlib:verified_hlo_module",
"//xla/service:algorithm_util",
"//xla/service/gpu:backend_configs_cc",
"//xla/service/gpu:gpu_device_info_for_tests",
"//xla/service/gpu/model:block_level_parameters",
"//xla/service/gpu/model/experimental:symbolic_expr",
"//xla/service/gpu/tests:gpu_codegen_test",
"//xla/stream_executor:device_description",

View File

@ -15,13 +15,10 @@ limitations under the License.
#include <array>
#include <cstdint>
#include <memory>
#include <ostream>
#include <string>
#include <tuple>
#include <utility>
#include <variant>
#include <vector>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
@ -47,13 +44,13 @@ limitations under the License.
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/hlo/testlib/filecheck.h"
#include "xla/hlo/testlib/verified_hlo_module.h"
#include "xla/literal.h"
#include "xla/literal_util.h"
#include "xla/primitive_util.h"
#include "xla/service/algorithm_util.h"
#include "xla/service/gpu/backend_configs.pb.h"
#include "xla/service/gpu/gpu_device_info_for_tests.h"
#include "xla/service/gpu/model/block_level_parameters.h"
#include "xla/service/gpu/model/experimental/symbolic_expr.h"
#include "xla/service/gpu/tests/gpu_codegen_test.h"
#include "xla/shape.h"
@ -116,6 +113,17 @@ INSTANTIATE_TEST_SUITE_P(TmaParameterizedTritonEmitterTestSuite,
return info.param ? "tma_allowed" : "tma_disabled";
});
class WarpSpecializationTritonEmitterTest : public TritonEmitterTest {
public:
DebugOptions GetDebugOptionsForTest() const override {
DebugOptions debug_options = TritonEmitterTest::GetDebugOptionsForTest();
debug_options.set_xla_gpu_experimental_enable_triton_tma(true);
debug_options.set_xla_gpu_experimental_enable_triton_warp_specialization(
true);
return debug_options;
}
};
struct TmaAndDotLayoutTestParams {
std::vector<int64_t> lhs_layout;
std::vector<int64_t> rhs_layout;
@ -3244,6 +3252,72 @@ CHECK-COUNT-1: triton_xla.insert
hlo_text, ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-6}));
}
TEST_F(WarpSpecializationTritonEmitterTest,
DotAccumulationLoopUsesWarpSpecialization) {
if (!GetCudaComputeCapability().IsAtLeastBlackwell()) {
GTEST_SKIP() << "Currently only supported on Blackwell and newer.";
}
const std::string hlo_text = R"(
flhs {
ROOT flhs.p0 = f16[256,256] parameter(0)
}
frhs {
ROOT frhs.p0 = f16[256,256] parameter(0)
}
fdot {
fdot.p0 = f16[256,256] parameter(0)
fdot.p1 = f16[256,256] parameter(1)
fdot.lhs = f16[256,256] fusion(fdot.p0), kind=kCustom, calls=flhs, backend_config={
"fusion_backend_config":{
"kind":"__triton_nested_gemm_fusion", "block_level_fusion_config":{
"output_tiles":[{"sizes":["128", "64"]}],
"is_tma_allowed":"1"
}
}
}
fdot.rhs = f16[256,256]{1,0} fusion(fdot.p1), kind=kCustom, calls=frhs, backend_config={
"fusion_backend_config":{
"kind":"__triton_nested_gemm_fusion", "block_level_fusion_config":{
"output_tiles":[{"sizes":["64", "128"]}],
"is_tma_allowed":"1"
}
}
}
ROOT fdot.root = f16[256,256]{1,0} dot(fdot.lhs, fdot.rhs),
lhs_contracting_dims={1}, rhs_contracting_dims={0},
algorithm=dot_f16_f16_f32
}
ENTRY entry {
entry.p0 = f16[256,256] parameter(0)
entry.p1 = f16[256,256] parameter(1)
ROOT fusion = f16[256,256] fusion(entry.p0, entry.p1),
kind=kCustom, calls=fdot, backend_config={
"fusion_backend_config":{
"kind":"__triton_nested_gemm_fusion",
"block_level_fusion_config":{
"output_tiles":[{"sizes":["128", "128"]}],
"num_warps":"8",
"num_ctas":"1",
"num_stages":"1",
"is_tma_allowed":"1"}}}
})";
// Check that the IR attribute is set correctly.
TF_EXPECT_OK(CreateTritonIrAndFileCheck(this, hlo_text, "fdot", R"(
// CHECK: scf.for
// CHECK: scf.yield
// CHECK-NEXT: tt.warp_specialize
// )"));
// Make sure it runs correctly.
EXPECT_TRUE(RunAndCompareNoHloPasses(
hlo_text, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3}));
}
TEST_F(TritonEmitterTest, MaskedDotIsEmittedCorrectly) {
const std::string kHloText = R"(
flhs {

View File

@ -30,6 +30,7 @@ limitations under the License.
#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h"
#include "xla/hlo/testlib/verified_hlo_module.h"
#include "xla/service/gpu/gpu_device_info_for_tests.h"
#include "xla/service/gpu/model/block_level_parameters.h"
#include "xla/service/gpu/model/experimental/symbolic_expr.h"
#include "xla/stream_executor/cuda/cuda_compute_capability.h"
#include "xla/stream_executor/device_description.h"
@ -43,6 +44,8 @@ namespace {
using ::tsl::testing::IsOkAndHolds;
using ::xla::gpu::ir_emitter_triton_internal::DumpTritonIR;
using TritonEmitterDevicelessTest = HloHardwareIndependentTestBase;
class AnnotationsTest : public HloHardwareIndependentTestBase {
public:
DebugOptions GetDebugOptionsForTest() const override {
@ -53,6 +56,18 @@ class AnnotationsTest : public HloHardwareIndependentTestBase {
}
};
class WarpSpecializationTritonEmitterTest : public TritonEmitterDevicelessTest {
public:
DebugOptions GetDebugOptionsForTest() const override {
DebugOptions debug_options =
TritonEmitterDevicelessTest::GetDebugOptionsForTest();
debug_options.set_xla_gpu_experimental_enable_triton_tma(true);
debug_options.set_xla_gpu_experimental_enable_triton_warp_specialization(
true);
return debug_options;
}
};
TEST_F(AnnotationsTest, Annotations) {
static constexpr absl::string_view kHloText = R"(
HloModule Annotations
@ -111,8 +126,6 @@ ENTRY e {
}
}
using TritonEmitterDevicelessTest = HloHardwareIndependentTestBase;
TEST_F(TritonEmitterDevicelessTest, FailsGracefullyIfNumWarpsIsMissing) {
constexpr absl::string_view kHloText = R"(
triton_computation {
@ -155,5 +168,89 @@ ENTRY entry {
"(num_warps, num_ctas, num_stages) must be positive")));
}
TEST_F(WarpSpecializationTritonEmitterTest,
ExtraWarpsAreRequestedForWarpSpecialization) {
const std::string hlo_text = R"(
flhs {
ROOT flhs.p0 = f16[256,256] parameter(0)
}
frhs {
ROOT frhs.p0 = f16[256,256] parameter(0)
}
fdot {
fdot.p0 = f16[256,256] parameter(0)
fdot.p1 = f16[256,256] parameter(1)
fdot.lhs = f16[256,256] fusion(fdot.p0), kind=kCustom, calls=flhs, backend_config={
"fusion_backend_config":{
"kind":"__triton_nested_gemm_fusion", "block_level_fusion_config":{
"output_tiles":[{"sizes":["128", "64"]}],
"is_tma_allowed":"1"
}
}
}
fdot.rhs = f16[256,256]{1,0} fusion(fdot.p1), kind=kCustom, calls=frhs, backend_config={
"fusion_backend_config":{
"kind":"__triton_nested_gemm_fusion", "block_level_fusion_config":{
"output_tiles":[{"sizes":["64", "128"]}],
"is_tma_allowed":"1"
}
}
}
ROOT fdot.root = f16[256,256]{1,0} dot(fdot.lhs, fdot.rhs),
lhs_contracting_dims={1}, rhs_contracting_dims={0},
algorithm=dot_f16_f16_f32
}
ENTRY entry {
entry.p0 = f16[256,256] parameter(0)
entry.p1 = f16[256,256] parameter(1)
ROOT fusion = f16[256,256] fusion(entry.p0, entry.p1),
kind=kCustom, calls=fdot, backend_config={
"fusion_backend_config":{
"kind":"__triton_nested_gemm_fusion",
"block_level_fusion_config":{
"output_tiles":[{"sizes":["128", "128"]}],
"num_warps":"8",
"num_ctas":"1",
"num_stages":"1",
"is_tma_allowed":"1"}}}
})";
// Check that we extract the launch configuration correctly when warp
// specialization is used.
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_text));
auto* fusion = Cast<HloFusionInstruction>(
module->entry_computation()->root_instruction());
const se::DeviceDescription dev_info =
TestGpuDeviceInfo::RTXB200SXMDeviceInfo();
llvm::LLVMContext llvm_ctx;
llvm::Module llvm_module("module", llvm_ctx);
mlir::MLIRContext mlir_context;
SymbolicExprContext symbolic_expr_context(&mlir_context);
TF_ASSERT_OK_AND_ASSIGN(
TritonWrapperResult result,
TritonWrapper("test_fn", fusion, se::CudaComputeCapability::Blackwell(),
dev_info,
BlockLevelParameters::FromBlockLevelFusionConfig(
fusion->backend_config<GpuBackendConfig>()
->fusion_backend_config()
.block_level_fusion_config()),
&llvm_module, symbolic_expr_context));
// Warp specialization influences the total number of threads we end up
// using. Usually we would expect num_warps * warp_size threads per block, but
// Triton allocates extra "worker warps" when WS is used.
//
// NOTE: The value used here is based on inspecting the value in the IR.
// Hopefully this is stable across different Triton versions. If it starts
// failing, we could modify the value here to match and try to understand why
// it changed.
EXPECT_EQ(result.thread_dims.x, 384);
EXPECT_EQ(result.thread_dims.y, 1);
EXPECT_EQ(result.thread_dims.z, 1);
}
} // namespace
} // namespace xla::gpu