From 70c9146c40497130a57f3e15573d2be02cedacea Mon Sep 17 00:00:00 2001 From: Raghavan Raman Date: Wed, 19 Jan 2022 15:14:41 -0800 Subject: [PATCH] [nnc] Update block and thread extents in cuda_codegen to use int64_t (#71428) Summary: The block and thread extent calculations in `cuda_codegen` should be using `int64_t` instead of `int`. The updated test, `test_dynamic_shapes`, fails without this change. Pull Request resolved: https://github.com/pytorch/pytorch/pull/71428 Reviewed By: samdow Differential Revision: D33640374 Pulled By: navahgar fbshipit-source-id: 64c340ad2a9a1fa1fe066cf1c5dfc3b546b7be6d (cherry picked from commit 6ea546ce116fc05d9d7e225bc29f7fe86be439de) --- test/cpp/tensorexpr/test_cuda.cpp | 22 +++++++++++----------- test/test_jit_fuser_te.py | 3 +-- torch/csrc/jit/tensorexpr/cuda_codegen.cpp | 8 ++++---- 3 files changed, 16 insertions(+), 17 deletions(-) diff --git a/test/cpp/tensorexpr/test_cuda.cpp b/test/cpp/tensorexpr/test_cuda.cpp index bc267d9158a..feca646a657 100644 --- a/test/cpp/tensorexpr/test_cuda.cpp +++ b/test/cpp/tensorexpr/test_cuda.cpp @@ -160,7 +160,7 @@ TEST(Cuda, TestVectorAdd01_CUDA) { testCudaTestVectorAdd01_impl(); } -static void testCudaTestVectorAdd02_impl(int N, int block_size) { +static void testCudaTestVectorAdd02_impl(int64_t N, int64_t block_size) { BufHandle a_buf("a", {N}, kFloat); BufHandle b_buf("b", {N}, kFloat); Tensor c = Compute( @@ -378,8 +378,8 @@ TEST(Cuda, TestRand01_CUDA) { } TEST(Cuda, DynamicShapeSplit_CUDA) { - constexpr int N = 4096; - VarHandle n("n", kInt); + constexpr int64_t N = 4096; + VarHandle n("n", kLong); BufHandle a("a", {n}, kFloat); Tensor b = Compute( "b", {{n, "n"}}, [&](const VarHandle& i) { return a.load(i) * 2.0f; }); @@ -1645,9 +1645,9 @@ TEST(Cuda, MaskMultiDim_CUDA) { // In this case both stores must be masked against the extent of the other loop, // incase it is larger. TEST(Cuda, MaskMultiDimSymbolic_CUDA) { - VarHandle OUTER_SIZE("OUTER_SIZE", kInt); - VarHandle A_SIZE("A_SIZE", kInt); - VarHandle B_SIZE("B_SIZE", kInt); + VarHandle OUTER_SIZE("OUTER_SIZE", kLong); + VarHandle A_SIZE("A_SIZE", kLong); + VarHandle B_SIZE("B_SIZE", kLong); BufHandle a_buf("a", {OUTER_SIZE, A_SIZE}, kFloat); BufHandle b_buf("b", {OUTER_SIZE, B_SIZE}, kFloat); Tensor c = Compute( @@ -1682,10 +1682,10 @@ TEST(Cuda, MaskMultiDimSymbolic_CUDA) { const std::string& verification_pattern = R"IR( # CHECK: if (threadIdx.x(A_SIZE.node(), B_SIZE.node(), true))); - int OUTER_EXTENT = 10; - int A_EXTENT = 100; - int B_EXTENT = 50; + int64_t OUTER_EXTENT = 10; + int64_t A_EXTENT = 100; + int64_t B_EXTENT = 50; PaddedBuffer a_v(OUTER_EXTENT, A_EXTENT); PaddedBuffer b_v(OUTER_EXTENT, B_EXTENT); diff --git a/test/test_jit_fuser_te.py b/test/test_jit_fuser_te.py index 5ebe09d99f4..cb6cfff036b 100644 --- a/test/test_jit_fuser_te.py +++ b/test/test_jit_fuser_te.py @@ -2058,8 +2058,7 @@ class TestTEFuser(JitTestCase): funcs = [foo, fi, fum] with inline_fusion_groups(): - # TODO: cuda ir eval error - for device in ['cpu']: + for device in self.devices: I = partial(torch.randint, 0, 100, device=device) R = partial(torch.randn, device=device) diff --git a/torch/csrc/jit/tensorexpr/cuda_codegen.cpp b/torch/csrc/jit/tensorexpr/cuda_codegen.cpp index a0e4cfd290b..72fe01f63df 100644 --- a/torch/csrc/jit/tensorexpr/cuda_codegen.cpp +++ b/torch/csrc/jit/tensorexpr/cuda_codegen.cpp @@ -1114,21 +1114,21 @@ void CudaCodeGen::call_raw(const std::vector& raw_args) { // module. for (size_t i = 0; i < gpu_block_extents.size(); i++) { if (gpu_block_extents[i]->isConstant()) { - gpu_block_extents_v[i] = immediateAs(gpu_block_extents[i]); + gpu_block_extents_v[i] = immediateAs(gpu_block_extents[i]); continue; } ExprEval eval( ExprHandle(gpu_block_extents[i]), buffer_args); - gpu_block_extents_v[i] = eval.value(raw_args); + gpu_block_extents_v[i] = eval.value(raw_args); } for (size_t i = 0; i < gpu_thread_extents.size(); i++) { if (gpu_thread_extents[i]->isConstant()) { - gpu_thread_extents_v[i] = immediateAs(gpu_thread_extents[i]); + gpu_thread_extents_v[i] = immediateAs(gpu_thread_extents[i]); continue; } ExprEval eval( ExprHandle(gpu_thread_extents[i]), buffer_args); - gpu_thread_extents_v[i] = eval.value(raw_args); + gpu_thread_extents_v[i] = eval.value(raw_args); } // Skip launching the kernel if there are no elements to process.