pytorch/torch/library.py
Sherlock Huang eb99c1efce Prefer python meta function over c++ meta function (#87426)
This is a policy update for meta registration. **We now prefer python meta implementation over C++ meta function.**  This is a flip of the previous policy, where we prefer C++ meta function over python meta function if they both exist.

Here's the meta registration process:
1. register_meta and register_decomposition will place the python meta/decomp functions into the `global_decomp_table`.  However, they will NOT register them into dispatcher.
2. After global_decomp_table is populated, we will compile an `active_meta_table`. For a given op, we pick the most specific decomp function from `global_decomp_table` in the preference order of Meta > PostAutograd > PreAutograd.
3. We will unconditionally register all of them into python dispatcher. And register them into C++ dispatcher, unless it one of the following 3 cases
- 1. the op is a CompositeImplicitAutograd, and should rely on decomposed op's meta
- 2. the op is a view op, as the MetaTensor doesn't support aliased storage
- 3. the op is in the blocklist (due to UT failures, and we will burn down this list op by op)

Over the long run, we wish to implement all meta functions in python. With this PR, 321 op_overloads will have cpp meta overridden by python meta. There are still 400 op_overloads is using cpp meta. The exact list can be found here https://gist.github.com/SherlockNoMad/d20bb736178df8eebd3b054c8bb7cdc5

cc @ngimel @jansel @lezcano @fdrocha @mlazos @soumith @voznesenskym @yanboliang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87426
Approved by: https://github.com/ezyang, https://github.com/jansel
2022-10-25 16:49:02 +00:00

155 lines
6.8 KiB
Python

from ._ops import OpOverload
from typing import Set
import traceback
import torch
import os
__all__ = ['Library', 'impl', 'define']
# Set containing the combination of (namespace, operator, DispatchKey) for which a new kernel has been registered
# The keys in the set are of the form `namespace + "/" + op_name + "/" + dispatch_key`.
# This set is maintained to ensure that two libraries don't try to override the exact same functionality to avoid
# libraries calling into kernels not intended to be called.
_impls: Set[str] = set()
# prim is reserved by TorchScript interpreter
_reserved_namespaces = ['prim']
class Library:
"""
A class to create libraries that can be used to register new operators or
override operators in existing libraries from Python.
A user can optionally pass in a dispatch keyname if they only want to register
kernels corresponding to only one specific dispatch key.
To create a library to override operators in an existing library (with name ns), set the kind to "IMPL".
To create a new library (with name ns) to register new operators, set the kind to "DEF".
Args:
ns: library name
kind: "DEF", "IMPL" (default: "IMPL")
dispatch_key: PyTorch dispatch key (default: "")
"""
def __init__(self, ns, kind, dispatch_key=""):
if os.environ.get('PYTORCH_DISABLE_LIBRARY', "0") == "1":
raise RuntimeError("Trying to use torch.library in an environment where it is disabled")
if kind != "IMPL" and kind != "DEF":
raise ValueError("Unsupported kind: ", kind)
if ns in _reserved_namespaces and kind == "DEF":
raise ValueError(ns, " is a reserved namespace. Please try creating a library with another name.")
frame = traceback.extract_stack(limit=3)[0]
filename, lineno = frame.filename, frame.lineno
self.m = torch._C._dispatch_library(kind, ns, dispatch_key, filename, lineno)
self.ns = ns
self._op_impls = set()
self.kind = kind
self.dispatch_key = dispatch_key
def __repr__(self):
return "Library(kind={}, ns={}, dispatch_key={})>".format(self.kind, self.ns, self.dispatch_key)
def define(self, schema, alias_analysis=""):
r'''Defines a new operator and its semantics in the ns namespace.
Args:
schema: function schema to define a new operator.
alias_analysis (optional): Indicates if the aliasing properties of the operator arguments can be
inferred from the schema (default behavior) or not ("CONSERVATIVE").
Returns:
name of the operator as inferred from the schema.
Example::
>>> my_lib = Library("foo", "DEF")
>>> my_lib.define("sum(Tensor self) -> Tensor")
'''
# This is added because we also want to disallow PURE_FUNCTION alias analysis which is a valid
# AliasAnalysis type in C++
if alias_analysis not in ["", "FROM_SCHEMA", "CONSERVATIVE"]:
raise RuntimeError("Invalid alias_analysis type {}".format(alias_analysis))
return self.m.define(schema, alias_analysis)
def impl(self, op_name, fn, dispatch_key=''):
r'''Registers the function implementation for an operator defined in the library.
Args:
op_name: operator name (along with the overload) or OpOverload object.
fn: function that's the operator implementation for the input dispatch key.
dispatch_key: dispatch key that the input function should be registered for. By default, it uses
the dispatch key that the library was created with.
Example::
>>> # xdoctest: +SKIP
>>> my_lib = Library("aten", "IMPL")
>>> def div_cpu(self, other):
>>> return self * (1 / other)
>>> my_lib.impl("div.Tensor", "CPU")
'''
if not callable(fn):
raise TypeError("Input function is required to be a callable but found type {}".format(type(fn)))
if dispatch_key == '':
dispatch_key = self.dispatch_key
if isinstance(op_name, str):
name = op_name
elif isinstance(op_name, OpOverload):
name = op_name._schema.name
overload_name = op_name._schema.overload_name
if overload_name != '':
name = name + '.' + overload_name
else:
raise RuntimeError("impl should be passed either a name or an OpOverload object as the first argument")
key = self.ns + "/" + name.split("::")[-1] + "/" + dispatch_key
if key in _impls:
# TODO: in future, add more info about where the existing function is registered (this info is
# today already returned by the C++ warning when impl is called but we error out before that)
raise RuntimeError("This is not allowed since there's already a kernel registered from python overriding {}"
"'s behavior for {} dispatch key and {} namespace.".
format(name.split("::")[-1], dispatch_key, self.ns))
if dispatch_key == "Meta":
dispatcher_op_name = name
if '::' not in dispatcher_op_name:
dispatcher_op_name = f'{self.ns}::{dispatcher_op_name}'
# Internally, we shouldn't be registering meta kernels for any operators that
# have CompositeImplicitAutograd kernels.
# Instead, we should be letting those decompositions run, and writing meta kernels
# only for the base operators.
if torch._C._dispatch_has_kernel_for_dispatch_key(dispatcher_op_name, "CompositeImplicitAutograd"):
raise RuntimeError(
f"We should not register a meta kernel directly to the operator '{name}',"
" because it has a CompositeImplicitAutograd kernel in core."
" Instead we should let the operator decompose, and ensure that we have meta kernels"
" for the base ops that it decomposes into.")
self.m.impl(name, dispatch_key, fn)
_impls.add(key)
self._op_impls.add(key)
def __del__(self):
# _op_impls might not have been initialized if an error was thrown in __init__
_op_impls_ = getattr(self, '_op_impls', None)
if _op_impls_:
for key in self._op_impls:
_impls.remove(key)
del self.m
# decorator to register python functions for library ops
# Note: this decorator API should remain consistent with `Library.impl` API
def impl(lib, name, dispatch_key=""):
def wrap(f):
lib.impl(name, f, dispatch_key)
return f
return wrap
def define(lib, schema, alias_analysis=""):
def wrap(f):
name = lib.define(schema, alias_analysis)
lib.impl(name, f)
return f
return wrap