mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
PR #32688: [XLA:GPU] Enable command buffer DynamicSliceCopyFusion command unrolling
Imported from GitHub PR https://github.com/openxla/xla/pull/32688 📝 Summary of Changes This PR enables command buffer DynamicSliceCopy command to be recorded into an unrolled cuda-graph, when it is surrounded by WhileCmd 🎯 Justification This feature is required if we want to fully command buffer WhileCmd into an unrolled cuda-graph. 🚀 Kind of Contribution Please remove what does not apply: ✨ New Feature 🧪 Unit Tests: xla/backends/gpu/runtime/command_buffer_cmd_test.cc: CommandBufferCmdTest:DynamicSliceCopyFusionCmd Copybara import of the project: -- feb2902fca397360460f6b9788ac0f7482cb547c by Shawn Wang <shawnw@nvidia.com>: Enable command buffer DynamicSliceCopyFusion command unrolling Merging this change closes #32688 PiperOrigin-RevId: 822104580
This commit is contained in:
parent
c28d80ae66
commit
97c777acc4
|
|
@ -2730,8 +2730,12 @@ DynamicSliceCopyFusionCmd::Record(const Thunk::ExecuteParams& execute_params,
|
|||
[&](const se::CommandBuffer::Command* command) {
|
||||
int64_t iteration_index = 0;
|
||||
if (offsets_.depends_on_loop) {
|
||||
TF_ASSIGN_OR_RETURN(iteration_index,
|
||||
WhileThunk::CurrentLoopIteration());
|
||||
if (WhileThunk::RunningWhileThunkLoop()) {
|
||||
TF_ASSIGN_OR_RETURN(iteration_index,
|
||||
WhileThunk::CurrentLoopIteration());
|
||||
} else {
|
||||
iteration_index = record_params.unroll_iteration;
|
||||
}
|
||||
}
|
||||
int64_t src_offset = offsets_.src_offsets[iteration_index];
|
||||
int64_t dst_offset = offsets_.dst_offsets[iteration_index];
|
||||
|
|
|
|||
|
|
@ -1298,7 +1298,7 @@ class DynamicSliceCopyFusionCmd : public CommandBufferCmd {
|
|||
|
||||
bool force_update() override { return offsets_.depends_on_loop; }
|
||||
|
||||
bool support_loop_unroll() override { return false; }
|
||||
bool support_loop_unroll() override { return true; }
|
||||
|
||||
BufferUseVector buffers() const override;
|
||||
|
||||
|
|
|
|||
|
|
@ -61,6 +61,8 @@ static std::list<RunningLoop>& RunningLoops() {
|
|||
return loops;
|
||||
}
|
||||
|
||||
bool WhileThunk::RunningWhileThunkLoop() { return RunningLoops().size() > 0; }
|
||||
|
||||
absl::StatusOr<int64_t> WhileThunk::CurrentLoopIteration(int64_t depth) {
|
||||
if (depth >= RunningLoops().size()) {
|
||||
return absl::InvalidArgumentError(absl::StrFormat(
|
||||
|
|
|
|||
|
|
@ -87,6 +87,7 @@ class WhileThunk : public Thunk {
|
|||
//
|
||||
// Implementation relies on thread local storage, be careful when call it from
|
||||
// code running on multiple threads.
|
||||
static bool RunningWhileThunkLoop();
|
||||
static absl::StatusOr<int64_t> CurrentLoopIteration(int64_t depth = 0);
|
||||
static absl::StatusOr<int64_t> CurrentLoopIteration(
|
||||
const HloInstruction* while_instr);
|
||||
|
|
|
|||
|
|
@ -576,6 +576,28 @@ TEST_P(CommandBufferTest, DynamicSliceCopyFusionCmd) {
|
|||
|
||||
EXPECT_TRUE(
|
||||
RunAndCompareNoHloPasses(std::move(module), ErrorSpec{1e-3, 2e-3}));
|
||||
|
||||
if (!IsAtLeastCuda12900(GpuExecutor())) {
|
||||
GTEST_SKIP() << "While loop unrolling is not supported for CUDA < 12.9";
|
||||
}
|
||||
|
||||
debug_options.add_xla_gpu_enable_command_buffer(
|
||||
DebugOptions::DYNAMIC_SLICE_COPY_FUSION);
|
||||
debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::FUSION);
|
||||
debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::CUBLAS);
|
||||
debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::CUBLASLT);
|
||||
debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::CUSTOM_CALL);
|
||||
debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::CUDNN);
|
||||
debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::COLLECTIVES);
|
||||
debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::WHILE);
|
||||
debug_options.set_xla_gpu_command_buffer_unroll_loops(true);
|
||||
config.set_debug_options(debug_options);
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto unrolled_module,
|
||||
ParseAndReturnVerifiedModule(hlo_text, config));
|
||||
|
||||
EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(unrolled_module),
|
||||
ErrorSpec{1e-3, 2e-3}));
|
||||
}
|
||||
|
||||
TEST_P(CommandBufferUnrollTest, WhileLoop) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user