mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
[XLA:GPU/WS] Adding test coverage for auto warp specialization via Triton.
PiperOrigin-RevId: 820637611
This commit is contained in:
parent
cc58fb18fd
commit
097f587e4e
|
|
@ -499,6 +499,7 @@ xla_cc_test(
|
||||||
"//xla/hlo/testlib:hlo_hardware_independent_test_base",
|
"//xla/hlo/testlib:hlo_hardware_independent_test_base",
|
||||||
"//xla/hlo/testlib:verified_hlo_module",
|
"//xla/hlo/testlib:verified_hlo_module",
|
||||||
"//xla/service/gpu:gpu_device_info_for_tests",
|
"//xla/service/gpu:gpu_device_info_for_tests",
|
||||||
|
"//xla/service/gpu/model:block_level_parameters",
|
||||||
"//xla/service/gpu/model/experimental:symbolic_expr",
|
"//xla/service/gpu/model/experimental:symbolic_expr",
|
||||||
"//xla/stream_executor:device_description",
|
"//xla/stream_executor:device_description",
|
||||||
"//xla/stream_executor/cuda:cuda_compute_capability",
|
"//xla/stream_executor/cuda:cuda_compute_capability",
|
||||||
|
|
@ -857,10 +858,10 @@ xla_test(
|
||||||
"//xla:xla_proto_cc",
|
"//xla:xla_proto_cc",
|
||||||
"//xla/hlo/ir:hlo",
|
"//xla/hlo/ir:hlo",
|
||||||
"//xla/hlo/testlib:filecheck",
|
"//xla/hlo/testlib:filecheck",
|
||||||
"//xla/hlo/testlib:verified_hlo_module",
|
|
||||||
"//xla/service:algorithm_util",
|
"//xla/service:algorithm_util",
|
||||||
"//xla/service/gpu:backend_configs_cc",
|
"//xla/service/gpu:backend_configs_cc",
|
||||||
"//xla/service/gpu:gpu_device_info_for_tests",
|
"//xla/service/gpu:gpu_device_info_for_tests",
|
||||||
|
"//xla/service/gpu/model:block_level_parameters",
|
||||||
"//xla/service/gpu/model/experimental:symbolic_expr",
|
"//xla/service/gpu/model/experimental:symbolic_expr",
|
||||||
"//xla/service/gpu/tests:gpu_codegen_test",
|
"//xla/service/gpu/tests:gpu_codegen_test",
|
||||||
"//xla/stream_executor:device_description",
|
"//xla/stream_executor:device_description",
|
||||||
|
|
|
||||||
|
|
@ -15,13 +15,10 @@ limitations under the License.
|
||||||
|
|
||||||
#include <array>
|
#include <array>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <memory>
|
|
||||||
#include <ostream>
|
#include <ostream>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <tuple>
|
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <variant>
|
#include <variant>
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include <gmock/gmock.h>
|
#include <gmock/gmock.h>
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
|
|
@ -47,13 +44,13 @@ limitations under the License.
|
||||||
#include "xla/hlo/ir/hlo_instructions.h"
|
#include "xla/hlo/ir/hlo_instructions.h"
|
||||||
#include "xla/hlo/ir/hlo_opcode.h"
|
#include "xla/hlo/ir/hlo_opcode.h"
|
||||||
#include "xla/hlo/testlib/filecheck.h"
|
#include "xla/hlo/testlib/filecheck.h"
|
||||||
#include "xla/hlo/testlib/verified_hlo_module.h"
|
|
||||||
#include "xla/literal.h"
|
#include "xla/literal.h"
|
||||||
#include "xla/literal_util.h"
|
#include "xla/literal_util.h"
|
||||||
#include "xla/primitive_util.h"
|
#include "xla/primitive_util.h"
|
||||||
#include "xla/service/algorithm_util.h"
|
#include "xla/service/algorithm_util.h"
|
||||||
#include "xla/service/gpu/backend_configs.pb.h"
|
#include "xla/service/gpu/backend_configs.pb.h"
|
||||||
#include "xla/service/gpu/gpu_device_info_for_tests.h"
|
#include "xla/service/gpu/gpu_device_info_for_tests.h"
|
||||||
|
#include "xla/service/gpu/model/block_level_parameters.h"
|
||||||
#include "xla/service/gpu/model/experimental/symbolic_expr.h"
|
#include "xla/service/gpu/model/experimental/symbolic_expr.h"
|
||||||
#include "xla/service/gpu/tests/gpu_codegen_test.h"
|
#include "xla/service/gpu/tests/gpu_codegen_test.h"
|
||||||
#include "xla/shape.h"
|
#include "xla/shape.h"
|
||||||
|
|
@ -116,6 +113,17 @@ INSTANTIATE_TEST_SUITE_P(TmaParameterizedTritonEmitterTestSuite,
|
||||||
return info.param ? "tma_allowed" : "tma_disabled";
|
return info.param ? "tma_allowed" : "tma_disabled";
|
||||||
});
|
});
|
||||||
|
|
||||||
|
class WarpSpecializationTritonEmitterTest : public TritonEmitterTest {
|
||||||
|
public:
|
||||||
|
DebugOptions GetDebugOptionsForTest() const override {
|
||||||
|
DebugOptions debug_options = TritonEmitterTest::GetDebugOptionsForTest();
|
||||||
|
debug_options.set_xla_gpu_experimental_enable_triton_tma(true);
|
||||||
|
debug_options.set_xla_gpu_experimental_enable_triton_warp_specialization(
|
||||||
|
true);
|
||||||
|
return debug_options;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
struct TmaAndDotLayoutTestParams {
|
struct TmaAndDotLayoutTestParams {
|
||||||
std::vector<int64_t> lhs_layout;
|
std::vector<int64_t> lhs_layout;
|
||||||
std::vector<int64_t> rhs_layout;
|
std::vector<int64_t> rhs_layout;
|
||||||
|
|
@ -3244,6 +3252,72 @@ CHECK-COUNT-1: triton_xla.insert
|
||||||
hlo_text, ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-6}));
|
hlo_text, ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-6}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(WarpSpecializationTritonEmitterTest,
|
||||||
|
DotAccumulationLoopUsesWarpSpecialization) {
|
||||||
|
if (!GetCudaComputeCapability().IsAtLeastBlackwell()) {
|
||||||
|
GTEST_SKIP() << "Currently only supported on Blackwell and newer.";
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::string hlo_text = R"(
|
||||||
|
flhs {
|
||||||
|
ROOT flhs.p0 = f16[256,256] parameter(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
frhs {
|
||||||
|
ROOT frhs.p0 = f16[256,256] parameter(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
fdot {
|
||||||
|
fdot.p0 = f16[256,256] parameter(0)
|
||||||
|
fdot.p1 = f16[256,256] parameter(1)
|
||||||
|
fdot.lhs = f16[256,256] fusion(fdot.p0), kind=kCustom, calls=flhs, backend_config={
|
||||||
|
"fusion_backend_config":{
|
||||||
|
"kind":"__triton_nested_gemm_fusion", "block_level_fusion_config":{
|
||||||
|
"output_tiles":[{"sizes":["128", "64"]}],
|
||||||
|
"is_tma_allowed":"1"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fdot.rhs = f16[256,256]{1,0} fusion(fdot.p1), kind=kCustom, calls=frhs, backend_config={
|
||||||
|
"fusion_backend_config":{
|
||||||
|
"kind":"__triton_nested_gemm_fusion", "block_level_fusion_config":{
|
||||||
|
"output_tiles":[{"sizes":["64", "128"]}],
|
||||||
|
"is_tma_allowed":"1"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ROOT fdot.root = f16[256,256]{1,0} dot(fdot.lhs, fdot.rhs),
|
||||||
|
lhs_contracting_dims={1}, rhs_contracting_dims={0},
|
||||||
|
algorithm=dot_f16_f16_f32
|
||||||
|
}
|
||||||
|
|
||||||
|
ENTRY entry {
|
||||||
|
entry.p0 = f16[256,256] parameter(0)
|
||||||
|
entry.p1 = f16[256,256] parameter(1)
|
||||||
|
ROOT fusion = f16[256,256] fusion(entry.p0, entry.p1),
|
||||||
|
kind=kCustom, calls=fdot, backend_config={
|
||||||
|
"fusion_backend_config":{
|
||||||
|
"kind":"__triton_nested_gemm_fusion",
|
||||||
|
"block_level_fusion_config":{
|
||||||
|
"output_tiles":[{"sizes":["128", "128"]}],
|
||||||
|
"num_warps":"8",
|
||||||
|
"num_ctas":"1",
|
||||||
|
"num_stages":"1",
|
||||||
|
"is_tma_allowed":"1"}}}
|
||||||
|
})";
|
||||||
|
|
||||||
|
// Check that the IR attribute is set correctly.
|
||||||
|
TF_EXPECT_OK(CreateTritonIrAndFileCheck(this, hlo_text, "fdot", R"(
|
||||||
|
// CHECK: scf.for
|
||||||
|
// CHECK: scf.yield
|
||||||
|
// CHECK-NEXT: tt.warp_specialize
|
||||||
|
// )"));
|
||||||
|
|
||||||
|
// Make sure it runs correctly.
|
||||||
|
EXPECT_TRUE(RunAndCompareNoHloPasses(
|
||||||
|
hlo_text, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3}));
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(TritonEmitterTest, MaskedDotIsEmittedCorrectly) {
|
TEST_F(TritonEmitterTest, MaskedDotIsEmittedCorrectly) {
|
||||||
const std::string kHloText = R"(
|
const std::string kHloText = R"(
|
||||||
flhs {
|
flhs {
|
||||||
|
|
|
||||||
|
|
@ -30,6 +30,7 @@ limitations under the License.
|
||||||
#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h"
|
#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h"
|
||||||
#include "xla/hlo/testlib/verified_hlo_module.h"
|
#include "xla/hlo/testlib/verified_hlo_module.h"
|
||||||
#include "xla/service/gpu/gpu_device_info_for_tests.h"
|
#include "xla/service/gpu/gpu_device_info_for_tests.h"
|
||||||
|
#include "xla/service/gpu/model/block_level_parameters.h"
|
||||||
#include "xla/service/gpu/model/experimental/symbolic_expr.h"
|
#include "xla/service/gpu/model/experimental/symbolic_expr.h"
|
||||||
#include "xla/stream_executor/cuda/cuda_compute_capability.h"
|
#include "xla/stream_executor/cuda/cuda_compute_capability.h"
|
||||||
#include "xla/stream_executor/device_description.h"
|
#include "xla/stream_executor/device_description.h"
|
||||||
|
|
@ -43,6 +44,8 @@ namespace {
|
||||||
using ::tsl::testing::IsOkAndHolds;
|
using ::tsl::testing::IsOkAndHolds;
|
||||||
using ::xla::gpu::ir_emitter_triton_internal::DumpTritonIR;
|
using ::xla::gpu::ir_emitter_triton_internal::DumpTritonIR;
|
||||||
|
|
||||||
|
using TritonEmitterDevicelessTest = HloHardwareIndependentTestBase;
|
||||||
|
|
||||||
class AnnotationsTest : public HloHardwareIndependentTestBase {
|
class AnnotationsTest : public HloHardwareIndependentTestBase {
|
||||||
public:
|
public:
|
||||||
DebugOptions GetDebugOptionsForTest() const override {
|
DebugOptions GetDebugOptionsForTest() const override {
|
||||||
|
|
@ -53,6 +56,18 @@ class AnnotationsTest : public HloHardwareIndependentTestBase {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class WarpSpecializationTritonEmitterTest : public TritonEmitterDevicelessTest {
|
||||||
|
public:
|
||||||
|
DebugOptions GetDebugOptionsForTest() const override {
|
||||||
|
DebugOptions debug_options =
|
||||||
|
TritonEmitterDevicelessTest::GetDebugOptionsForTest();
|
||||||
|
debug_options.set_xla_gpu_experimental_enable_triton_tma(true);
|
||||||
|
debug_options.set_xla_gpu_experimental_enable_triton_warp_specialization(
|
||||||
|
true);
|
||||||
|
return debug_options;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
TEST_F(AnnotationsTest, Annotations) {
|
TEST_F(AnnotationsTest, Annotations) {
|
||||||
static constexpr absl::string_view kHloText = R"(
|
static constexpr absl::string_view kHloText = R"(
|
||||||
HloModule Annotations
|
HloModule Annotations
|
||||||
|
|
@ -111,8 +126,6 @@ ENTRY e {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
using TritonEmitterDevicelessTest = HloHardwareIndependentTestBase;
|
|
||||||
|
|
||||||
TEST_F(TritonEmitterDevicelessTest, FailsGracefullyIfNumWarpsIsMissing) {
|
TEST_F(TritonEmitterDevicelessTest, FailsGracefullyIfNumWarpsIsMissing) {
|
||||||
constexpr absl::string_view kHloText = R"(
|
constexpr absl::string_view kHloText = R"(
|
||||||
triton_computation {
|
triton_computation {
|
||||||
|
|
@ -155,5 +168,89 @@ ENTRY entry {
|
||||||
"(num_warps, num_ctas, num_stages) must be positive")));
|
"(num_warps, num_ctas, num_stages) must be positive")));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(WarpSpecializationTritonEmitterTest,
|
||||||
|
ExtraWarpsAreRequestedForWarpSpecialization) {
|
||||||
|
const std::string hlo_text = R"(
|
||||||
|
flhs {
|
||||||
|
ROOT flhs.p0 = f16[256,256] parameter(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
frhs {
|
||||||
|
ROOT frhs.p0 = f16[256,256] parameter(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
fdot {
|
||||||
|
fdot.p0 = f16[256,256] parameter(0)
|
||||||
|
fdot.p1 = f16[256,256] parameter(1)
|
||||||
|
fdot.lhs = f16[256,256] fusion(fdot.p0), kind=kCustom, calls=flhs, backend_config={
|
||||||
|
"fusion_backend_config":{
|
||||||
|
"kind":"__triton_nested_gemm_fusion", "block_level_fusion_config":{
|
||||||
|
"output_tiles":[{"sizes":["128", "64"]}],
|
||||||
|
"is_tma_allowed":"1"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fdot.rhs = f16[256,256]{1,0} fusion(fdot.p1), kind=kCustom, calls=frhs, backend_config={
|
||||||
|
"fusion_backend_config":{
|
||||||
|
"kind":"__triton_nested_gemm_fusion", "block_level_fusion_config":{
|
||||||
|
"output_tiles":[{"sizes":["64", "128"]}],
|
||||||
|
"is_tma_allowed":"1"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ROOT fdot.root = f16[256,256]{1,0} dot(fdot.lhs, fdot.rhs),
|
||||||
|
lhs_contracting_dims={1}, rhs_contracting_dims={0},
|
||||||
|
algorithm=dot_f16_f16_f32
|
||||||
|
}
|
||||||
|
|
||||||
|
ENTRY entry {
|
||||||
|
entry.p0 = f16[256,256] parameter(0)
|
||||||
|
entry.p1 = f16[256,256] parameter(1)
|
||||||
|
ROOT fusion = f16[256,256] fusion(entry.p0, entry.p1),
|
||||||
|
kind=kCustom, calls=fdot, backend_config={
|
||||||
|
"fusion_backend_config":{
|
||||||
|
"kind":"__triton_nested_gemm_fusion",
|
||||||
|
"block_level_fusion_config":{
|
||||||
|
"output_tiles":[{"sizes":["128", "128"]}],
|
||||||
|
"num_warps":"8",
|
||||||
|
"num_ctas":"1",
|
||||||
|
"num_stages":"1",
|
||||||
|
"is_tma_allowed":"1"}}}
|
||||||
|
})";
|
||||||
|
|
||||||
|
// Check that we extract the launch configuration correctly when warp
|
||||||
|
// specialization is used.
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_text));
|
||||||
|
auto* fusion = Cast<HloFusionInstruction>(
|
||||||
|
module->entry_computation()->root_instruction());
|
||||||
|
const se::DeviceDescription dev_info =
|
||||||
|
TestGpuDeviceInfo::RTXB200SXMDeviceInfo();
|
||||||
|
llvm::LLVMContext llvm_ctx;
|
||||||
|
llvm::Module llvm_module("module", llvm_ctx);
|
||||||
|
mlir::MLIRContext mlir_context;
|
||||||
|
SymbolicExprContext symbolic_expr_context(&mlir_context);
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
|
TritonWrapperResult result,
|
||||||
|
TritonWrapper("test_fn", fusion, se::CudaComputeCapability::Blackwell(),
|
||||||
|
dev_info,
|
||||||
|
BlockLevelParameters::FromBlockLevelFusionConfig(
|
||||||
|
fusion->backend_config<GpuBackendConfig>()
|
||||||
|
->fusion_backend_config()
|
||||||
|
.block_level_fusion_config()),
|
||||||
|
&llvm_module, symbolic_expr_context));
|
||||||
|
|
||||||
|
// Warp specialization influences the total number of threads we end up
|
||||||
|
// using. Usually we would expect num_warps * warp_size threads per block, but
|
||||||
|
// Triton allocates extra "worker warps" when WS is used.
|
||||||
|
//
|
||||||
|
// NOTE: The value used here is based on inspecting the value in the IR.
|
||||||
|
// Hopefully this is stable across different Triton versions. If it starts
|
||||||
|
// failing, we could modify the value here to match and try to understand why
|
||||||
|
// it changed.
|
||||||
|
EXPECT_EQ(result.thread_dims.x, 384);
|
||||||
|
EXPECT_EQ(result.thread_dims.y, 1);
|
||||||
|
EXPECT_EQ(result.thread_dims.z, 1);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace xla::gpu
|
} // namespace xla::gpu
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user