mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Move attention kernels from meta_registrations to fake_impls (#120682)
This PR is mostly just code movement to make the code review easier - AFAIK it should not change any functionality. The final goal is to remove the xfails for some of the test_fake opinfos for these ops. The opinfos are failing because the outputs can have mixed devices - we need to move them to fake_impls first before we can support mixed device returns.
This PR:
* Move the `_meta_registrations.py` implementations to `fake_impls.py`
* Change the function signature from taking explicit named variables to taking `{args, kwargs}` and normalizing them
* Wrap all the returned tensors in FakeTensors
Tests: relying on opinfos. I also checked `test_fake_*` for these tests (by removing x-fails and patching things until they passed) to verify general correctness.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120682
Approved by: https://github.com/drisspg
This commit is contained in:
parent
50073248ed
commit
d6c202975c
|
|
@ -5063,67 +5063,6 @@ def meta_scatter_(self, dim, index, src_or_value, reduce=None):
|
|||
return self
|
||||
|
||||
|
||||
@register_meta(
|
||||
[
|
||||
aten._scaled_dot_product_flash_attention,
|
||||
]
|
||||
)
|
||||
def meta__scaled_dot_product_flash(
|
||||
query: Tensor,
|
||||
key: Tensor,
|
||||
value: Tensor,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
return_debug_mask: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
):
|
||||
batch_size = query.size(0)
|
||||
num_heads = query.size(1)
|
||||
max_seqlen_batch_q = query.size(2)
|
||||
head_dim = query.size(3)
|
||||
max_seqlen_batch_k = key.size(2)
|
||||
|
||||
query_t = query.transpose(1, 2)
|
||||
attention = torch.empty_like(query_t).transpose(1, 2)
|
||||
logsumexp = torch.empty(
|
||||
(batch_size, num_heads, max_seqlen_batch_q),
|
||||
dtype=torch.float,
|
||||
device=query.device,
|
||||
)
|
||||
|
||||
if return_debug_mask:
|
||||
blocksize_c = 128 if head_dim > 64 else 256
|
||||
max_seqlen_k = math.ceil(max_seqlen_batch_q / blocksize_c)
|
||||
if max_seqlen_batch_k <= 128:
|
||||
max_seqlen_k = 128
|
||||
elif max_seqlen_batch_k <= 256:
|
||||
max_seqlen_k = 256
|
||||
debug_mask = torch.empty(
|
||||
(batch_size, num_heads, max_seqlen_batch_q, max_seqlen_k),
|
||||
dtype=query.dtype,
|
||||
device=query.device,
|
||||
)
|
||||
else:
|
||||
debug_mask = torch.empty(0, dtype=query.dtype, device=query.device)
|
||||
|
||||
# Note [Seed and Offset]: device for seed and offset below depends on whether we are
|
||||
# capturing or not, but at the time of tracing we don't know if we
|
||||
# are going to use cudagraphs or not, so we return meta tensors here
|
||||
# it's possible we'll need to have some special handling in inductor for sdpa
|
||||
|
||||
return (
|
||||
attention,
|
||||
logsumexp,
|
||||
None,
|
||||
None,
|
||||
max_seqlen_batch_q,
|
||||
max_seqlen_batch_k,
|
||||
torch.empty((), dtype=torch.long, device="meta"),
|
||||
torch.empty((), dtype=torch.long, device="meta"),
|
||||
debug_mask,
|
||||
)
|
||||
|
||||
|
||||
@register_meta(
|
||||
[
|
||||
aten._scaled_dot_product_flash_attention_backward,
|
||||
|
|
@ -5238,50 +5177,6 @@ def meta__scaled_dot_product_flash_attention_for_cpu_backward(
|
|||
return grad_q, grad_k, grad_v
|
||||
|
||||
|
||||
@register_meta(
|
||||
[
|
||||
aten._scaled_dot_product_efficient_attention,
|
||||
]
|
||||
)
|
||||
def meta__scaled_dot_product_efficient(
|
||||
query: Tensor,
|
||||
key: Tensor,
|
||||
value: Tensor,
|
||||
attn_bias: Optional[Tensor],
|
||||
compute_log_sumexp: bool,
|
||||
dropout_p=0.0,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
):
|
||||
query = query.transpose(1, 2)
|
||||
key = key.transpose(1, 2)
|
||||
value = value.transpose(1, 2)
|
||||
|
||||
B = query.size(0)
|
||||
M = query.size(1)
|
||||
N = key.size(1)
|
||||
num_heads = query.size(-2)
|
||||
K = query.size(-1)
|
||||
Kv = value.size(-1)
|
||||
|
||||
res = torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device=query.device)
|
||||
|
||||
logsumexp_dim = math.ceil(M / 32) * 32 if compute_log_sumexp else 0
|
||||
logsum_exp = torch.empty(
|
||||
(B, num_heads, logsumexp_dim),
|
||||
dtype=torch.float,
|
||||
device=query.device,
|
||||
)
|
||||
|
||||
res = res.transpose(1, 2)
|
||||
|
||||
# See Note [Seed and Offset]:
|
||||
seed = torch.empty((), dtype=torch.long, device="meta")
|
||||
offset = torch.empty((), dtype=torch.long, device="meta")
|
||||
|
||||
return res, logsum_exp, seed, offset
|
||||
|
||||
|
||||
@register_meta(
|
||||
[
|
||||
aten._scaled_dot_product_efficient_attention_backward,
|
||||
|
|
@ -5342,67 +5237,6 @@ def meta__scaled_dot_product_efficient_backward(
|
|||
return grad_q, grad_k, grad_v, grad_bias
|
||||
|
||||
|
||||
@register_meta(
|
||||
[
|
||||
aten._flash_attention_forward,
|
||||
]
|
||||
)
|
||||
def meta__flash_attention_forward(
|
||||
query: Tensor,
|
||||
key: Tensor,
|
||||
value: Tensor,
|
||||
cum_seq_q: Optional[Tensor],
|
||||
cum_seq_k: Optional[Tensor],
|
||||
max_q: int,
|
||||
max_k: int,
|
||||
dropout_p: float,
|
||||
is_causal: bool,
|
||||
return_debug_mask: bool,
|
||||
scale: Optional[float] = None,
|
||||
):
|
||||
# NB: there are two underlying paths:
|
||||
# 1. normal dense path; expect 4D inputs of shape (batch_size, seqlen, num_heads, head_dim)
|
||||
# 2. varseqlen path; expect 3D inputs of shape (total, num_heads, head_dim) where total
|
||||
# includes all batch item sequences. cum_seq_q / cum_seq_k contain offsets into total
|
||||
batch_size = query.size(0) if cum_seq_q is None else cum_seq_q.numel() - 1
|
||||
max_seqlen_batch_q = query.size(1) if cum_seq_q is None else max_q
|
||||
max_seqlen_batch_k = key.size(1) if cum_seq_k is None else max_k
|
||||
num_heads = query.size(-2)
|
||||
head_dim = query.size(-1)
|
||||
|
||||
# Cuda Path
|
||||
attention = torch.empty_like(query)
|
||||
logsumexp = torch.empty(
|
||||
(batch_size, num_heads, max_seqlen_batch_q),
|
||||
dtype=torch.float,
|
||||
device=query.device,
|
||||
)
|
||||
|
||||
if return_debug_mask:
|
||||
blocksize_c = 128 if head_dim > 64 else 256
|
||||
max_seqlen_k = math.ceil(max_seqlen_batch_q / blocksize_c)
|
||||
if max_seqlen_batch_k <= 128:
|
||||
max_seqlen_k = 128
|
||||
elif max_seqlen_batch_k <= 256:
|
||||
max_seqlen_k = 256
|
||||
debug_mask = torch.empty(
|
||||
(batch_size, num_heads, max_seqlen_batch_q, max_seqlen_k),
|
||||
dtype=query.dtype,
|
||||
device=query.device,
|
||||
)
|
||||
else:
|
||||
debug_mask = torch.empty(0, dtype=query.dtype, device=query.device)
|
||||
|
||||
# See Note [Seed and Offset]:
|
||||
return (
|
||||
attention,
|
||||
logsumexp,
|
||||
torch.empty((), dtype=torch.long, device="meta"),
|
||||
torch.empty((), dtype=torch.long, device="meta"),
|
||||
debug_mask,
|
||||
)
|
||||
|
||||
|
||||
@register_meta(
|
||||
[
|
||||
aten._flash_attention_backward,
|
||||
|
|
@ -5432,57 +5266,6 @@ def meta__flash_attention_backward(
|
|||
return grad_query, grad_key, grad_value
|
||||
|
||||
|
||||
@register_meta(
|
||||
[
|
||||
aten._efficient_attention_forward,
|
||||
]
|
||||
)
|
||||
def meta__efficient_attention_forward(
|
||||
query: Tensor,
|
||||
key: Tensor,
|
||||
value: Tensor,
|
||||
bias: Optional[Tensor],
|
||||
cu_seqlens_q: Optional[Tensor],
|
||||
cu_seqlens_k: Optional[Tensor],
|
||||
max_seqlen_q: Optional[int],
|
||||
max_seqlen_k: Optional[int],
|
||||
dropout_p: float,
|
||||
custom_mask_type: int,
|
||||
compute_log_sumexp: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
causal_diagonal: Optional[Tensor] = None,
|
||||
seqlen_k: Optional[Tensor] = None,
|
||||
):
|
||||
B = query.size(0)
|
||||
M = query.size(1)
|
||||
N = key.size(1)
|
||||
num_heads = query.size(-2)
|
||||
K = query.size(-1)
|
||||
Kv = value.size(-1)
|
||||
|
||||
res = torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device=query.device)
|
||||
|
||||
logsumexp_batch_dim = cu_seqlens_q.size(0) - 1 if (cu_seqlens_q is not None) else B
|
||||
actual_max_seqlen_q = M
|
||||
if cu_seqlens_q is not None:
|
||||
assert max_seqlen_q is not None
|
||||
actual_max_seqlen_q = max_seqlen_q
|
||||
logsumexp_dim = (
|
||||
math.ceil(actual_max_seqlen_q / 32) * 32 if compute_log_sumexp else 0
|
||||
)
|
||||
logsum_exp = torch.empty(
|
||||
(logsumexp_batch_dim, num_heads, logsumexp_dim),
|
||||
dtype=torch.float,
|
||||
device=query.device,
|
||||
)
|
||||
|
||||
# See Note [Seed and Offset]:
|
||||
seed = torch.empty((), dtype=torch.long, device="meta")
|
||||
offset = torch.empty((), dtype=torch.long, device="meta")
|
||||
|
||||
return res, logsum_exp, seed, offset, M, N
|
||||
|
||||
|
||||
@register_meta(
|
||||
[
|
||||
aten._efficient_attention_backward,
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
import functools
|
||||
import itertools
|
||||
import math
|
||||
import sys
|
||||
from typing import Callable, Union
|
||||
|
||||
|
|
@ -595,6 +596,259 @@ def conv(fake_mode, func, *args, **kwargs):
|
|||
)
|
||||
|
||||
|
||||
@register_op_impl(aten._scaled_dot_product_flash_attention.default)
|
||||
def meta__scaled_dot_product_flash(fake_mode, func, *args, **kwargs):
|
||||
_, kwargs = normalize_function(
|
||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||
)
|
||||
|
||||
query = kwargs["query"]
|
||||
key = kwargs["key"]
|
||||
return_debug_mask = kwargs["return_debug_mask"]
|
||||
# unused: value, dropout_p, is_causal, scale
|
||||
|
||||
def convert_tensor(t, device):
|
||||
return FakeTensor(fake_mode, t, device)
|
||||
|
||||
batch_size = query.size(0)
|
||||
num_heads = query.size(1)
|
||||
max_seqlen_batch_q = query.size(2)
|
||||
head_dim = query.size(3)
|
||||
max_seqlen_batch_k = key.size(2)
|
||||
|
||||
query_t = query.transpose(1, 2)
|
||||
# empty_like already returns a fake tensor so we don't need to convert it
|
||||
attention = torch.empty_like(query_t).transpose(1, 2)
|
||||
logsumexp = convert_tensor(
|
||||
torch.empty(
|
||||
(batch_size, num_heads, max_seqlen_batch_q),
|
||||
dtype=torch.float,
|
||||
device="meta",
|
||||
),
|
||||
device=query.device,
|
||||
)
|
||||
|
||||
if return_debug_mask:
|
||||
blocksize_c = 128 if head_dim > 64 else 256
|
||||
max_seqlen_k = math.ceil(max_seqlen_batch_q / blocksize_c)
|
||||
if max_seqlen_batch_k <= 128:
|
||||
max_seqlen_k = 128
|
||||
elif max_seqlen_batch_k <= 256:
|
||||
max_seqlen_k = 256
|
||||
debug_mask = convert_tensor(
|
||||
torch.empty(
|
||||
(batch_size, num_heads, max_seqlen_batch_q, max_seqlen_k),
|
||||
dtype=query.dtype,
|
||||
device="meta",
|
||||
),
|
||||
device=query.device,
|
||||
)
|
||||
else:
|
||||
debug_mask = convert_tensor(
|
||||
torch.empty(0, dtype=query.dtype, device="meta"),
|
||||
query.device,
|
||||
)
|
||||
|
||||
# Note [Seed and Offset]: device for seed and offset below depends on whether we are
|
||||
# capturing or not, but at the time of tracing we don't know if we
|
||||
# are going to use cudagraphs or not, so we return meta tensors here
|
||||
# it's possible we'll need to have some special handling in inductor for sdpa
|
||||
|
||||
return (
|
||||
attention,
|
||||
logsumexp,
|
||||
None,
|
||||
None,
|
||||
max_seqlen_batch_q,
|
||||
max_seqlen_batch_k,
|
||||
convert_tensor(torch.empty((), dtype=torch.long, device="meta"), query.device),
|
||||
convert_tensor(torch.empty((), dtype=torch.long, device="meta"), query.device),
|
||||
debug_mask,
|
||||
)
|
||||
|
||||
|
||||
@register_op_impl(aten._scaled_dot_product_efficient_attention.default)
|
||||
def meta__scaled_dot_product_efficient(fake_mode, func, *args, **kwargs):
|
||||
_, kwargs = normalize_function(
|
||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||
)
|
||||
|
||||
query = kwargs["query"]
|
||||
key = kwargs["key"]
|
||||
value = kwargs["value"]
|
||||
compute_log_sumexp = kwargs["compute_log_sumexp"]
|
||||
# unused: attn_bias, dropout_p, is_causal, scale
|
||||
|
||||
def convert_tensor(t, device):
|
||||
return FakeTensor(fake_mode, t, device)
|
||||
|
||||
query = query.transpose(1, 2)
|
||||
key = key.transpose(1, 2)
|
||||
value = value.transpose(1, 2)
|
||||
|
||||
B = query.size(0)
|
||||
M = query.size(1)
|
||||
N = key.size(1)
|
||||
num_heads = query.size(-2)
|
||||
K = query.size(-1)
|
||||
Kv = value.size(-1)
|
||||
|
||||
res = convert_tensor(
|
||||
torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device="meta"),
|
||||
query.device,
|
||||
)
|
||||
|
||||
logsumexp_dim = math.ceil(M / 32) * 32 if compute_log_sumexp else 0
|
||||
logsum_exp = convert_tensor(
|
||||
torch.empty(
|
||||
(B, num_heads, logsumexp_dim),
|
||||
dtype=torch.float,
|
||||
device="meta",
|
||||
),
|
||||
query.device,
|
||||
)
|
||||
|
||||
res = res.transpose(1, 2)
|
||||
|
||||
# See Note [Seed and Offset]:
|
||||
seed = convert_tensor(
|
||||
torch.empty((), dtype=torch.long, device="meta"), query.device
|
||||
)
|
||||
offset = convert_tensor(
|
||||
torch.empty((), dtype=torch.long, device="meta"), query.device
|
||||
)
|
||||
|
||||
return res, logsum_exp, seed, offset
|
||||
|
||||
|
||||
@register_op_impl(aten._flash_attention_forward.default)
|
||||
def meta__flash_attention_forward(fake_mode, func, *args, **kwargs):
|
||||
_, kwargs = normalize_function(
|
||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||
)
|
||||
|
||||
query = kwargs["query"]
|
||||
key = kwargs["key"]
|
||||
cum_seq_q = kwargs["cum_seq_q"]
|
||||
cum_seq_k = kwargs["cum_seq_k"]
|
||||
max_q = kwargs["max_q"]
|
||||
max_k = kwargs["max_k"]
|
||||
return_debug_mask = kwargs["return_debug_mask"]
|
||||
# unused: value, dropout_p, is_causal, scale
|
||||
|
||||
def convert_tensor(t, device):
|
||||
return FakeTensor(fake_mode, t, device)
|
||||
|
||||
# NB: there are two underlying paths:
|
||||
# 1. normal dense path; expect 4D inputs of shape (batch_size, seqlen, num_heads, head_dim)
|
||||
# 2. varseqlen path; expect 3D inputs of shape (total, num_heads, head_dim) where total
|
||||
# includes all batch item sequences. cum_seq_q / cum_seq_k contain offsets into total
|
||||
batch_size = query.size(0) if cum_seq_q is None else cum_seq_q.numel() - 1
|
||||
max_seqlen_batch_q = query.size(1) if cum_seq_q is None else max_q
|
||||
max_seqlen_batch_k = key.size(1) if cum_seq_k is None else max_k
|
||||
num_heads = query.size(-2)
|
||||
head_dim = query.size(-1)
|
||||
|
||||
# Cuda Path
|
||||
# note: empty_like already returns a fake tensor, we don't need to wrap it
|
||||
attention = torch.empty_like(query)
|
||||
logsumexp = convert_tensor(
|
||||
torch.empty(
|
||||
(batch_size, num_heads, max_seqlen_batch_q),
|
||||
dtype=torch.float,
|
||||
device="meta",
|
||||
),
|
||||
device=query.device,
|
||||
)
|
||||
|
||||
if return_debug_mask:
|
||||
blocksize_c = 128 if head_dim > 64 else 256
|
||||
max_seqlen_k = math.ceil(max_seqlen_batch_q / blocksize_c)
|
||||
if max_seqlen_batch_k <= 128:
|
||||
max_seqlen_k = 128
|
||||
elif max_seqlen_batch_k <= 256:
|
||||
max_seqlen_k = 256
|
||||
debug_mask = convert_tensor(
|
||||
torch.empty(
|
||||
(batch_size, num_heads, max_seqlen_batch_q, max_seqlen_k),
|
||||
dtype=query.dtype,
|
||||
device="meta",
|
||||
),
|
||||
query.device,
|
||||
)
|
||||
else:
|
||||
debug_mask = convert_tensor(
|
||||
torch.empty(0, dtype=query.dtype, device="meta"),
|
||||
query.device,
|
||||
)
|
||||
|
||||
# See Note [Seed and Offset]:
|
||||
return (
|
||||
attention,
|
||||
logsumexp,
|
||||
convert_tensor(torch.empty((), dtype=torch.long, device="meta"), query.device),
|
||||
convert_tensor(torch.empty((), dtype=torch.long, device="meta"), query.device),
|
||||
debug_mask,
|
||||
)
|
||||
|
||||
|
||||
@register_op_impl(aten._efficient_attention_forward.default)
|
||||
def meta__efficient_attention_forward(fake_mode, func, *args, **kwargs):
|
||||
_, kwargs = normalize_function(
|
||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||
)
|
||||
|
||||
query = kwargs["query"]
|
||||
key = kwargs["key"]
|
||||
value = kwargs["value"]
|
||||
cu_seqlens_q = kwargs["cu_seqlens_q"]
|
||||
max_seqlen_q = kwargs["max_seqlen_q"]
|
||||
compute_log_sumexp = kwargs["compute_log_sumexp"]
|
||||
# unused: bias, cu_seqlens_k, max_seqlen_k, dropout_p, custom_mask_type, scale, causal_diagonal, seqlen_k
|
||||
|
||||
def convert_tensor(t, device):
|
||||
return FakeTensor(fake_mode, t, device)
|
||||
|
||||
B = query.size(0)
|
||||
M = query.size(1)
|
||||
N = key.size(1)
|
||||
num_heads = query.size(-2)
|
||||
K = query.size(-1)
|
||||
Kv = value.size(-1)
|
||||
|
||||
res = convert_tensor(
|
||||
torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device="meta"),
|
||||
query.device,
|
||||
)
|
||||
|
||||
logsumexp_batch_dim = cu_seqlens_q.size(0) - 1 if (cu_seqlens_q is not None) else B
|
||||
actual_max_seqlen_q = M
|
||||
if cu_seqlens_q is not None:
|
||||
assert max_seqlen_q is not None
|
||||
actual_max_seqlen_q = max_seqlen_q
|
||||
logsumexp_dim = (
|
||||
math.ceil(actual_max_seqlen_q / 32) * 32 if compute_log_sumexp else 0
|
||||
)
|
||||
logsum_exp = convert_tensor(
|
||||
torch.empty(
|
||||
(logsumexp_batch_dim, num_heads, logsumexp_dim),
|
||||
dtype=torch.float,
|
||||
device="meta",
|
||||
),
|
||||
query.device,
|
||||
)
|
||||
|
||||
# See Note [Seed and Offset]:
|
||||
seed = convert_tensor(
|
||||
torch.empty((), dtype=torch.long, device="meta"), query.device
|
||||
)
|
||||
offset = convert_tensor(
|
||||
torch.empty((), dtype=torch.long, device="meta"), query.device
|
||||
)
|
||||
|
||||
return res, logsum_exp, seed, offset, M, N
|
||||
|
||||
|
||||
FAST_OP_IMPLEMENTATIONS = {}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -14428,7 +14428,14 @@ op_db: List[OpInfo] = [
|
|||
DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_meta_outplace',
|
||||
device_type='cuda', dtypes=(torch.bfloat16,), active_if=not SM80OrLater),
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_symbolic_meta_outplace',
|
||||
device_type='cuda', dtypes=(torch.bfloat16,), active_if=not SM80OrLater),),
|
||||
device_type='cuda', dtypes=(torch.bfloat16,), active_if=not SM80OrLater),
|
||||
# registered in fake_impls.py instead of _meta_registrations.py
|
||||
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace",
|
||||
dtypes=(torch.bfloat16, torch.float16, torch.float32)),
|
||||
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace",
|
||||
dtypes=(torch.bfloat16, torch.float16, torch.float32)),
|
||||
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace_all_strides"),
|
||||
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),),
|
||||
),
|
||||
OpInfo(
|
||||
'torch.ops.aten._flash_attention_forward',
|
||||
|
|
@ -14447,6 +14454,11 @@ op_db: List[OpInfo] = [
|
|||
# Device mismatch due to philox seed and offset
|
||||
DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake_autocast', device_type='cuda'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake', device_type='cuda'),
|
||||
# meta implementation is in fake_impls.py instead of being a meta registration
|
||||
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
|
||||
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
|
||||
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
|
||||
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
|
||||
# Checking the scalar value of the philox seed and offset
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator', device_type='cuda'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples', device_type='cuda'),
|
||||
|
|
@ -14481,6 +14493,11 @@ op_db: List[OpInfo] = [
|
|||
# Device mismatch due to philox seed and offset
|
||||
DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake_autocast', device_type='cuda'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake', device_type='cuda'),
|
||||
# meta implementation is in fake_impls.py instead of being a meta registration
|
||||
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
|
||||
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
|
||||
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
|
||||
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
|
||||
# Checking the scaler value of the philox seed and offset
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator', device_type='cuda'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples', device_type='cuda'),
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user