mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Update fused kernels and call _safe_softmax from SDPA (#133882)
# UPDATE: This is take 3 of https://github.com/pytorch/pytorch/pull/131863 which was landed via co dev but not applying correclty # Summary Changes the stance of SDPA on what to do for fully masked out rows ## Current Behavior Several PyTorch users have expressed frustration over this issue: - https://github.com/pytorch/pytorch/issues/41508 - https://github.com/pytorch/pytorch/issues/103749 - https://github.com/pytorch/pytorch/issues/103963 These are significant issues with extensive discussion but no satisfactory resolution. The PyTorch team's consensus, as stated here: https://github.com/pytorch/pytorch/issues/24816#issuecomment-524415617 Can be paraphrased as follows: When passing in fully masked out rows, attention becomes ambiguous. We have two main options: 1. Uniformly attend to all values: ```python scores[masked_out_rows] = 1 / len(row) out[masked_out_rows] = 1 / len(row) * value ``` 2. Decide that attention between no queries (masked) and no keys (masked) is meaningless: ```python output[fully_masked_rows] = NaN ``` We went with option 2. Partially because it was easier to implement, but also people argued that users can slice the output to remove the NaNs: ``` Python >fill_value = -float("inf") >row0 = torch.randn(4) >row1 = torch.tensor([(fill_value for _ in range(4)]) >matrix = torch.stack([row0, row1]).requires_grad_(True) >out = torch.softmax(matrix, 1) >out = out[0] >print(out) tensor([0.5377, 0.2729, 0.0692, 0.1201]) ``` Cool, problem solved. But what happends when you call backwards.. ```Python >out.backward(torch.ones_like(out)) >print(matrix.grad) tensor([[3.0957e-08, 1.4157e-08, 7.7802e-10, 1.3713e-08], [ nan, nan, nan, nan]]) ``` Those pesky NaNs are back! ## Why do we see NaNs today? The core of the problem revolves around using softmax function in sdpa: ```python > row = torch.tensor([(-float("inf")) for _ in range(4)]) > torch.softmax(row, 0) tensor([nan, nan, nan, nan]) ``` ## Quick Aside: Masking in Attention Attention itself doesn't have a concept of masking. The `sdpa` function has an argument called `attn_mask`, which would be more accurately named `attn_bias`. This is because we don't actually "mask" entries when computing attention. Instead, due to implementation details([performance](https://github.com/pytorch/pytorch/issues/25110#issuecomment-524519087)), we add a value to the masked-out query/key pairs. We use a large negative number (typically -inf) to decrease the attention weight, as softmax assigns more weight to larger values. ## Alternative Approaches If we use a very large negative number instead of -inf: ```python > row = torch.tensor([(-1e6) for _ in range(4)]) > torch.softmax(row, 0) tensor([0.2500, 0.2500, 0.2500, 0.2500]) ``` However if users always remembered to "slice" out their outputs i.e.: ```Python >fill_value = -1e6 >... >out.backward(torch.ones_like(out)) >print(matrix.grad) tensor([[-0.0563, -0.0564, 0.1613, -0.0486], [ 0.0000, 0.0000, 0.0000, 0.0000]]) ``` This would bring us back into a better state. ## A Third Option We don't necessarily need to alter the behavior of softmax for -inf or very large negative numbers. The fundamental goal is to exclude certain query/key pairs from attention, regardless of the underlying implementation. This PR implements the new semantic for masking w/ attention in fully masked-out rows: ```python out[masked_out_rows] = 0 ``` **Important Note**: This idea isn't entirely new. The [MaskedTensor](https://pytorch.org/tutorials/prototype/maskedtensor_overview#safe-softmax) prototype, a tensor subclass, was designed to handle such cases. However, it remains a prototype feature and hasn't gained widespread adoption. ## Details This PR stack does 3 things: 1. Adds a PRIVATE _safe_softmax op 2. Updates semantic for flash_cpu fused kernel 3. Updates semantic for efficient_cuda fused kernel _safe_softmax is not supposed to be used generically and is only meant to be used within the context of SDPA. Due to this fact instead of decomposing softmax and checking for -inf rows we instead "cheat" and use nan_to_num. Why I think this is okay? (please find a counter point if avail) There are multiple ways NaNs can emerge. For the fully masked out rows case nan_to_num works. But what if there were other NaNs, wouldn't this silently remove them? The only case that this can happen is if the input itself had a NaN or an Inf For example: ```Python a = torch.ones([4], requires_grad=False, dtype=torch.float16) a[1] = torch.finfo(torch.float16).max print(a.softmax(-1)) ``` Will return `tensor([0., 1., 0., 0.], dtype=torch.float16)` Where ```Python a = torch.ones([4], requires_grad=False, dtype=torch.float16) a[1] = float("inf") a.softmax(-1) ``` returns: `tensor([nan, nan, nan, nan], dtype=torch.float16)` If we dont want to even allow for the possibility of "inf" or "NaN" attention scores to be converted to 0 then we can implemented it something like this ```Python max = torch.max(a, dim=-1, keepdim=True) exp = torch.exp(a - max.values) denom = torch.sum(exp, dim=-1, keepdim=True) softmax = exp / denom softmax = torch.where(max.values == float('-inf'), 0.0, softmax) ``` however we would be paying for this in math performance. ## Why Now I think one point that has substantially changed where PyTorch should lie on this argument is the fact that we have fused implementations for SDPA now. And these fused implementations allow us to easily and performantly support this new semantic. Differential Revision: [D61418679](https://our.internmc.facebook.com/intern/diff/D61418679) Pull Request resolved: https://github.com/pytorch/pytorch/pull/133882 Approved by: https://github.com/soulitzer
This commit is contained in:
parent
f1dc3b108a
commit
fb26b84390
|
|
@ -452,9 +452,15 @@ void cpu_flash_attention(
|
|||
dst_data,
|
||||
headSize);
|
||||
}
|
||||
|
||||
// dst <- dst / sum[row]
|
||||
// reorder MHA output with strides
|
||||
for (int64_t row = 0; row < qBlockSize; ++row) {
|
||||
// Row sums for full masked out rows are 0, we set them to 1
|
||||
// in order to avoid NaNs in the output and instead set fully
|
||||
// masked out rows to 0
|
||||
qk_max_data[row] = qk_max_data[row] == -std::numeric_limits<accum_t>::infinity() ? 0 : qk_max_data[row];
|
||||
qk_sum_data[row] = qk_sum_data[row] == 0 ? 1 : qk_sum_data[row];
|
||||
accum_t sum_reciprocal = 1 / qk_sum_data[row];
|
||||
vec::map<scalar_t>(
|
||||
[sum_reciprocal](Vec x) { return x * Vec(sum_reciprocal); },
|
||||
|
|
|
|||
|
|
@ -8890,6 +8890,7 @@
|
|||
variants: method, function
|
||||
dispatch:
|
||||
QuantizedCPU: eq_quantized_cpu
|
||||
NestedTensorCPU, NestedTensorCUDA: eq_tensor_nested
|
||||
tags: [core, pointwise]
|
||||
|
||||
- func: ge.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
|
||||
|
|
|
|||
|
|
@ -321,4 +321,13 @@ Tensor eq_scalar_nested(const Tensor& self, const Scalar& other) {
|
|||
});
|
||||
}
|
||||
|
||||
Tensor eq_tensor_nested(const Tensor& self, const Tensor& other) {
|
||||
TORCH_CHECK(!other.is_nested(), "eq does not support nested tensor as other value.");
|
||||
return NestedTensor_elementwise_Tensor(
|
||||
self, other, "eq", false /*supports_striding*/,
|
||||
[](const Tensor& b1, const Tensor& b2) {
|
||||
return b1.eq(b2);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
|
|
|
|||
|
|
@ -647,9 +647,11 @@ Tensor _safe_softmax(
|
|||
int64_t dim,
|
||||
std::optional<ScalarType> dtype) {
|
||||
auto out = at::softmax(self, dim, dtype);
|
||||
const auto masked = self.eq(-std::numeric_limits<float>::infinity());
|
||||
const auto neg_inf = at::scalar_tensor(-std::numeric_limits<float>::infinity(), at::TensorOptions().dtype(out.dtype()).device(out.device()));
|
||||
const auto masked = self.eq(neg_inf);
|
||||
const auto masked_rows = all(masked, dim, true);
|
||||
return at::where(masked_rows, at::scalar_tensor(0.0, at::TensorOptions().dtype(out.dtype()).device(out.device())), out);
|
||||
const auto zero = at::scalar_tensor(0.0, at::TensorOptions().dtype(out.dtype()).device(out.device()));
|
||||
return at::where(masked_rows, zero, out);
|
||||
}
|
||||
// Computes scaled dot product attention on query, key and value tensors, using
|
||||
// an optional attention mask if passed, and applying dropout if a probability
|
||||
|
|
@ -837,7 +839,7 @@ std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math(
|
|||
attn.add_(*attn_mask);
|
||||
}
|
||||
}
|
||||
attn = at::softmax(attn, -1);
|
||||
attn = at::_safe_softmax(attn, -1);
|
||||
if (dropout_p > 0.0) {
|
||||
if (dropout_mask.has_value()) {
|
||||
// In order to validate the correctness of the fused kernels, we need to
|
||||
|
|
|
|||
|
|
@ -144,7 +144,10 @@ class MemoryEfficientAttentionNormalize {
|
|||
multiplies<ComputeFragment> mul_add_source;
|
||||
multiply_add<ComputeFragment> mul_add_accumulator;
|
||||
|
||||
ElementCompute alpha = isLast ? (1 / s_prime_[row]) : 1;
|
||||
// Row sums for full masked out rows are 0, we set them to 1
|
||||
// In order to avoid NaNs in the output and instead sem them to 0.
|
||||
ElementCompute denom = s_prime_[row] == 0 ? 1 : s_prime_[row];
|
||||
ElementCompute alpha = isLast ? (1 / denom) : 1;
|
||||
ElementCompute beta = alpha * m_prime_[row];
|
||||
|
||||
intermediate = mul_add_source(beta, converted_source); // X = beta * C
|
||||
|
|
@ -174,7 +177,10 @@ class MemoryEfficientAttentionNormalize {
|
|||
ComputeFragment intermediate;
|
||||
multiplies<ComputeFragment> mul_accumulator;
|
||||
|
||||
ElementCompute alpha = isLast ? (1 / s_prime_[row]) : 1;
|
||||
// Row sums for full masked out rows are 0, we set them to 1
|
||||
// In order to avoid NaNs in the output and instead sem them to 0.
|
||||
ElementCompute denom = s_prime_[row] == 0 ? 1 : s_prime_[row];
|
||||
ElementCompute alpha = isLast ? (1 / denom) : 1;
|
||||
|
||||
intermediate = mul_accumulator(
|
||||
alpha, converted_accumulator); // X = alpha * C + uniform
|
||||
|
|
|
|||
|
|
@ -1166,6 +1166,10 @@ struct AttentionKernel {
|
|||
auto lse_dim = ceil_div((int32_t)p.num_queries, kAlignLSE) * kAlignLSE;
|
||||
constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E
|
||||
if (thread_id() < p.num_queries) {
|
||||
// We set fully masked out rows to 0, the sumexp for masked out rows will be 0
|
||||
// We update it to be 1 prior to calling log so that log(1) = 0
|
||||
s_prime[thread_id()] = (s_prime[thread_id()] == 0) ? 1: s_prime[thread_id()];
|
||||
mi[thread_id()] = (mi[thread_id()] == -cutlass::platform::numeric_limits<accum_t>::infinity()) ? 0: mi[thread_id()];
|
||||
p.logsumexp_ptr[thread_id()] = accum_t(mi[thread_id()] / kLog2e) +
|
||||
cutlass::fast_log(accum_t(s_prime[thread_id()]));
|
||||
} else if (thread_id() < lse_dim) {
|
||||
|
|
|
|||
|
|
@ -1791,9 +1791,6 @@ class TestOperators(TestCase):
|
|||
), # NYI: forward-AD for soft_margin_loss_backward
|
||||
xfail("nn.functional.ctc_loss", ""), # NYI: forward-AD for _ctc_loss
|
||||
xfail("nn.functional.pdist", ""), # NYI: forward-AD with _pdist_forward
|
||||
xfail(
|
||||
"torch.ops.aten._safe_softmax.default"
|
||||
), # NYI: forward-AD for _safe_softmax
|
||||
skip("nn.functional.scaled_dot_product_attention"),
|
||||
xfail("torch.ops.aten._efficient_attention_forward"), # outputs ints
|
||||
xfail(
|
||||
|
|
@ -1976,9 +1973,6 @@ class TestOperators(TestCase):
|
|||
xfail(
|
||||
"nn.functional.ctc_loss"
|
||||
), # ForwardAD not implemented and no decomposition
|
||||
xfail(
|
||||
"torch.ops.aten._safe_softmax.default"
|
||||
), # ForwardAD not implemented
|
||||
xfail("nn.functional.dropout2d"), # calls random op
|
||||
xfail("nn.functional.dropout3d"), # calls random op
|
||||
xfail("nn.functional.dropout"), # calls random op
|
||||
|
|
|
|||
|
|
@ -12387,12 +12387,20 @@ if __name__ == '__main__':
|
|||
result = model(encoder_input, src_key_padding_mask=mask)
|
||||
self.assertEqual(result.shape, ref_output.shape)
|
||||
torch.testing.assert_close(result, ref_output, atol=atol, rtol=rtol)
|
||||
# 1 values are masked. Since there is only 1 input embedding this
|
||||
# will result in nan.
|
||||
mask = torch.tensor([[1]], device=device) == 1
|
||||
result = model(encoder_input, src_key_padding_mask=mask)
|
||||
fast_path_device = result.is_cuda or result.is_cpu
|
||||
result = result.cpu().detach().numpy()
|
||||
self.assertTrue(np.isnan(result).all())
|
||||
# Non Fast Paths
|
||||
if training or not batch_first or TEST_WITH_CROSSREF or not fast_path_device:
|
||||
# We changed the semenatic, on the non fast path so that fully masked out rows return
|
||||
# 0 from attention thus NaNs should no longer be present and the output should be nonzero
|
||||
# due to skip connections
|
||||
self.assertTrue(not np.isnan(result).any())
|
||||
else:
|
||||
# Fast Paths
|
||||
self.assertTrue(np.isnan(result).all())
|
||||
|
||||
|
||||
# deterministic input
|
||||
encoder_input = perm_fn(torch.tensor([[[1., 2., 3., 4.]],
|
||||
|
|
|
|||
|
|
@ -347,6 +347,7 @@ class TestTransformers(NNTestCase):
|
|||
@parametrize("key_padding_mask_dim", [2, None])
|
||||
@parametrize("mask_dtype", [torch.bool, torch.float32])
|
||||
def test_multiheadattention_fastpath_attn_mask(self, device, attn_mask_dim, key_padding_mask_dim, mask_dtype):
|
||||
# MHA converts all
|
||||
with torch.no_grad():
|
||||
B = 2
|
||||
L = 4
|
||||
|
|
@ -356,7 +357,7 @@ class TestTransformers(NNTestCase):
|
|||
if attn_mask_dim == 2:
|
||||
attn_mask = make_tensor((L, L), dtype=mask_dtype, device=device)
|
||||
elif attn_mask_dim == 3:
|
||||
attn_mask = make_tensor((B * H, L, L), dtype=mask_dtype, device=device)
|
||||
attn_mask = make_tensor((B, 1, L, L), dtype=mask_dtype, device=device).expand(B, H, L, L).reshape(B * H, L, L)
|
||||
elif attn_mask_dim is None:
|
||||
attn_mask = None
|
||||
|
||||
|
|
@ -372,7 +373,9 @@ class TestTransformers(NNTestCase):
|
|||
out, _ = mha(X, X, X, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False)
|
||||
mha.eval() # enable fast path
|
||||
out_fp, _ = mha(X, X, X, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False)
|
||||
self.assertEqual(out, out_fp)
|
||||
# The FP kernel will return NaNs while the sdpa kernel which is ran when the fast path is turned off returns 0 instead
|
||||
# of NaNs for fully masked rows
|
||||
torch.testing.assert_close(out, out_fp.nan_to_num())
|
||||
|
||||
@parametrize("nhead", [1, 4, 8])
|
||||
def test_transformerencoderlayer_src_mask(self, device, nhead):
|
||||
|
|
@ -1156,6 +1159,25 @@ class TestTransformers(NNTestCase):
|
|||
else:
|
||||
actual = torch.nn.functional.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask, dropout_p, is_causal)
|
||||
# This test the fully masked out rows case
|
||||
if torch.isnan(expected).any():
|
||||
row_sums = attn_mask.sum(dim=-1)
|
||||
masked_out_rows = (row_sums == 0)
|
||||
|
||||
for _ in range((input_dim - attn_mask_dim) - 1):
|
||||
masked_out_rows = masked_out_rows.unsqueeze(0)
|
||||
|
||||
masked_out_rows = masked_out_rows.expand(expected.shape[:-1])
|
||||
# Slice out the fully masked rows from expected and actual
|
||||
expected_masked_out = expected[masked_out_rows]
|
||||
actual_masked_out = actual[masked_out_rows]
|
||||
|
||||
expected_all_nan = torch.isnan(expected_masked_out).all()
|
||||
actual_all_zero = (actual_masked_out.abs().sum() == 0)
|
||||
|
||||
self.assertTrue(expected_all_nan)
|
||||
self.assertTrue(actual_all_zero)
|
||||
return
|
||||
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
|
|
@ -1961,7 +1983,7 @@ class TestSDPACpuOnly(NNTestCase):
|
|||
@parametrize("n_head", [1, 3])
|
||||
@parametrize("head_dim", [8])
|
||||
@parametrize("mask_dim", [2, 4])
|
||||
@parametrize("bool_mask", [0, 1])
|
||||
@parametrize("bool_mask", [False, True])
|
||||
@parametrize("train", [True, False])
|
||||
@parametrize("casual", [True, False])
|
||||
@parametrize("set_attn_mask", [True, False])
|
||||
|
|
@ -2036,6 +2058,9 @@ class TestSDPACpuOnly(NNTestCase):
|
|||
if dtype in [torch.bfloat16, torch.float16]:
|
||||
math_ref = math_ref.to(dtype)
|
||||
|
||||
self.assertFalse(torch.isnan(math_ref).any())
|
||||
self.assertFalse(torch.isnan(actual).any())
|
||||
|
||||
self.assertEqual(actual, math_ref, atol=tol.atol, rtol=tol.rtol)
|
||||
|
||||
if train:
|
||||
|
|
@ -2064,6 +2089,104 @@ class TestSDPACpuOnly(NNTestCase):
|
|||
actual = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask)
|
||||
self.assertEqual(math_ref, actual)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system")
|
||||
@parametrize("backend", [SDPBackend.EFFICIENT_ATTENTION, SDPBackend.FLASH_ATTENTION])
|
||||
@parametrize("seq_len", [32, 64, 128])
|
||||
@parametrize("head_dim", [16, 32])
|
||||
@parametrize("dtype", [torch.float32, torch.float16])
|
||||
def test_fully_masked_out_rows(self, backend, device, seq_len, head_dim, dtype):
|
||||
def attention_inputs(seq_len, head_dim, device, dtype, mask_every_n_rows=4):
|
||||
query = torch.rand(1, 1, seq_len, head_dim, requires_grad=True, device=device, dtype=dtype)
|
||||
key = torch.rand(1, 1, seq_len, head_dim, requires_grad=True, device=device, dtype=dtype)
|
||||
value = torch.rand(1, 1, seq_len, head_dim, requires_grad=True, device=device, dtype=dtype)
|
||||
|
||||
# Create a mask with deterministic row masking
|
||||
mask = torch.ones(1, 1, seq_len, seq_len, dtype=torch.bool, device=device)
|
||||
|
||||
# Mask every nth row
|
||||
mask[0, 0, ::mask_every_n_rows, :] = False
|
||||
|
||||
# Create a fixed pattern for element-wise masking
|
||||
element_mask = torch.zeros(seq_len, seq_len, dtype=torch.bool, device=device)
|
||||
element_mask[torch.arange(seq_len)[:, None] % 5 == torch.arange(seq_len) % 5] = True
|
||||
|
||||
# Combine row masking and element-wise masking
|
||||
mask = mask & element_mask.unsqueeze(0).unsqueeze(0)
|
||||
|
||||
return query, key, value, mask
|
||||
|
||||
def compute_output_and_grads(query, key, value, mask, backend):
|
||||
with sdpa_kernel(backend):
|
||||
masked_out = scaled_dot_product_attention(query, key, value, attn_mask=mask)
|
||||
loss = masked_out.sum()
|
||||
grads = torch.autograd.grad(loss, [query, key, value])
|
||||
return masked_out, grads
|
||||
|
||||
if backend == SDPBackend.FLASH_ATTENTION and "cuda" in str(device):
|
||||
unittest.skip("FlashAttention does not support masks on cuda")
|
||||
return
|
||||
if backend == SDPBackend.EFFICIENT_ATTENTION and "cpu" in str(device):
|
||||
unittest.skip("EfficientAttention does not support masks on cpu")
|
||||
return
|
||||
query, key, value, mask = attention_inputs(seq_len, head_dim, device, dtype)
|
||||
|
||||
# Compute results for the tested backend
|
||||
backend_out, backend_grads = compute_output_and_grads(query, key, value, mask, backend)
|
||||
|
||||
# Compute results for the Math backend
|
||||
math_out, math_grads = compute_output_and_grads(query, key, value, mask, SDPBackend.MATH)
|
||||
|
||||
# Compare outputs
|
||||
torch.testing.assert_close(backend_out, math_out, atol=5e-3, rtol=0)
|
||||
self.assertFalse(backend_out.isnan().any())
|
||||
self.assertFalse(math_out.isnan().any())
|
||||
# Compare gradients
|
||||
for bg, mg in zip(backend_grads, math_grads):
|
||||
torch.testing.assert_close(bg, mg, atol=3e-3, rtol=0)
|
||||
self.assertFalse(bg.isnan().any())
|
||||
self.assertFalse(mg.isnan().any())
|
||||
|
||||
# Check if masked rows are zero in output
|
||||
mask_sum = mask.sum(dim=-1, keepdim=True)
|
||||
masked_rows = (mask_sum == 0).expand_as(backend_out)
|
||||
self.assertTrue((mask_sum == 0).sum() > 0, "No fully masked out rows found")
|
||||
assert torch.all(backend_out[masked_rows] == 0), \
|
||||
f"Non-zero values in fully masked rows for {backend=}"
|
||||
|
||||
# Check if gradients for masked rows are zero
|
||||
grad_query = backend_grads[0]
|
||||
assert torch.all(grad_query[masked_rows] == 0), f"Non-zero gradients in fully masked rows for {backend=}"
|
||||
|
||||
@parametrize("dtype", [torch.float32, torch.float16])
|
||||
@parametrize("fill_val", [float("inf")])
|
||||
def test_non_masked_rows_nan_props(self, device, dtype, fill_val):
|
||||
query = torch.randn(1, 2, 4, 16, device=device, dtype=dtype)
|
||||
# a single NaN in the query input
|
||||
query[0, 1, 2, 3] = fill_val
|
||||
query = query.detach().requires_grad_(True)
|
||||
key = torch.randn(1, 2, 4, 16, device=device, dtype=dtype, requires_grad=True)
|
||||
value = torch.randn(1, 2, 4, 16, device=device, dtype=dtype, requires_grad=True)
|
||||
|
||||
out = torch.nn.functional.scaled_dot_product_attention(query, key, value)
|
||||
self.assertTrue(torch.isnan(out).any())
|
||||
out.sum().backward()
|
||||
self.assertTrue(torch.isnan(query.grad).any())
|
||||
|
||||
@parametrize("kernel", [SDPBackend.MATH])
|
||||
def test_scaled_dot_product_attention_math_with_negative_scale(self, device, kernel: SDPBackend):
|
||||
# https://github.com/pytorch/pytorch/issues/105190.
|
||||
def ref(x):
|
||||
v1 = torch.matmul(x, x.transpose(-1, -2))
|
||||
v2 = v1 / -0.0001
|
||||
v3 = v2.softmax(dim=-1)
|
||||
v4 = torch.matmul(v3, x)
|
||||
return v4
|
||||
|
||||
x = torch.randn(1, 3, 64, 64, device=device)
|
||||
ref_result = ref(x)
|
||||
with sdpa_kernel(backends=[kernel]):
|
||||
sdp_math = torch.nn.functional.scaled_dot_product_attention(x, x, x, scale=-1.0 / 0.0001)
|
||||
self.assertEqual(ref_result, sdp_math)
|
||||
|
||||
class TestSDPACudaOnly(NNTestCase):
|
||||
""" Used to test CUDA only functionality of scaled_dot_product_attention
|
||||
|
|
|
|||
|
|
@ -2845,6 +2845,7 @@
|
|||
# Transformer
|
||||
- name: _safe_softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor
|
||||
self: _softmax_backward_data(grad, result, dim, self.scalar_type())
|
||||
result: result * (self_t - safe_logsumexp_jvp(self_p, self_t, {dim}, true))
|
||||
|
||||
- name: _scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor log_sumexp, Tensor philox_seed, Tensor philox_offset)
|
||||
output_differentiability: [True, False, False, False]
|
||||
|
|
|
|||
|
|
@ -6715,6 +6715,22 @@ Tensor logsumexp_jvp(
|
|||
}
|
||||
}
|
||||
|
||||
Tensor safe_logsumexp_jvp(
|
||||
const Tensor& self_p,
|
||||
const Tensor& self_t,
|
||||
IntArrayRef dim,
|
||||
bool keepdim) {
|
||||
auto lse_jvp = logsumexp_jvp(self_p, self_t, dim, keepdim);
|
||||
const auto neg_inf = at::scalar_tensor(
|
||||
-std::numeric_limits<float>::infinity(),
|
||||
at::TensorOptions().dtype(lse_jvp.dtype()).device(lse_jvp.device()));
|
||||
const auto masked = self_p.eq(neg_inf);
|
||||
const auto masked_rows = all(masked, dim, true);
|
||||
const auto zero = at::scalar_tensor(
|
||||
0.0, at::TensorOptions().dtype(lse_jvp.dtype()).device(lse_jvp.device()));
|
||||
return at::where(masked_rows, zero, lse_jvp);
|
||||
}
|
||||
|
||||
Tensor warn_backwards(const Tensor& grad_output) {
|
||||
TORCH_WARN("Warn from backward");
|
||||
return grad_output;
|
||||
|
|
|
|||
|
|
@ -229,6 +229,11 @@ at::Tensor logsumexp_jvp(
|
|||
const at::Tensor& self_t,
|
||||
IntArrayRef dim,
|
||||
bool keepdim);
|
||||
at::Tensor safe_logsumexp_jvp(
|
||||
const at::Tensor& self_p,
|
||||
const at::Tensor& self_t,
|
||||
IntArrayRef dim,
|
||||
bool keepdim);
|
||||
at::Tensor logcumsumexp_backward(
|
||||
at::Tensor grad,
|
||||
const at::Tensor& self,
|
||||
|
|
|
|||
|
|
@ -16210,8 +16210,8 @@ op_db: List[OpInfo] = [
|
|||
sample_inputs_func=sample_inputs_safe_softmax,
|
||||
assert_jit_shape_analysis=True,
|
||||
assert_autodiffed=True,
|
||||
supports_forward_ad=False,
|
||||
supports_fwgrad_bwgrad=False,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
supports_out=False,
|
||||
supports_cow_input_no_materialize_backward=False,
|
||||
decorators=[],
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user