PR #32719: 【XLA:GPU] Command buffer DynamicSliceFusionCmd supports cuda graph loop unrolling

Imported from GitHub PR https://github.com/openxla/xla/pull/32719

📝 Summary of Changes
This PR enables command buffer DynamicSliceFusion command to be recorded into an unrolled cuda-graph, when it is surrounded by WhileCmd

🎯 Justification
This feature is required if we want to fully command buffer WhileCmd into an unrolled cuda-graph.

🚀 Kind of Contribution
Please remove what does not apply:
 New Feature

🧪 Unit Tests:
xla/backends/gpu/codegen/dynamic_slice_fusion_test.cc
Copybara import of the project:

--
daa975804cbffcc3a6bc5b37e3494b51a2dbe2ca by Shawn Wang <shawnw@nvidia.com>:

DynamicSliceFsuionCmd supports unrolling

Merging this change closes #32719

PiperOrigin-RevId: 822071751
This commit is contained in:
Shaogang Wang 2025-10-21 05:31:21 -07:00 committed by TensorFlower Gardener
parent 2d4dd83773
commit 8c169d147d
4 changed files with 74 additions and 5 deletions

View File

@ -218,7 +218,10 @@ xla_test(
"//xla/service/gpu:ir_emission_utils",
"//xla/service/gpu/transforms:dynamic_slice_fusion_rewriter",
"//xla/stream_executor:device_memory",
"//xla/stream_executor:platform_manager",
"//xla/stream_executor:stream",
"//xla/stream_executor:stream_executor_h",
"//xla/stream_executor/cuda:cuda_compute_capability",
"//xla/tests:hlo_test_base",
"//xla/tsl/platform:errors",
"//xla/tsl/platform:statusor",

View File

@ -47,8 +47,11 @@ limitations under the License.
#include "xla/service/platform_util.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/stream_executor/cuda/cuda_compute_capability.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/stream_executor/platform_manager.h"
#include "xla/stream_executor/stream.h"
#include "xla/stream_executor/stream_executor.h"
#include "xla/tests/hlo_test_base.h"
#include "xla/tsl/platform/errors.h"
#include "xla/tsl/platform/statusor.h"
@ -73,6 +76,28 @@ MATCHER_P(ThunkKindIs, kind, "") {
return ExplainMatchResult(::testing::Eq(kind), arg->kind(), result_listener);
}
se::StreamExecutor* GpuExecutor() {
auto name =
absl::AsciiStrToUpper(PlatformUtil::CanonicalPlatformName("gpu").value());
auto* platform = se::PlatformManager::PlatformWithName(name).value();
return platform->ExecutorForDevice(0).value();
}
bool IsAtLeastCuda12900(const se::StreamExecutor* stream_executor) {
const auto& device_description = stream_executor->GetDeviceDescription();
const auto* cuda_cc = std::get_if<se::CudaComputeCapability>(
&device_description.gpu_compute_capability());
if (cuda_cc != nullptr) {
if (device_description.driver_version() >=
stream_executor::SemanticVersion(12, 9, 0) &&
device_description.runtime_version() >=
stream_executor::SemanticVersion(12, 9, 0)) {
return true;
}
}
return false;
}
class DynamicSliceFusionTest : public HloTestBase {
public:
HloModuleConfig GetModuleConfigWithoutCommandBuffer() {
@ -105,6 +130,30 @@ class DynamicSliceFusionTest : public HloTestBase {
return config;
}
HloModuleConfig GetModuleConfigWithCommandBufferUnrollLoops() {
DebugOptions debug_options = GetDebugOptionsForTest();
debug_options.set_xla_gpu_enable_cublaslt(false);
debug_options.set_xla_gpu_gemm_rewrite_size_threshold(0);
debug_options.set_xla_gpu_enable_dynamic_slice_fusion(true);
debug_options.set_xla_gpu_triton_gemm_any(false);
debug_options.set_xla_gpu_enable_cublaslt(false);
debug_options.set_xla_gpu_cublas_fallback(true);
debug_options.set_xla_gpu_graph_min_graph_size(1);
debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::FUSION);
debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::CUBLAS);
debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::CUBLASLT);
debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::CUSTOM_CALL);
debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::CUDNN);
debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::COLLECTIVES);
debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::WHILE);
debug_options.add_xla_gpu_enable_command_buffer(
DebugOptions::DYNAMIC_SLICE_FUSION);
debug_options.set_xla_gpu_command_buffer_unroll_loops(true);
HloModuleConfig config;
config.set_debug_options(debug_options);
return config;
}
HloModuleConfig GetModuleConfigWithDeterministicOps() {
DebugOptions debug_options = GetDebugOptionsForTest();
debug_options.set_xla_gpu_exclude_nondeterministic_ops(true);
@ -3454,10 +3503,6 @@ TEST_F(DynamicSliceFusionTest,
ROOT %while = while(%initial_tuple), condition=%Cond, body=%Body, backend_config={"known_trip_count":{"n":"11"}}
})";
// Run the same HLO with and without command buffer and compare results.
HloModuleConfig with_cmd_buffer = GetModuleConfigWithCommandBuffer();
HloModuleConfig without_cmd_buffer = GetModuleConfigWithoutCommandBuffer();
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> cmd_buffer_module,
ParseAndReturnVerifiedModule(hlo));
@ -3469,6 +3514,15 @@ TEST_F(DynamicSliceFusionTest,
RunAndCompareTwoModules(hlo, hlo, GetModuleConfigWithCommandBuffer(),
GetModuleConfigWithoutCommandBuffer(), error_spec,
/*run_hlo_passes=*/true));
se::StreamExecutor* stream_executor = GpuExecutor();
if (!IsAtLeastCuda12900(stream_executor)) {
GTEST_SKIP() << "While loop unrolling is not supported for CUDA < 12.9";
}
EXPECT_TRUE(RunAndCompareTwoModules(
hlo, hlo, GetModuleConfigWithCommandBufferUnrollLoops(),
GetModuleConfigWithoutCommandBuffer(), error_spec,
/*run_hlo_passes=*/true));
}
TEST_F(DynamicSliceFusionTest, MultipleOffsetsAsFunctionOfInductionVariable) {

View File

@ -1493,6 +1493,14 @@ absl::Status WhileCmd::Initialize(const Thunk::InitializeParams& params,
return absl::OkStatus();
}
absl::Status WhileCmd::Prepare(
const Thunk::PrepareParams& params,
Thunk::ResourceRequestsInterface& resource_requests) {
TF_RETURN_IF_ERROR(cond_commands_.Prepare(params, resource_requests));
TF_RETURN_IF_ERROR(body_commands_.Prepare(params, resource_requests));
return absl::OkStatus();
}
absl::StatusOr<const se::CommandBuffer::Command*> WhileCmd::Record(
const Thunk::ExecuteParams& execute_params,
const RecordParams& record_params, RecordAction record_action,

View File

@ -869,6 +869,10 @@ class WhileCmd : public CommandBufferCmd {
absl::Status Initialize(const Thunk::InitializeParams& params,
StateManager& state) override;
absl::Status Prepare(
const Thunk::PrepareParams& params,
Thunk::ResourceRequestsInterface& resource_requests) override;
absl::StatusOr<const se::CommandBuffer::Command*> Record(
const Thunk::ExecuteParams& execute_params,
const RecordParams& record_params, RecordAction record_action,
@ -1240,7 +1244,7 @@ class DynamicSliceFusionCmd : public CommandBufferCmd {
bool requires_initialization() override;
bool support_loop_unroll() override { return false; }
bool support_loop_unroll() override { return true; }
bool IsNestedCommandBuffer() const final { return true; }