mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[custom ops] add default value support for device types (#129792)
Fixes #129371 I think the first case in Issue #129371 is already supported in the current code? Since it takes care of string default values. This PR adds support for device type default values. Pull Request resolved: https://github.com/pytorch/pytorch/pull/129792 Approved by: https://github.com/zou3519
This commit is contained in:
parent
d7680a564b
commit
aa0352ca38
|
|
@ -678,6 +678,29 @@ class TestCustomOp(CustomOpTestCaseBase):
|
|||
"""(Tensor(a0!) x, Tensor(a1!)[] y, Tensor(a2!)[] z, Tensor(a3!)?[] w) -> ()""",
|
||||
)
|
||||
|
||||
def h(
|
||||
x: Tensor,
|
||||
a: Optional[int] = None,
|
||||
b: float = 3.14,
|
||||
c: bool = True,
|
||||
d: int = 3,
|
||||
e: str = "foo",
|
||||
f: torch.dtype = torch.float,
|
||||
g: torch.dtype = torch.float32,
|
||||
h: torch.dtype = torch.int,
|
||||
i: torch.device = torch.device("cpu:0"),
|
||||
j: torch.device = "cpu",
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
self.assertExpectedInline(
|
||||
infer_schema(h),
|
||||
(
|
||||
"""(Tensor x, SymInt? a=None, float b=3.14, bool c=True, SymInt d=3, str e="foo", """
|
||||
"""ScalarType f=float32, ScalarType g=float32, ScalarType h=int32, Device i="cpu:0", Device j="cpu") -> ()"""
|
||||
),
|
||||
)
|
||||
|
||||
def test_infer_schema_unsupported(self):
|
||||
with self.assertRaisesRegex(ValueError, "varargs"):
|
||||
|
||||
|
|
@ -2439,15 +2462,53 @@ class TestCustomOpAPI(TestCase):
|
|||
f: torch.dtype = torch.float,
|
||||
g: torch.dtype = torch.float32,
|
||||
h: torch.dtype = torch.int,
|
||||
i: torch.device = torch.device("cpu:0"),
|
||||
j: torch.device = "cpu",
|
||||
) -> Tensor:
|
||||
defaults.extend([a, b, c, d, e, f, g, h])
|
||||
defaults.extend([a, b, c, d, e, f, g, h, i, j])
|
||||
return x.clone()
|
||||
|
||||
x = torch.randn(3)
|
||||
f(x)
|
||||
self.assertEqual(
|
||||
defaults,
|
||||
[None, 3.14, True, 3, "foo", torch.float, torch.float32, torch.int],
|
||||
[
|
||||
None,
|
||||
3.14,
|
||||
True,
|
||||
3,
|
||||
"foo",
|
||||
torch.float,
|
||||
torch.float32,
|
||||
torch.int,
|
||||
torch.device("cpu:0"),
|
||||
"cpu",
|
||||
],
|
||||
)
|
||||
default_values = [
|
||||
arg.default_value
|
||||
for arg in torch.ops._torch_testing.f.default._schema.arguments
|
||||
]
|
||||
# enum values taken from c10/core/ScalarType.h
|
||||
type_enum = {
|
||||
"float": 6,
|
||||
"int": 3,
|
||||
}
|
||||
self.assertEqual(
|
||||
default_values,
|
||||
[
|
||||
None,
|
||||
None,
|
||||
3.14,
|
||||
True,
|
||||
3,
|
||||
"foo",
|
||||
type_enum["float"],
|
||||
type_enum["float"],
|
||||
type_enum["int"],
|
||||
torch.device("cpu:0"),
|
||||
torch.device("cpu"),
|
||||
],
|
||||
)
|
||||
|
||||
def test_mutated_error(self):
|
||||
|
|
|
|||
|
|
@ -100,7 +100,7 @@ def infer_schema(prototype_function: typing.Callable, mutates_args=()) -> str:
|
|||
default_repr = None
|
||||
if param.default is None or isinstance(param.default, (int, float, bool)):
|
||||
default_repr = str(param.default)
|
||||
elif isinstance(param.default, str):
|
||||
elif isinstance(param.default, (str, torch.device)):
|
||||
default_repr = f'"{param.default}"'
|
||||
elif isinstance(param.default, torch.dtype):
|
||||
dtype_repr = str(param.default)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user