From cecfc7dc53d7b3bf8a2014bc3ec87af54bc104b3 Mon Sep 17 00:00:00 2001 From: Eddie Yan Date: Wed, 7 May 2025 22:01:18 +0000 Subject: [PATCH] [CUDA][cuDNN] Fix handling of `CPU` side input and target length tensors in `CTCLoss` (#152745) https://github.com/pytorch/pytorch/pull/128271 migrated to cuDNN V8 CTCLoss which expects input and target length tensors to be on `CUDA` rather than `CPU` without adding the logic to account for the edge case of them being on `CPU` see also #152421 Pull Request resolved: https://github.com/pytorch/pytorch/pull/152745 Approved by: https://github.com/Skylion007 --- aten/src/ATen/native/cudnn/LossCTC.cpp | 24 +++++++++++++++---- test/test_nn.py | 32 +++++++++++++++++++++++++- 2 files changed, 51 insertions(+), 5 deletions(-) diff --git a/aten/src/ATen/native/cudnn/LossCTC.cpp b/aten/src/ATen/native/cudnn/LossCTC.cpp index 915fbed0f06..5cd295f17c6 100644 --- a/aten/src/ATen/native/cudnn/LossCTC.cpp +++ b/aten/src/ATen/native/cudnn/LossCTC.cpp @@ -151,6 +151,13 @@ bool _use_cudnn_ctc_loss_tensor( } } } else { + if (target_lengths.device().type() != at::kCUDA || + input_lengths.device().type() != at::kCUDA) { + TORCH_CHECK( + false, + "CTCLoss cannot be graph captured with CPU length tensors. " + "Move CPU length tensors to GPU memory to enable graph capture.") + } at::_assert_async(at::lt(input_lengths.max(), 256)); at::_assert_async(at::le(target_lengths, input_lengths).all()); } @@ -253,9 +260,18 @@ std::tuple _cudnn_ctc_loss_tensor( bool deterministic, bool zero_infinity) { Tensor targets_t_ = targets_t; + Tensor input_lengths_ = input_lengths; + Tensor target_lengths_ = target_lengths; if (targets_t.device().type() == at::kCPU) { targets_t_ = targets_t.to(Device(at::kCUDA)); } + if (input_lengths.device().type() == at::kCPU) { + input_lengths_ = input_lengths.to(Device(at::kCUDA)); + } + if (input_lengths.device().type() == at::kCPU) { + target_lengths_ = target_lengths.to(Device(at::kCUDA)); + } + const CheckedFrom c = "cudnn_ctc_loss"; const TensorArg log_probs{log_probs_t, "log_probs", 1}; const TensorArg targets{targets_t_, "targets", 2}; @@ -268,9 +284,9 @@ std::tuple _cudnn_ctc_loss_tensor( checkBackend(c, {*targets}, Backend::CUDA); const auto batch_size = log_probs->size(1); int64_t input_lengths_size = - input_lengths.sizes().size() ? input_lengths.size(0) : 1; + input_lengths_.sizes().size() ? input_lengths_.size(0) : 1; int64_t target_lengths_size = - target_lengths.sizes().size() ? target_lengths.size(0) : 1; + target_lengths_.sizes().size() ? target_lengths_.size(0) : 1; TORCH_CHECK( input_lengths_size == batch_size, "input_lengths needs to have size to match batch_size"); @@ -319,8 +335,8 @@ std::tuple _cudnn_ctc_loss_tensor( log_probs_desc.desc(), log_probs_t.data_ptr(), targets_t_.data_ptr(), - target_lengths.data_ptr(), - input_lengths.data_ptr(), + target_lengths_.data_ptr(), + input_lengths_.data_ptr(), costs.data_ptr(), grad_desc.desc(), grad.data_ptr(), diff --git a/test/test_nn.py b/test/test_nn.py index 784d641bd4e..4a940cc2f32 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -11523,7 +11523,7 @@ class TestNNDeviceType(NNTestCase): @onlyCUDA @skipCUDAIfRocm(msg="skipped Cudnn test on ROCm") - def test_ctc_loss_cudnn_tensor(self, device): + def test_ctc_loss_cudnn_tensor_cuda(self): batch_size = 16 input_length = 30 num_labels = 101 @@ -11549,6 +11549,36 @@ class TestNNDeviceType(NNTestCase): grad_cudnn, = torch.autograd.grad(loss_cudnn, log_probs, grad_out) self.assertEqual(grad_cudnn, grad_native, atol=1e-4, rtol=0) + @onlyCUDA + @skipCUDAIfRocm(msg="skipped Cudnn test on ROCm") + def test_ctc_loss_cudnn_tensor_cpu_length_cuda(self): + # batch size + N = 50 + # audio length + T = 100 + # text dimension + C = 80 + # max text length + S = 10 + + prob_device = torch.device("cuda") + other_device = torch.device("cpu") + other_dtype = torch.int32 + + log_probs = torch.randn(T, N, C).log_softmax(2).to(prob_device) + + input_lengths = torch.full((N,), T, dtype=other_dtype).to(other_device) + target_lengths = torch.randint(low=1, high=S, size=(N,), dtype=other_dtype).to(other_device) + targets = torch.randint(low=0, high=C, size=(sum(target_lengths),), dtype=other_dtype).to(other_device) + + ctc_loss = torch.nn.functional.ctc_loss( + log_probs=log_probs, + targets=targets, + input_lengths=input_lengths, + target_lengths=target_lengths, + reduction="sum", + ) + @expectedFailureMPS def test_ctc_loss_error(self, device): log_probs = torch.rand(0, 0, 4, device=device)