mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
[XLA:GPU/WS] Adding test coverage for auto warp specialization via Triton.
PiperOrigin-RevId: 820637611
This commit is contained in:
parent
cc58fb18fd
commit
097f587e4e
|
|
@ -499,6 +499,7 @@ xla_cc_test(
|
|||
"//xla/hlo/testlib:hlo_hardware_independent_test_base",
|
||||
"//xla/hlo/testlib:verified_hlo_module",
|
||||
"//xla/service/gpu:gpu_device_info_for_tests",
|
||||
"//xla/service/gpu/model:block_level_parameters",
|
||||
"//xla/service/gpu/model/experimental:symbolic_expr",
|
||||
"//xla/stream_executor:device_description",
|
||||
"//xla/stream_executor/cuda:cuda_compute_capability",
|
||||
|
|
@ -857,10 +858,10 @@ xla_test(
|
|||
"//xla:xla_proto_cc",
|
||||
"//xla/hlo/ir:hlo",
|
||||
"//xla/hlo/testlib:filecheck",
|
||||
"//xla/hlo/testlib:verified_hlo_module",
|
||||
"//xla/service:algorithm_util",
|
||||
"//xla/service/gpu:backend_configs_cc",
|
||||
"//xla/service/gpu:gpu_device_info_for_tests",
|
||||
"//xla/service/gpu/model:block_level_parameters",
|
||||
"//xla/service/gpu/model/experimental:symbolic_expr",
|
||||
"//xla/service/gpu/tests:gpu_codegen_test",
|
||||
"//xla/stream_executor:device_description",
|
||||
|
|
|
|||
|
|
@ -15,13 +15,10 @@ limitations under the License.
|
|||
|
||||
#include <array>
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
#include <variant>
|
||||
#include <vector>
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
|
@ -47,13 +44,13 @@ limitations under the License.
|
|||
#include "xla/hlo/ir/hlo_instructions.h"
|
||||
#include "xla/hlo/ir/hlo_opcode.h"
|
||||
#include "xla/hlo/testlib/filecheck.h"
|
||||
#include "xla/hlo/testlib/verified_hlo_module.h"
|
||||
#include "xla/literal.h"
|
||||
#include "xla/literal_util.h"
|
||||
#include "xla/primitive_util.h"
|
||||
#include "xla/service/algorithm_util.h"
|
||||
#include "xla/service/gpu/backend_configs.pb.h"
|
||||
#include "xla/service/gpu/gpu_device_info_for_tests.h"
|
||||
#include "xla/service/gpu/model/block_level_parameters.h"
|
||||
#include "xla/service/gpu/model/experimental/symbolic_expr.h"
|
||||
#include "xla/service/gpu/tests/gpu_codegen_test.h"
|
||||
#include "xla/shape.h"
|
||||
|
|
@ -116,6 +113,17 @@ INSTANTIATE_TEST_SUITE_P(TmaParameterizedTritonEmitterTestSuite,
|
|||
return info.param ? "tma_allowed" : "tma_disabled";
|
||||
});
|
||||
|
||||
class WarpSpecializationTritonEmitterTest : public TritonEmitterTest {
|
||||
public:
|
||||
DebugOptions GetDebugOptionsForTest() const override {
|
||||
DebugOptions debug_options = TritonEmitterTest::GetDebugOptionsForTest();
|
||||
debug_options.set_xla_gpu_experimental_enable_triton_tma(true);
|
||||
debug_options.set_xla_gpu_experimental_enable_triton_warp_specialization(
|
||||
true);
|
||||
return debug_options;
|
||||
}
|
||||
};
|
||||
|
||||
struct TmaAndDotLayoutTestParams {
|
||||
std::vector<int64_t> lhs_layout;
|
||||
std::vector<int64_t> rhs_layout;
|
||||
|
|
@ -3244,6 +3252,72 @@ CHECK-COUNT-1: triton_xla.insert
|
|||
hlo_text, ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-6}));
|
||||
}
|
||||
|
||||
TEST_F(WarpSpecializationTritonEmitterTest,
|
||||
DotAccumulationLoopUsesWarpSpecialization) {
|
||||
if (!GetCudaComputeCapability().IsAtLeastBlackwell()) {
|
||||
GTEST_SKIP() << "Currently only supported on Blackwell and newer.";
|
||||
}
|
||||
|
||||
const std::string hlo_text = R"(
|
||||
flhs {
|
||||
ROOT flhs.p0 = f16[256,256] parameter(0)
|
||||
}
|
||||
|
||||
frhs {
|
||||
ROOT frhs.p0 = f16[256,256] parameter(0)
|
||||
}
|
||||
|
||||
fdot {
|
||||
fdot.p0 = f16[256,256] parameter(0)
|
||||
fdot.p1 = f16[256,256] parameter(1)
|
||||
fdot.lhs = f16[256,256] fusion(fdot.p0), kind=kCustom, calls=flhs, backend_config={
|
||||
"fusion_backend_config":{
|
||||
"kind":"__triton_nested_gemm_fusion", "block_level_fusion_config":{
|
||||
"output_tiles":[{"sizes":["128", "64"]}],
|
||||
"is_tma_allowed":"1"
|
||||
}
|
||||
}
|
||||
}
|
||||
fdot.rhs = f16[256,256]{1,0} fusion(fdot.p1), kind=kCustom, calls=frhs, backend_config={
|
||||
"fusion_backend_config":{
|
||||
"kind":"__triton_nested_gemm_fusion", "block_level_fusion_config":{
|
||||
"output_tiles":[{"sizes":["64", "128"]}],
|
||||
"is_tma_allowed":"1"
|
||||
}
|
||||
}
|
||||
}
|
||||
ROOT fdot.root = f16[256,256]{1,0} dot(fdot.lhs, fdot.rhs),
|
||||
lhs_contracting_dims={1}, rhs_contracting_dims={0},
|
||||
algorithm=dot_f16_f16_f32
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
entry.p0 = f16[256,256] parameter(0)
|
||||
entry.p1 = f16[256,256] parameter(1)
|
||||
ROOT fusion = f16[256,256] fusion(entry.p0, entry.p1),
|
||||
kind=kCustom, calls=fdot, backend_config={
|
||||
"fusion_backend_config":{
|
||||
"kind":"__triton_nested_gemm_fusion",
|
||||
"block_level_fusion_config":{
|
||||
"output_tiles":[{"sizes":["128", "128"]}],
|
||||
"num_warps":"8",
|
||||
"num_ctas":"1",
|
||||
"num_stages":"1",
|
||||
"is_tma_allowed":"1"}}}
|
||||
})";
|
||||
|
||||
// Check that the IR attribute is set correctly.
|
||||
TF_EXPECT_OK(CreateTritonIrAndFileCheck(this, hlo_text, "fdot", R"(
|
||||
// CHECK: scf.for
|
||||
// CHECK: scf.yield
|
||||
// CHECK-NEXT: tt.warp_specialize
|
||||
// )"));
|
||||
|
||||
// Make sure it runs correctly.
|
||||
EXPECT_TRUE(RunAndCompareNoHloPasses(
|
||||
hlo_text, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3}));
|
||||
}
|
||||
|
||||
TEST_F(TritonEmitterTest, MaskedDotIsEmittedCorrectly) {
|
||||
const std::string kHloText = R"(
|
||||
flhs {
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ limitations under the License.
|
|||
#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h"
|
||||
#include "xla/hlo/testlib/verified_hlo_module.h"
|
||||
#include "xla/service/gpu/gpu_device_info_for_tests.h"
|
||||
#include "xla/service/gpu/model/block_level_parameters.h"
|
||||
#include "xla/service/gpu/model/experimental/symbolic_expr.h"
|
||||
#include "xla/stream_executor/cuda/cuda_compute_capability.h"
|
||||
#include "xla/stream_executor/device_description.h"
|
||||
|
|
@ -43,6 +44,8 @@ namespace {
|
|||
using ::tsl::testing::IsOkAndHolds;
|
||||
using ::xla::gpu::ir_emitter_triton_internal::DumpTritonIR;
|
||||
|
||||
using TritonEmitterDevicelessTest = HloHardwareIndependentTestBase;
|
||||
|
||||
class AnnotationsTest : public HloHardwareIndependentTestBase {
|
||||
public:
|
||||
DebugOptions GetDebugOptionsForTest() const override {
|
||||
|
|
@ -53,6 +56,18 @@ class AnnotationsTest : public HloHardwareIndependentTestBase {
|
|||
}
|
||||
};
|
||||
|
||||
class WarpSpecializationTritonEmitterTest : public TritonEmitterDevicelessTest {
|
||||
public:
|
||||
DebugOptions GetDebugOptionsForTest() const override {
|
||||
DebugOptions debug_options =
|
||||
TritonEmitterDevicelessTest::GetDebugOptionsForTest();
|
||||
debug_options.set_xla_gpu_experimental_enable_triton_tma(true);
|
||||
debug_options.set_xla_gpu_experimental_enable_triton_warp_specialization(
|
||||
true);
|
||||
return debug_options;
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(AnnotationsTest, Annotations) {
|
||||
static constexpr absl::string_view kHloText = R"(
|
||||
HloModule Annotations
|
||||
|
|
@ -111,8 +126,6 @@ ENTRY e {
|
|||
}
|
||||
}
|
||||
|
||||
using TritonEmitterDevicelessTest = HloHardwareIndependentTestBase;
|
||||
|
||||
TEST_F(TritonEmitterDevicelessTest, FailsGracefullyIfNumWarpsIsMissing) {
|
||||
constexpr absl::string_view kHloText = R"(
|
||||
triton_computation {
|
||||
|
|
@ -155,5 +168,89 @@ ENTRY entry {
|
|||
"(num_warps, num_ctas, num_stages) must be positive")));
|
||||
}
|
||||
|
||||
TEST_F(WarpSpecializationTritonEmitterTest,
|
||||
ExtraWarpsAreRequestedForWarpSpecialization) {
|
||||
const std::string hlo_text = R"(
|
||||
flhs {
|
||||
ROOT flhs.p0 = f16[256,256] parameter(0)
|
||||
}
|
||||
|
||||
frhs {
|
||||
ROOT frhs.p0 = f16[256,256] parameter(0)
|
||||
}
|
||||
|
||||
fdot {
|
||||
fdot.p0 = f16[256,256] parameter(0)
|
||||
fdot.p1 = f16[256,256] parameter(1)
|
||||
fdot.lhs = f16[256,256] fusion(fdot.p0), kind=kCustom, calls=flhs, backend_config={
|
||||
"fusion_backend_config":{
|
||||
"kind":"__triton_nested_gemm_fusion", "block_level_fusion_config":{
|
||||
"output_tiles":[{"sizes":["128", "64"]}],
|
||||
"is_tma_allowed":"1"
|
||||
}
|
||||
}
|
||||
}
|
||||
fdot.rhs = f16[256,256]{1,0} fusion(fdot.p1), kind=kCustom, calls=frhs, backend_config={
|
||||
"fusion_backend_config":{
|
||||
"kind":"__triton_nested_gemm_fusion", "block_level_fusion_config":{
|
||||
"output_tiles":[{"sizes":["64", "128"]}],
|
||||
"is_tma_allowed":"1"
|
||||
}
|
||||
}
|
||||
}
|
||||
ROOT fdot.root = f16[256,256]{1,0} dot(fdot.lhs, fdot.rhs),
|
||||
lhs_contracting_dims={1}, rhs_contracting_dims={0},
|
||||
algorithm=dot_f16_f16_f32
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
entry.p0 = f16[256,256] parameter(0)
|
||||
entry.p1 = f16[256,256] parameter(1)
|
||||
ROOT fusion = f16[256,256] fusion(entry.p0, entry.p1),
|
||||
kind=kCustom, calls=fdot, backend_config={
|
||||
"fusion_backend_config":{
|
||||
"kind":"__triton_nested_gemm_fusion",
|
||||
"block_level_fusion_config":{
|
||||
"output_tiles":[{"sizes":["128", "128"]}],
|
||||
"num_warps":"8",
|
||||
"num_ctas":"1",
|
||||
"num_stages":"1",
|
||||
"is_tma_allowed":"1"}}}
|
||||
})";
|
||||
|
||||
// Check that we extract the launch configuration correctly when warp
|
||||
// specialization is used.
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_text));
|
||||
auto* fusion = Cast<HloFusionInstruction>(
|
||||
module->entry_computation()->root_instruction());
|
||||
const se::DeviceDescription dev_info =
|
||||
TestGpuDeviceInfo::RTXB200SXMDeviceInfo();
|
||||
llvm::LLVMContext llvm_ctx;
|
||||
llvm::Module llvm_module("module", llvm_ctx);
|
||||
mlir::MLIRContext mlir_context;
|
||||
SymbolicExprContext symbolic_expr_context(&mlir_context);
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
TritonWrapperResult result,
|
||||
TritonWrapper("test_fn", fusion, se::CudaComputeCapability::Blackwell(),
|
||||
dev_info,
|
||||
BlockLevelParameters::FromBlockLevelFusionConfig(
|
||||
fusion->backend_config<GpuBackendConfig>()
|
||||
->fusion_backend_config()
|
||||
.block_level_fusion_config()),
|
||||
&llvm_module, symbolic_expr_context));
|
||||
|
||||
// Warp specialization influences the total number of threads we end up
|
||||
// using. Usually we would expect num_warps * warp_size threads per block, but
|
||||
// Triton allocates extra "worker warps" when WS is used.
|
||||
//
|
||||
// NOTE: The value used here is based on inspecting the value in the IR.
|
||||
// Hopefully this is stable across different Triton versions. If it starts
|
||||
// failing, we could modify the value here to match and try to understand why
|
||||
// it changed.
|
||||
EXPECT_EQ(result.thread_dims.x, 384);
|
||||
EXPECT_EQ(result.thread_dims.y, 1);
|
||||
EXPECT_EQ(result.thread_dims.z, 1);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla::gpu
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user