pytorch/tools/codegen/gen.py
Brian Hirsh 0032fa7725 Add a Functionalization pass in core (#64432)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64432

Original PR description + feedback here: https://github.com/pytorch/pytorch/pull/63048

I've addressed all of the feedback in the original PR and made some pretty large changes, listed below.

**Table of Contents**
- Starting points
- List of the main changes from the original PR
- Next Steps
- Example codegen output (for a view, mutation, and view+mutation op)

**Starting Points**

A good place to start when looking through the PR:
* Alban mentioned that this is a useful mental model (thanks Ed for originally making this clear to me). Semantically, the pass currently does THREE things, which are all needed by functorch - all fused together into one big pass.
  * (a) alias removal, which replaces {view} calls with {view}_copy calls, and manually tracks aliasing information, so that when one tensor is mutated, we re-apply the same mutation to all of the aliases. This is the bulk of the work - once this is done, the next 2 things are trivial to implement.
  * (b) mutation removal, which is easy to do once we know that there are no aliases. Every mutation `a.add_(b)` becomes `a.replace_(a.add(b))`
  * (c) reapplying views: all of the `{view}_copy` calls are replaced with `{view}` calls again. This is an optimization that we can make specifically for functorch (and strided backends), that only care about mutation removal and not alias removal
  * XLA and Vulkan only want (a), or (a) + (b). Later, we'll want to split this out so that you can actually opt into different versions of this logic.
  * There is currently no {view}_copy replacement, because the pass just <replace views with copies> and <replace copies with views> steps have been combined. Later, we'll want to actually implement {view}_copy variants of each view operator, probably with codegen.
* documentation breadcrumb 1, in `FunctionalTensorWrapper.cpp`: https://github.com/pytorch/pytorch/pull/64432/files#diff-a0bac99bf205dba5b94cb64fc2466d3d55d991887572f9cd6a02e27b3a91dd60R59 (you might have to expand the `FunctionalTensorWrapper.cpp` file, which GitHub closes by default because it's large)
* documentation breadcrumb 2, in `FunctionalTensorWrapper.h`: https://github.com/pytorch/pytorch/pull/64432/files#diff-c945c71a4ccac65871f24a912e8904f9a5088b24a32e636727ea9c8fe920708aR12
* Reading through the codegen output at the bottom of this description.

**Main changes from the original PR**

(1)  I use lambdas instead of a giant enum to handle all of the different views.

This results in less boilerplate per view op (and more stuff that can be codegen'd). Every `ViewMeta` object now contains a `forward` and `reverse` lambda, that knows how to replay the view and its inverse. This makes the actual code that executes the replaying logic a lot less boilerplate-y (see `Alias::sync_update_operations` and `FunctionalTensorWrapper::sync_`)

(2) Every tensor during the functionalization pass is always wrapped in a `FunctionalTensorWrapper`.

This is potentially unnecessary for Vulkan/XLA, and will have a mild perf impact, but for now this PR just targets the functorch use case. I previously had a complicated design a (`FunctionalTensorImplBase` class) to avoid needing the wrapper for XLA, but it had some subtleties that are gonna require more thought to fix, so I'm pushing that off for now.

(3) `FunctionalTensorWrapper` objects accurately report stride information.

It's a little annoying to do this though, because the logic that calculates stride info for each view isn't easily separated from the actual view kernels in core, `at::native::{view}`. I do this by adding logic in each `at::functionalization::{view}` kernel to call the reference implementation `at::native::{view}`. I don't do anything with the output aside from taking it's size/stride/storage_offset to set the actual output tensor's size/stride/storage_offset correctly. There's another annoying part to this: I'm pretty sure that we want to pass in the actual *wrapper* tensors directly into the native kernels, not their inner unwrapped values. But there are some `at::native::{view}` kernels that call other tensor methods, which re-invokes the dispatcher, calling functionalization/functorch kernels that try do the unwrapping.

To do this, right now I have an `AutoDispatchDirectlyToNative` guard that basically ensures that any tensor methods called inside of the at::native::{view} op always redispatch straight to the CPU kernel (which will be another at::native:: kernel). This feels kind of heavy handed, but I'm not sure of a better way to do it.

(4) `FunctionalTensorWrapper` objects accurately report aliasing information.

There's a new `FunctionalStorageImpl` class (subclass of `StorageImpl`) that allows tensors in the functionalization pass to accurately alias storage. If two tensors `a` and `b` in a functionalized program are views of one another, then `a.storage.is_alias_of(b.storage)` should return true. I added this in a pretty similar way to how meta tensors allocate storage, although I don't pass in an actual allocator (I think this is fine because you should never resize a functional tensor's storage).

One thing I'm not sure about - should `FunctionalTensorWrapper` set `storage_access_should_throw_`: (a) always, (b) never, (c) only if its wrapped tensor has it set.

Right now I have it not set, mostly because calling the reference view functions (`at::native::{view}`) requires looking at the storage. But that means that if you try to access storage from python in a functionalized program, you'll get silent garbage instead of an error. Related question: are we planning on exposing meta tensor storage to python in the future (even though it contains garbage)?

(5) better docs :)

**View operator coverage**

(6) The functionalization pass now gets math-composite view ops for free.

I didn't add the `Functionalize` dispatch key to the composite set, because I don't want composite ops like `torch.ones` to get decomposed before hitting the functionalization pass. Instead, I added codegen to manually register the `at::native::` kernels of composite view ops. This is a little hairy, because the names of the `at::native::` kernels aren't easily accessible. They're stored in a `Dict[DispatchKey, BackendIndex]`. I made a best-effort attempt to get each view kernel's name, basically by assuming that every view op has either a composite or cpu implementation.
There's also a hardcoded list of composite view ops in `gen_inplace_or_view_type.py`, but it looks like it's wrong. This is probably worth rationalizing later, but instead I created a new list of the "complete" set of composite view ops, and preserved the old set by hardcoding the delta between the two sets.

(7) I've added codegen for ops that are both views AND mutations, like `transpose_()` (why do we even have these {emoji:1f622}).

From some light testing, it looks like they work correctly with one caveat: I had a hard time ensuring that functorch programs that mutate their inputs using ops like `transpose_()` preserve the input mutations after the program finishes running. For (in my corresponding functorch branch) I emit a warning when this happens, and just don't preserve the mutation

(8) I added `{view}_inverse` implementations for every view op, in `FunctionalInverses.cpp`.

These are needed to take mutations made to views and replay them back onto the base. To reduce boilerplate, the codegen generates function declarations for each `{view}_inverse` function, so you get a nice compiler error when someone eventually adds a new view op.

The only view ops currently not supported are (a) as_strided, and (b) the sparse view ops (values()/indices()).

I can add support for as_strided, but it needs an `as_strided_inverse()` function. That will look really similar to the `as_strided_backward()` function in FunctionsManual.cpp, but it has some noticeable differences: we basically want an `as_strided_embed` for autograd and `as_strided_scatter` for functionalization. We also will probably need them to be primitives w.r.t to autograd, since the currently implementation for autograd uses view().copy_() calls that XLA won't be able to handle. I'm wondering if anyone has any objections, but otherwise I can make those change (which will require writing backward formulas for `as_strided_embed` and `as_strided_scatter`).

I did a bunch of manual testing that all looks pretty good, but it's definitely not fully tested. Ed pointed out that once XLA uses this pass (or at least once there's a POC), we can just run the existing xla view test suite. Hopefully that delay is okay - if it's not, maybe we can think about using OpInfos similar to how functorch uses them for testing.

Note: there's some duplication with autograd's view code. Every `{view}_inverse` implementation is really similar to the implementation for that view listed in `derivatives.yaml`. There are some major differences though:
* the autograd implementations over those backwards functions (like `permute_backwards()`, in `FunctionsManual.cpp`) internally call other view ops. For functoinalization, we want them to (eventually call `{view}_copy` operators).
* For view ops that take a subset of the original storage, like `slice/select/diagonal/as_strided()`, the autograd backward functions fill the "spaces" in the inverse call with zeroes. For functionalizations, we want to fill them with the value of `base` at those positions. It looks like this currently applies to 6 total ops (since we can ignore composites):
  * select
  * slice
  * diagonal
  * as_stridied
  * split
  * split_with_sizes
A nice end state would probably be for the autograd + functoinalization codegen to both look at the same yaml (either `derivatives.yaml`, or something else), and automatically generate the right thing. I didn't leave that in scope for this PR though.

**Current State + Next Steps**

There are a bunch of followups after this PR eventually lands. Roughly in order:
* Use the current pass to register problematic composite ops in functorch. Also, nested `functionalize()` calls aren't supported yet (I mostly just need to remove some debug asserts and test it).
* Work on freeing up dispatch key space in the by deduplicating the `{backend}`/`Autograd{backend}`/`Sparse{backend}`/`Quantized{backend}` keys
* Once we have more dispatch keys, split up this pass into 3 pieces - it's currently fused, and doesn't do the right thing for vulkan/XLA. Specifically, all of the `{view}` calls in the current pass's view-replay logic should turn into `{view}_copy` calls that vulkan/XLA know how to implement, and there will be separate passes for (a) removing mutations, and (b) turning `{view}_copy` calls back into `{view}` calls. For Vulkan, we eventually want a pass that ONLY removes aliasing and view calls, and doesn't remove mutations. We can also probably make the 2 new passes user dispatch keys to save dispatch key space, if they'll only be used by functorch anyway.
* Do more of a dive on perf for the vulkan/xla use cases. There are several areas to improve perf with varying levels of effort required. The simplest one that I'll probably do regardless is to codegen the out-of-place kernels instead of using a boxed fallback. Getting a POC working for xla will also be useful to test the view operator coverage.

**Example Codegen Output**

View Op:
```
::std::vector<at::Tensor> split_Tensor(c10::DispatchKeySet ks, const at::Tensor & self, int64_t split_size, int64_t dim) {

      auto self_ = at::functionalization::impl::unwrapFunctionalTensor(self);
      ::std::vector<at::Tensor> out;
      {
        at::AutoDispatchBelowFunctionalize guard;
        auto tmp_output = at::redispatch::split(ks & c10::after_func_keyset, self_, split_size, dim);
        out = at::functionalization::impl::wrapFunctionalTensor(tmp_output);
        // I'm fusing the [alias removal], [mutation removal], [add views back] passes together.
        // Later, we'll want to turn them into separate passes (since e.g. vulkan only cares about alias removal).
      }

      at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
        [split_size, dim](const at::Tensor& base, int64_t mutated_view_idx) -> at::Tensor {
          return base.split(split_size, dim)[mutated_view_idx];
        },
        [split_size, dim](const at::Tensor& base, const at::Tensor& mutated_view, int64_t mutated_view_idx) -> at::Tensor {
          return at::functionalization::impl::split_inverse(base, mutated_view, mutated_view_idx, split_size, dim);
        }
      );
      at::functionalization::impl::set_view_meta(out, self, view_meta);

      at::AutoDispatchDirectlyToNative native_guard;
      ::std::vector<at::Tensor> reference_tensor_output = at::native::split(self, split_size, dim);
      at::functionalization::impl::set_strides(out, reference_tensor_output);
      return out;

}
```

Mutation Op:
```
at::Tensor & add__Tensor(c10::DispatchKeySet ks, at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) {

      at::functionalization::impl::sync(self);
      at::functionalization::impl::sync(other);
      auto self_ = at::functionalization::impl::unwrapFunctionalTensor(self);
      auto other_ = at::functionalization::impl::unwrapFunctionalTensor(other);
      at::Tensor tmp_output;
      {
          at::AutoDispatchBelowFunctionalize guard;
          // The functionalization pass explicitly doesn't pass out= parameters to the redispatch
          tmp_output = at::redispatch::add(
            ks & c10::after_func_keyset, self_, other_, alpha);
      }

      self.replace_(tmp_output);
      at::functionalization::impl::maybe_add_update(self);
      return self;
}
```

View + Mutation Op:
```
at::Tensor & transpose_(c10::DispatchKeySet ks, at::Tensor & self, int64_t dim0, int64_t dim1) {

      at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
        [dim0, dim1](const at::Tensor& base, int64_t mutated_view_idx) -> at::Tensor {
          return base.transpose(dim0, dim1);
        },
        [dim0, dim1](const at::Tensor& base, const at::Tensor& mutated_view, int64_t mutated_view_idx) -> at::Tensor {
          return at::functionalization::impl::transpose_inverse(base, mutated_view, dim0, dim1);
        }
      );
      at::functionalization::impl::mutate_view_meta(self, view_meta);
      // See  Note [Propagating strides in the functionalization pass]
      // Directly update the sizes/strides/storage_offset fields on self using the inplace call.
      // I need the guard because I don't want the at::native kernel to end up calling more functionalization/functorch kernels.
      // Its only job is to directly compute the output size/stride/storage_offset metadata.
      at::AutoDispatchDirectlyToNative native_guard;
      at::native::transpose_(self, dim0, dim1);
      return self;

}
```

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D31942093

Pulled By: bdhirsh

fbshipit-source-id: b95598dae35dd1842fa8b1d8d1448332f3afaadf
2021-10-28 10:51:17 -07:00

1272 lines
55 KiB
Python

import os
from typing import List, Dict, Optional, Tuple, Set, Any, Union, Sequence, TypeVar
from typing_extensions import Literal
import yaml
from collections import OrderedDict, defaultdict, namedtuple
import argparse
import pathlib
import json
from dataclasses import dataclass
from tools.codegen.model import (Argument, DispatchKey, FunctionSchema,
Location, NativeFunction,
NativeFunctionsGroup, OperatorName,
BackendIndex, BackendMetadata,
OptionalType, SchemaKind, SelfArgument,
TensorOptionsArguments, Type, Variant,
is_cuda_dispatch_key,
is_generic_dispatch_key,
Tag, BaseOperatorName)
from tools.codegen.api.types import (Binding, CppSignature, CppSignatureGroup,
DispatcherSignature, NativeSignature)
from tools.codegen.api import cpp
import tools.codegen.api.dispatcher as dispatcher
import tools.codegen.api.native as native
import tools.codegen.api.meta as meta
import tools.codegen.api.structured as structured
from tools.codegen.api.translate import translate
from tools.codegen.selective_build.selector import SelectiveBuilder
from tools.codegen.utils import (
Target, concatMap, context, mapMaybe, YamlDumper, YamlLoader, FileManager, assert_never
)
from tools.codegen.context import (method_with_native_function,
native_function_manager,
with_native_function_and_indices,
with_native_function)
import tools.codegen.dest as dest
from tools.codegen.gen_functionalization_type import (
gen_functionalization_definition,
gen_functionalization_registration,
gen_functionalization_view_inverse_declaration
)
T = TypeVar('T')
# Welcome to the ATen code generator v2! The ATen code generator is
# responsible for parsing native_functions.yaml and then generating
# various generated files (e.g., TypeDefault.cpp) based on the operators
# defined in this file. This means that the code generator knows how to
# parse function schema, and then translate this into various C++ types
# and boilerplate code.
#
# Some things to know about this file when you modify it:
#
# - This file has STRICT mypy typechecking. Typecheck it with
# `mypy --config mypy-strict.ini` in the root source directory
#
# - Most of the heavy lifting lives in external modules:
# - 'model' has the data model for native_functions.yaml. The classes
# in those file represent what you see when you look at
# a native_functions.yaml
# - 'api' has conversions for how to translate JIT schema into
# the various C++ APIs that the codegen interacts with. There
# are in fact THREE different C++ APIs: the public C++ API,
# the dispatcher API, and the legacy disaptcher API. See each
# of these respective files for more information
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
#
# HELPER FUNCTIONS
#
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
# A custom loader for YAML to let us also keep track of line numbers
# of each entry in the YAML file
class LineLoader(YamlLoader):
def construct_mapping(self, node, deep=False): # type: ignore[no-untyped-def]
mapping = super().construct_mapping(node, deep=deep) # type: ignore[no-untyped-call]
# Add 1 so line numbering starts at 1
mapping['__line__'] = node.start_mark.line + 1
return mapping
_GLOBAL_PARSE_NATIVE_YAML_CACHE = {}
# Parse native_functions.yaml into a sequence of NativeFunctions and Backend Indices.
ParsedYaml = namedtuple('ParsedYaml', ['native_functions', 'backend_indices'])
def parse_native_yaml(path: str) -> ParsedYaml:
global _GLOBAL_PARSE_NATIVE_YAML_CACHE
if path not in _GLOBAL_PARSE_NATIVE_YAML_CACHE:
with open(path, 'r') as f:
es = yaml.load(f, Loader=LineLoader)
assert isinstance(es, list)
rs: List[NativeFunction] = []
bs: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]] = defaultdict(dict)
for e in es:
assert isinstance(e.get('__line__'), int), e
loc = Location(path, e['__line__'])
funcs = e.get('func')
with context(lambda: f'in {loc}:\n {funcs}'):
func, m = NativeFunction.from_yaml(e, loc)
rs.append(func)
BackendIndex.grow_index(bs, m)
error_check_native_functions(rs)
# Default dict is to prevent the codegen from barfing when we have a dispatch key that has no kernels yet.
indices: Dict[DispatchKey, BackendIndex] = defaultdict(lambda: BackendIndex(
dispatch_key=DispatchKey.Undefined, use_out_as_primary=True, external=False, index={}))
for k, v in bs.items():
# All structured in-tree operators are implemented in terms of their out operator.
indices[k] = BackendIndex(dispatch_key=k, use_out_as_primary=True, external=False, index=v)
_GLOBAL_PARSE_NATIVE_YAML_CACHE[path] = ParsedYaml(rs, indices)
return _GLOBAL_PARSE_NATIVE_YAML_CACHE[path]
# Some assertions are already performed during parsing, but those are only within a single NativeFunction.
# Assertions here are meant to be performed across NativeFunctions.
def error_check_native_functions(funcs: Sequence[NativeFunction]) -> None:
func_map: Dict[OperatorName, NativeFunction] = {}
base_func_map: Dict[BaseOperatorName, List[NativeFunction]] = defaultdict(list)
for f in funcs:
func_map[f.func.name] = f
base_func_map[f.func.name.name].append(f)
for f in funcs:
if f.structured_delegate is not None:
delegate_func = func_map[f.structured_delegate]
assert delegate_func.structured, \
f"{f.func.name} is marked as a structured_delegate pointing to " \
f"{f.structured_delegate}, but {f.structured_delegate} is not marked as structured. " \
f"Consider adding 'structured=True' to the delegated operator"
if f.tag is not None and f.tag is Tag.inplace_view:
base_name = f.func.name.name
overload_name = f.func.name.overload_name
assert base_name.inplace, \
f"{f.func.name} is marked with tag: inplace_view, but it doesn't follow the naming " \
"convention for inplace ops - the codegen expects the base name to have a trailing underscore. "
out_of_place_base_name = BaseOperatorName(base_name.base, False, base_name.dunder_method)
assert len(base_func_map[out_of_place_base_name]) > 0, \
f"{f.func.name} is marked with tag: inplace_view. The codegen expects there to be a corresponding " \
f"out-of-place view op with the name '{base_name}' and matching schema, but it didn't find one. "
def cpp_string(s: str) -> str:
"""Convert a python string into a c++ string literal """
s = s.replace('\\', '\\\\')
s = s.replace('"', '\\"')
s = s.replace('\a', '\\a')
s = s.replace('\b', '\\b')
s = s.replace('\f', '\\f')
s = s.replace('\n', '\\n')
s = s.replace('\v', '\\v')
s = s.replace('\t', '\\t')
return f'"{s}"'
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
#
# C++ CODE GENERATION
#
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
# Most functions in this section are curried: they consist of a function
# that takes some parameters (e.g., what is to be generated) which itself
# returns a function that actually maps NativeFunction to the code
# to be generated. This pattern makes it convenient to use map, concatMap
# and similar functional combinators.
def static_dispatch_keys(backend: Optional[BackendIndex]) -> List[DispatchKey]:
if backend is None:
return []
else:
return [
backend.dispatch_key,
DispatchKey.CompositeImplicitAutograd,
DispatchKey.CompositeExplicitAutograd
]
def static_dispatch_extra_headers(backend: Optional[BackendIndex], skip_tensor_include: bool = False) -> str:
if skip_tensor_include:
# See Note [Avoiding Include Cycles In Static Dispatch]
maybe_inl = '_inl'
else:
maybe_inl = ''
return '\n'.join([
f'#include <ATen/{dispatch_key}Functions{maybe_inl}.h>' for dispatch_key in static_dispatch_keys(backend)])
def static_dispatch(
f: NativeFunction, cpp_sig: CppSignature,
*, method: bool, backend_index: Optional[BackendIndex]
) -> Optional[str]:
if backend_index is None or f.manual_kernel_registration:
return None
target_sig = CppSignatureGroup.from_native_function(f, method=False, fallback_binding=False).signature
name = target_sig.name()
exprs = translate(cpp_sig.arguments(), target_sig.arguments(), method=method)
exprs_str = ', '.join(a.expr for a in exprs)
if f.structured_delegate is not None:
# TODO: for ops with structured_delegate it should check the dispatch table of
# the out variant instead. For now, these structured ops all have CPU/CUDA kernels
# so we always dispatch to the `backend`, but this could be wrong when we
# migrate math/default_backend ops to use structured delegate.
return f'return at::{backend_index.dispatch_key.lower()}::{name}({exprs_str});'
if backend_index.has_kernel(f):
return f'return at::{backend_index.dispatch_key.lower()}::{name}({exprs_str});'
elif f.has_composite_explicit_autograd_kernel:
return f'return at::{DispatchKey.CompositeExplicitAutograd.lower()}::{name}({exprs_str});'
elif f.has_composite_implicit_autograd_kernel:
return f'return at::{DispatchKey.CompositeImplicitAutograd.lower()}::{name}({exprs_str});'
return f'TORCH_CHECK(false, "Static dispatch does not support {name} for {backend_index.dispatch_key}.");'
# Generates RegisterSchema.cpp. Depending on the selector, either
# all schemas are registered, or only some are (in the case of
# selective build)
@dataclass(frozen=True)
class RegisterSchema:
selector: SelectiveBuilder
@method_with_native_function
def __call__(self, f: NativeFunction) -> Optional[str]:
if not self.selector.is_native_function_selected(f):
return None
return f'm.def({cpp_string(str(f.func))});\n'
# Generates Operators.h and Operators.cpp.
# These provide macros that, given an operator and overload name, allow users
# to access an "un-overloaded" function version of the operator. This
# is useful for extension writers who want to (1) want to decltype the operator
# and (2) don't want to worry about method-only operators.
@dataclass(frozen=True)
class ComputeOperators:
target: Union[
Literal[Target.DECLARATION],
Literal[Target.DEFINITION]
]
@method_with_native_function
def __call__(self, f: NativeFunction) -> str:
sig = DispatcherSignature.from_schema(f.func)
name = f.func.name.unambiguous_name()
call_method_name = 'call'
redispatch_method_name = 'redispatch'
if self.target is Target.DECLARATION:
# Note [The ATen Operators API]
# The ATen Operators API lives in the at::_ops namespace, and contains compile-time
# metadata about each operator + entry points into the Dispatcher.
# The C++ function, method, and redispatch API's are all implemented as wrappers
# into various bits of the structs defined here.
#
# Important characteristics about the Operators API:
# (1) It follows the Dispatcher API.
# This is kind of necessary to avoid overhead.
# For example: if it followed the C++ API, then all of the faithful C++ factory functions
# would need to wrap their arguments into TensorOptions only to unwrap them again.
# (2) Overload names are disambiguated.
# This is helpful for pytorch extenders who would like to decltype() an aten operator,
# that has overloads, e.g. decltype(at::_ops::mul_Tensor::call)
# (3) No argument defaulting is allowed.
# This is more of an implementation detail to avoid #include cycles,
# since TensorBody.h (which defines the Tensor class) needs to include this file.
# (4) manual_cpp_bindings and faithful names are not included in the API.
# This applies to stuff like __dispatch__is_complex(), and add_outf().
# These aren't "real aten ops", they're just additional functions provided by the C++ API.
# They're implemented as wrappers in Functions.h that call into the actual operators
# defined here, i.e. at::_ops::is_complex::call() and at::_ops::add_out::call().
# This means that ATEN_OP(is_complex) will not fastpath, and will go through the dispatcher.
return f"""
struct TORCH_API {name} {{
using schema = {sig.type()};
using ptr_schema = schema*;
// See Note [static constexpr char* members for windows NVCC]
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::{f.func.name.name}")
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "{f.func.name.overload_name}")
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, {cpp_string(str(f.func))})
static {sig.defn(name=call_method_name, is_redispatching_fn=False)};
static {sig.defn(name=redispatch_method_name, is_redispatching_fn=True)};
}};"""
elif self.target is Target.DEFINITION:
defns = f"""
STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name}, name, "aten::{f.func.name.name}")
STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name}, overload_name, "{f.func.name.overload_name}")
STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name}, schema_str, {cpp_string(str(f.func))})
// aten::{f.func}
static C10_NOINLINE c10::TypedOperatorHandle<{name}::schema> create_{name}_typed_handle() {{
return c10::Dispatcher::singleton()
.findSchemaOrThrow({name}::name, {name}::overload_name)
.typed<{name}::schema>();
}}
"""
for is_redispatching_fn in [False, True]:
if is_redispatching_fn:
dispatcher_exprs_str = ', '.join(['dispatchKeySet'] + [a.name for a in sig.arguments()])
dispatcher_call = 'redispatch'
method_name = f'{name}::{redispatch_method_name}'
else:
dispatcher_exprs_str = ', '.join([a.name for a in sig.arguments()])
dispatcher_call = 'call'
method_name = f'{name}::{call_method_name}'
defns += f"""
// aten::{f.func}
{sig.defn(name=method_name, is_redispatching_fn=is_redispatching_fn)} {{
static auto op = create_{name}_typed_handle();
return op.{dispatcher_call}({dispatcher_exprs_str});
}}
"""
return defns
else:
assert_never(self.target)
# Generates Function.h, which provides the functional public C++ API,
# and the scaffolding to call into the dispatcher from these functions.
@dataclass(frozen=True)
class ComputeFunction:
static_dispatch_backend_index: Optional[BackendIndex]
@method_with_native_function
def __call__(self, f: NativeFunction) -> Optional[str]:
if Variant.function not in f.variants:
return None
sig_group = CppSignatureGroup.from_native_function(f, method=False, fallback_binding=f.manual_cpp_binding)
def generate_defn(faithful: bool) -> str:
if faithful:
sig = sig_group.faithful_signature
assert sig is not None
else:
sig = sig_group.signature
# See Note [The ATen Operators API]
target_sig = DispatcherSignature.from_schema(f.func)
exprs = translate(sig.arguments(), target_sig.arguments())
exprs_str = ', '.join([e.expr for e in exprs])
static_dispatch_block = static_dispatch(f, sig, method=False, backend_index=self.static_dispatch_backend_index)
if static_dispatch_block is None:
return f"""
// aten::{f.func}
TORCH_API inline {sig.decl()} {{
return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str});
}}
"""
else:
return f"""
// aten::{f.func}
TORCH_API inline {sig.decl()} {{
{static_dispatch_block}
}}
"""
result = generate_defn(False)
if sig_group.faithful_signature is not None:
result += generate_defn(True)
return result
# Generates TensorBody.h. This file provides the object-oriented (method-based)
# public C++ API, and the scaffolding to call into the dispatcher from these functions.
@dataclass(frozen=True)
class ComputeTensorMethod:
target: Union[
Literal[Target.DECLARATION],
Literal[Target.DEFINITION]
]
static_dispatch_backend_index: Optional[BackendIndex]
@method_with_native_function
def __call__(self, f: NativeFunction) -> Optional[str]:
if Variant.method not in f.variants:
return None
assert not f.func.is_out_fn()
assert f.func.arguments.self_arg is not None
sig_group = CppSignatureGroup.from_native_function(f, method=True, fallback_binding=f.manual_cpp_binding)
if self.target is Target.DECLARATION:
result = f"{sig_group.signature.decl()} const;\n"
if sig_group.faithful_signature is not None:
result += f"{sig_group.faithful_signature.decl()} const;\n"
return result
if self.target is not Target.DEFINITION:
assert_never(self.target)
def generate_defn(faithful: bool) -> str:
if faithful:
sig = sig_group.faithful_signature
assert sig is not None
else:
sig = sig_group.signature
target_sig = DispatcherSignature.from_schema(f.func)
exprs = translate(sig.arguments(), target_sig.arguments(), method=True)
exprs_str = ', '.join([e.expr for e in exprs])
static_dispatch_block = static_dispatch(f, sig, method=True, backend_index=self.static_dispatch_backend_index)
if static_dispatch_block is None:
return f"""
// aten::{f.func}
inline {sig.defn(prefix="Tensor::")} const {{
return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str});
}}
"""
else:
return f"""
// aten::{f.func}
inline {sig.defn(prefix="Tensor::")} const {{
{static_dispatch_block}
}}
"""
result = generate_defn(faithful=False)
if sig_group.faithful_signature is not None:
result += generate_defn(faithful=True)
return result
# Generates RedispatchFunctions.h.
# This is similar to the C++ API defined in Functions.h, but provides access
# to the dispatcher's redispatch API.
@dataclass(frozen=True)
class ComputeRedispatchFunction:
@method_with_native_function
def __call__(self, f: NativeFunction) -> Optional[str]:
# We unconditionally generate function variants of the redispatch API.
# This is mainly because we can namespace functions separately, but not methods,
sig_group = CppSignatureGroup.from_native_function(f, method=False, fallback_binding=f.manual_cpp_binding)
def generate_defn(faithful: bool) -> str:
if faithful:
sig = sig_group.faithful_signature
assert sig is not None
else:
sig = sig_group.signature
target_sig = DispatcherSignature.from_schema(f.func)
exprs = translate(sig.arguments(), target_sig.arguments())
exprs_str = ', '.join(['dispatchKeySet'] + [a.expr for a in exprs])
return f"""
// aten::{f.func}
TORCH_API inline {sig.decl(is_redispatching_fn=True)} {{
return at::_ops::{f.func.name.unambiguous_name()}::redispatch({exprs_str});
}}
"""
result = generate_defn(False)
if sig_group.faithful_signature is not None:
result += generate_defn(True)
return result
# Generates ATenOpList.cpp, a runtime accessible list of all aten
# operators.
# TODO: This was historically used to help some JIT interop code
# figure out whether or not to treat aten namespace'd operators
# one way or another, we should reevaluate if this is actually needed.
@with_native_function
def compute_aten_op(f: NativeFunction) -> str:
return f'{{"aten::{f.func.name.name}", "{f.func.name.overload_name}"}},'
# Generates MetaFunctions.h
def compute_meta_function_declaration(g: NativeFunctionsGroup) -> Optional[str]:
if not g.structured:
return None
with native_function_manager(g.out):
name = meta.name(g)
args = structured.meta_arguments(g)
args_str = ', '.join(a.decl() for a in args)
parent_class = g.out.structured_inherits
if parent_class is None:
parent_class = "at::impl::MetaBase"
meta_return = "void"
precomputed = g.out.precomputed if g.structured else None
if precomputed:
# Generate the template declaration with one bool parameter for each
# precomputed element. Each parameter is true if the corresponding (in
# terms of position) precomputed element has been set.
precomputed_elements = [elem for replace_list in precomputed.replace.values() for elem in replace_list]
precomputed_template_parameters = [elem.name.upper() for elem in precomputed_elements]
precomputed_template_params_str = ", ".join(f"bool {param} = false" for param in precomputed_template_parameters)
precompute_template_decl = f"template <{precomputed_template_params_str}>"
# Generate a string containing declarations of all precomputed elements.
precomputed_elements_with_cpp_types = [
structured.argument_type(elem, binds=elem.name)
for elem in precomputed_elements
]
precomputed_elements_decl = ";\n".join(
f"{elem.cpp_type(strip_ref=True)} {elem.name}" for elem in precomputed_elements_with_cpp_types
)
# Generate "setter" methods for each precomputed element. Each method will return
# a new instance of precompute_out with the template parameter that corresponds to
# the member set by the method to true (to indicate that it has been set).
setter_methods = []
for i, elem in enumerate(precomputed_elements):
# Generate the signature. The return type will be the same
# as the type of `this` but with the template parameter
# corresponding to the element set by this method set to true.
# The assert generated below will ensure that this template
# parameter is false on the type of `this`.
return_ty_templates = ", ".join(
precomputed_template_parameters[:i] + ["true"] + precomputed_template_parameters[i + 1:]
)
return_ty = f"precompute_out<{return_ty_templates}>"
elem_cpp_ty = precomputed_elements_with_cpp_types[i].cpp_type(strip_ref=True)
signature = f"{return_ty} set_{elem.name}({elem_cpp_ty} value)"
# Generate an assert which checks that the
# template parameter corresponding to the precomputed
# element that is set by this method is false on the
# class corresponding to the object that `this` points to.
# This ensures that each element can be set only once.
assert_msg = f"\"{precomputed_elements[i].name} already set\""
assert_stmt = f"static_assert({precomputed_template_parameters[i]} == false, {assert_msg});"
# Generate the new object construction block. All state
# except the element that this method sets is copied from the
# object that `this` points to. The value for the element that
# the method sets is taken from a method parameter.
construction_stmts = []
construction_stmts.append(f"{return_ty} ret;")
for j, elem in enumerate(precomputed_elements):
if i == j:
construction_stmts.append(f"ret.{elem.name} = value;")
else:
construction_stmts.append(f"ret.{elem.name} = this->{elem.name};")
construction_stmts.append("return ret;")
construction_block = "\n".join(construction_stmts)
setter_methods.append(f"""
{signature} {{
{assert_stmt}
{construction_block}
}}
""")
setter_methods_decl = "\n".join(setter_methods)
# Meta should return an instance of the struct containing the precomputed elements.
meta_return_template_params = ", ".join(["true"] * len(precomputed_template_parameters))
# This typedef (actually a using statement) is needed so that TORCH_META_FUNC can reuse the return
# type (which has a variable number of template parameters).
meta_return_typedef = f"using meta_return_ty = precompute_out <{meta_return_template_params}>;"
meta_return = "meta_return_ty"
precomputed_decl = f"""
{precompute_template_decl}
struct TORCH_API precompute_out {{
{setter_methods_decl}
{precomputed_elements_decl};
}};"""
else:
meta_return_typedef = ""
precomputed_decl = ""
return f"""\
struct TORCH_API structured_{name} : public {parent_class} {{
{precomputed_decl}
{meta_return_typedef}
{meta_return} meta({args_str});
}};
"""
# Generates RegisterBackendSelect.cpp, a series of kernels which provide
# specialized computation of dispatch key for operator signatures which cannot
# be easily done automatically using templating.
@dataclass(frozen=True)
class ComputeBackendSelect:
target: Union[
Literal[Target.DEFINITION],
Literal[Target.REGISTRATION]
]
# Selector object to determine which operators to generate
# registration code for.
selector: SelectiveBuilder
@method_with_native_function
def __call__(self, f: NativeFunction) -> Optional[str]:
if str(f.func.name.name).endswith('_like') or str(f.func.name.name).startswith('new_'):
return None
name = native.name(f.func)
native_sig = NativeSignature(f.func)
if not any(isinstance(a.argument, TensorOptionsArguments) for a in native_sig.arguments()):
return None
if not self.selector.is_native_function_selected(f):
return None
native_tensor_args = [
a for a in native_sig.arguments()
if isinstance(a.argument, Argument) and a.argument.type.is_tensor_like()
]
dispatcher_sig = DispatcherSignature.from_schema(f.func)
sig: Union[NativeSignature, DispatcherSignature]
sig = dispatcher_sig
dispatcher_exprs = dispatcher_sig.exprs()
dispatch_key = "c10::computeDispatchKey(dtype, layout, device)"
if self.target is Target.DEFINITION:
# I don't think there's actually a good reason to generate
# these two cases differently
# The first case could probably be improved though- it calls computeDispatchKeySet(),
# which looks at TLS dispatch keys- there should not be any by the time we reach backend select.
if native_tensor_args:
tensor_args = ', '.join(a.name for a in native_tensor_args)
compute_dk = f"""\
DispatchKeySet _dk_set = c10::DispatchKeySet({dispatch_key}) | c10::detail::multi_dispatch_key_set({tensor_args});
DispatchKeySet _dk_mask = c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, DispatchKey::BackendSelect);
DispatchKeySet _dk = c10::impl::computeDispatchKeySet(_dk_set, _dk_mask);"""
else:
compute_dk = f"DispatchKeySet _dk = c10::DispatchKeySet({dispatch_key});"
return f"""\
// aten::{f.func}
C10_ALWAYS_INLINE
{sig.defn(name)} {{
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("aten::{f.func.name.name}", "{f.func.name.overload_name}")
.typed<{dispatcher_sig.type()}>();
{compute_dk}
return op.redispatch(_dk, {', '.join(a.expr for a in dispatcher_exprs)});
}}
"""
elif self.target is Target.REGISTRATION:
return f"""m.impl("aten::{f.func.name}", TORCH_FN({name}));"""
else:
assert_never(self.target)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
#
# YAML CODE GENERATION
#
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
def format_yaml(data: object) -> str:
# Ignore alias in Dumper
YamlDumper.ignore_aliases = lambda self, data: True # type: ignore[assignment]
# Support serializing OrderedDict
def dict_representer(dumper: Any, data: Any) -> Any:
return dumper.represent_dict(data.items())
YamlDumper.add_representer(OrderedDict, dict_representer) # type: ignore[no-untyped-call]
# Some yaml parsers (e.g. Haskell's) don't understand line breaks.
# width=1e9 turns off optional line breaks and improves
# the portability of the outputted yaml.
return yaml.dump(data, default_flow_style=False, Dumper=YamlDumper, width=1e9) # type: ignore[no-any-return]
# For some reason, some defaults we write to YAML are written as native
# YAML objects, rather than doing them uniformly as strings. This
# function detects those cases and converts them into native Python
# objects.
def pythonify_default(s: str) -> object:
if s == 'true':
return True
elif s == 'false':
return False
try:
return int(s)
except ValueError:
try:
return float(s)
except ValueError:
return s
# What is a dynamic type? Over time, the semantic meaning of
# dynamic type has degraded to meaninglessness (in the old days,
# it captured dtype-ness of types, but that has gone away with
# the removal of TH). These days, it's mostly the same thing as
# the C++ API argument type, except that Tensor and Tensor?
# arguments simply present as Tensor.
#
# TODO: Get rid of dynamic_type, after getting tools/autograd
# to use the new codegen framework
def dynamic_type(t: Type) -> str:
if isinstance(t, OptionalType):
return dynamic_type(t.elem)
# Note we don't use t.is_tensor_like() here because it would
# also include Tensor[]
if str(t) == 'Tensor':
return 'at::Tensor'
return cpp.argumenttype_type(t, mutable=False, binds='__placeholder__').cpp_type()
def compute_method_of_yaml(variants: Set[Variant]) -> List[str]:
# This is written out explicitly to ensure that Tensor and
# namespace are put into the list in the right order
method_of = ['Type']
if Variant.method in variants:
method_of.append('Tensor')
if Variant.function in variants:
method_of.append('namespace')
return method_of
def compute_returns_yaml(f: NativeFunction) -> Tuple[List[Dict[str, str]], Dict[str, str]]:
# Note [name and field_name]
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
# To understand name_to_field_name, we must first talk about this
# schema:
#
# lstsq.X(Tensor self, Tensor A, *, Tensor(a!) X, Tensor(b!) qr) -> (Tensor(a!) solution, Tensor(b!) QR)
#
# There is something very odd about this schema: it is an out
# variant of the function (that is to say, it will convert into
# at::lstsq_out() in the C++ API), but the names of the output
# return arguments don't match the keyword argument names of
# the inputs. It TURNS OUT that in this situation, the historical
# Declarations.yaml we want to output is this (abbreviated to
# only show relevant fields):
#
# arguments:
# ...
# - field_name: solution
# name: X
# - field_name: QR
# name: qr
# ...
#
# returns:
# - field_name: solution
# name: X
# - field_name: QR
# name: qr
#
# The name of the return fields is stored in 'field_name', and the
# name of the arguments is stored in 'name'. So when we process
# arguments, we need a way to get at the corresponding return. At
# the moment, this is most conveniently done by constructing a
# mapping from name (the argument concept) to field_name (the
# return concept) while processing return arguments, since we don't
# directly maintain this correspondence in the modeling of function
# schema itself.
#
# See also https://github.com/pytorch/pytorch/issues/43114
name_to_field_name: Dict[str, str] = {}
# Compute the returns field of the YAML entry
names = cpp.return_names(f)
returns = []
for i, (r, name) in enumerate(zip(f.func.returns, names)):
ret = {
'dynamic_type': dynamic_type(r.type),
'name': name,
'type': cpp.return_type(r).cpp_type(),
}
if r.name:
# See Note [name and field_name]
ret['field_name'] = r.name
if f.func.is_out_fn():
name_to_field_name[f.func.arguments.out[i].name] = r.name
returns.append(ret)
return returns, name_to_field_name
# arguments in yaml roughly corresponds to the public C++ API
def compute_cpp_argument_yaml(cpp_a: Binding, *, schema_order: bool, kwarg_only_set: Set[str],
out_arg_set: Set[str], name_to_field_name: Dict[str, str]) -> object:
if isinstance(cpp_a.argument, TensorOptionsArguments):
arg: Dict[str, object] = {
'annotation': None,
'dynamic_type': 'at::TensorOptions',
'is_nullable': False,
'name': cpp_a.name,
'type': cpp_a.type,
'kwarg_only': True,
}
if cpp_a.default is not None:
arg['default'] = cpp_a.default
return arg
elif isinstance(cpp_a.argument, SelfArgument):
raise AssertionError()
elif isinstance(cpp_a.argument, Argument):
return compute_argument_yaml(
cpp_a.argument, schema_order=schema_order,
kwarg_only_set=kwarg_only_set, out_arg_set=out_arg_set, name_to_field_name=name_to_field_name)
def compute_argument_yaml(a: Argument, *, schema_order: bool, kwarg_only_set: Set[str],
out_arg_set: Set[str], name_to_field_name: Dict[str, str]) -> object:
arg: Dict[str, object] = {
'annotation': str(a.annotation) if a.annotation else None,
'dynamic_type': dynamic_type(a.type),
'is_nullable': a.type.is_nullable(),
'name': a.name,
'type': cpp.argument_type(a, binds="__placeholder__").cpp_type(),
}
if a.default is not None:
arg['default'] = pythonify_default(cpp.default_expr(a.default, a.type))
if a.name in kwarg_only_set:
arg['kwarg_only'] = True
if a.name in out_arg_set:
arg['output'] = True
arg['allocate'] = True
# See Note [name and field_name]
if a.name in name_to_field_name:
arg['field_name'] = name_to_field_name[a.name]
# Historically, booleans don't get their size recorded, because it
# is already built into the cpp type (e.g., std::array<bool, 4>)
l = a.type.is_list_like()
if l is not None and l.size is not None and str(l.elem) != 'bool':
arg['size'] = l.size
return arg
@with_native_function
def compute_declaration_yaml(f: NativeFunction) -> object:
returns, name_to_field_name = compute_returns_yaml(f)
# These sets are used to conveniently test if an argument is a
# kwarg-only or out argument
kwarg_only_set = set(a.name for a in f.func.arguments.flat_kwarg_only)
out_arg_set = set(a.name for a in f.func.arguments.out)
sig_group = CppSignatureGroup.from_native_function(f, method=False, fallback_binding=False)
cpp_args = sig_group.signature.arguments()
arguments = [
compute_cpp_argument_yaml(
cpp_a, schema_order=False,
kwarg_only_set=kwarg_only_set, out_arg_set=out_arg_set, name_to_field_name=name_to_field_name)
for cpp_a in cpp_args
]
schema_order_jit_arguments = list(f.func.schema_order_arguments())
schema_order_arguments = [
compute_argument_yaml(
a, schema_order=True,
kwarg_only_set=kwarg_only_set, out_arg_set=out_arg_set, name_to_field_name=name_to_field_name)
for a in schema_order_jit_arguments
]
cpp_schema_order_types = [
# NB: method here doesn't matter
r.type for a in schema_order_jit_arguments
for r in cpp.argument(
a, method=False, cpp_no_default_args=set(), faithful=False, has_tensor_options=False)
]
cpp_returns = cpp.returns_type(f.func.returns).cpp_type()
schema_order_cpp_signature = f"{cpp_returns} ({', '.join(cpp_schema_order_types)})"
is_factory_method = any(isinstance(a.argument, TensorOptionsArguments) for a in cpp_args) \
and Variant.method not in f.variants
return OrderedDict([
('name', cpp.name(f.func)),
('operator_name', str(f.func.name.name)),
('overload_name', str(f.func.name.overload_name)),
('manual_kernel_registration', f.manual_kernel_registration),
('category_override', f.category_override if f.category_override is not None else ''),
('schema_string', f'aten::{f.func}'),
('arguments', arguments),
('schema_order_cpp_signature', schema_order_cpp_signature),
('schema_order_arguments', schema_order_arguments),
('method_of', compute_method_of_yaml(f.variants)),
('mode', 'native'),
('python_module', '' if f.python_module is None else f.python_module),
('returns', returns),
('inplace', f.func.name.name.inplace),
('is_factory_method', is_factory_method),
('abstract', f.is_abstract),
('device_guard', f.device_guard),
('with_gil', False),
('deprecated', False),
('has_math_kernel', f.has_composite_implicit_autograd_kernel),
])
# See Note [Auto generated composite kernels]
def has_autogenerated_composite_kernel(f: NativeFunction) -> bool:
return (f.structured or f.structured_delegate is not None) and \
(f.func.kind() == SchemaKind.functional or f.func.kind() == SchemaKind.inplace)
@with_native_function_and_indices
def compute_registration_declarations(f: NativeFunction, backend_indices: Dict[DispatchKey, BackendIndex]) -> str:
name = dispatcher.name(f.func)
returns_type = dispatcher.returns_type(f.func.returns).cpp_type_registration_declarations()
args = dispatcher.arguments(f.func)
args_str = ', '.join(a.no_default().decl_registration_declarations() for a in args)
comment_data : Dict[str, str] = {
'schema': f'aten::{f.func}',
# TODO: What exactly is the semantics of the 'dispatch' field?
'dispatch': str({k for k, v in backend_indices.items() if v.has_kernel(f)} != {DispatchKey.CompositeImplicitAutograd}),
'default': str(f.has_composite_kernel or has_autogenerated_composite_kernel(f))
}
return f"""{returns_type} {name}({args_str}); // {json.dumps(comment_data)}
"""
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
#
# RUN IT ALL
#
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
def get_custom_build_selector(
provided_op_registration_allowlist: Optional[List[str]],
op_selection_yaml_path: Optional[str]) -> SelectiveBuilder:
assert not (
provided_op_registration_allowlist is not None and
op_selection_yaml_path is not None), (
"Both provided_op_registration_allowlist and " +
"op_selection_yaml_path can NOT be provided at the " +
"same time.")
op_registration_allowlist: Optional[Set[str]] = None
if provided_op_registration_allowlist is not None:
op_registration_allowlist = set(provided_op_registration_allowlist)
if op_registration_allowlist is not None:
selector = SelectiveBuilder.from_legacy_op_registration_allow_list(
op_registration_allowlist,
True,
False,
)
elif op_selection_yaml_path is not None:
selector = SelectiveBuilder.from_yaml_path(op_selection_yaml_path)
else:
selector = SelectiveBuilder.get_nop_selector()
return selector
def pre_group_native_functions(
native_functions: Sequence[NativeFunction]) -> Dict[FunctionSchema, Dict[SchemaKind, NativeFunction]]:
pre_grouped_native_functions: Dict[FunctionSchema, Dict[SchemaKind, NativeFunction]] = defaultdict(dict)
for f in native_functions:
d = pre_grouped_native_functions[f.func.signature()]
assert f.func.kind() not in d
d[f.func.kind()] = f
return pre_grouped_native_functions
def get_grouped_native_functions(
native_functions: Sequence[NativeFunction]) -> Sequence[Union[NativeFunction, NativeFunctionsGroup]]:
def flatten_pre_group(d: Dict[SchemaKind, NativeFunction]) -> Sequence[Union[NativeFunction, NativeFunctionsGroup]]:
r = NativeFunctionsGroup.from_dict(d)
if r is None:
return list(d.values())
else:
return [r]
# TODO: how come ValuesView isn't a Sequence lol
pre_grouped_native_functions = pre_group_native_functions(native_functions)
return list(concatMap(flatten_pre_group, list(pre_grouped_native_functions.values())))
def main() -> None:
parser = argparse.ArgumentParser(description='Generate ATen source files')
parser.add_argument(
'-s',
'--source-path',
help='path to source directory for ATen',
default='aten/src/ATen')
parser.add_argument(
'-o',
'--output-dependencies',
help='output a list of dependencies into the given file and exit')
parser.add_argument(
'-d', '--install_dir', help='output directory',
default='build/aten/src/ATen')
parser.add_argument(
'--rocm',
action='store_true',
help='reinterpret CUDA as ROCm/HIP and adjust filepaths accordingly')
# TODO: --op_registration_whitelist will be removed when all call-sites
# for gen.py are moved over to using the operator YAML file for mobile
# custom build.
parser.add_argument(
'--op_registration_whitelist',
nargs='*',
help='filter op registrations by the whitelist (if set); '
'each item is `namespace`::`operator name` without overload name; '
'e.g.: aten::empty aten::conv2d ...')
parser.add_argument(
'--op_selection_yaml_path',
help='Provide a path to the operator selection (for custom build) YAML '
'that contains the information about the set of selected operators '
'and their categories (training, ...). Each operator is either a '
'full operator name with overload or just a bare operator name. '
'The operator names also contain the namespace prefix (e.g. aten::)')
parser.add_argument(
'--backend_whitelist',
nargs='*',
help='filter dispatch backend by the whitelist (if set), '
'e.g.: CPU CUDA QuantizedCPU ...')
parser.add_argument(
'--static_dispatch_backend',
help='generate static dispatch code for the specific backend (if set)')
parser.add_argument(
'--force_schema_registration',
action='store_true',
help='force it to generate schema-only registrations for all ops, including'
'those that are not listed on --op_registration_whitelist')
options = parser.parse_args()
selector = get_custom_build_selector(
options.op_registration_whitelist,
options.op_selection_yaml_path,
)
native_yaml_path = os.path.join(options.source_path, 'native/native_functions.yaml')
parsed_yaml = parse_native_yaml(native_yaml_path)
native_functions, backend_indices = parsed_yaml.native_functions, parsed_yaml.backend_indices
grouped_native_functions = get_grouped_native_functions(native_functions)
structured_native_functions = [g for g in grouped_native_functions if isinstance(g, NativeFunctionsGroup)]
template_dir = os.path.join(options.source_path, "templates")
# NB: It is mandatory to NOT use os.path.join here, as the install directory
# will eventually be ingested by cmake, which does not respect Windows style
# path slashes. If you switch this to use os.path.join, you'll get an error
# like:
#
# Syntax error in cmake code when parsing string
#
# C:/Jenkins/workspace/pytorch-builds/pytorch-win-ws2016-cuda9-cudnn7-py3-build/build/aten/src/ATen\core/TensorMethods.h
#
# Invalid character escape '\c'.
core_install_dir = f'{options.install_dir}/core'
pathlib.Path(core_install_dir).mkdir(parents=True, exist_ok=True)
def make_file_manager(install_dir: str) -> FileManager:
return FileManager(install_dir=install_dir, template_dir=template_dir, dry_run=options.output_dependencies)
core_fm = make_file_manager(core_install_dir)
cpu_fm = make_file_manager(options.install_dir)
cuda_fm = make_file_manager(options.install_dir)
extra_cuda_headers = '''\
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/ATenCUDAGeneral.h>
#include <ATen/cuda/CUDADevice.h>
#include <ATen/cuda/CUDAContext.h>'''
if options.rocm:
extra_cuda_headers = '''\
#include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h>
#include <ATen/hip/ATenHIPGeneral.h>
#include <ATen/hip/HIPDevice.h>
#include <ATen/hip/HIPContext.h>'''
dispatch_keys = [
DispatchKey.CPU,
DispatchKey.SparseCPU,
DispatchKey.SparseCsrCPU,
DispatchKey.MkldnnCPU,
DispatchKey.CUDA,
DispatchKey.SparseCUDA,
DispatchKey.SparseCsrCUDA,
DispatchKey.QuantizedCPU,
DispatchKey.QuantizedCUDA,
DispatchKey.CompositeImplicitAutograd,
DispatchKey.CompositeExplicitAutograd,
# Meta is a magic key: it is automatically generated for structured
# kernels
DispatchKey.Meta,
]
# Only a limited set of dispatch keys get CPUFunctions.h headers generated
# for them; this is the set
functions_keys = {
DispatchKey.CPU,
DispatchKey.CUDA,
DispatchKey.CompositeImplicitAutograd,
DispatchKey.CompositeExplicitAutograd,
DispatchKey.Meta,
}
if options.backend_whitelist:
dispatch_keys = [k for k in dispatch_keys if is_generic_dispatch_key(k) or str(k) in options.backend_whitelist]
static_dispatch_idx: Optional[BackendIndex] = None
if options.static_dispatch_backend:
static_dispatch_idx = backend_indices[DispatchKey.parse(options.static_dispatch_backend)]
for dispatch_key in dispatch_keys:
fm = cuda_fm if is_cuda_dispatch_key(dispatch_key) else cpu_fm
fm.write_with_template(f'Register{dispatch_key}.cpp', 'RegisterDispatchKey.cpp', lambda: {
'extra_cuda_headers': extra_cuda_headers if is_cuda_dispatch_key(dispatch_key) else '',
'external_backend_headers': '',
'namespaced_headers': f'#include <ATen/{dispatch_key}Functions.h>' if dispatch_key in functions_keys else '',
'DispatchKey': dispatch_key,
'dispatch_namespace': dispatch_key.lower(),
'dispatch_helpers': dest.gen_registration_helpers(backend_indices[dispatch_key]),
'dispatch_namespaced_definitions': list(concatMap(
dest.RegisterDispatchKey(
backend_indices[dispatch_key],
Target.NAMESPACED_DEFINITION,
selector,
rocm=options.rocm,
cpp_namespace='at::native',
class_method_name=None),
grouped_native_functions
)),
'dispatch_anonymous_definitions': list(concatMap(
dest.RegisterDispatchKey(
backend_indices[dispatch_key],
Target.ANONYMOUS_DEFINITION,
selector,
rocm=options.rocm,
cpp_namespace='at::native',
class_method_name=None),
grouped_native_functions
)),
'dispatch_registrations': list(concatMap(
dest.RegisterDispatchKey(
backend_indices[dispatch_key],
Target.REGISTRATION,
selector,
rocm=options.rocm,
cpp_namespace='at::native',
class_method_name=None),
grouped_native_functions
)),
})
if dispatch_key in functions_keys:
if dispatch_key in static_dispatch_keys(static_dispatch_idx):
# See Note [Avoiding Include Cycles In Static Dispatch]
inl_headers = ''
else:
inl_headers = f'#include <ATen/{dispatch_key}Functions_inl.h>'
fm.write_with_template(f'{dispatch_key}Functions.h', 'DispatchKeyFunctions.h', lambda: {
'dispatch_key': str(dispatch_key),
'inline_headers_for_nonstatic_build': inl_headers,
})
fm.write_with_template(f'{dispatch_key}Functions_inl.h', 'DispatchKeyFunctions_inl.h', lambda: {
'dispatch_namespace': dispatch_key.lower(),
'dispatch_namespaced_declarations': list(concatMap(
dest.RegisterDispatchKey(
backend_indices[dispatch_key],
Target.NAMESPACED_DECLARATION,
selector,
rocm=options.rocm,
cpp_namespace='at::native',
class_method_name=None),
grouped_native_functions
)),
})
del fm
# BackendSelect is generated specially
cpu_fm.write('RegisterBackendSelect.cpp', lambda: {
'backend_select_method_definitions':
list(mapMaybe(ComputeBackendSelect(Target.DEFINITION, selector), native_functions)),
'backend_select_function_registrations':
list(mapMaybe(ComputeBackendSelect(Target.REGISTRATION, selector), native_functions)),
})
cpu_fm.write('NativeMetaFunctions.h', lambda: {
'declarations': list(mapMaybe(compute_meta_function_declaration, structured_native_functions)),
})
schema_selector = selector
if options.force_schema_registration:
schema_selector = SelectiveBuilder.get_nop_selector()
cpu_fm.write('RegisterSchema.cpp', lambda: {
'schema_registrations': list(mapMaybe(RegisterSchema(schema_selector), native_functions)),
})
def key_func(fn: NativeFunction) -> str:
return fn.func.name.unambiguous_name()
def key_func_grouped(g: Union[NativeFunction, NativeFunctionsGroup]) -> str:
if isinstance(g, NativeFunction):
f = g
else:
f = g.functional
return key_func(f)
cpu_fm.write_sharded(
'Operators.cpp',
native_functions,
key_fn=key_func,
env_callable=lambda fn: {
'definitions': [ComputeOperators(Target.DEFINITION)(fn)]},
num_shards=5,
sharded_keys={'definitions'}
)
cpu_fm.write('Operators.h', lambda: {
'declarations': list(mapMaybe(ComputeOperators(
Target.DECLARATION), native_functions)),
})
cpu_fm.write('Functions.h', lambda: {
'static_dispatch_extra_headers': static_dispatch_extra_headers(static_dispatch_idx),
'function_definitions': list(mapMaybe(ComputeFunction(
static_dispatch_backend_index=static_dispatch_idx), native_functions)),
})
cpu_fm.write('Functions.cpp', lambda: {})
core_fm.write('TensorBody.h', lambda: {
'static_dispatch_extra_headers': static_dispatch_extra_headers(static_dispatch_idx, skip_tensor_include=True),
'tensor_method_declarations': list(mapMaybe(ComputeTensorMethod(
target=Target.DECLARATION, static_dispatch_backend_index=static_dispatch_idx), native_functions)),
'tensor_method_definitions': list(mapMaybe(ComputeTensorMethod(
target=Target.DEFINITION, static_dispatch_backend_index=static_dispatch_idx), native_functions)),
})
core_fm.write('TensorMethods.cpp', lambda: {})
cpu_fm.write('RedispatchFunctions.h', lambda: {
'function_redispatch_definitions': list(mapMaybe(ComputeRedispatchFunction(), native_functions)),
})
core_fm.write('ATenOpList.cpp', lambda: {
'aten_ops': list(mapMaybe(compute_aten_op, native_functions)),
})
cpu_fm.write('NativeFunctions.h', lambda: {
'native_function_declarations': list(concatMap(
# Convert to a set first to remove duplicate kernel names.
# Backends are allowed to repeat kernel names; only generate the declaration once!
lambda f: list(OrderedDict.fromkeys(concatMap(
lambda backend_idx:
dest.compute_native_function_declaration(f, backend_idx),
backend_indices.values()))),
grouped_native_functions)),
})
cpu_fm.write('Declarations.yaml', lambda: format_yaml([compute_declaration_yaml(f) for f in native_functions]))
cpu_fm.write('RegistrationDeclarations.h', lambda: {
'registration_declarations': [compute_registration_declarations(f, backend_indices) for f in native_functions],
})
# We need to easily map from [inplace_op_name] -> [functional_op] for the functionalization pass,
# so here I generate a mapping from every operator name to its corresponding functional NativeFunction (if it exist).
pre_grouped_d: Dict[FunctionSchema, Dict[SchemaKind, NativeFunction]] = pre_group_native_functions(native_functions)
to_functional_op: Dict[OperatorName, Optional[NativeFunction]] = {
k: v for d in [
{f.func.name: pre_grouped_d[func][SchemaKind.functional]
if SchemaKind.functional in pre_grouped_d[func].keys() else None
for f in pre_grouped_d[func].values()}
for func in pre_grouped_d.keys()]
for k, v in d.items()
}
cpu_fm.write_sharded(
'RegisterFunctionalization.cpp',
grouped_native_functions,
key_fn=key_func_grouped,
env_callable=lambda g: {
'func_definitions': list(mapMaybe(lambda f: gen_functionalization_definition(
selector, f, to_functional_op[f.func.name]),
[g] if isinstance(g, NativeFunction) else g.functions())),
'func_registrations': list(mapMaybe(lambda f: gen_functionalization_registration(
selector, f, backend_indices[DispatchKey.CompositeImplicitAutograd]),
[g] if isinstance(g, NativeFunction) else g.functions()))
},
num_shards=4,
sharded_keys={'func_definitions', 'func_registrations'}
)
cpu_fm.write('FunctionalInverses.h', lambda: {
'view_inverse_declarations': list(mapMaybe(gen_functionalization_view_inverse_declaration, native_functions))
})
if options.output_dependencies:
cpu_fm.write_outputs(options.output_dependencies)
core_fm.write_outputs(f"{options.output_dependencies}-core")
cuda_fm.write_outputs(f"{options.output_dependencies}-cuda")
if __name__ == '__main__':
main()