mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add meta reg for addcdiv/addcmul ScalarList (#123486)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123486 Approved by: https://github.com/awgu
This commit is contained in:
parent
b287dbbc24
commit
adcfc2b582
|
|
@ -3337,6 +3337,29 @@ def meta__foreach_addcop__scalar(self, tensor1, tensor2, scalar=1):
|
|||
)
|
||||
|
||||
|
||||
@register_meta(
|
||||
[
|
||||
aten._foreach_addcdiv_.ScalarList,
|
||||
aten._foreach_addcmul_.ScalarList,
|
||||
]
|
||||
)
|
||||
def meta__foreach_addcop__scalarlist(self, tensor1, tensor2, scalars):
|
||||
torch._check(
|
||||
all(isinstance(l, List) for l in [self, tensor1, tensor2, scalars]),
|
||||
lambda: (
|
||||
"_foreach_addc*_ op expects arguments of type: List[Tensor], List[Tensor], List[Tensor], List[Scalar], "
|
||||
f"but got {type(self)}, {type(tensor1)}, {type(tensor2)}, and {type(scalars)}"
|
||||
),
|
||||
)
|
||||
torch._check(len(self) > 0, lambda: "input tensor list must not be empty.")
|
||||
torch._check(
|
||||
len(self) == len(tensor1)
|
||||
and len(self) == len(tensor2)
|
||||
and len(self) == len(scalars),
|
||||
lambda: "All input tensor lists must have the same length",
|
||||
)
|
||||
|
||||
|
||||
@register_meta([aten._fused_adam_.default])
|
||||
def meta__fused_adam_(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -9847,14 +9847,19 @@ foreach_pointwise_op_db: List[ForeachFuncInfo] = [
|
|||
dtypes=all_types_and_complex(),
|
||||
dtypesIfCUDA=all_types_and_complex_and(torch.half, torch.bfloat16),
|
||||
skips=(
|
||||
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
|
||||
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
|
||||
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
|
||||
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
|
||||
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
|
||||
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
|
||||
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
|
||||
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
|
||||
# Samples have complex types and inplace only works if the dtype is complex.
|
||||
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace",
|
||||
dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)),
|
||||
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace",
|
||||
dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)),
|
||||
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace",
|
||||
dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)),
|
||||
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides",
|
||||
dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)),
|
||||
),
|
||||
),
|
||||
ForeachFuncInfo(
|
||||
|
|
@ -9863,13 +9868,19 @@ foreach_pointwise_op_db: List[ForeachFuncInfo] = [
|
|||
dtypes=all_types_and_complex(),
|
||||
dtypesIfCUDA=all_types_and_complex_and(torch.half, torch.bfloat16),
|
||||
skips=(
|
||||
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
|
||||
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
|
||||
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
|
||||
# Samples have complex types and inplace only works if the dtype is complex.
|
||||
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace",
|
||||
dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)),
|
||||
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace",
|
||||
dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)),
|
||||
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace",
|
||||
dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)),
|
||||
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides",
|
||||
dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)),
|
||||
# fails with div_cpu is not implemented with ComplexHalf
|
||||
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
|
||||
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
|
||||
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
|
||||
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
|
||||
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
|
||||
),
|
||||
),
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user