mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
[XLA:GPU] Add verbose kernel scheduling tracing for debugging
PiperOrigin-RevId: 818918076
This commit is contained in:
parent
631a48b8da
commit
43f9e0789c
|
|
@ -932,6 +932,8 @@ cc_library(
|
|||
"@com_google_absl//absl/synchronization",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@local_tsl//tsl/profiler/lib:traceme",
|
||||
"@local_tsl//tsl/profiler/lib:traceme_encode",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -50,6 +50,12 @@ limitations under the License.
|
|||
#include "xla/stream_executor/stream.h"
|
||||
#include "xla/stream_executor/stream_executor.h"
|
||||
#include "xla/tsl/platform/statusor.h"
|
||||
#include "tsl/profiler/lib/traceme.h"
|
||||
#include "tsl/profiler/lib/traceme_encode.h"
|
||||
|
||||
using tsl::profiler::TraceMe;
|
||||
using tsl::profiler::TraceMeEncode;
|
||||
using tsl::profiler::TraceMeLevel;
|
||||
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
|
@ -223,43 +229,71 @@ static void PrintBufferContents(
|
|||
}
|
||||
|
||||
absl::Status KernelThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
TraceMe trace(
|
||||
[] { return TraceMeEncode("KernelThunk::ExecuteOnStream", {}); },
|
||||
/*level=*/TraceMeLevel::kVerbose);
|
||||
|
||||
// Load the kernel.
|
||||
se::StreamExecutor* executor = params.stream->parent();
|
||||
se::Kernel* kernel = nullptr;
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
se::Stream * stream,
|
||||
GetStreamForExecution(Thunk::execution_stream_id(), params));
|
||||
se::Stream* stream = nullptr;
|
||||
{
|
||||
TraceMe trace(
|
||||
[] {
|
||||
return TraceMeEncode(
|
||||
"KernelThunk::ExecuteOnStream/GetStreamForExecution", {});
|
||||
},
|
||||
/*level=*/TraceMeLevel::kVerbose);
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
stream, GetStreamForExecution(Thunk::execution_stream_id(), params));
|
||||
}
|
||||
|
||||
{
|
||||
TraceMe trace(
|
||||
[] { return TraceMeEncode("KernelThunk::ExecuteOnStream/mutex", {}); },
|
||||
/*level=*/TraceMeLevel::kVerbose);
|
||||
absl::MutexLock lock(mutex_);
|
||||
TraceMe trace_find(
|
||||
[] {
|
||||
return TraceMeEncode("KernelThunk::ExecuteOnStream/mutex/find", {});
|
||||
},
|
||||
/*level=*/TraceMeLevel::kVerbose);
|
||||
auto it = kernel_cache_.find(executor);
|
||||
CHECK(it != kernel_cache_.end())
|
||||
<< "Initialize() not called for StreamExecutor " << executor;
|
||||
kernel = it->second.get();
|
||||
}
|
||||
|
||||
int device_ordinal = executor->device_ordinal();
|
||||
VLOG(3) << "[" << device_ordinal << "] Launching " << kernel->name();
|
||||
absl::InlinedVector<se::KernelArgument, 4> kernel_args;
|
||||
for (const auto& [idx, arg] : llvm::enumerate(args_)) {
|
||||
se::DeviceMemoryBase buf = params.buffer_allocations->GetDeviceAddress(arg);
|
||||
VLOG(3) << "[" << device_ordinal << "] Arg: alloc #" << arg.index()
|
||||
<< ", offset: " << arg.offset() << ": " << buf.opaque() << " ("
|
||||
<< buf.size() << "B)";
|
||||
{
|
||||
TraceMe trace(
|
||||
[] {
|
||||
return TraceMeEncode("KernelThunk::ExecuteOnStream/kernel_args", {});
|
||||
},
|
||||
/*level=*/TraceMeLevel::kVerbose);
|
||||
int device_ordinal = executor->device_ordinal();
|
||||
VLOG(3) << "[" << device_ordinal << "] Launching " << kernel->name();
|
||||
for (const auto& [idx, arg] : llvm::enumerate(args_)) {
|
||||
se::DeviceMemoryBase buf =
|
||||
params.buffer_allocations->GetDeviceAddress(arg);
|
||||
VLOG(3) << "[" << device_ordinal << "] Arg: alloc #" << arg.index()
|
||||
<< ", offset: " << arg.offset() << ": " << buf.opaque() << " ("
|
||||
<< buf.size() << "B)";
|
||||
|
||||
if (auto it = tma_metadata_.arg_index_to_tma_info.find(idx);
|
||||
it != tma_metadata_.arg_index_to_tma_info.end()) {
|
||||
// TMA descriptor argument.
|
||||
const se::gpu::TmaDescriptor& tma_desc = it->second;
|
||||
TF_ASSIGN_OR_RETURN(se::TensorMap tensor_map,
|
||||
executor->CreateTensorMap(tma_desc, buf.opaque()));
|
||||
VLOG(3) << "[" << device_ordinal << "] Using TensorMap for arg #" << idx
|
||||
<< ": " << tma_desc.ToString();
|
||||
kernel_args.push_back(std::move(tensor_map));
|
||||
} else {
|
||||
// Buffer argument.
|
||||
kernel_args.push_back(buf);
|
||||
if (auto it = tma_metadata_.arg_index_to_tma_info.find(idx);
|
||||
it != tma_metadata_.arg_index_to_tma_info.end()) {
|
||||
// TMA descriptor argument.
|
||||
const se::gpu::TmaDescriptor& tma_desc = it->second;
|
||||
TF_ASSIGN_OR_RETURN(se::TensorMap tensor_map,
|
||||
executor->CreateTensorMap(tma_desc, buf.opaque()));
|
||||
VLOG(3) << "[" << device_ordinal << "] Using TensorMap for arg #"
|
||||
<< idx << ": " << tma_desc.ToString();
|
||||
kernel_args.push_back(std::move(tensor_map));
|
||||
} else {
|
||||
// Buffer argument.
|
||||
kernel_args.push_back(buf);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
2
third_party/xla/xla/service/gpu/BUILD
vendored
2
third_party/xla/xla/service/gpu/BUILD
vendored
|
|
@ -2625,6 +2625,8 @@ cc_library(
|
|||
"@local_tsl//tsl/platform:ml_dtypes",
|
||||
"@local_tsl//tsl/platform:status",
|
||||
"@local_tsl//tsl/platform:statusor",
|
||||
"@local_tsl//tsl/profiler/lib:traceme",
|
||||
"@local_tsl//tsl/profiler/lib:traceme_encode",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -67,6 +67,12 @@ limitations under the License.
|
|||
#include "xla/util.h"
|
||||
#include "xla/xla_data.pb.h"
|
||||
#include "tsl/platform/ml_dtypes.h"
|
||||
#include "tsl/profiler/lib/traceme.h"
|
||||
#include "tsl/profiler/lib/traceme_encode.h"
|
||||
|
||||
using tsl::profiler::TraceMe;
|
||||
using tsl::profiler::TraceMeEncode;
|
||||
using tsl::profiler::TraceMeLevel;
|
||||
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
|
@ -405,9 +411,19 @@ absl::Status ExecuteKernelOnStream(
|
|||
se::Kernel& kernel, absl::Span<const se::KernelArgument> args,
|
||||
const LaunchDimensions& dims,
|
||||
const std::optional<se::ClusterDim>& cluster_dim, se::Stream* stream) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<se::KernelArgsPackedArrayBase> kernel_args,
|
||||
se::PackKernelArgs(args, kernel.metadata()));
|
||||
TraceMe trace([] { return TraceMeEncode("ExecuteKernelOnStream", {}); },
|
||||
/*level=*/TraceMeLevel::kVerbose);
|
||||
|
||||
std::unique_ptr<se::KernelArgsPackedArrayBase> kernel_args;
|
||||
{
|
||||
TraceMe trace(
|
||||
[] {
|
||||
return TraceMeEncode("ExecuteKernelOnStream/PackKernelArgs", {});
|
||||
},
|
||||
/*level=*/TraceMeLevel::kVerbose);
|
||||
TF_ASSIGN_OR_RETURN(kernel_args,
|
||||
se::PackKernelArgs(args, kernel.metadata()));
|
||||
}
|
||||
|
||||
return kernel.Launch(dims.thread_counts_per_block(), dims.block_counts(),
|
||||
cluster_dim, stream, *kernel_args);
|
||||
|
|
|
|||
|
|
@ -699,6 +699,8 @@ cc_library(
|
|||
"@com_google_absl//absl/strings:str_format",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
"@local_tsl//tsl/platform:logging",
|
||||
"@local_tsl//tsl/profiler/lib:traceme",
|
||||
"@local_tsl//tsl/profiler/lib:traceme_encode",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -1380,6 +1382,8 @@ cc_library(
|
|||
"@com_google_absl//absl/synchronization",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
"@local_tsl//tsl/profiler/lib:nvtx_utils",
|
||||
"@local_tsl//tsl/profiler/lib:traceme",
|
||||
"@local_tsl//tsl/profiler/lib:traceme_encode",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -34,6 +34,12 @@ limitations under the License.
|
|||
#include "xla/stream_executor/stream.h"
|
||||
#include "xla/tsl/platform/errors.h"
|
||||
#include "xla/tsl/platform/statusor.h"
|
||||
#include "tsl/profiler/lib/traceme.h"
|
||||
#include "tsl/profiler/lib/traceme_encode.h"
|
||||
|
||||
using tsl::profiler::TraceMe;
|
||||
using tsl::profiler::TraceMeEncode;
|
||||
using tsl::profiler::TraceMeLevel;
|
||||
|
||||
namespace stream_executor {
|
||||
namespace gpu {
|
||||
|
|
@ -84,11 +90,17 @@ absl::Status CudaKernel::Launch(const ThreadDim& thread_dims,
|
|||
const BlockDim& block_dims,
|
||||
const std::optional<ClusterDim>& cluster_dims,
|
||||
Stream* stream, const KernelArgs& args) {
|
||||
TraceMe trace([] { return TraceMeEncode("CudaKernel::Launch", {}); },
|
||||
/*level=*/TraceMeLevel::kVerbose);
|
||||
|
||||
CUfunction function = gpu_function();
|
||||
|
||||
// Launch kernels with packed arguments.
|
||||
auto launch = [this, stream, &cluster_dims, &thread_dims, &block_dims,
|
||||
function](const KernelArgsPackedArrayBase& packed) {
|
||||
TraceMe trace([] { return TraceMeEncode("CudaKernel::Launch/launch", {}); },
|
||||
/*level=*/TraceMeLevel::kVerbose);
|
||||
|
||||
int32_t expected_number_of_arguments =
|
||||
Arity() + (packed.number_of_shared_bytes() > 0);
|
||||
|
||||
|
|
|
|||
|
|
@ -48,6 +48,12 @@ limitations under the License.
|
|||
#include "xla/tsl/platform/errors.h"
|
||||
#include "xla/tsl/platform/statusor.h"
|
||||
#include "tsl/profiler/lib/nvtx_utils.h"
|
||||
#include "tsl/profiler/lib/traceme.h"
|
||||
#include "tsl/profiler/lib/traceme_encode.h"
|
||||
|
||||
using tsl::profiler::TraceMe;
|
||||
using tsl::profiler::TraceMeEncode;
|
||||
using tsl::profiler::TraceMeLevel;
|
||||
|
||||
namespace stream_executor {
|
||||
namespace gpu {
|
||||
|
|
@ -363,6 +369,9 @@ absl::Status LaunchCudaKernel(
|
|||
unsigned int grid_dim_z, unsigned int block_dim_x, unsigned int block_dim_y,
|
||||
unsigned int block_dim_z, unsigned int shared_mem_bytes, CUstream stream,
|
||||
void** kernel_params, void** extra) {
|
||||
TraceMe trace([] { return TraceMeEncode("LaunchCudaKernel", {}); },
|
||||
/*level=*/TraceMeLevel::kVerbose);
|
||||
|
||||
std::unique_ptr<ActivateContext> activation = executor->Activate();
|
||||
VLOG(2) << "launching kernel: " << kernel_name << "; gdx: " << grid_dim_x
|
||||
<< " gdy: " << grid_dim_y << " gdz: " << grid_dim_z
|
||||
|
|
@ -383,15 +392,20 @@ absl::Status LaunchCudaKernel(
|
|||
cuFuncSetCacheConfig(function, CU_FUNC_CACHE_PREFER_SHARED)));
|
||||
}
|
||||
|
||||
return cuda::ToStatus(
|
||||
cuLaunchKernel(function, grid_dim_x, grid_dim_y, grid_dim_z, block_dim_x,
|
||||
block_dim_y, block_dim_z, shared_mem_bytes, stream,
|
||||
kernel_params, extra),
|
||||
absl::StrCat("Failed to launch CUDA kernel: ", kernel_name,
|
||||
"; block dims: ", block_dim_x, "x", block_dim_y, "x",
|
||||
block_dim_z, "; grid dims: ", grid_dim_x, "x", grid_dim_y,
|
||||
"x", grid_dim_z,
|
||||
"; shared memory size: ", shared_mem_bytes));
|
||||
{
|
||||
TraceMe trace(
|
||||
[&] { return TraceMeEncode("LaunchCudaKernel/cuLaunchKernel", {}); },
|
||||
/*level=*/TraceMeLevel::kVerbose);
|
||||
return cuda::ToStatus(
|
||||
cuLaunchKernel(function, grid_dim_x, grid_dim_y, grid_dim_z,
|
||||
block_dim_x, block_dim_y, block_dim_z, shared_mem_bytes,
|
||||
stream, kernel_params, extra),
|
||||
absl::StrCat("Failed to launch CUDA kernel: ", kernel_name,
|
||||
"; block dims: ", block_dim_x, "x", block_dim_y, "x",
|
||||
block_dim_z, "; grid dims: ", grid_dim_x, "x", grid_dim_y,
|
||||
"x", grid_dim_z,
|
||||
"; shared memory size: ", shared_mem_bytes));
|
||||
}
|
||||
}
|
||||
|
||||
absl::Status LaunchCudaKernel(
|
||||
|
|
@ -402,6 +416,8 @@ absl::Status LaunchCudaKernel(
|
|||
unsigned int block_dim_y, unsigned int block_dim_z,
|
||||
unsigned int shared_mem_bytes, CUstream stream, void** kernel_params,
|
||||
void** extra) {
|
||||
TraceMe trace([] { return TraceMeEncode("LaunchCudaKernel", {}); },
|
||||
/*level=*/TraceMeLevel::kVerbose);
|
||||
std::unique_ptr<ActivateContext> activation = executor->Activate();
|
||||
VLOG(2) << "launching kernel: " << kernel_name << "; cdx: " << cluster_dim_x
|
||||
<< " cdy: " << cluster_dim_y << " cdz: " << cluster_dim_z
|
||||
|
|
@ -444,14 +460,19 @@ absl::Status LaunchCudaKernel(
|
|||
launch_config.attrs = &cluster_dims;
|
||||
launch_config.numAttrs = 1;
|
||||
|
||||
return cuda::ToStatus(
|
||||
cuLaunchKernelEx(&launch_config, function, kernel_params, extra),
|
||||
absl::StrCat("Failed to launch CUDA kernel: ", kernel_name,
|
||||
"; cluster dims: ", cluster_dim_x, "x", cluster_dim_y, "x",
|
||||
cluster_dim_z, "; block dims: ", block_dim_x, "x",
|
||||
block_dim_y, "x", block_dim_z, "; grid dims: ", grid_dim_x,
|
||||
"x", grid_dim_y, "x", grid_dim_z,
|
||||
"; shared memory size: ", shared_mem_bytes));
|
||||
{
|
||||
TraceMe trace(
|
||||
[] { return TraceMeEncode("LaunchCudaKernel/cuLaunchKernelEx", {}); },
|
||||
/*level=*/TraceMeLevel::kVerbose);
|
||||
return cuda::ToStatus(
|
||||
cuLaunchKernelEx(&launch_config, function, kernel_params, extra),
|
||||
absl::StrCat("Failed to launch CUDA kernel: ", kernel_name,
|
||||
"; cluster dims: ", cluster_dim_x, "x", cluster_dim_y, "x",
|
||||
cluster_dim_z, "; block dims: ", block_dim_x, "x",
|
||||
block_dim_y, "x", block_dim_z, "; grid dims: ", grid_dim_x,
|
||||
"x", grid_dim_y, "x", grid_dim_z,
|
||||
"; shared memory size: ", shared_mem_bytes));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
|
@ -460,6 +481,9 @@ absl::Status CudaStream::LaunchKernel(
|
|||
const ThreadDim& thread_dims, const BlockDim& block_dims,
|
||||
const std::optional<ClusterDim>& cluster_dims, void* function,
|
||||
absl::string_view name, void** args, int64_t shmem_bytes) {
|
||||
TraceMe trace([] { return TraceMeEncode("CudaStream::LaunchKernel", {}); },
|
||||
/*level=*/TraceMeLevel::kVerbose);
|
||||
|
||||
if (cluster_dims.has_value()) {
|
||||
return LaunchCudaKernel(executor_, name, static_cast<CUfunction>(function),
|
||||
cluster_dims->x, cluster_dims->y, cluster_dims->z,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user