no-op torch.library.custom_op APIs on torch.deploy (#139509)

We forgot this case in the previous PR. Fixes
https://github.com/pytorch/pytorch/issues/137536

Test Plan:
- better tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139509
Approved by: https://github.com/williamwen42
This commit is contained in:
rzou 2024-11-01 12:39:22 -07:00 committed by PyTorch MergeBot
parent 6dada2136a
commit 85c3c4132d
4 changed files with 39 additions and 17 deletions

View File

@ -490,6 +490,24 @@ def sin_override(x):
m.impl("sin", sin_override, "CompositeImplicitAutograd")
x = torch.randn(3)
y = torch.sin(x)
# should be a no-op
@torch.library.custom_op("mylib::foobar", mutates_args={})
def foobar(x: torch.Tensor) -> torch.Tensor:
return x.sin()
# should be a no-op
@foobar.register_fake
def _(x):
return torch.empty_like(x)
# should be a no-op
m2.define("foobarbaz9996(Tensor x) -> Tensor")
# should be a no-op
@torch.library.register_fake("mylib4392::foobarbaz9996")
def _(x):
return torch.empty_like(x)
"""
script = script.strip()
env = os.environ.copy()

View File

@ -553,6 +553,10 @@ class CustomOpDef:
self._setup_context_fn = setup_context
def _register_to_dispatcher(self) -> None:
if torch._running_with_deploy():
utils.warn_deploy(stacklevel=5)
return
lib = self._lib
schema_str = self._name + self._schema
cpp_schema = _C.parse_schema(schema_str)

View File

@ -2,6 +2,7 @@
import dataclasses
import inspect
import sys
import warnings
from typing import Any, Callable, Dict, Iterable, Iterator, Tuple, Union
import torch
@ -10,6 +11,15 @@ from torch import _C, _utils_internal
from torch._ops import OpOverload
def warn_deploy(stacklevel=3):
warnings.warn(
"Python torch.library APIs do nothing under torch::deploy (multipy). "
"Please instead use C++ custom operator registration APIs.",
RuntimeWarning,
stacklevel=stacklevel,
)
@dataclasses.dataclass
class Kernel:
"""Models a (function, source location)"""

View File

@ -5,7 +5,6 @@ import inspect
import re
import sys
import traceback
import warnings
import weakref
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
from typing_extensions import deprecated
@ -54,15 +53,6 @@ def fallthrough_kernel():
raise NotImplementedError("fallthrough_kernel() should never be called.")
def _warn_deploy():
warnings.warn(
"Python torch.library APIs do nothing under torch::deploy (multipy). "
"Please instead use C++ custom operator registration APIs.",
RuntimeWarning,
stacklevel=3,
)
class Library:
"""
A class to create libraries that can be used to register new operators or
@ -92,7 +82,7 @@ class Library:
" is a reserved namespace. Please try creating a library with another name.",
)
if torch._running_with_deploy():
_warn_deploy()
_library.utils.warn_deploy()
return
frame = traceback.extract_stack(limit=3)[0]
@ -143,7 +133,7 @@ class Library:
>>> my_lib.define("sum(Tensor self) -> Tensor")
"""
if torch._running_with_deploy():
_warn_deploy()
_library.utils.warn_deploy()
return
# This is added because we also want to disallow PURE_FUNCTION alias analysis which is a valid
@ -178,7 +168,7 @@ class Library:
def _register_fake(self, op_name, fn, _stacklevel=1):
r"""Registers the fake impl for an operator defined in the library."""
if torch._running_with_deploy():
_warn_deploy()
_library.utils.warn_deploy()
return
source = torch._library.utils.get_source(_stacklevel + 1)
@ -222,7 +212,7 @@ class Library:
(mode, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any
"""
if torch._running_with_deploy():
_warn_deploy()
_library.utils.warn_deploy()
return
qualname = f"{self.ns}::{op_name}"
@ -243,7 +233,7 @@ class Library:
>>> my_lib._impl_with_aoti_compile("div.Tensor", "CPU")
"""
if torch._running_with_deploy():
_warn_deploy()
_library.utils.warn_deploy()
return
if dispatch_key == "":
@ -300,7 +290,7 @@ class Library:
>>> my_lib.impl("div.Tensor", div_cpu, "CPU")
"""
if torch._running_with_deploy():
_warn_deploy()
_library.utils.warn_deploy()
return
if not callable(fn):
@ -384,7 +374,7 @@ class Library:
>>> my_lib.fallback(fallback_kernel, "Autocast")
"""
if torch._running_with_deploy():
_warn_deploy()
_library.utils.warn_deploy()
return
if dispatch_key == "":