mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "[reland][custom ops] infer schema (#130079)"
This reverts commit bef085bdfa.
Reverted https://github.com/pytorch/pytorch/pull/130079 on behalf of https://github.com/izaitsevfb due to depends on #130064 which needs to be reverted ([comment](https://github.com/pytorch/pytorch/pull/130079#issuecomment-2221561483))
This commit is contained in:
parent
46c52661bc
commit
e14a0f45ed
|
|
@ -44,7 +44,6 @@ via PyTorch's C++ operator registration APIs).
|
|||
.. autofunction:: register_fake
|
||||
.. autofunction:: impl_abstract
|
||||
.. autofunction:: get_ctx
|
||||
.. autofunction:: infer_schema
|
||||
.. autofunction:: register_torch_dispatch
|
||||
|
||||
Low-level APIs
|
||||
|
|
|
|||
|
|
@ -710,12 +710,6 @@ class TestCustomOp(CustomOpTestCaseBase):
|
|||
),
|
||||
)
|
||||
|
||||
def foo_impl(x: torch.Tensor) -> torch.Tensor:
|
||||
return x.sin()
|
||||
|
||||
schema = torch.library.infer_schema(foo_impl, op_name="myop", mutates_args={})
|
||||
self.assertExpectedInline(schema, "myop(Tensor x) -> Tensor")
|
||||
|
||||
def test_infer_schema_unsupported(self):
|
||||
with self.assertRaisesRegex(ValueError, "varargs"):
|
||||
|
||||
|
|
@ -3305,16 +3299,6 @@ Please use `add.register_fake` to add an fake impl.""",
|
|||
self.assertEqual(result.device, torch.device("cpu"))
|
||||
self.assertEqual(result, torch.ones(3))
|
||||
|
||||
def test_library_schema_infer(self):
|
||||
def foo_impl(x: torch.Tensor) -> torch.Tensor:
|
||||
return x.sin()
|
||||
|
||||
schema = torch.library.infer_schema(foo_impl, op_name="myop", mutates_args={})
|
||||
self.assertExpectedInline(schema, "myop(Tensor x) -> Tensor")
|
||||
|
||||
schema = torch.library.infer_schema(foo_impl, mutates_args={})
|
||||
self.assertExpectedInline(schema, "(Tensor x) -> Tensor")
|
||||
|
||||
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
|
||||
def test_set_kernel_enabled(self):
|
||||
x = torch.ones(1)
|
||||
|
|
|
|||
|
|
@ -4,48 +4,22 @@ import typing
|
|||
from typing import List, Optional, Sequence, Union # noqa: F401
|
||||
|
||||
import torch # noqa: F401
|
||||
|
||||
from torch.utils._exposed_in import exposed_in
|
||||
from .. import device, dtype, Tensor, types
|
||||
|
||||
|
||||
@exposed_in("torch.library")
|
||||
def infer_schema(
|
||||
prototype_function: typing.Callable, mutates_args=(), op_name: Optional[str] = None
|
||||
) -> str:
|
||||
r"""Parses the schema of a given function with type hints. The schema is inferred from the
|
||||
function's type hints, and can be used to define a new operator.
|
||||
def infer_schema(prototype_function: typing.Callable, mutates_args=()) -> str:
|
||||
"""Given a function with type hints, parses a schema.
|
||||
|
||||
We make the following assumptions:
|
||||
|
||||
* None of the outputs alias any of the inputs or each other.
|
||||
* | String type annotations "device, dtype, Tensor, types" without library specification are
|
||||
| assumed to be torch.*. Similarly, string type annotations "Optional, List, Sequence, Union"
|
||||
| without library specification are assumed to be typing.*.
|
||||
* | Only the args listed in ``mutates_args`` are being mutated. If ``mutates_args`` is "unknown",
|
||||
| it assumes that all inputs to the operator are being mutates.
|
||||
We make some assumptions to make our lives easier that correspond to how people
|
||||
write custom ops in real life:
|
||||
- none of the outputs alias any of the inputs or each other.
|
||||
- only the args listed in ``mutates_args`` are being mutated. If ``mutates_args`` is "unknown",
|
||||
it assumes that all inputs to the operator are being mutates.
|
||||
- string type annotations "device, dtype, Tensor, types" without library specification
|
||||
are assumed to be torch.*. Similarly, string type annotations "Optional, List, Sequence, Union"
|
||||
without library specification are assumed to be typing.*.
|
||||
|
||||
Callers (e.g. the custom ops API) are responsible for checking these assumptions.
|
||||
|
||||
Args:
|
||||
prototype_function: The function from which to infer a schema for from its type annotations.
|
||||
op_name (Optional[str]): The name of the operator in the schema. If ``name`` is None, then the
|
||||
name is not included in the inferred schema. Note that the input schema to
|
||||
``torch.library.Library.define`` requires a operator name.
|
||||
mutates_args ("unknown" | Iterable[str]): The arguments that are mutated in the function.
|
||||
|
||||
Returns:
|
||||
The inferred schema.
|
||||
|
||||
Example:
|
||||
>>> def foo_impl(x: torch.Tensor) -> torch.Tensor:
|
||||
>>> return x.sin()
|
||||
>>>
|
||||
>>> infer_schema(foo_impl, op_name="foo", mutates_args={})
|
||||
foo(Tensor x) -> Tensor
|
||||
>>>
|
||||
>>> infer_schema(foo_impl, mutates_args={})
|
||||
(Tensor x) -> Tensor
|
||||
"""
|
||||
UNKNOWN_MUTATES = "unknown"
|
||||
sig = inspect.signature(prototype_function)
|
||||
|
|
@ -152,8 +126,6 @@ def infer_schema(
|
|||
if type(return_annotation) == str:
|
||||
return_annotation = convert_type_string(return_annotation)
|
||||
ret = parse_return(return_annotation, error_fn)
|
||||
if op_name is not None:
|
||||
return f"{op_name}({', '.join(params)}) -> {ret}"
|
||||
return f"({', '.join(params)}) -> {ret}"
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -17,7 +17,6 @@ from torch._library.custom_ops import (
|
|||
CustomOpDef,
|
||||
device_types_t,
|
||||
)
|
||||
from torch._library.infer_schema import infer_schema # noqa: F401
|
||||
from torch._ops import OpOverload
|
||||
|
||||
|
||||
|
|
@ -31,7 +30,6 @@ __all__ = [
|
|||
"register_torch_dispatch",
|
||||
"get_ctx",
|
||||
"custom_op",
|
||||
"infer_schema",
|
||||
]
|
||||
|
||||
# Set containing the combination of (namespace, operator, DispatchKey) for which a new kernel has been registered
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user