[custom ops] add default value support for device types (#129792)

Fixes #129371

I think the first case in Issue #129371 is already supported in the current code? Since it takes care of string default values. This PR adds support for device type default values.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129792
Approved by: https://github.com/zou3519
This commit is contained in:
Shangdi Yu 2024-07-02 23:31:29 +00:00 committed by PyTorch MergeBot
parent d7680a564b
commit aa0352ca38
2 changed files with 64 additions and 3 deletions

View File

@ -678,6 +678,29 @@ class TestCustomOp(CustomOpTestCaseBase):
"""(Tensor(a0!) x, Tensor(a1!)[] y, Tensor(a2!)[] z, Tensor(a3!)?[] w) -> ()""",
)
def h(
x: Tensor,
a: Optional[int] = None,
b: float = 3.14,
c: bool = True,
d: int = 3,
e: str = "foo",
f: torch.dtype = torch.float,
g: torch.dtype = torch.float32,
h: torch.dtype = torch.int,
i: torch.device = torch.device("cpu:0"),
j: torch.device = "cpu",
) -> None:
pass
self.assertExpectedInline(
infer_schema(h),
(
"""(Tensor x, SymInt? a=None, float b=3.14, bool c=True, SymInt d=3, str e="foo", """
"""ScalarType f=float32, ScalarType g=float32, ScalarType h=int32, Device i="cpu:0", Device j="cpu") -> ()"""
),
)
def test_infer_schema_unsupported(self):
with self.assertRaisesRegex(ValueError, "varargs"):
@ -2439,15 +2462,53 @@ class TestCustomOpAPI(TestCase):
f: torch.dtype = torch.float,
g: torch.dtype = torch.float32,
h: torch.dtype = torch.int,
i: torch.device = torch.device("cpu:0"),
j: torch.device = "cpu",
) -> Tensor:
defaults.extend([a, b, c, d, e, f, g, h])
defaults.extend([a, b, c, d, e, f, g, h, i, j])
return x.clone()
x = torch.randn(3)
f(x)
self.assertEqual(
defaults,
[None, 3.14, True, 3, "foo", torch.float, torch.float32, torch.int],
[
None,
3.14,
True,
3,
"foo",
torch.float,
torch.float32,
torch.int,
torch.device("cpu:0"),
"cpu",
],
)
default_values = [
arg.default_value
for arg in torch.ops._torch_testing.f.default._schema.arguments
]
# enum values taken from c10/core/ScalarType.h
type_enum = {
"float": 6,
"int": 3,
}
self.assertEqual(
default_values,
[
None,
None,
3.14,
True,
3,
"foo",
type_enum["float"],
type_enum["float"],
type_enum["int"],
torch.device("cpu:0"),
torch.device("cpu"),
],
)
def test_mutated_error(self):

View File

@ -100,7 +100,7 @@ def infer_schema(prototype_function: typing.Callable, mutates_args=()) -> str:
default_repr = None
if param.default is None or isinstance(param.default, (int, float, bool)):
default_repr = str(param.default)
elif isinstance(param.default, str):
elif isinstance(param.default, (str, torch.device)):
default_repr = f'"{param.default}"'
elif isinstance(param.default, torch.dtype):
dtype_repr = str(param.default)