mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
PR #32719: 【XLA:GPU] Command buffer DynamicSliceFusionCmd supports cuda graph loop unrolling
Imported from GitHub PR https://github.com/openxla/xla/pull/32719 📝 Summary of Changes This PR enables command buffer DynamicSliceFusion command to be recorded into an unrolled cuda-graph, when it is surrounded by WhileCmd 🎯 Justification This feature is required if we want to fully command buffer WhileCmd into an unrolled cuda-graph. 🚀 Kind of Contribution Please remove what does not apply: ✨ New Feature 🧪 Unit Tests: xla/backends/gpu/codegen/dynamic_slice_fusion_test.cc Copybara import of the project: -- daa975804cbffcc3a6bc5b37e3494b51a2dbe2ca by Shawn Wang <shawnw@nvidia.com>: DynamicSliceFsuionCmd supports unrolling Merging this change closes #32719 PiperOrigin-RevId: 822071751
This commit is contained in:
parent
2d4dd83773
commit
8c169d147d
|
|
@ -218,7 +218,10 @@ xla_test(
|
|||
"//xla/service/gpu:ir_emission_utils",
|
||||
"//xla/service/gpu/transforms:dynamic_slice_fusion_rewriter",
|
||||
"//xla/stream_executor:device_memory",
|
||||
"//xla/stream_executor:platform_manager",
|
||||
"//xla/stream_executor:stream",
|
||||
"//xla/stream_executor:stream_executor_h",
|
||||
"//xla/stream_executor/cuda:cuda_compute_capability",
|
||||
"//xla/tests:hlo_test_base",
|
||||
"//xla/tsl/platform:errors",
|
||||
"//xla/tsl/platform:statusor",
|
||||
|
|
|
|||
|
|
@ -47,8 +47,11 @@ limitations under the License.
|
|||
#include "xla/service/platform_util.h"
|
||||
#include "xla/shape.h"
|
||||
#include "xla/shape_util.h"
|
||||
#include "xla/stream_executor/cuda/cuda_compute_capability.h"
|
||||
#include "xla/stream_executor/device_memory.h"
|
||||
#include "xla/stream_executor/platform_manager.h"
|
||||
#include "xla/stream_executor/stream.h"
|
||||
#include "xla/stream_executor/stream_executor.h"
|
||||
#include "xla/tests/hlo_test_base.h"
|
||||
#include "xla/tsl/platform/errors.h"
|
||||
#include "xla/tsl/platform/statusor.h"
|
||||
|
|
@ -73,6 +76,28 @@ MATCHER_P(ThunkKindIs, kind, "") {
|
|||
return ExplainMatchResult(::testing::Eq(kind), arg->kind(), result_listener);
|
||||
}
|
||||
|
||||
se::StreamExecutor* GpuExecutor() {
|
||||
auto name =
|
||||
absl::AsciiStrToUpper(PlatformUtil::CanonicalPlatformName("gpu").value());
|
||||
auto* platform = se::PlatformManager::PlatformWithName(name).value();
|
||||
return platform->ExecutorForDevice(0).value();
|
||||
}
|
||||
|
||||
bool IsAtLeastCuda12900(const se::StreamExecutor* stream_executor) {
|
||||
const auto& device_description = stream_executor->GetDeviceDescription();
|
||||
const auto* cuda_cc = std::get_if<se::CudaComputeCapability>(
|
||||
&device_description.gpu_compute_capability());
|
||||
if (cuda_cc != nullptr) {
|
||||
if (device_description.driver_version() >=
|
||||
stream_executor::SemanticVersion(12, 9, 0) &&
|
||||
device_description.runtime_version() >=
|
||||
stream_executor::SemanticVersion(12, 9, 0)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
class DynamicSliceFusionTest : public HloTestBase {
|
||||
public:
|
||||
HloModuleConfig GetModuleConfigWithoutCommandBuffer() {
|
||||
|
|
@ -105,6 +130,30 @@ class DynamicSliceFusionTest : public HloTestBase {
|
|||
return config;
|
||||
}
|
||||
|
||||
HloModuleConfig GetModuleConfigWithCommandBufferUnrollLoops() {
|
||||
DebugOptions debug_options = GetDebugOptionsForTest();
|
||||
debug_options.set_xla_gpu_enable_cublaslt(false);
|
||||
debug_options.set_xla_gpu_gemm_rewrite_size_threshold(0);
|
||||
debug_options.set_xla_gpu_enable_dynamic_slice_fusion(true);
|
||||
debug_options.set_xla_gpu_triton_gemm_any(false);
|
||||
debug_options.set_xla_gpu_enable_cublaslt(false);
|
||||
debug_options.set_xla_gpu_cublas_fallback(true);
|
||||
debug_options.set_xla_gpu_graph_min_graph_size(1);
|
||||
debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::FUSION);
|
||||
debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::CUBLAS);
|
||||
debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::CUBLASLT);
|
||||
debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::CUSTOM_CALL);
|
||||
debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::CUDNN);
|
||||
debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::COLLECTIVES);
|
||||
debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::WHILE);
|
||||
debug_options.add_xla_gpu_enable_command_buffer(
|
||||
DebugOptions::DYNAMIC_SLICE_FUSION);
|
||||
debug_options.set_xla_gpu_command_buffer_unroll_loops(true);
|
||||
HloModuleConfig config;
|
||||
config.set_debug_options(debug_options);
|
||||
return config;
|
||||
}
|
||||
|
||||
HloModuleConfig GetModuleConfigWithDeterministicOps() {
|
||||
DebugOptions debug_options = GetDebugOptionsForTest();
|
||||
debug_options.set_xla_gpu_exclude_nondeterministic_ops(true);
|
||||
|
|
@ -3454,10 +3503,6 @@ TEST_F(DynamicSliceFusionTest,
|
|||
ROOT %while = while(%initial_tuple), condition=%Cond, body=%Body, backend_config={"known_trip_count":{"n":"11"}}
|
||||
})";
|
||||
|
||||
// Run the same HLO with and without command buffer and compare results.
|
||||
HloModuleConfig with_cmd_buffer = GetModuleConfigWithCommandBuffer();
|
||||
HloModuleConfig without_cmd_buffer = GetModuleConfigWithoutCommandBuffer();
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> cmd_buffer_module,
|
||||
ParseAndReturnVerifiedModule(hlo));
|
||||
|
||||
|
|
@ -3469,6 +3514,15 @@ TEST_F(DynamicSliceFusionTest,
|
|||
RunAndCompareTwoModules(hlo, hlo, GetModuleConfigWithCommandBuffer(),
|
||||
GetModuleConfigWithoutCommandBuffer(), error_spec,
|
||||
/*run_hlo_passes=*/true));
|
||||
|
||||
se::StreamExecutor* stream_executor = GpuExecutor();
|
||||
if (!IsAtLeastCuda12900(stream_executor)) {
|
||||
GTEST_SKIP() << "While loop unrolling is not supported for CUDA < 12.9";
|
||||
}
|
||||
EXPECT_TRUE(RunAndCompareTwoModules(
|
||||
hlo, hlo, GetModuleConfigWithCommandBufferUnrollLoops(),
|
||||
GetModuleConfigWithoutCommandBuffer(), error_spec,
|
||||
/*run_hlo_passes=*/true));
|
||||
}
|
||||
|
||||
TEST_F(DynamicSliceFusionTest, MultipleOffsetsAsFunctionOfInductionVariable) {
|
||||
|
|
|
|||
|
|
@ -1493,6 +1493,14 @@ absl::Status WhileCmd::Initialize(const Thunk::InitializeParams& params,
|
|||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status WhileCmd::Prepare(
|
||||
const Thunk::PrepareParams& params,
|
||||
Thunk::ResourceRequestsInterface& resource_requests) {
|
||||
TF_RETURN_IF_ERROR(cond_commands_.Prepare(params, resource_requests));
|
||||
TF_RETURN_IF_ERROR(body_commands_.Prepare(params, resource_requests));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::StatusOr<const se::CommandBuffer::Command*> WhileCmd::Record(
|
||||
const Thunk::ExecuteParams& execute_params,
|
||||
const RecordParams& record_params, RecordAction record_action,
|
||||
|
|
|
|||
|
|
@ -869,6 +869,10 @@ class WhileCmd : public CommandBufferCmd {
|
|||
absl::Status Initialize(const Thunk::InitializeParams& params,
|
||||
StateManager& state) override;
|
||||
|
||||
absl::Status Prepare(
|
||||
const Thunk::PrepareParams& params,
|
||||
Thunk::ResourceRequestsInterface& resource_requests) override;
|
||||
|
||||
absl::StatusOr<const se::CommandBuffer::Command*> Record(
|
||||
const Thunk::ExecuteParams& execute_params,
|
||||
const RecordParams& record_params, RecordAction record_action,
|
||||
|
|
@ -1240,7 +1244,7 @@ class DynamicSliceFusionCmd : public CommandBufferCmd {
|
|||
|
||||
bool requires_initialization() override;
|
||||
|
||||
bool support_loop_unroll() override { return false; }
|
||||
bool support_loop_unroll() override { return true; }
|
||||
|
||||
bool IsNestedCommandBuffer() const final { return true; }
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user