mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
no-op torch.library.custom_op APIs on torch.deploy (#139509)
We forgot this case in the previous PR. Fixes https://github.com/pytorch/pytorch/issues/137536 Test Plan: - better tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/139509 Approved by: https://github.com/williamwen42
This commit is contained in:
parent
6dada2136a
commit
85c3c4132d
|
|
@ -490,6 +490,24 @@ def sin_override(x):
|
|||
m.impl("sin", sin_override, "CompositeImplicitAutograd")
|
||||
x = torch.randn(3)
|
||||
y = torch.sin(x)
|
||||
|
||||
# should be a no-op
|
||||
@torch.library.custom_op("mylib::foobar", mutates_args={})
|
||||
def foobar(x: torch.Tensor) -> torch.Tensor:
|
||||
return x.sin()
|
||||
|
||||
# should be a no-op
|
||||
@foobar.register_fake
|
||||
def _(x):
|
||||
return torch.empty_like(x)
|
||||
|
||||
# should be a no-op
|
||||
m2.define("foobarbaz9996(Tensor x) -> Tensor")
|
||||
|
||||
# should be a no-op
|
||||
@torch.library.register_fake("mylib4392::foobarbaz9996")
|
||||
def _(x):
|
||||
return torch.empty_like(x)
|
||||
"""
|
||||
script = script.strip()
|
||||
env = os.environ.copy()
|
||||
|
|
|
|||
|
|
@ -553,6 +553,10 @@ class CustomOpDef:
|
|||
self._setup_context_fn = setup_context
|
||||
|
||||
def _register_to_dispatcher(self) -> None:
|
||||
if torch._running_with_deploy():
|
||||
utils.warn_deploy(stacklevel=5)
|
||||
return
|
||||
|
||||
lib = self._lib
|
||||
schema_str = self._name + self._schema
|
||||
cpp_schema = _C.parse_schema(schema_str)
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
import dataclasses
|
||||
import inspect
|
||||
import sys
|
||||
import warnings
|
||||
from typing import Any, Callable, Dict, Iterable, Iterator, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
|
@ -10,6 +11,15 @@ from torch import _C, _utils_internal
|
|||
from torch._ops import OpOverload
|
||||
|
||||
|
||||
def warn_deploy(stacklevel=3):
|
||||
warnings.warn(
|
||||
"Python torch.library APIs do nothing under torch::deploy (multipy). "
|
||||
"Please instead use C++ custom operator registration APIs.",
|
||||
RuntimeWarning,
|
||||
stacklevel=stacklevel,
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Kernel:
|
||||
"""Models a (function, source location)"""
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ import inspect
|
|||
import re
|
||||
import sys
|
||||
import traceback
|
||||
import warnings
|
||||
import weakref
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
|
||||
from typing_extensions import deprecated
|
||||
|
|
@ -54,15 +53,6 @@ def fallthrough_kernel():
|
|||
raise NotImplementedError("fallthrough_kernel() should never be called.")
|
||||
|
||||
|
||||
def _warn_deploy():
|
||||
warnings.warn(
|
||||
"Python torch.library APIs do nothing under torch::deploy (multipy). "
|
||||
"Please instead use C++ custom operator registration APIs.",
|
||||
RuntimeWarning,
|
||||
stacklevel=3,
|
||||
)
|
||||
|
||||
|
||||
class Library:
|
||||
"""
|
||||
A class to create libraries that can be used to register new operators or
|
||||
|
|
@ -92,7 +82,7 @@ class Library:
|
|||
" is a reserved namespace. Please try creating a library with another name.",
|
||||
)
|
||||
if torch._running_with_deploy():
|
||||
_warn_deploy()
|
||||
_library.utils.warn_deploy()
|
||||
return
|
||||
|
||||
frame = traceback.extract_stack(limit=3)[0]
|
||||
|
|
@ -143,7 +133,7 @@ class Library:
|
|||
>>> my_lib.define("sum(Tensor self) -> Tensor")
|
||||
"""
|
||||
if torch._running_with_deploy():
|
||||
_warn_deploy()
|
||||
_library.utils.warn_deploy()
|
||||
return
|
||||
|
||||
# This is added because we also want to disallow PURE_FUNCTION alias analysis which is a valid
|
||||
|
|
@ -178,7 +168,7 @@ class Library:
|
|||
def _register_fake(self, op_name, fn, _stacklevel=1):
|
||||
r"""Registers the fake impl for an operator defined in the library."""
|
||||
if torch._running_with_deploy():
|
||||
_warn_deploy()
|
||||
_library.utils.warn_deploy()
|
||||
return
|
||||
|
||||
source = torch._library.utils.get_source(_stacklevel + 1)
|
||||
|
|
@ -222,7 +212,7 @@ class Library:
|
|||
(mode, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any
|
||||
"""
|
||||
if torch._running_with_deploy():
|
||||
_warn_deploy()
|
||||
_library.utils.warn_deploy()
|
||||
return
|
||||
|
||||
qualname = f"{self.ns}::{op_name}"
|
||||
|
|
@ -243,7 +233,7 @@ class Library:
|
|||
>>> my_lib._impl_with_aoti_compile("div.Tensor", "CPU")
|
||||
"""
|
||||
if torch._running_with_deploy():
|
||||
_warn_deploy()
|
||||
_library.utils.warn_deploy()
|
||||
return
|
||||
|
||||
if dispatch_key == "":
|
||||
|
|
@ -300,7 +290,7 @@ class Library:
|
|||
>>> my_lib.impl("div.Tensor", div_cpu, "CPU")
|
||||
"""
|
||||
if torch._running_with_deploy():
|
||||
_warn_deploy()
|
||||
_library.utils.warn_deploy()
|
||||
return
|
||||
|
||||
if not callable(fn):
|
||||
|
|
@ -384,7 +374,7 @@ class Library:
|
|||
>>> my_lib.fallback(fallback_kernel, "Autocast")
|
||||
"""
|
||||
if torch._running_with_deploy():
|
||||
_warn_deploy()
|
||||
_library.utils.warn_deploy()
|
||||
return
|
||||
|
||||
if dispatch_key == "":
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user