From fb26b843906bbad5e28d1edccf298c74b8e00492 Mon Sep 17 00:00:00 2001 From: drisspg Date: Mon, 19 Aug 2024 11:36:47 -0700 Subject: [PATCH] Update fused kernels and call _safe_softmax from SDPA (#133882) # UPDATE: This is take 3 of https://github.com/pytorch/pytorch/pull/131863 which was landed via co dev but not applying correclty # Summary Changes the stance of SDPA on what to do for fully masked out rows ## Current Behavior Several PyTorch users have expressed frustration over this issue: - https://github.com/pytorch/pytorch/issues/41508 - https://github.com/pytorch/pytorch/issues/103749 - https://github.com/pytorch/pytorch/issues/103963 These are significant issues with extensive discussion but no satisfactory resolution. The PyTorch team's consensus, as stated here: https://github.com/pytorch/pytorch/issues/24816#issuecomment-524415617 Can be paraphrased as follows: When passing in fully masked out rows, attention becomes ambiguous. We have two main options: 1. Uniformly attend to all values: ```python scores[masked_out_rows] = 1 / len(row) out[masked_out_rows] = 1 / len(row) * value ``` 2. Decide that attention between no queries (masked) and no keys (masked) is meaningless: ```python output[fully_masked_rows] = NaN ``` We went with option 2. Partially because it was easier to implement, but also people argued that users can slice the output to remove the NaNs: ``` Python >fill_value = -float("inf") >row0 = torch.randn(4) >row1 = torch.tensor([(fill_value for _ in range(4)]) >matrix = torch.stack([row0, row1]).requires_grad_(True) >out = torch.softmax(matrix, 1) >out = out[0] >print(out) tensor([0.5377, 0.2729, 0.0692, 0.1201]) ``` Cool, problem solved. But what happends when you call backwards.. ```Python >out.backward(torch.ones_like(out)) >print(matrix.grad) tensor([[3.0957e-08, 1.4157e-08, 7.7802e-10, 1.3713e-08], [ nan, nan, nan, nan]]) ``` Those pesky NaNs are back! ## Why do we see NaNs today? The core of the problem revolves around using softmax function in sdpa: ```python > row = torch.tensor([(-float("inf")) for _ in range(4)]) > torch.softmax(row, 0) tensor([nan, nan, nan, nan]) ``` ## Quick Aside: Masking in Attention Attention itself doesn't have a concept of masking. The `sdpa` function has an argument called `attn_mask`, which would be more accurately named `attn_bias`. This is because we don't actually "mask" entries when computing attention. Instead, due to implementation details([performance](https://github.com/pytorch/pytorch/issues/25110#issuecomment-524519087)), we add a value to the masked-out query/key pairs. We use a large negative number (typically -inf) to decrease the attention weight, as softmax assigns more weight to larger values. ## Alternative Approaches If we use a very large negative number instead of -inf: ```python > row = torch.tensor([(-1e6) for _ in range(4)]) > torch.softmax(row, 0) tensor([0.2500, 0.2500, 0.2500, 0.2500]) ``` However if users always remembered to "slice" out their outputs i.e.: ```Python >fill_value = -1e6 >... >out.backward(torch.ones_like(out)) >print(matrix.grad) tensor([[-0.0563, -0.0564, 0.1613, -0.0486], [ 0.0000, 0.0000, 0.0000, 0.0000]]) ``` This would bring us back into a better state. ## A Third Option We don't necessarily need to alter the behavior of softmax for -inf or very large negative numbers. The fundamental goal is to exclude certain query/key pairs from attention, regardless of the underlying implementation. This PR implements the new semantic for masking w/ attention in fully masked-out rows: ```python out[masked_out_rows] = 0 ``` **Important Note**: This idea isn't entirely new. The [MaskedTensor](https://pytorch.org/tutorials/prototype/maskedtensor_overview#safe-softmax) prototype, a tensor subclass, was designed to handle such cases. However, it remains a prototype feature and hasn't gained widespread adoption. ## Details This PR stack does 3 things: 1. Adds a PRIVATE _safe_softmax op 2. Updates semantic for flash_cpu fused kernel 3. Updates semantic for efficient_cuda fused kernel _safe_softmax is not supposed to be used generically and is only meant to be used within the context of SDPA. Due to this fact instead of decomposing softmax and checking for -inf rows we instead "cheat" and use nan_to_num. Why I think this is okay? (please find a counter point if avail) There are multiple ways NaNs can emerge. For the fully masked out rows case nan_to_num works. But what if there were other NaNs, wouldn't this silently remove them? The only case that this can happen is if the input itself had a NaN or an Inf For example: ```Python a = torch.ones([4], requires_grad=False, dtype=torch.float16) a[1] = torch.finfo(torch.float16).max print(a.softmax(-1)) ``` Will return `tensor([0., 1., 0., 0.], dtype=torch.float16)` Where ```Python a = torch.ones([4], requires_grad=False, dtype=torch.float16) a[1] = float("inf") a.softmax(-1) ``` returns: `tensor([nan, nan, nan, nan], dtype=torch.float16)` If we dont want to even allow for the possibility of "inf" or "NaN" attention scores to be converted to 0 then we can implemented it something like this ```Python max = torch.max(a, dim=-1, keepdim=True) exp = torch.exp(a - max.values) denom = torch.sum(exp, dim=-1, keepdim=True) softmax = exp / denom softmax = torch.where(max.values == float('-inf'), 0.0, softmax) ``` however we would be paying for this in math performance. ## Why Now I think one point that has substantially changed where PyTorch should lie on this argument is the fact that we have fused implementations for SDPA now. And these fused implementations allow us to easily and performantly support this new semantic. Differential Revision: [D61418679](https://our.internmc.facebook.com/intern/diff/D61418679) Pull Request resolved: https://github.com/pytorch/pytorch/pull/133882 Approved by: https://github.com/soulitzer --- .../ATen/native/cpu/FlashAttentionKernel.cpp | 6 + aten/src/ATen/native/native_functions.yaml | 1 + .../native/nested/NestedTensorBinaryOps.cpp | 9 ++ .../ATen/native/transformers/attention.cpp | 8 +- .../epilogue/epilogue_rescale_output.h | 10 +- .../cuda/mem_eff_attention/kernel_forward.h | 4 + test/functorch/test_ops.py | 6 - test/test_nn.py | 14 +- test/test_transformers.py | 129 +++++++++++++++++- tools/autograd/derivatives.yaml | 1 + torch/csrc/autograd/FunctionsManual.cpp | 16 +++ torch/csrc/autograd/FunctionsManual.h | 5 + .../_internal/common_methods_invocations.py | 4 +- 13 files changed, 194 insertions(+), 19 deletions(-) diff --git a/aten/src/ATen/native/cpu/FlashAttentionKernel.cpp b/aten/src/ATen/native/cpu/FlashAttentionKernel.cpp index 9d5575e998a..601b098d37c 100644 --- a/aten/src/ATen/native/cpu/FlashAttentionKernel.cpp +++ b/aten/src/ATen/native/cpu/FlashAttentionKernel.cpp @@ -452,9 +452,15 @@ void cpu_flash_attention( dst_data, headSize); } + // dst <- dst / sum[row] // reorder MHA output with strides for (int64_t row = 0; row < qBlockSize; ++row) { + // Row sums for full masked out rows are 0, we set them to 1 + // in order to avoid NaNs in the output and instead set fully + // masked out rows to 0 + qk_max_data[row] = qk_max_data[row] == -std::numeric_limits::infinity() ? 0 : qk_max_data[row]; + qk_sum_data[row] = qk_sum_data[row] == 0 ? 1 : qk_sum_data[row]; accum_t sum_reciprocal = 1 / qk_sum_data[row]; vec::map( [sum_reciprocal](Vec x) { return x * Vec(sum_reciprocal); }, diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 29760d010a2..f92b4a3d5ba 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -8890,6 +8890,7 @@ variants: method, function dispatch: QuantizedCPU: eq_quantized_cpu + NestedTensorCPU, NestedTensorCUDA: eq_tensor_nested tags: [core, pointwise] - func: ge.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) diff --git a/aten/src/ATen/native/nested/NestedTensorBinaryOps.cpp b/aten/src/ATen/native/nested/NestedTensorBinaryOps.cpp index ce3754bb73c..a22466a8e0d 100644 --- a/aten/src/ATen/native/nested/NestedTensorBinaryOps.cpp +++ b/aten/src/ATen/native/nested/NestedTensorBinaryOps.cpp @@ -321,4 +321,13 @@ Tensor eq_scalar_nested(const Tensor& self, const Scalar& other) { }); } +Tensor eq_tensor_nested(const Tensor& self, const Tensor& other) { + TORCH_CHECK(!other.is_nested(), "eq does not support nested tensor as other value."); + return NestedTensor_elementwise_Tensor( + self, other, "eq", false /*supports_striding*/, + [](const Tensor& b1, const Tensor& b2) { + return b1.eq(b2); + }); +} + } // namespace at::native diff --git a/aten/src/ATen/native/transformers/attention.cpp b/aten/src/ATen/native/transformers/attention.cpp index 9e4cd357ade..5369e87d58b 100644 --- a/aten/src/ATen/native/transformers/attention.cpp +++ b/aten/src/ATen/native/transformers/attention.cpp @@ -647,9 +647,11 @@ Tensor _safe_softmax( int64_t dim, std::optional dtype) { auto out = at::softmax(self, dim, dtype); - const auto masked = self.eq(-std::numeric_limits::infinity()); + const auto neg_inf = at::scalar_tensor(-std::numeric_limits::infinity(), at::TensorOptions().dtype(out.dtype()).device(out.device())); + const auto masked = self.eq(neg_inf); const auto masked_rows = all(masked, dim, true); - return at::where(masked_rows, at::scalar_tensor(0.0, at::TensorOptions().dtype(out.dtype()).device(out.device())), out); + const auto zero = at::scalar_tensor(0.0, at::TensorOptions().dtype(out.dtype()).device(out.device())); + return at::where(masked_rows, zero, out); } // Computes scaled dot product attention on query, key and value tensors, using // an optional attention mask if passed, and applying dropout if a probability @@ -837,7 +839,7 @@ std::tuple _scaled_dot_product_attention_math( attn.add_(*attn_mask); } } - attn = at::softmax(attn, -1); + attn = at::_safe_softmax(attn, -1); if (dropout_p > 0.0) { if (dropout_mask.has_value()) { // In order to validate the correctness of the fused kernels, we need to diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/epilogue/epilogue_rescale_output.h b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/epilogue/epilogue_rescale_output.h index 9633d286bb6..fd7982e5f69 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/epilogue/epilogue_rescale_output.h +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/epilogue/epilogue_rescale_output.h @@ -144,7 +144,10 @@ class MemoryEfficientAttentionNormalize { multiplies mul_add_source; multiply_add mul_add_accumulator; - ElementCompute alpha = isLast ? (1 / s_prime_[row]) : 1; + // Row sums for full masked out rows are 0, we set them to 1 + // In order to avoid NaNs in the output and instead sem them to 0. + ElementCompute denom = s_prime_[row] == 0 ? 1 : s_prime_[row]; + ElementCompute alpha = isLast ? (1 / denom) : 1; ElementCompute beta = alpha * m_prime_[row]; intermediate = mul_add_source(beta, converted_source); // X = beta * C @@ -174,7 +177,10 @@ class MemoryEfficientAttentionNormalize { ComputeFragment intermediate; multiplies mul_accumulator; - ElementCompute alpha = isLast ? (1 / s_prime_[row]) : 1; + // Row sums for full masked out rows are 0, we set them to 1 + // In order to avoid NaNs in the output and instead sem them to 0. + ElementCompute denom = s_prime_[row] == 0 ? 1 : s_prime_[row]; + ElementCompute alpha = isLast ? (1 / denom) : 1; intermediate = mul_accumulator( alpha, converted_accumulator); // X = alpha * C + uniform diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_forward.h b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_forward.h index a10e5a9c44a..466b6013c9d 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_forward.h +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_forward.h @@ -1166,6 +1166,10 @@ struct AttentionKernel { auto lse_dim = ceil_div((int32_t)p.num_queries, kAlignLSE) * kAlignLSE; constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E if (thread_id() < p.num_queries) { + // We set fully masked out rows to 0, the sumexp for masked out rows will be 0 + // We update it to be 1 prior to calling log so that log(1) = 0 + s_prime[thread_id()] = (s_prime[thread_id()] == 0) ? 1: s_prime[thread_id()]; + mi[thread_id()] = (mi[thread_id()] == -cutlass::platform::numeric_limits::infinity()) ? 0: mi[thread_id()]; p.logsumexp_ptr[thread_id()] = accum_t(mi[thread_id()] / kLog2e) + cutlass::fast_log(accum_t(s_prime[thread_id()])); } else if (thread_id() < lse_dim) { diff --git a/test/functorch/test_ops.py b/test/functorch/test_ops.py index 64c3a1706b5..03744b7a8ef 100644 --- a/test/functorch/test_ops.py +++ b/test/functorch/test_ops.py @@ -1791,9 +1791,6 @@ class TestOperators(TestCase): ), # NYI: forward-AD for soft_margin_loss_backward xfail("nn.functional.ctc_loss", ""), # NYI: forward-AD for _ctc_loss xfail("nn.functional.pdist", ""), # NYI: forward-AD with _pdist_forward - xfail( - "torch.ops.aten._safe_softmax.default" - ), # NYI: forward-AD for _safe_softmax skip("nn.functional.scaled_dot_product_attention"), xfail("torch.ops.aten._efficient_attention_forward"), # outputs ints xfail( @@ -1976,9 +1973,6 @@ class TestOperators(TestCase): xfail( "nn.functional.ctc_loss" ), # ForwardAD not implemented and no decomposition - xfail( - "torch.ops.aten._safe_softmax.default" - ), # ForwardAD not implemented xfail("nn.functional.dropout2d"), # calls random op xfail("nn.functional.dropout3d"), # calls random op xfail("nn.functional.dropout"), # calls random op diff --git a/test/test_nn.py b/test/test_nn.py index 0b165ca8d51..eb4ccd76515 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -12387,12 +12387,20 @@ if __name__ == '__main__': result = model(encoder_input, src_key_padding_mask=mask) self.assertEqual(result.shape, ref_output.shape) torch.testing.assert_close(result, ref_output, atol=atol, rtol=rtol) - # 1 values are masked. Since there is only 1 input embedding this - # will result in nan. mask = torch.tensor([[1]], device=device) == 1 result = model(encoder_input, src_key_padding_mask=mask) + fast_path_device = result.is_cuda or result.is_cpu result = result.cpu().detach().numpy() - self.assertTrue(np.isnan(result).all()) + # Non Fast Paths + if training or not batch_first or TEST_WITH_CROSSREF or not fast_path_device: + # We changed the semenatic, on the non fast path so that fully masked out rows return + # 0 from attention thus NaNs should no longer be present and the output should be nonzero + # due to skip connections + self.assertTrue(not np.isnan(result).any()) + else: + # Fast Paths + self.assertTrue(np.isnan(result).all()) + # deterministic input encoder_input = perm_fn(torch.tensor([[[1., 2., 3., 4.]], diff --git a/test/test_transformers.py b/test/test_transformers.py index f402a01a0f8..68f46330a12 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -347,6 +347,7 @@ class TestTransformers(NNTestCase): @parametrize("key_padding_mask_dim", [2, None]) @parametrize("mask_dtype", [torch.bool, torch.float32]) def test_multiheadattention_fastpath_attn_mask(self, device, attn_mask_dim, key_padding_mask_dim, mask_dtype): + # MHA converts all with torch.no_grad(): B = 2 L = 4 @@ -356,7 +357,7 @@ class TestTransformers(NNTestCase): if attn_mask_dim == 2: attn_mask = make_tensor((L, L), dtype=mask_dtype, device=device) elif attn_mask_dim == 3: - attn_mask = make_tensor((B * H, L, L), dtype=mask_dtype, device=device) + attn_mask = make_tensor((B, 1, L, L), dtype=mask_dtype, device=device).expand(B, H, L, L).reshape(B * H, L, L) elif attn_mask_dim is None: attn_mask = None @@ -372,7 +373,9 @@ class TestTransformers(NNTestCase): out, _ = mha(X, X, X, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False) mha.eval() # enable fast path out_fp, _ = mha(X, X, X, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False) - self.assertEqual(out, out_fp) + # The FP kernel will return NaNs while the sdpa kernel which is ran when the fast path is turned off returns 0 instead + # of NaNs for fully masked rows + torch.testing.assert_close(out, out_fp.nan_to_num()) @parametrize("nhead", [1, 4, 8]) def test_transformerencoderlayer_src_mask(self, device, nhead): @@ -1156,6 +1159,25 @@ class TestTransformers(NNTestCase): else: actual = torch.nn.functional.scaled_dot_product_attention( query, key, value, attn_mask, dropout_p, is_causal) + # This test the fully masked out rows case + if torch.isnan(expected).any(): + row_sums = attn_mask.sum(dim=-1) + masked_out_rows = (row_sums == 0) + + for _ in range((input_dim - attn_mask_dim) - 1): + masked_out_rows = masked_out_rows.unsqueeze(0) + + masked_out_rows = masked_out_rows.expand(expected.shape[:-1]) + # Slice out the fully masked rows from expected and actual + expected_masked_out = expected[masked_out_rows] + actual_masked_out = actual[masked_out_rows] + + expected_all_nan = torch.isnan(expected_masked_out).all() + actual_all_zero = (actual_masked_out.abs().sum() == 0) + + self.assertTrue(expected_all_nan) + self.assertTrue(actual_all_zero) + return self.assertEqual(actual, expected) @@ -1961,7 +1983,7 @@ class TestSDPACpuOnly(NNTestCase): @parametrize("n_head", [1, 3]) @parametrize("head_dim", [8]) @parametrize("mask_dim", [2, 4]) - @parametrize("bool_mask", [0, 1]) + @parametrize("bool_mask", [False, True]) @parametrize("train", [True, False]) @parametrize("casual", [True, False]) @parametrize("set_attn_mask", [True, False]) @@ -2036,6 +2058,9 @@ class TestSDPACpuOnly(NNTestCase): if dtype in [torch.bfloat16, torch.float16]: math_ref = math_ref.to(dtype) + self.assertFalse(torch.isnan(math_ref).any()) + self.assertFalse(torch.isnan(actual).any()) + self.assertEqual(actual, math_ref, atol=tol.atol, rtol=tol.rtol) if train: @@ -2064,6 +2089,104 @@ class TestSDPACpuOnly(NNTestCase): actual = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask) self.assertEqual(math_ref, actual) + @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system") + @parametrize("backend", [SDPBackend.EFFICIENT_ATTENTION, SDPBackend.FLASH_ATTENTION]) + @parametrize("seq_len", [32, 64, 128]) + @parametrize("head_dim", [16, 32]) + @parametrize("dtype", [torch.float32, torch.float16]) + def test_fully_masked_out_rows(self, backend, device, seq_len, head_dim, dtype): + def attention_inputs(seq_len, head_dim, device, dtype, mask_every_n_rows=4): + query = torch.rand(1, 1, seq_len, head_dim, requires_grad=True, device=device, dtype=dtype) + key = torch.rand(1, 1, seq_len, head_dim, requires_grad=True, device=device, dtype=dtype) + value = torch.rand(1, 1, seq_len, head_dim, requires_grad=True, device=device, dtype=dtype) + + # Create a mask with deterministic row masking + mask = torch.ones(1, 1, seq_len, seq_len, dtype=torch.bool, device=device) + + # Mask every nth row + mask[0, 0, ::mask_every_n_rows, :] = False + + # Create a fixed pattern for element-wise masking + element_mask = torch.zeros(seq_len, seq_len, dtype=torch.bool, device=device) + element_mask[torch.arange(seq_len)[:, None] % 5 == torch.arange(seq_len) % 5] = True + + # Combine row masking and element-wise masking + mask = mask & element_mask.unsqueeze(0).unsqueeze(0) + + return query, key, value, mask + + def compute_output_and_grads(query, key, value, mask, backend): + with sdpa_kernel(backend): + masked_out = scaled_dot_product_attention(query, key, value, attn_mask=mask) + loss = masked_out.sum() + grads = torch.autograd.grad(loss, [query, key, value]) + return masked_out, grads + + if backend == SDPBackend.FLASH_ATTENTION and "cuda" in str(device): + unittest.skip("FlashAttention does not support masks on cuda") + return + if backend == SDPBackend.EFFICIENT_ATTENTION and "cpu" in str(device): + unittest.skip("EfficientAttention does not support masks on cpu") + return + query, key, value, mask = attention_inputs(seq_len, head_dim, device, dtype) + + # Compute results for the tested backend + backend_out, backend_grads = compute_output_and_grads(query, key, value, mask, backend) + + # Compute results for the Math backend + math_out, math_grads = compute_output_and_grads(query, key, value, mask, SDPBackend.MATH) + + # Compare outputs + torch.testing.assert_close(backend_out, math_out, atol=5e-3, rtol=0) + self.assertFalse(backend_out.isnan().any()) + self.assertFalse(math_out.isnan().any()) + # Compare gradients + for bg, mg in zip(backend_grads, math_grads): + torch.testing.assert_close(bg, mg, atol=3e-3, rtol=0) + self.assertFalse(bg.isnan().any()) + self.assertFalse(mg.isnan().any()) + + # Check if masked rows are zero in output + mask_sum = mask.sum(dim=-1, keepdim=True) + masked_rows = (mask_sum == 0).expand_as(backend_out) + self.assertTrue((mask_sum == 0).sum() > 0, "No fully masked out rows found") + assert torch.all(backend_out[masked_rows] == 0), \ + f"Non-zero values in fully masked rows for {backend=}" + + # Check if gradients for masked rows are zero + grad_query = backend_grads[0] + assert torch.all(grad_query[masked_rows] == 0), f"Non-zero gradients in fully masked rows for {backend=}" + + @parametrize("dtype", [torch.float32, torch.float16]) + @parametrize("fill_val", [float("inf")]) + def test_non_masked_rows_nan_props(self, device, dtype, fill_val): + query = torch.randn(1, 2, 4, 16, device=device, dtype=dtype) + # a single NaN in the query input + query[0, 1, 2, 3] = fill_val + query = query.detach().requires_grad_(True) + key = torch.randn(1, 2, 4, 16, device=device, dtype=dtype, requires_grad=True) + value = torch.randn(1, 2, 4, 16, device=device, dtype=dtype, requires_grad=True) + + out = torch.nn.functional.scaled_dot_product_attention(query, key, value) + self.assertTrue(torch.isnan(out).any()) + out.sum().backward() + self.assertTrue(torch.isnan(query.grad).any()) + + @parametrize("kernel", [SDPBackend.MATH]) + def test_scaled_dot_product_attention_math_with_negative_scale(self, device, kernel: SDPBackend): + # https://github.com/pytorch/pytorch/issues/105190. + def ref(x): + v1 = torch.matmul(x, x.transpose(-1, -2)) + v2 = v1 / -0.0001 + v3 = v2.softmax(dim=-1) + v4 = torch.matmul(v3, x) + return v4 + + x = torch.randn(1, 3, 64, 64, device=device) + ref_result = ref(x) + with sdpa_kernel(backends=[kernel]): + sdp_math = torch.nn.functional.scaled_dot_product_attention(x, x, x, scale=-1.0 / 0.0001) + self.assertEqual(ref_result, sdp_math) class TestSDPACudaOnly(NNTestCase): """ Used to test CUDA only functionality of scaled_dot_product_attention diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index e3e2291083a..96c55c666b6 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -2845,6 +2845,7 @@ # Transformer - name: _safe_softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor self: _softmax_backward_data(grad, result, dim, self.scalar_type()) + result: result * (self_t - safe_logsumexp_jvp(self_p, self_t, {dim}, true)) - name: _scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor log_sumexp, Tensor philox_seed, Tensor philox_offset) output_differentiability: [True, False, False, False] diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index f7be4845807..37b1a7bd598 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -6715,6 +6715,22 @@ Tensor logsumexp_jvp( } } +Tensor safe_logsumexp_jvp( + const Tensor& self_p, + const Tensor& self_t, + IntArrayRef dim, + bool keepdim) { + auto lse_jvp = logsumexp_jvp(self_p, self_t, dim, keepdim); + const auto neg_inf = at::scalar_tensor( + -std::numeric_limits::infinity(), + at::TensorOptions().dtype(lse_jvp.dtype()).device(lse_jvp.device())); + const auto masked = self_p.eq(neg_inf); + const auto masked_rows = all(masked, dim, true); + const auto zero = at::scalar_tensor( + 0.0, at::TensorOptions().dtype(lse_jvp.dtype()).device(lse_jvp.device())); + return at::where(masked_rows, zero, lse_jvp); +} + Tensor warn_backwards(const Tensor& grad_output) { TORCH_WARN("Warn from backward"); return grad_output; diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h index 44a00ee525c..cacfc6e9015 100644 --- a/torch/csrc/autograd/FunctionsManual.h +++ b/torch/csrc/autograd/FunctionsManual.h @@ -229,6 +229,11 @@ at::Tensor logsumexp_jvp( const at::Tensor& self_t, IntArrayRef dim, bool keepdim); +at::Tensor safe_logsumexp_jvp( + const at::Tensor& self_p, + const at::Tensor& self_t, + IntArrayRef dim, + bool keepdim); at::Tensor logcumsumexp_backward( at::Tensor grad, const at::Tensor& self, diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 78f99a9012f..43ffd6dbbf4 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -16210,8 +16210,8 @@ op_db: List[OpInfo] = [ sample_inputs_func=sample_inputs_safe_softmax, assert_jit_shape_analysis=True, assert_autodiffed=True, - supports_forward_ad=False, - supports_fwgrad_bwgrad=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, supports_out=False, supports_cow_input_no_materialize_backward=False, decorators=[],