Add meta reg for addcdiv/addcmul ScalarList (#123486)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123486
Approved by: https://github.com/awgu
This commit is contained in:
Jane Xu 2024-04-09 11:59:56 -07:00 committed by PyTorch MergeBot
parent b287dbbc24
commit adcfc2b582
2 changed files with 42 additions and 8 deletions

View File

@ -3337,6 +3337,29 @@ def meta__foreach_addcop__scalar(self, tensor1, tensor2, scalar=1):
)
@register_meta(
[
aten._foreach_addcdiv_.ScalarList,
aten._foreach_addcmul_.ScalarList,
]
)
def meta__foreach_addcop__scalarlist(self, tensor1, tensor2, scalars):
torch._check(
all(isinstance(l, List) for l in [self, tensor1, tensor2, scalars]),
lambda: (
"_foreach_addc*_ op expects arguments of type: List[Tensor], List[Tensor], List[Tensor], List[Scalar], "
f"but got {type(self)}, {type(tensor1)}, {type(tensor2)}, and {type(scalars)}"
),
)
torch._check(len(self) > 0, lambda: "input tensor list must not be empty.")
torch._check(
len(self) == len(tensor1)
and len(self) == len(tensor2)
and len(self) == len(scalars),
lambda: "All input tensor lists must have the same length",
)
@register_meta([aten._fused_adam_.default])
def meta__fused_adam_(
self,

View File

@ -9847,14 +9847,19 @@ foreach_pointwise_op_db: List[ForeachFuncInfo] = [
dtypes=all_types_and_complex(),
dtypesIfCUDA=all_types_and_complex_and(torch.half, torch.bfloat16),
skips=(
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
# Samples have complex types and inplace only works if the dtype is complex.
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace",
dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)),
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace",
dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)),
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace",
dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)),
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides",
dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)),
),
),
ForeachFuncInfo(
@ -9863,13 +9868,19 @@ foreach_pointwise_op_db: List[ForeachFuncInfo] = [
dtypes=all_types_and_complex(),
dtypesIfCUDA=all_types_and_complex_and(torch.half, torch.bfloat16),
skips=(
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
# Samples have complex types and inplace only works if the dtype is complex.
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace",
dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)),
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace",
dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)),
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace",
dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)),
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides",
dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)),
# fails with div_cpu is not implemented with ComplexHalf
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
),
),