mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "Refine alignment check along dynamic dimension for grouped MMs (#155466)"
This reverts commit 830a335a7d.
Reverted https://github.com/pytorch/pytorch/pull/155466 on behalf of https://github.com/atalman due to breaks internal builds ([comment](https://github.com/pytorch/pytorch/pull/155466#issuecomment-2988285117))
This commit is contained in:
parent
fec8af8b98
commit
0b62465b99
|
|
@ -36,7 +36,6 @@
|
||||||
#include <ATen/ops/copy_native.h>
|
#include <ATen/ops/copy_native.h>
|
||||||
#include <ATen/ops/dot_native.h>
|
#include <ATen/ops/dot_native.h>
|
||||||
#include <ATen/ops/empty.h>
|
#include <ATen/ops/empty.h>
|
||||||
#include <ATen/ops/empty_strided.h>
|
|
||||||
#include <ATen/ops/gelu.h>
|
#include <ATen/ops/gelu.h>
|
||||||
#include <ATen/ops/max.h>
|
#include <ATen/ops/max.h>
|
||||||
#include <ATen/ops/mm_native.h>
|
#include <ATen/ops/mm_native.h>
|
||||||
|
|
@ -1482,49 +1481,29 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
at::Tensor create_grouped_gemm_output_tensor(const Tensor& mat_a,
|
c10::SmallVector<int64_t, 3> compute_grouped_gemm_output_size(const Tensor& mat_a,
|
||||||
const Tensor& mat_b,
|
const Tensor& mat_b,
|
||||||
const std::optional<at::Tensor>& offs,
|
const std::optional<at::Tensor>& offs
|
||||||
std::optional<c10::ScalarType> out_dtype
|
|
||||||
) {
|
) {
|
||||||
c10::SmallVector<int64_t, 3> out_size;
|
|
||||||
const bool a_is_2d = mat_a.dim() == 2;
|
const bool a_is_2d = mat_a.dim() == 2;
|
||||||
const bool b_is_2d = mat_b.dim() == 2;
|
const bool b_is_2d = mat_b.dim() == 2;
|
||||||
if (a_is_2d) {
|
if (a_is_2d) {
|
||||||
if (b_is_2d) {
|
if (b_is_2d) {
|
||||||
out_size = {offs->size(0), mat_a.size(0), mat_b.size(1)};
|
return {offs->size(0), mat_a.size(0), mat_b.size(1)};
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(offs->size(0) == mat_b.size(0), "matrix batch sizes have to match");
|
TORCH_CHECK(offs->size(0) == mat_b.size(0), "matrix batch sizes have to match");
|
||||||
out_size = {mat_a.size(0), mat_b.size(-1)};
|
return {mat_a.size(0), mat_b.size(-1)};
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (b_is_2d) {
|
if (b_is_2d) {
|
||||||
// this case is not actually encountered for MoE gemms
|
// this case is not actually encountered for MoE gemms
|
||||||
TORCH_CHECK(offs->size(0) == mat_a.size(0), "matrix batch sizes have to match");
|
TORCH_CHECK(offs->size(0) == mat_a.size(0), "matrix batch sizes have to match");
|
||||||
out_size = {mat_a.size(1), mat_b.size(1)};
|
return {mat_a.size(1), mat_b.size(1)};
|
||||||
} else { // regular bmm
|
} else { // regular bmm
|
||||||
TORCH_CHECK(mat_a.size(0) == mat_b.size(0), "batched dimension has to match");
|
TORCH_CHECK(mat_a.size(0) == mat_b.size(0), "batched dimension has to match");
|
||||||
out_size = {mat_a.size(0), mat_a.size(1), mat_b.size(-1)};
|
return {mat_a.size(0), mat_a.size(1), mat_b.size(-1)};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto out_dtype_ = out_dtype.value_or(kBFloat16);
|
|
||||||
TORCH_CHECK(out_dtype_ == kBFloat16, "Only bf16 high precision output types are supported for grouped gemm");
|
|
||||||
|
|
||||||
// For TMA transfers, strides of output tensor have to be either
|
|
||||||
// 1, or aligned to 16 bytes.
|
|
||||||
const auto last_dim = out_size.size() - 1;
|
|
||||||
const auto alignment = 16 / c10::elementSize(out_dtype_);
|
|
||||||
const int64_t size_padded = (out_size[last_dim] + alignment - 1) / alignment * alignment;
|
|
||||||
std::vector<int64_t> out_stride;
|
|
||||||
if (a_is_2d != b_is_2d) {
|
|
||||||
out_stride = {size_padded, 1};
|
|
||||||
} else {
|
|
||||||
out_stride = {out_size[1] * size_padded, size_padded, 1};
|
|
||||||
}
|
|
||||||
auto out = at::empty_strided(out_size, out_stride, mat_a.options().dtype(out_dtype_));
|
|
||||||
|
|
||||||
return out;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool check_valid_strides_and_return_transposed(const Tensor& mat) {
|
bool check_valid_strides_and_return_transposed(const Tensor& mat) {
|
||||||
|
|
@ -1540,7 +1519,7 @@ namespace {
|
||||||
TORCH_CHECK(tensor_strides[end_dim - 1] % alignment == 0, "strides should be multiple of 16 bytes");
|
TORCH_CHECK(tensor_strides[end_dim - 1] % alignment == 0, "strides should be multiple of 16 bytes");
|
||||||
return false;
|
return false;
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(false, "Invalid strides/sizes, got ", mat.strides(), " for strides and ", mat.sizes(), " for sizes");
|
TORCH_CHECK(false, "Tensor should have a contiguous dimension and not be self-overlapping, got ", mat.strides(), " for strides and ", mat.sizes(), " for sizes");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1648,7 +1627,11 @@ bool use_fast_accum) {
|
||||||
check_scale(mat_a, scale_a, 0 ,0, scale_multiplier);
|
check_scale(mat_a, scale_a, 0 ,0, scale_multiplier);
|
||||||
check_scale(mat_b, scale_b, 1, 1, scale_multiplier);
|
check_scale(mat_b, scale_b, 1, 1, scale_multiplier);
|
||||||
|
|
||||||
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype);
|
const auto out_dtype_ = out_dtype.value_or(mat_a.scalar_type());
|
||||||
|
TORCH_CHECK(out_dtype_ == kBFloat16, "Only bf16 high precision output types are supported for grouped gemm");
|
||||||
|
const auto out_size = compute_grouped_gemm_output_size(mat_a, mat_b, offs);
|
||||||
|
Tensor out = at::empty(out_size, mat_a.options().dtype(out_dtype_));
|
||||||
|
|
||||||
|
|
||||||
at::cuda::detail::f8f8bf16_grouped_mm(
|
at::cuda::detail::f8f8bf16_grouped_mm(
|
||||||
mat_a,
|
mat_a,
|
||||||
|
|
@ -1684,7 +1667,6 @@ std::optional<c10::ScalarType> out_dtype) {
|
||||||
TORCH_CHECK(mat_b.dim() == 2 || mat_b.dim() == 3, "mat_b has to be 2 or 3d");
|
TORCH_CHECK(mat_b.dim() == 2 || mat_b.dim() == 3, "mat_b has to be 2 or 3d");
|
||||||
const bool a_is_2d = mat_a.dim() == 2;
|
const bool a_is_2d = mat_a.dim() == 2;
|
||||||
const bool b_is_2d = mat_b.dim() == 2;
|
const bool b_is_2d = mat_b.dim() == 2;
|
||||||
|
|
||||||
// check that the strides are valid, the fn will throw an error if not
|
// check that the strides are valid, the fn will throw an error if not
|
||||||
check_valid_strides_and_return_transposed(mat_a);
|
check_valid_strides_and_return_transposed(mat_a);
|
||||||
check_valid_strides_and_return_transposed(mat_b);
|
check_valid_strides_and_return_transposed(mat_b);
|
||||||
|
|
@ -1694,10 +1676,12 @@ std::optional<c10::ScalarType> out_dtype) {
|
||||||
TORCH_CHECK(offs->dim() == 1, "offs has to be 1D");
|
TORCH_CHECK(offs->dim() == 1, "offs has to be 1D");
|
||||||
TORCH_CHECK(offs->dtype() == at::kInt, "Offsets have to be int32");
|
TORCH_CHECK(offs->dtype() == at::kInt, "Offsets have to be int32");
|
||||||
}
|
}
|
||||||
|
const auto out_dtype_ = out_dtype.value_or(mat_a.scalar_type());
|
||||||
|
TORCH_CHECK(out_dtype_ == kBFloat16, "Only bf16 high output type is supported for grouped gemm");
|
||||||
TORCH_CHECK(!bias.has_value(), "Bias not supported yet");
|
TORCH_CHECK(!bias.has_value(), "Bias not supported yet");
|
||||||
|
|
||||||
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype);
|
const auto out_size = compute_grouped_gemm_output_size(mat_a, mat_b, offs);
|
||||||
|
Tensor out = at::empty(out_size, mat_a.options().dtype(out_dtype_));
|
||||||
at::cuda::detail::bf16bf16_grouped_mm(mat_a, mat_b, offs, bias, out);
|
at::cuda::detail::bf16bf16_grouped_mm(mat_a, mat_b, offs, bias, out);
|
||||||
return out;
|
return out;
|
||||||
#else
|
#else
|
||||||
|
|
|
||||||
|
|
@ -47,42 +47,10 @@ __global__ void prepare_grouped_gemm_data(
|
||||||
if (offs != nullptr) {
|
if (offs != nullptr) {
|
||||||
int32_t start = tid == 0 ? 0 : offs[tid - 1];
|
int32_t start = tid == 0 ? 0 : offs[tid - 1];
|
||||||
delta = offs[tid] - start;
|
delta = offs[tid] - start;
|
||||||
if (K < 0) {
|
int align = 16 / sizeof(DtypeA);
|
||||||
// CUTLASS cannot handle delta=0 here.
|
CUDA_KERNEL_ASSERT(
|
||||||
CUDA_KERNEL_ASSERT(delta >0 && "expected ofsets to be greater than 0\n");
|
delta >=0 && delta % align == 0 &&
|
||||||
} else {
|
"expected dynamic dimension byte size to be non-negative multiple of 16 \n");
|
||||||
CUDA_KERNEL_ASSERT(delta >=0 && "expected ofsets to be greater or equal 0\n");
|
|
||||||
}
|
|
||||||
|
|
||||||
// TMA transfers require global memory tensor addresses to be
|
|
||||||
// aligned to 16 bytes.
|
|
||||||
if (tid < blockDim.x - 1) {
|
|
||||||
// Check this requirement for input tensors, in case group
|
|
||||||
// addresses are increased along the dynamic dimension.
|
|
||||||
if ((K < 0 && a_row_major) || // 2D/2D: check along K dimension
|
|
||||||
(M < 0 && !a_row_major)) { // 3D/2D: check along N dimension
|
|
||||||
int align = 128 / cutlass::sizeof_bits<DtypeA>::value;
|
|
||||||
CUDA_KERNEL_ASSERT(
|
|
||||||
delta % align == 0 &&
|
|
||||||
"expected input tensor dynamic dimension byte size to be non-negative multiple of 16\n");
|
|
||||||
}
|
|
||||||
if ((K < 0 && !b_row_major) || // 2D/2D: check along K dimension
|
|
||||||
(N < 0 && b_row_major)) { // 3D/2D: check along N dimension
|
|
||||||
int align = 128 / cutlass::sizeof_bits<DtypeB>::value;
|
|
||||||
CUDA_KERNEL_ASSERT(
|
|
||||||
delta % align == 0 &&
|
|
||||||
"expected input tensor dynamic dimension byte size to be non-negative multiple of 16\n");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check the same requirement for output tensor (that is always
|
|
||||||
// contiguous, and in row-major layout).
|
|
||||||
if (N < 0) {
|
|
||||||
int align = 128 / cutlass::sizeof_bits<DtypeOutput>::value;
|
|
||||||
CUDA_KERNEL_ASSERT(
|
|
||||||
delta % align == 0 &&
|
|
||||||
"expected output tensor dynamic dimension byte size to be non-negative multiple of 16\n");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
int64_t lda, ldb, ldoutput;
|
int64_t lda, ldb, ldoutput;
|
||||||
if (M < 0) {
|
if (M < 0) {
|
||||||
|
|
@ -113,6 +81,7 @@ __global__ void prepare_grouped_gemm_data(
|
||||||
} else if (K < 0) {
|
} else if (K < 0) {
|
||||||
// A, B is 2d, output is 3d
|
// A, B is 2d, output is 3d
|
||||||
K = delta;
|
K = delta;
|
||||||
|
CUDA_KERNEL_ASSERT(delta > 0 && "can't handle K=0");
|
||||||
lda = a_row_major ? tensor_StrideA[0] : tensor_StrideA[1];
|
lda = a_row_major ? tensor_StrideA[0] : tensor_StrideA[1];
|
||||||
ldb = b_row_major ? tensor_StrideB[0] : tensor_StrideB[1];
|
ldb = b_row_major ? tensor_StrideB[0] : tensor_StrideB[1];
|
||||||
ldoutput = tensor_StrideOutput[1];
|
ldoutput = tensor_StrideOutput[1];
|
||||||
|
|
|
||||||
|
|
@ -315,7 +315,7 @@ class TestMatmulCuda(TestCase):
|
||||||
def test_grouped_gemm_2d_2d(self, strided, a_row_major, b_row_major, use_torch_compile):
|
def test_grouped_gemm_2d_2d(self, strided, a_row_major, b_row_major, use_torch_compile):
|
||||||
device = "cuda"
|
device = "cuda"
|
||||||
dtype = torch.bfloat16
|
dtype = torch.bfloat16
|
||||||
m, n, k, n_groups = 16, 32, 64, 4
|
m, n, k, n_groups = 16, 32, 64, 4 # all sizes have to be divisible by 16
|
||||||
if a_row_major:
|
if a_row_major:
|
||||||
a = torch.randn(m, k * n_groups + k * int(strided), device=device, dtype=dtype)[:, :k * n_groups]
|
a = torch.randn(m, k * n_groups + k * int(strided), device=device, dtype=dtype)[:, :k * n_groups]
|
||||||
else:
|
else:
|
||||||
|
|
@ -382,9 +382,6 @@ class TestMatmulCuda(TestCase):
|
||||||
b_contig = b if b_row_major else b.transpose(-2, -1)
|
b_contig = b if b_row_major else b.transpose(-2, -1)
|
||||||
self.assertTrue(b_contig.is_contiguous() is not strided)
|
self.assertTrue(b_contig.is_contiguous() is not strided)
|
||||||
for check_zero_size in (False, True):
|
for check_zero_size in (False, True):
|
||||||
if check_zero_size and n_groups <= 1:
|
|
||||||
continue
|
|
||||||
|
|
||||||
a.grad = None
|
a.grad = None
|
||||||
b.grad = None
|
b.grad = None
|
||||||
offs = torch.arange(m, n_groups * m + 1, m, device="cuda", dtype=torch.int32)
|
offs = torch.arange(m, n_groups * m + 1, m, device="cuda", dtype=torch.int32)
|
||||||
|
|
@ -487,9 +484,6 @@ class TestMatmulCuda(TestCase):
|
||||||
b_contig = b if b_row_major else b.transpose(-2, -1)
|
b_contig = b if b_row_major else b.transpose(-2, -1)
|
||||||
self.assertTrue(b_contig.is_contiguous() is not strided)
|
self.assertTrue(b_contig.is_contiguous() is not strided)
|
||||||
for check_zero_size in (False, True):
|
for check_zero_size in (False, True):
|
||||||
if check_zero_size and n_groups <= 1:
|
|
||||||
continue
|
|
||||||
|
|
||||||
offs = torch.arange(n, n_groups * n + 1, n, device="cuda", dtype=torch.int32)
|
offs = torch.arange(n, n_groups * n + 1, n, device="cuda", dtype=torch.int32)
|
||||||
if check_zero_size:
|
if check_zero_size:
|
||||||
offs[0] = offs[1]
|
offs[0] = offs[1]
|
||||||
|
|
@ -1651,27 +1645,17 @@ class TestFP8Matmul(TestCase):
|
||||||
for a, b, ascale, bscale, out in zip(alist, blist, ascalelist, bscalelist, outlist):
|
for a, b, ascale, bscale, out in zip(alist, blist, ascalelist, bscalelist, outlist):
|
||||||
out_ref = torch._scaled_mm(a, b.t(), ascale.view(-1, 1), bscale.view(1, -1),
|
out_ref = torch._scaled_mm(a, b.t(), ascale.view(-1, 1), bscale.view(1, -1),
|
||||||
out_dtype=torch.bfloat16, use_fast_accum=use_fast_accum)
|
out_dtype=torch.bfloat16, use_fast_accum=use_fast_accum)
|
||||||
self.assertEqual(out, out_ref, atol=5e-2, rtol=5e-4)
|
self.assertEqual(out, out_ref, atol=8e-2, rtol=8e-4)
|
||||||
|
|
||||||
# Testing only _scaled_grouped_mm() with multiple shapes, as
|
|
||||||
# _scaled_mm() already has more combinations of parameters than
|
|
||||||
# _scaled_grouped_mm(), for supporing more than one inputs layout
|
|
||||||
# combinations.
|
|
||||||
|
|
||||||
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
|
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
|
||||||
@xfailIfSM100OrLater
|
@xfailIfSM100OrLater
|
||||||
@unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90")
|
@unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90")
|
||||||
@parametrize(
|
|
||||||
"n_groups, m, n, k",
|
|
||||||
[(2, 1, 16, 16),
|
|
||||||
(4, 16, 16, 16)],
|
|
||||||
name_fn=lambda n_groups, m, n, k: f"{n_groups}_{m}_{n}_{k}",
|
|
||||||
)
|
|
||||||
@parametrize("fast_accum", [False, True])
|
@parametrize("fast_accum", [False, True])
|
||||||
@parametrize("strided", [False, True])
|
@parametrize("strided", [False, True])
|
||||||
@parametrize("use_torch_compile", [False, True])
|
@parametrize("use_torch_compile", [False, True])
|
||||||
def test_scaled_grouped_gemm_2d_2d(self, n_groups, m, n, k, fast_accum, strided, use_torch_compile):
|
def test_scaled_grouped_gemm_2d_2d(self, fast_accum, strided, use_torch_compile):
|
||||||
device = "cuda"
|
device = "cuda"
|
||||||
|
m, n, k, n_groups = 16, 32, 64, 4 # all sizes have to be divisible by 16
|
||||||
a = torch.randn(m, k * n_groups + k * int(strided), device=device).to(torch.float8_e4m3fn)[:, :k * n_groups]
|
a = torch.randn(m, k * n_groups + k * int(strided), device=device).to(torch.float8_e4m3fn)[:, :k * n_groups]
|
||||||
b = torch.randn(n, k * n_groups + k * int(strided), device=device).to(torch.float8_e4m3fn)[:, :k * n_groups]
|
b = torch.randn(n, k * n_groups + k * int(strided), device=device).to(torch.float8_e4m3fn)[:, :k * n_groups]
|
||||||
scale_a = torch.rand(m * n_groups, device=device, dtype=torch.float32)
|
scale_a = torch.rand(m * n_groups, device=device, dtype=torch.float32)
|
||||||
|
|
@ -1701,26 +1685,18 @@ class TestFP8Matmul(TestCase):
|
||||||
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
|
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
|
||||||
@xfailIfSM100OrLater
|
@xfailIfSM100OrLater
|
||||||
@unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90")
|
@unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90")
|
||||||
@parametrize(
|
|
||||||
"n_groups, m, n, k",
|
|
||||||
[(2, 1, 16, 16),
|
|
||||||
(4, 16, 16, 16)],
|
|
||||||
name_fn=lambda n_groups, m, n, k: f"{n_groups}_{m}_{n}_{k}",
|
|
||||||
)
|
|
||||||
@parametrize("fast_accum", [False, True])
|
@parametrize("fast_accum", [False, True])
|
||||||
@parametrize("strided", [False, True])
|
@parametrize("strided", [False, True])
|
||||||
@parametrize("use_torch_compile", [False, True])
|
@parametrize("use_torch_compile", [False, True])
|
||||||
def test_scaled_grouped_gemm_2d_3d(self, n_groups, m, n, k, fast_accum, strided, use_torch_compile):
|
def test_scaled_grouped_gemm_2d_3d(self, fast_accum, strided, use_torch_compile):
|
||||||
device = "cuda"
|
device = "cuda"
|
||||||
s_int = int(strided)
|
s_int = int(strided)
|
||||||
|
m, n, k, n_groups = 16, 32, 64, 4
|
||||||
a = torch.randn(m * n_groups, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[:, :k]
|
a = torch.randn(m * n_groups, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[:, :k]
|
||||||
b = torch.randn(n_groups * (1 + s_int), n, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k]
|
b = torch.randn(n_groups * (1 + s_int), n, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k]
|
||||||
self.assertTrue(a.is_contiguous() is not strided)
|
self.assertTrue(a.is_contiguous() is not strided)
|
||||||
self.assertTrue(b.is_contiguous() is not strided)
|
self.assertTrue(b.is_contiguous() is not strided)
|
||||||
for check_zero_size in (True, False):
|
for check_zero_size in (True, False):
|
||||||
if check_zero_size and n_groups <= 1:
|
|
||||||
continue
|
|
||||||
|
|
||||||
offs = torch.arange(m, n_groups * m + 1, m, device="cuda", dtype=torch.int32)
|
offs = torch.arange(m, n_groups * m + 1, m, device="cuda", dtype=torch.int32)
|
||||||
if check_zero_size:
|
if check_zero_size:
|
||||||
offs[0] = offs[1]
|
offs[0] = offs[1]
|
||||||
|
|
@ -1751,18 +1727,13 @@ class TestFP8Matmul(TestCase):
|
||||||
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
|
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
|
||||||
@xfailIfSM100OrLater
|
@xfailIfSM100OrLater
|
||||||
@unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90")
|
@unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90")
|
||||||
@parametrize(
|
|
||||||
"n_groups, m, n, k",
|
|
||||||
[(2, 1, 16, 16),
|
|
||||||
(4, 16, 16, 16)],
|
|
||||||
name_fn=lambda n_groups, m, n, k: f"{n_groups}_{m}_{n}_{k}",
|
|
||||||
)
|
|
||||||
@parametrize("fast_accum", [False, True])
|
@parametrize("fast_accum", [False, True])
|
||||||
@parametrize("strided", [False, True])
|
@parametrize("strided", [False, True])
|
||||||
@parametrize("use_torch_compile", [False, True])
|
@parametrize("use_torch_compile", [False, True])
|
||||||
def test_scaled_grouped_gemm_3d_3d(self, n_groups, m, n, k, fast_accum, strided, use_torch_compile):
|
def test_scaled_grouped_gemm_3d_3d(self, fast_accum, strided, use_torch_compile):
|
||||||
device = "cuda"
|
device = "cuda"
|
||||||
s_int = int(strided)
|
s_int = int(strided)
|
||||||
|
m, n, k, n_groups = 16, 32, 64, 4
|
||||||
a = torch.randn(n_groups * (1 + s_int), m, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k]
|
a = torch.randn(n_groups * (1 + s_int), m, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k]
|
||||||
b = torch.randn(n_groups * (1 + s_int), n, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k]
|
b = torch.randn(n_groups * (1 + s_int), n, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k]
|
||||||
self.assertTrue(a.is_contiguous() is not strided)
|
self.assertTrue(a.is_contiguous() is not strided)
|
||||||
|
|
@ -1786,18 +1757,13 @@ class TestFP8Matmul(TestCase):
|
||||||
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
|
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
|
||||||
@xfailIfSM100OrLater
|
@xfailIfSM100OrLater
|
||||||
@unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90")
|
@unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90")
|
||||||
@parametrize(
|
|
||||||
"n_groups, m, n, k",
|
|
||||||
[(2, 1, 16, 16),
|
|
||||||
(4, 16, 16, 16)],
|
|
||||||
name_fn=lambda n_groups, m, n, k: f"{n_groups}_{m}_{n}_{k}",
|
|
||||||
)
|
|
||||||
@parametrize("fast_accum", [False, True])
|
@parametrize("fast_accum", [False, True])
|
||||||
@parametrize("strided", [False, True])
|
@parametrize("strided", [False, True])
|
||||||
@parametrize("use_torch_compile", [False, True])
|
@parametrize("use_torch_compile", [False, True])
|
||||||
def test_scaled_grouped_gemm_3d_2d(self, n_groups, m, n, k, fast_accum, strided, use_torch_compile):
|
def test_scaled_grouped_gemm_3d_2d(self, fast_accum, strided, use_torch_compile):
|
||||||
device = "cuda"
|
device = "cuda"
|
||||||
s_int = int(strided)
|
s_int = int(strided)
|
||||||
|
m, n, k, n_groups = 16, 32, 64, 4
|
||||||
a = torch.randn(n_groups * (1 + s_int), m, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k]
|
a = torch.randn(n_groups * (1 + s_int), m, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k]
|
||||||
b = torch.randn(n * n_groups, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[:, :k]
|
b = torch.randn(n * n_groups, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[:, :k]
|
||||||
self.assertTrue(a.is_contiguous() is not strided)
|
self.assertTrue(a.is_contiguous() is not strided)
|
||||||
|
|
@ -1805,9 +1771,6 @@ class TestFP8Matmul(TestCase):
|
||||||
scale_a = torch.rand(n_groups * m, device="cuda", dtype=torch.float32).view(n_groups, m)
|
scale_a = torch.rand(n_groups * m, device="cuda", dtype=torch.float32).view(n_groups, m)
|
||||||
scale_b = torch.rand(n_groups * n, device="cuda", dtype=torch.float32)
|
scale_b = torch.rand(n_groups * n, device="cuda", dtype=torch.float32)
|
||||||
for check_zero_size in (True, False):
|
for check_zero_size in (True, False):
|
||||||
if check_zero_size and n_groups <= 1:
|
|
||||||
continue
|
|
||||||
|
|
||||||
offs = torch.arange(n, n_groups * n + 1, n, device="cuda", dtype=torch.int32)
|
offs = torch.arange(n, n_groups * n + 1, n, device="cuda", dtype=torch.int32)
|
||||||
if check_zero_size:
|
if check_zero_size:
|
||||||
offs[0] = offs[1]
|
offs[0] = offs[1]
|
||||||
|
|
|
||||||
|
|
@ -120,6 +120,7 @@ def early_config_prune(g, m, configs, named_args):
|
||||||
return pruned_configs
|
return pruned_configs
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from fbgemm grouped_gemm.py
|
||||||
triton_grouped_mm_source = r"""
|
triton_grouped_mm_source = r"""
|
||||||
{%- if SCALED %}
|
{%- if SCALED %}
|
||||||
{%- if A_IS_2D or B_IS_2D %}
|
{%- if A_IS_2D or B_IS_2D %}
|
||||||
|
|
@ -670,7 +671,7 @@ def _tuned_grouped_mm_common(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_lowering(aten._grouped_mm.default, type_promotion_kind=None)
|
@register_lowering(aten._grouped_mm, type_promotion_kind=None)
|
||||||
def tuned_grouped_mm(
|
def tuned_grouped_mm(
|
||||||
mat_a: TensorBox,
|
mat_a: TensorBox,
|
||||||
mat_b: TensorBox,
|
mat_b: TensorBox,
|
||||||
|
|
@ -682,7 +683,7 @@ def tuned_grouped_mm(
|
||||||
"""Auto-tuning for _grouped_mm() operator."""
|
"""Auto-tuning for _grouped_mm() operator."""
|
||||||
|
|
||||||
return _tuned_grouped_mm_common(
|
return _tuned_grouped_mm_common(
|
||||||
"aten._grouped_mm.default",
|
"aten._grouped_mm",
|
||||||
"grouped_mm",
|
"grouped_mm",
|
||||||
aten__grouped_mm,
|
aten__grouped_mm,
|
||||||
triton_grouped_mm_template,
|
triton_grouped_mm_template,
|
||||||
|
|
|
||||||
|
|
@ -7467,39 +7467,28 @@ def sigmoid(self: Tensor) -> Tensor:
|
||||||
return torch.empty_like(self, dtype=result_dtype)
|
return torch.empty_like(self, dtype=result_dtype)
|
||||||
|
|
||||||
|
|
||||||
def _create_grouped_mm_output_tensor(mat1, mat2, offs, out_dtype):
|
def _compute_grouped_mm_output_size(mat1, mat2, offs):
|
||||||
mat1_is_2d = mat1.dim() == 2
|
mat1_is_2d = mat1.dim() == 2
|
||||||
mat2_is_2d = mat2.dim() == 2
|
mat2_is_2d = mat2.dim() == 2
|
||||||
|
|
||||||
if mat1_is_2d:
|
if mat1_is_2d:
|
||||||
if mat2_is_2d:
|
if mat2_is_2d:
|
||||||
out_size = [offs.size(0), mat1.size(0), mat2.size(1)]
|
return offs.size(0), mat1.size(0), mat2.size(1)
|
||||||
else:
|
else:
|
||||||
torch._check(
|
torch._check(
|
||||||
offs.size(0) == mat2.size(0), "matrix batch sizes have to match"
|
offs.size(0) == mat2.size(0), "matrix batch sizes have to match"
|
||||||
)
|
)
|
||||||
out_size = [mat1.size(0), mat2.size(-1)]
|
return mat1.size(0), mat2.size(-1)
|
||||||
else:
|
else:
|
||||||
if mat2_is_2d:
|
if mat2_is_2d:
|
||||||
torch._check(
|
torch._check(
|
||||||
offs.size(0) == mat1.size(0), "matrix batch sizes have to match"
|
offs.size(0) == mat1.size(0), "matrix batch sizes have to match"
|
||||||
)
|
)
|
||||||
out_size = [mat1.size(1), mat2.size(1)]
|
return mat1.size(1), mat2.size(1)
|
||||||
else:
|
else:
|
||||||
# regular bmm
|
# regular bmm
|
||||||
torch._check(mat1.size(0) == mat2.size(0), "batched dimension has to match")
|
torch._check(mat1.size(0) == mat2.size(0), "batched dimension has to match")
|
||||||
out_size = [mat1.size(0), mat1.size(1), mat2.size(-1)]
|
return mat1.size(0), mat1.size(1), mat2.size(-1)
|
||||||
|
|
||||||
out_dtype = out_dtype or mat1.dtype
|
|
||||||
|
|
||||||
alignment = 16 // out_dtype.itemsize
|
|
||||||
size_padded = (out_size[-1] + alignment - 1) // alignment * alignment
|
|
||||||
if mat1_is_2d == mat2_is_2d:
|
|
||||||
out_stride = [out_size[1] * size_padded, size_padded, 1]
|
|
||||||
else:
|
|
||||||
out_stride = [size_padded, 1]
|
|
||||||
out = torch.empty_strided(out_size, out_stride, dtype=out_dtype, device=mat1.device)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
def _meta_grouped_mm_common(
|
def _meta_grouped_mm_common(
|
||||||
|
|
@ -7542,6 +7531,15 @@ def _meta_grouped_mm_common(
|
||||||
mat_a_is_2d = mat_a.dim() == 2
|
mat_a_is_2d = mat_a.dim() == 2
|
||||||
mat_b_is_2d = mat_b.dim() == 2
|
mat_b_is_2d = mat_b.dim() == 2
|
||||||
|
|
||||||
|
torch._check(
|
||||||
|
mat_a.shape[-1] % 16 == 0,
|
||||||
|
lambda: f"Expected mat_a.shape[-1] to be divisible by 16, but got mat_a.shape[-1]={mat_a.shape[1]}",
|
||||||
|
)
|
||||||
|
torch._check(
|
||||||
|
mat_b.shape[-2] % 16 == 0 and mat_b.shape[-1] % 16 == 0,
|
||||||
|
lambda: f"Expected mat_b.shape[-2] and mat_b.shape[-1] to be both divisble by 16 but got {mat_b.shape[-2]} and {mat_b.shape[-1]}", # noqa: B950
|
||||||
|
)
|
||||||
|
|
||||||
if scaled:
|
if scaled:
|
||||||
|
|
||||||
def is_row_major(mat):
|
def is_row_major(mat):
|
||||||
|
|
@ -7563,7 +7561,7 @@ def _meta_grouped_mm_common(
|
||||||
|
|
||||||
def check_valid_strides(mat_name, mat):
|
def check_valid_strides(mat_name, mat):
|
||||||
end_dim = mat.dim() - 1
|
end_dim = mat.dim() - 1
|
||||||
alignment = 16 // mat.element_size()
|
alignment = 16 / mat.element_size()
|
||||||
mat_stride = mat.stride()
|
mat_stride = mat.stride()
|
||||||
if mat_stride[end_dim - 1] == 1 and mat_stride[end_dim] >= max(
|
if mat_stride[end_dim - 1] == 1 and mat_stride[end_dim] >= max(
|
||||||
1, mat.shape[end_dim - 1]
|
1, mat.shape[end_dim - 1]
|
||||||
|
|
@ -7582,7 +7580,7 @@ def _meta_grouped_mm_common(
|
||||||
else:
|
else:
|
||||||
torch._check(
|
torch._check(
|
||||||
False,
|
False,
|
||||||
lambda: f"Invalid strides/sizes, got {mat_stride} for strides and {mat.shape} for sizes.", # noqa: B950
|
lambda: f"Expected {mat_name} to have a contiguous dimension and not be mat_a-overlapping, got {mat_stride} for strides and {mat.shape} for sizes.", # noqa: B950
|
||||||
)
|
)
|
||||||
|
|
||||||
check_valid_strides("mat_a", mat_a)
|
check_valid_strides("mat_a", mat_a)
|
||||||
|
|
@ -7667,7 +7665,9 @@ def _meta_grouped_mm_common(
|
||||||
lambda: "If output dtype provided, it must be torch.bfloat16.",
|
lambda: "If output dtype provided, it must be torch.bfloat16.",
|
||||||
)
|
)
|
||||||
|
|
||||||
return _create_grouped_mm_output_tensor(mat_a, mat_b, offs, out_dtype)
|
out_size = _compute_grouped_mm_output_size(mat_a, mat_b, offs)
|
||||||
|
out_dtype = out_dtype or mat_a.dtype
|
||||||
|
return torch.empty(out_size, dtype=out_dtype, device=mat_a.device)
|
||||||
|
|
||||||
|
|
||||||
@register_meta(aten._grouped_mm)
|
@register_meta(aten._grouped_mm)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user