[nnc] Update block and thread extents in cuda_codegen to use int64_t (#71428)

Summary:
The block and thread extent calculations in `cuda_codegen` should be using `int64_t` instead of `int`. The updated test, `test_dynamic_shapes`, fails without this change.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/71428

Reviewed By: samdow

Differential Revision: D33640374

Pulled By: navahgar

fbshipit-source-id: 64c340ad2a9a1fa1fe066cf1c5dfc3b546b7be6d
(cherry picked from commit 6ea546ce11)
This commit is contained in:
Raghavan Raman 2022-01-19 15:14:41 -08:00 committed by PyTorch MergeBot
parent 2dbbb1a921
commit 70c9146c40
3 changed files with 16 additions and 17 deletions

View File

@ -160,7 +160,7 @@ TEST(Cuda, TestVectorAdd01_CUDA) {
testCudaTestVectorAdd01_impl<int64_t>();
}
static void testCudaTestVectorAdd02_impl(int N, int block_size) {
static void testCudaTestVectorAdd02_impl(int64_t N, int64_t block_size) {
BufHandle a_buf("a", {N}, kFloat);
BufHandle b_buf("b", {N}, kFloat);
Tensor c = Compute(
@ -378,8 +378,8 @@ TEST(Cuda, TestRand01_CUDA) {
}
TEST(Cuda, DynamicShapeSplit_CUDA) {
constexpr int N = 4096;
VarHandle n("n", kInt);
constexpr int64_t N = 4096;
VarHandle n("n", kLong);
BufHandle a("a", {n}, kFloat);
Tensor b = Compute(
"b", {{n, "n"}}, [&](const VarHandle& i) { return a.load(i) * 2.0f; });
@ -1645,9 +1645,9 @@ TEST(Cuda, MaskMultiDim_CUDA) {
// In this case both stores must be masked against the extent of the other loop,
// incase it is larger.
TEST(Cuda, MaskMultiDimSymbolic_CUDA) {
VarHandle OUTER_SIZE("OUTER_SIZE", kInt);
VarHandle A_SIZE("A_SIZE", kInt);
VarHandle B_SIZE("B_SIZE", kInt);
VarHandle OUTER_SIZE("OUTER_SIZE", kLong);
VarHandle A_SIZE("A_SIZE", kLong);
VarHandle B_SIZE("B_SIZE", kLong);
BufHandle a_buf("a", {OUTER_SIZE, A_SIZE}, kFloat);
BufHandle b_buf("b", {OUTER_SIZE, B_SIZE}, kFloat);
Tensor c = Compute(
@ -1682,10 +1682,10 @@ TEST(Cuda, MaskMultiDimSymbolic_CUDA) {
const std::string& verification_pattern =
R"IR(
# CHECK: if (threadIdx.x<A_SIZE
# CHECK: C[A_SIZE * blockIdx.x + threadIdx.x] =
# CHECK: C[A_SIZE * int64_t(blockIdx.x) + int64_t(threadIdx.x)] =
# CHECK: __syncthreads();
# CHECK: if (threadIdx.x<B_SIZE
# CHECK: D[B_SIZE * blockIdx.x + threadIdx.x] =)IR";
# CHECK: D[B_SIZE * int64_t(blockIdx.x) + int64_t(threadIdx.x)] =)IR";
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
@ -1695,9 +1695,9 @@ TEST(Cuda, MaskMultiDimSymbolic_CUDA) {
ASSERT_TRUE(exprEquals(
threadExtents[0], alloc<Max>(A_SIZE.node(), B_SIZE.node(), true)));
int OUTER_EXTENT = 10;
int A_EXTENT = 100;
int B_EXTENT = 50;
int64_t OUTER_EXTENT = 10;
int64_t A_EXTENT = 100;
int64_t B_EXTENT = 50;
PaddedBuffer<float> a_v(OUTER_EXTENT, A_EXTENT);
PaddedBuffer<float> b_v(OUTER_EXTENT, B_EXTENT);

View File

@ -2058,8 +2058,7 @@ class TestTEFuser(JitTestCase):
funcs = [foo, fi, fum]
with inline_fusion_groups():
# TODO: cuda ir eval error
for device in ['cpu']:
for device in self.devices:
I = partial(torch.randint, 0, 100, device=device)
R = partial(torch.randn, device=device)

View File

@ -1114,21 +1114,21 @@ void CudaCodeGen::call_raw(const std::vector<void*>& raw_args) {
// module.
for (size_t i = 0; i < gpu_block_extents.size(); i++) {
if (gpu_block_extents[i]->isConstant()) {
gpu_block_extents_v[i] = immediateAs<int>(gpu_block_extents[i]);
gpu_block_extents_v[i] = immediateAs<int64_t>(gpu_block_extents[i]);
continue;
}
ExprEval<SimpleIREvaluator> eval(
ExprHandle(gpu_block_extents[i]), buffer_args);
gpu_block_extents_v[i] = eval.value<int>(raw_args);
gpu_block_extents_v[i] = eval.value<int64_t>(raw_args);
}
for (size_t i = 0; i < gpu_thread_extents.size(); i++) {
if (gpu_thread_extents[i]->isConstant()) {
gpu_thread_extents_v[i] = immediateAs<int>(gpu_thread_extents[i]);
gpu_thread_extents_v[i] = immediateAs<int64_t>(gpu_thread_extents[i]);
continue;
}
ExprEval<SimpleIREvaluator> eval(
ExprHandle(gpu_thread_extents[i]), buffer_args);
gpu_thread_extents_v[i] = eval.value<int>(raw_args);
gpu_thread_extents_v[i] = eval.value<int64_t>(raw_args);
}
// Skip launching the kernel if there are no elements to process.