Revert "[reland][custom ops] infer schema (#130079)"

This reverts commit bef085bdfa.

Reverted https://github.com/pytorch/pytorch/pull/130079 on behalf of https://github.com/izaitsevfb due to depends on #130064 which needs to be reverted ([comment](https://github.com/pytorch/pytorch/pull/130079#issuecomment-2221561483))
This commit is contained in:
PyTorch MergeBot 2024-07-10 21:40:16 +00:00
parent 46c52661bc
commit e14a0f45ed
4 changed files with 10 additions and 57 deletions

View File

@ -44,7 +44,6 @@ via PyTorch's C++ operator registration APIs).
.. autofunction:: register_fake
.. autofunction:: impl_abstract
.. autofunction:: get_ctx
.. autofunction:: infer_schema
.. autofunction:: register_torch_dispatch
Low-level APIs

View File

@ -710,12 +710,6 @@ class TestCustomOp(CustomOpTestCaseBase):
),
)
def foo_impl(x: torch.Tensor) -> torch.Tensor:
return x.sin()
schema = torch.library.infer_schema(foo_impl, op_name="myop", mutates_args={})
self.assertExpectedInline(schema, "myop(Tensor x) -> Tensor")
def test_infer_schema_unsupported(self):
with self.assertRaisesRegex(ValueError, "varargs"):
@ -3305,16 +3299,6 @@ Please use `add.register_fake` to add an fake impl.""",
self.assertEqual(result.device, torch.device("cpu"))
self.assertEqual(result, torch.ones(3))
def test_library_schema_infer(self):
def foo_impl(x: torch.Tensor) -> torch.Tensor:
return x.sin()
schema = torch.library.infer_schema(foo_impl, op_name="myop", mutates_args={})
self.assertExpectedInline(schema, "myop(Tensor x) -> Tensor")
schema = torch.library.infer_schema(foo_impl, mutates_args={})
self.assertExpectedInline(schema, "(Tensor x) -> Tensor")
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
def test_set_kernel_enabled(self):
x = torch.ones(1)

View File

@ -4,48 +4,22 @@ import typing
from typing import List, Optional, Sequence, Union # noqa: F401
import torch # noqa: F401
from torch.utils._exposed_in import exposed_in
from .. import device, dtype, Tensor, types
@exposed_in("torch.library")
def infer_schema(
prototype_function: typing.Callable, mutates_args=(), op_name: Optional[str] = None
) -> str:
r"""Parses the schema of a given function with type hints. The schema is inferred from the
function's type hints, and can be used to define a new operator.
def infer_schema(prototype_function: typing.Callable, mutates_args=()) -> str:
"""Given a function with type hints, parses a schema.
We make the following assumptions:
* None of the outputs alias any of the inputs or each other.
* | String type annotations "device, dtype, Tensor, types" without library specification are
| assumed to be torch.*. Similarly, string type annotations "Optional, List, Sequence, Union"
| without library specification are assumed to be typing.*.
* | Only the args listed in ``mutates_args`` are being mutated. If ``mutates_args`` is "unknown",
| it assumes that all inputs to the operator are being mutates.
We make some assumptions to make our lives easier that correspond to how people
write custom ops in real life:
- none of the outputs alias any of the inputs or each other.
- only the args listed in ``mutates_args`` are being mutated. If ``mutates_args`` is "unknown",
it assumes that all inputs to the operator are being mutates.
- string type annotations "device, dtype, Tensor, types" without library specification
are assumed to be torch.*. Similarly, string type annotations "Optional, List, Sequence, Union"
without library specification are assumed to be typing.*.
Callers (e.g. the custom ops API) are responsible for checking these assumptions.
Args:
prototype_function: The function from which to infer a schema for from its type annotations.
op_name (Optional[str]): The name of the operator in the schema. If ``name`` is None, then the
name is not included in the inferred schema. Note that the input schema to
``torch.library.Library.define`` requires a operator name.
mutates_args ("unknown" | Iterable[str]): The arguments that are mutated in the function.
Returns:
The inferred schema.
Example:
>>> def foo_impl(x: torch.Tensor) -> torch.Tensor:
>>> return x.sin()
>>>
>>> infer_schema(foo_impl, op_name="foo", mutates_args={})
foo(Tensor x) -> Tensor
>>>
>>> infer_schema(foo_impl, mutates_args={})
(Tensor x) -> Tensor
"""
UNKNOWN_MUTATES = "unknown"
sig = inspect.signature(prototype_function)
@ -152,8 +126,6 @@ def infer_schema(
if type(return_annotation) == str:
return_annotation = convert_type_string(return_annotation)
ret = parse_return(return_annotation, error_fn)
if op_name is not None:
return f"{op_name}({', '.join(params)}) -> {ret}"
return f"({', '.join(params)}) -> {ret}"

View File

@ -17,7 +17,6 @@ from torch._library.custom_ops import (
CustomOpDef,
device_types_t,
)
from torch._library.infer_schema import infer_schema # noqa: F401
from torch._ops import OpOverload
@ -31,7 +30,6 @@ __all__ = [
"register_torch_dispatch",
"get_ctx",
"custom_op",
"infer_schema",
]
# Set containing the combination of (namespace, operator, DispatchKey) for which a new kernel has been registered