mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: LLVM-15 has a warning `-Wunused-variable` which we treat as an error because it's so often diagnostic of a code issue. Unused variables can compromise readability or, worse, performance. This diff either (a) removes an unused variable and, possibly, it's associated code or (b) qualifies the variable with `[[maybe_unused]]`. - If you approve of this diff, please use the "Accept & Ship" button :-) Test Plan: Sandcastle Reviewed By: dtolnay Pull Request resolved: https://github.com/pytorch/pytorch/pull/148501 Approved by: https://github.com/Skylion007
272 lines
8.7 KiB
Plaintext
272 lines
8.7 KiB
Plaintext
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
|
#include <ATen/Dispatch.h>
|
|
#include <ATen/core/Tensor.h>
|
|
#include <ATen/cuda/CUDAContext.h>
|
|
#include <ATen/cuda/nvrtc_stub/ATenNVRTC.h>
|
|
#include <c10/cuda/CUDAGuard.h>
|
|
|
|
// Two warninngs in Cutlass included header files
|
|
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wset-but-not-used")
|
|
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-but-set-parameter")
|
|
|
|
#if !defined(USE_ROCM) && !defined(_WIN32) && defined(CUDA_VERSION) && \
|
|
CUDA_VERSION >= 12000
|
|
#define BUILD_ASYNC_MM_KERNEL
|
|
#endif
|
|
|
|
#if defined(BUILD_ASYNC_MM_KERNEL)
|
|
|
|
#include <cutlass/core_io.h>
|
|
#include <cutlass/cutlass.h>
|
|
#include <cutlass/gemm/device/gemm.h>
|
|
#include <cutlass/half.h>
|
|
#include <cutlass/numeric_types.h>
|
|
#include <cutlass/trace.h>
|
|
#include <cutlass/util/host_tensor.h>
|
|
#include <cute/tensor.hpp>
|
|
|
|
#include <cutlass/version.h>
|
|
#include <cutlass/gemm/collective/collective_builder.hpp>
|
|
#include <cutlass/gemm/device/gemm_universal_adapter.h>
|
|
#include <cutlass/epilogue/collective/collective_builder.hpp>
|
|
|
|
#include <cute/atom/mma_atom.hpp>
|
|
#include <cutlass/gemm/dispatch_policy.hpp>
|
|
#include <cutlass/gemm/kernel/gemm_universal.hpp>
|
|
#include <cutlass/util/packed_stride.hpp>
|
|
|
|
#include <torch/csrc/distributed/c10d/cuda/cutlass/gemm/kernel/persistent_async_input_scheduler.cuh>
|
|
|
|
C10_DIAGNOSTIC_POP()
|
|
C10_DIAGNOSTIC_POP()
|
|
|
|
namespace {
|
|
|
|
using namespace cute;
|
|
|
|
template <typename LayoutB, typename TileShape_MNK, typename ClusterShape_MNK>
|
|
at::Tensor async_input_mm_impl(
|
|
at::Tensor a,
|
|
at::Tensor b,
|
|
at::Tensor a_chunk_signals,
|
|
int64_t a_chunk_pivot,
|
|
at::Tensor out) {
|
|
c10::cuda::CUDAGuard guard(a.device());
|
|
|
|
using ElementA = cutlass::bfloat16_t;
|
|
using LayoutA = cutlass::layout::RowMajor;
|
|
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
|
|
|
|
using ElementB = cutlass::bfloat16_t;
|
|
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
|
|
|
|
using ElementC = cutlass::bfloat16_t;
|
|
using LayoutC = cutlass::layout::RowMajor;
|
|
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
|
|
|
|
using ElementAccumulator = float;
|
|
|
|
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperative;
|
|
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
|
|
|
|
using CollectiveEpilogue =
|
|
typename cutlass::epilogue::collective::CollectiveBuilder<
|
|
cutlass::arch::Sm90,
|
|
cutlass::arch::OpClassTensorOp,
|
|
TileShape_MNK,
|
|
ClusterShape_MNK,
|
|
cutlass::epilogue::collective::EpilogueTileAuto,
|
|
ElementAccumulator,
|
|
ElementAccumulator,
|
|
ElementC,
|
|
LayoutC,
|
|
AlignmentC,
|
|
ElementC,
|
|
LayoutC,
|
|
AlignmentC,
|
|
EpilogueSchedule>::CollectiveOp;
|
|
|
|
using CollectiveMainloop =
|
|
typename cutlass::gemm::collective::CollectiveBuilder<
|
|
cutlass::arch::Sm90,
|
|
cutlass::arch::OpClassTensorOp,
|
|
ElementA,
|
|
LayoutA,
|
|
AlignmentA,
|
|
ElementB,
|
|
LayoutB,
|
|
AlignmentB,
|
|
ElementAccumulator,
|
|
TileShape_MNK,
|
|
ClusterShape_MNK,
|
|
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
|
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
|
KernelSchedule>::CollectiveOp;
|
|
|
|
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
|
Shape<int, int, int>,
|
|
CollectiveMainloop,
|
|
CollectiveEpilogue,
|
|
cutlass::gemm::PersistentAsyncInputScheduler<KernelSchedule>>;
|
|
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
|
|
|
using StrideA = typename Gemm::GemmKernel::StrideA;
|
|
using StrideB = typename Gemm::GemmKernel::StrideB;
|
|
using StrideC = typename Gemm::GemmKernel::StrideC;
|
|
|
|
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && out.dim() == 2);
|
|
TORCH_CHECK(a.is_contiguous() && out.is_contiguous());
|
|
|
|
if constexpr (std::is_same_v<LayoutB, cutlass::layout::RowMajor>) {
|
|
TORCH_CHECK(b.is_contiguous());
|
|
} else {
|
|
TORCH_CHECK(b.stride(1) == b.size(0));
|
|
TORCH_CHECK(b.stride(0) == 1);
|
|
}
|
|
TORCH_CHECK_EQ(a.scalar_type(), at::kBFloat16);
|
|
TORCH_CHECK_EQ(b.scalar_type(), at::kBFloat16);
|
|
TORCH_CHECK_EQ(out.scalar_type(), at::kBFloat16);
|
|
|
|
int M = static_cast<int>(a.sizes()[0]);
|
|
int N = static_cast<int>(b.sizes()[1]);
|
|
int K = static_cast<int>(a.sizes()[1]);
|
|
TORCH_CHECK_EQ(b.sizes()[0], K);
|
|
TORCH_CHECK_EQ(out.sizes()[0], M);
|
|
TORCH_CHECK_EQ(out.sizes()[1], N);
|
|
|
|
auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1});
|
|
auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1});
|
|
auto stride_C = cutlass::make_cute_packed_stride(StrideC{}, {M, N, 1});
|
|
|
|
Gemm gemm;
|
|
|
|
typename Gemm::Arguments arguments{
|
|
cutlass::gemm::GemmUniversalMode::kGemm,
|
|
{M, N, K},
|
|
{
|
|
reinterpret_cast<ElementA*>(a.data_ptr<at::BFloat16>()),
|
|
stride_A,
|
|
reinterpret_cast<ElementB*>(b.data_ptr<at::BFloat16>()),
|
|
stride_B,
|
|
},
|
|
{{1, 1},
|
|
reinterpret_cast<ElementC*>(out.data_ptr<at::BFloat16>()),
|
|
stride_C,
|
|
reinterpret_cast<ElementC*>(out.data_ptr<at::BFloat16>()),
|
|
stride_C},
|
|
};
|
|
|
|
TORCH_CHECK(
|
|
a_chunk_signals.sizes().size() == 1,
|
|
"async_input_mm: `a_chunk_signals` must be a 1D tensor.");
|
|
size_t num_chunks_M = a_chunk_signals.numel();
|
|
|
|
TORCH_CHECK(
|
|
M % num_chunks_M == 0,
|
|
"async_input_mm: `a.shape(0)` must be an interger multiple of `a_chunk_signals.numel()`");
|
|
size_t chunk_size_M = M / num_chunks_M;
|
|
size_t tile_size_M = cute::get<0>(TileShape_MNK{});
|
|
|
|
TORCH_CHECK(chunk_size_M % tile_size_M == 0);
|
|
|
|
// We want to swizzle within a chunk
|
|
arguments.scheduler.max_swizzle_size = chunk_size_M / tile_size_M;
|
|
|
|
// PersistentAsyncInputScheduler currently only supports rastering along N
|
|
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::
|
|
PersistentTileSchedulerSm90::RasterOrderOptions;
|
|
arguments.scheduler.raster_order = RasterOrderOptions::AlongN;
|
|
|
|
// Convert the number of chunks to pivot to the number of m idx to pivot
|
|
arguments.scheduler.tile_idx_pivot_m =
|
|
a_chunk_pivot * (chunk_size_M / tile_size_M);
|
|
arguments.scheduler.tiles_per_chunk_m = chunk_size_M / tile_size_M;
|
|
arguments.scheduler.chunk_signals = a_chunk_signals.data_ptr<uint32_t>();
|
|
|
|
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
|
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
|
|
|
TORCH_CHECK(gemm.can_implement(arguments) == cutlass::Status::kSuccess);
|
|
TORCH_CHECK(
|
|
gemm.initialize(arguments, workspace.get()) == cutlass::Status::kSuccess);
|
|
TORCH_CHECK(
|
|
gemm(at::cuda::getCurrentCUDAStream()) == cutlass::Status::kSuccess);
|
|
|
|
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
|
return out;
|
|
}
|
|
|
|
} // namespace
|
|
|
|
#endif
|
|
|
|
namespace c10d::cuda::detail {
|
|
|
|
#define DISPATCH_LAYOUT_B(is_b_row_major, ...) \
|
|
if (is_b_row_major) { \
|
|
using LayoutB = cutlass::layout::RowMajor; \
|
|
__VA_ARGS__(); \
|
|
} else { \
|
|
using LayoutB = cutlass::layout::ColumnMajor; \
|
|
__VA_ARGS__(); \
|
|
}
|
|
|
|
at::Tensor async_input_mm_out(
|
|
at::Tensor a,
|
|
at::Tensor b,
|
|
at::Tensor a_chunk_signals,
|
|
int64_t a_chunk_pivot,
|
|
at::Tensor out) {
|
|
TORCH_CHECK(
|
|
a.dim() == 2 && b.dim() == 2 && out.dim() == 2,
|
|
"async_input_mm: `a`, `b` and `out` must be matrices")
|
|
TORCH_CHECK(
|
|
a.is_contiguous() && out.is_contiguous(),
|
|
"async_input_mm: `a` and `out` must be in row-major layout");
|
|
|
|
if (!b.is_contiguous()) {
|
|
TORCH_CHECK(b.stride(1) == b.size(0));
|
|
TORCH_CHECK(b.stride(0) == 1);
|
|
}
|
|
TORCH_CHECK_EQ(a.scalar_type(), at::kBFloat16);
|
|
TORCH_CHECK_EQ(b.scalar_type(), at::kBFloat16);
|
|
TORCH_CHECK_EQ(out.scalar_type(), at::kBFloat16);
|
|
|
|
int64_t M = a.sizes()[0];
|
|
int64_t N = b.sizes()[1];
|
|
int64_t K = a.sizes()[1];
|
|
TORCH_CHECK_EQ(b.sizes()[0], K);
|
|
TORCH_CHECK_EQ(out.sizes()[0], M);
|
|
TORCH_CHECK_EQ(out.sizes()[1], N);
|
|
|
|
#if defined(BUILD_ASYNC_MM_KERNEL)
|
|
const bool is_b_row_major = b.is_contiguous();
|
|
DISPATCH_LAYOUT_B(is_b_row_major, [&]() {
|
|
// TODO(yifu): tuning
|
|
async_input_mm_impl<LayoutB, Shape<_128, _256, _64>, Shape<_2, _1, _1>>(
|
|
a, b, a_chunk_signals, a_chunk_pivot, out);
|
|
});
|
|
#else
|
|
TORCH_CHECK(
|
|
false, "async_input_mm is not currenlty supported on your device");
|
|
#endif
|
|
return out;
|
|
}
|
|
|
|
at::Tensor async_input_mm(
|
|
at::Tensor a,
|
|
at::Tensor b,
|
|
at::Tensor a_chunk_signals,
|
|
int64_t a_chunk_pivot) {
|
|
TORCH_CHECK(
|
|
a.dim() == 2 && b.dim() == 2,
|
|
"async_input_mm: `a`, `b` and `out` must all be a matrix")
|
|
|
|
int64_t M = a.sizes()[0];
|
|
int64_t N = b.sizes()[1];
|
|
auto out = a.new_empty({M, N});
|
|
return async_input_mm_out(a, b, a_chunk_signals, a_chunk_pivot, out);
|
|
}
|
|
|
|
} // namespace c10d::cuda::detail
|