mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
In preparation for the next PR up in the stack, which is going to update "can_auto_functionalize" to support more operators than just ones that return nothing. We are unable to auto-generate FakeTensor kernels for operators that do not return nothing, but we are able to generate functionalization kernels for operators that return something. Test Plan: Existing tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/115134 Approved by: https://github.com/bdhirsh ghstack dependencies: #114955, #114956
99 lines
2.6 KiB
Python
99 lines
2.6 KiB
Python
import dataclasses
|
|
import inspect
|
|
import sys
|
|
from typing import Any, Callable, Tuple
|
|
|
|
import torch
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class Kernel:
|
|
"""Models a (function, source location)"""
|
|
|
|
func: Callable
|
|
source: str
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
return self.func(*args, **kwargs)
|
|
|
|
|
|
class RegistrationHandle:
|
|
"""Does something when someone calls .destroy() on it"""
|
|
|
|
def __init__(self, on_destroy: Callable):
|
|
self._on_destroy = on_destroy
|
|
|
|
def destroy(self) -> None:
|
|
self._on_destroy()
|
|
|
|
|
|
def get_source(stacklevel: int) -> str:
|
|
"""Get a string that represents the caller.
|
|
|
|
Example: "/path/to/foo.py:42"
|
|
|
|
Use stacklevel=1 to get the caller's source
|
|
Use stacklevel=2 to get the caller's caller's source
|
|
etc.
|
|
"""
|
|
frame = inspect.getframeinfo(sys._getframe(stacklevel))
|
|
source = f"{frame.filename}:{frame.lineno}"
|
|
return source
|
|
|
|
|
|
def parse_namespace(qualname: str) -> Tuple[str, str]:
|
|
splits = qualname.split("::")
|
|
if len(splits) != 2:
|
|
raise ValueError(
|
|
f"Expected `qualname` to be of the form "
|
|
f'"namespace::name", but got {qualname}. '
|
|
f"The qualname passed to the torch.library APIs must consist "
|
|
f"of a namespace and a name, e.g. aten::sin"
|
|
)
|
|
return splits[0], splits[1]
|
|
|
|
|
|
def lookup_op(qualname: str) -> torch._ops.OpOverloadPacket:
|
|
namespace, name = parse_namespace(qualname)
|
|
if "." in name:
|
|
name, overload = name.split(".")
|
|
else:
|
|
overload = "default"
|
|
ns = getattr(torch.ops, namespace)
|
|
packet = getattr(ns, name)
|
|
return getattr(packet, overload)
|
|
|
|
|
|
def is_builtin(op: torch._ops.OpOverload) -> bool:
|
|
assert isinstance(op, torch._ops.OpOverload)
|
|
return op.namespace in {"aten", "prim", "prims"}
|
|
|
|
|
|
def is_functional_schema(schema: Any) -> bool:
|
|
"""Check if the schema is functional.
|
|
|
|
An operator is functional if:
|
|
- it does not mutate any of its inputs
|
|
- it does not return a view on any of its inputs
|
|
- it has at least one return
|
|
"""
|
|
|
|
# Lazy import because not all PyTorch builds have torchgen
|
|
from torchgen.model import FunctionSchema, SchemaKind
|
|
|
|
assert isinstance(schema, (str, FunctionSchema))
|
|
if isinstance(schema, str):
|
|
schema = FunctionSchema.parse(schema)
|
|
|
|
if schema.kind() != SchemaKind.functional:
|
|
return False
|
|
rets = schema.returns
|
|
is_non_mutating_view = len(rets) > 0 and any(
|
|
r.annotation is not None and not r.annotation.is_write for r in rets
|
|
)
|
|
if is_non_mutating_view:
|
|
return False
|
|
if not schema.returns:
|
|
return False
|
|
return True
|