pytorch/tools/codegen/gen_backend_stubs.py
Brian Hirsh 9354a68e7d [codegen] split out backend-specific information from NativeFunction in the model (#57361)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/57361

Data model change in the codegen, which splits backend-specific information out of `NativeFunction`

### Overview
Currently in the codegen, native_functions.yaml has backend-specific information about each operator that is encoded directly into the data model, in the `NativeFunction` object. That's reasonable, since the native_functions.yaml is the source of truth for information about an operator, and the data model encodes that information into types.

Now that external backends can use the codegen though, that information is technically incomplete/inaccurate. In another PR, I tried patching the information on the `NativeFunction` object with the additional external information, by updating the `dispatch` entry to contain the external backend kernel name and dispatch key.

Instead, this PR tries to split out that information. The `NativeFunction` class contains all information about an operator from native_functions.yaml that's backend-independent and is known never to change regardless of what extra information backends provide. We also build up a backend "index", which is basically a mapping from [backend] -> [backend-specific-metadata]. Reading in an external backend yaml just involves updating that index with the new backend.

There were a few places where `NativeFunction` used the dispatch table directly, that I encoded as properties directly on the NativeFunction object (e.g. `is_abstract`). They were mostly around whether or not the operator has a composite kernel, which isn't something that's going to change for any external backends.

This has a few advantages:
- We can more easily re-use the existing logic in `native_function.py` and `register_dispatch_key.py` for both native and external backends, since they both involve a NativeFunction + a particular backend index
- The data in the data model will be the same regardless of how the codegen is run. Running the codegen with a new external backend doesn't change the data inside of NativeFunction or an existing backend index. It just adds a new index for that backend.
- There are several of codegen areas that don't care about backend-specific information: mostly the tracing and autograd codegen. We can reason about the codegen there more easily, knowing that backend-specific info is entirely uninvolved.

An alternative to this split would be to augment the NativeFunction objects with external backend information at the time that we create them. So the external codegen could read both native_functions.yaml and the external backend's yaml at the same time, and construct a NativeObject with a full dispatch table (including the XLA entry), and the correct setting of structured (taking into account both yamls). One disadvantage to this approach is that NativeFunction objects now contain different stuff depending on how you ran the codegen, and you have to make sure that any changes to the codegen can properly handle all the different variants.

### Data Model Changes
Removed 3 classes, which are used by the external codegen:
- ExternalBackendFunction
- ExternalBackendFunctionsGroup
- ExternalBackendMetadata

And added two new ones:
- BackendIndex
- BackendMetadata

`BackendIndex` contains any info that's specific to that backend, plus a mapping from operator names to backend specific metadata about the operator. One example of backend-specific info that's not operator-dependent is the fact that XLA prefers to implement functional kernels instead of out kernels (and so when they eventually mark an op as structured, they're going to mark the functional op and not the out op).

`BackendMetadata` contains info specific to an (operator, backend) pair. Right now, that's just (a) the name of the kernel, and (b) whether or not that operator is structured.

### Questions
I wanted to get this PR up earlier so I could get feedback, but there are a few things I want to call out:

**Dealing with `structured`.**
This PR separates out the notion of `structured` into two bits of information:
- Does [operator] have a meta() function. This is backend-agnostic, and is represented by the `structured` property on `NativeFunction`, same as before. This is used, e.g., to decide what signatures to add to `MetaFunctions.h`.
- Does [operator, backend] have an impl() function. This is backend dependent; even though technically all in-tree backends are forced to write impl() functions for an operator when we port the op to structured in native_functions.yaml, out-of-tree backends can decide to opt in independently. This is represented as a property on `BackendMetadata`. This is used in most other cases, e.g. in `RegisterDispatchKey` when we're deciding whether or not to gen a structured or unstructured wrapper.

I also baked `is_structured_dispatch_key` directly into each BackendIndex. So for operators marked "structured" in native_functions.yaml, their corresponding CPU/CUDA BackendIndex entries will be marked structured, and all others (except for potentially external backends) will not.

I ended up trying to deal with `structured` in this change since it's technically backend dependent (XLA can opt kernels into structured separately from in-tree ops), but that may have been too ambitious: it's technically not relevant until we actually add support for structured external kernels. If it's not clear that this is the right path for dealing with structured and we want to push that off, I'm fine with backing out the bits of this PR that make `structured` backend-dependent. I don't see anything *too* controversial related to structured in the change, but I tried to call out any areas in the comments

**Localizing the fact that external backends follow Dispatcher convention.**
Another thing that's sort of backend specific that I didn't totally address in this PR is the fact the fact that in-tree backends follow the Native API while external backends follow the Dispatcher API. I painted over that in `native_functions.py` by adding a helper, `kernel_signature`, that takes in a native function and gives you the "correct" signature for the specified backend- NativeSignature for in-tree backends, and DispatcherSignature for out-of-tree backends. In order to make that fully useable though, we'll need `NativeSignature` and `DispatcherSignature` to have matching interfaces. I didn't bother with that in this PR, which is why `gen_external_aten_fallbacks.py` still has a bunch of direct references to the dispatcher API. Thinking of adding it in a later PR but wanted to see if anyone has other opinions.

Maybe `is_external()` shouldn't even be a property on the BackendMetadata, and anything the codegen does that requires asking for that information should just be better abstracted away.

**Thoughts on the `BackendIndex` / `BackendMetadata` breakdown.**
One thing that's annoying right now is that to query for various pieces of metadata, you call helper functions like `backend_index.structured(f)`, which queries that particular backend and tells you if that specific NativeFunctionGroup is structured for that backend. It has to return an `Optional[bool]` though, since you have to handle the case where that operator doesn't have a kernel for that backend at all. So users of those helpers end up with a bunch of optionals that they need to unpack, even if they know at some point that the result isn't None. I think it would be easier instead to just store the NativeFunction object as a field directly on the BackendMetadata. Curious if there are any other opinions on a better way to model it though.

Test Plan: Imported from OSS

Reviewed By: navahgar

Differential Revision: D28474362

Pulled By: bdhirsh

fbshipit-source-id: 41a00821acf172467d764cb41e771e096542f661
2021-05-17 12:25:35 -07:00

209 lines
11 KiB
Python

import pathlib
import argparse
import os
import yaml
from collections import namedtuple
from typing import List, Dict, Union, Sequence, Optional
from tools.codegen.gen import FileManager, get_grouped_native_functions, parse_native_yaml
from tools.codegen.model import (BackendIndex, BackendMetadata, DispatchKey,
NativeFunction, NativeFunctionsGroup, OperatorName)
from tools.codegen.selective_build.selector import SelectiveBuilder
from tools.codegen.utils import Target, concatMap, context
import tools.codegen.dest as dest
import tools.codegen.api.dispatcher as dispatcher
try:
# use faster C loader if available
from yaml import CSafeLoader as Loader
except ImportError:
from yaml import SafeLoader as Loader # type: ignore[misc]
# Parses the external backend's yaml, and adds a new BackendIndex for the backend's dispatch key.
# Returns a Tuple of (backend_key, autograd_key, cpp_namespace, updated BackendIndex mapping)
ParsedExternalYaml = namedtuple('ParsedExternalYaml', ['backend_key', 'autograd_key', 'cpp_namespace', 'backend_indices'])
def parse_backend_yaml(
backend_yaml_path: str,
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
backend_indices: Dict[DispatchKey, BackendIndex]
) -> ParsedExternalYaml:
native_functions_map: Dict[OperatorName, NativeFunction] = {
f.func.name: f
for f in concatMap(lambda f: [f] if isinstance(f, NativeFunction) else list(f.functions()), grouped_native_functions)
}
with open(backend_yaml_path, 'r') as f:
yaml_values = yaml.load(f, Loader=Loader)
assert isinstance(yaml_values, dict)
valid_keys = ['backend', 'cpp_namespace', 'supported', 'autograd']
backend = yaml_values.pop('backend', None)
assert backend is not None, 'You must provide a value for "backend"'
cpp_namespace = yaml_values.pop('cpp_namespace', None)
assert cpp_namespace is not None, 'You must provide a value for "cpp_namespace"'
supported = yaml_values.pop('supported', [])
if supported is None:
supported = [] # Allow an empty list of supported ops
assert isinstance(supported, list), f'expected "supported" to be a list, but got: {supported} (of type {type(supported)})'
supported_autograd = yaml_values.pop('autograd', [])
assert isinstance(supported, list), f'expected "autograd" to be a list, but got: {supported_autograd}'
assert len(yaml_values.keys()) == 0, \
f'{backend_yaml_path} contains unexpected keys: {", ".join(yaml_values.keys())}. \
Only the following keys are supported: {", ".join(valid_keys)}'
def create_backend_index(backend_ops: List[str], dispatch_key: DispatchKey) -> BackendIndex:
metadata: Dict[OperatorName, BackendMetadata] = {}
for op in backend_ops:
op_name = OperatorName.parse(op)
assert op_name in native_functions_map, f"Found an invalid operator name: {op_name}"
# See Note [External Backends Follow Dispatcher API]
kernel_name = dispatcher.name(native_functions_map[op_name].func)
# TODO: allow structured external backends later.
m = BackendMetadata(kernel=kernel_name, structured=False)
metadata[op_name] = m
# TODO: currently hardcoding the fact that XLA implements out/inplace in terms of functional ops,
# this should eventually be toggleable per-backend.
return BackendIndex(dispatch_key=dispatch_key, use_out_as_primary=False, external=True, index=metadata)
backend_key: Optional[DispatchKey] = None
if len(supported) > 0:
with context(f'The provided value for "backend" must be a valid DispatchKey, but got {backend}.'):
backend_key = DispatchKey.parse(backend)
backend_idx = create_backend_index(supported, backend_key)
assert backend_key not in backend_indices
backend_indices[backend_key] = backend_idx
autograd_key: Optional[DispatchKey] = None
if len(supported_autograd) > 0:
with context(f'The "autograd" key was specified, which indicates that you would like to override \
the behavior of autograd for some operators on your backend. However "Autograd{backend}" is not a valid DispatchKey.'):
autograd_key = DispatchKey.parse(f'Autograd{backend}')
autograd_idx = create_backend_index(supported_autograd, autograd_key)
assert autograd_key not in backend_indices
backend_indices[autograd_key] = autograd_idx
for g in grouped_native_functions:
if isinstance(g, NativeFunction):
forward_kernels = [] if backend_key is None else \
[m for m in [backend_indices[backend_key].get_kernel(g)] if m is not None]
backward_kernels = [] if autograd_key is None else \
[m for m in [backend_indices[autograd_key].get_kernel(g)] if m is not None]
else:
forward_kernels = [] if backend_key is None else [m for m in [
backend_indices[backend_key].get_kernel(f) for f in g.functions()]
if m is not None]
backward_kernels = [] if autograd_key is None else [m for m in [
backend_indices[autograd_key].get_kernel(f) for f in g.functions()]
if m is not None]
forward_kernels = [f for f in forward_kernels if f is not None]
backward_kernels = [f for f in backward_kernels if f is not None]
assert len(forward_kernels) == 0 or len(backward_kernels) == 0, \
f'Currently, all variants of an op must either be registered to a backend key, or to a backend\'s \
autograd key. They cannot be mix and matched. If this is something you need, feel free to create an issue! \
{forward_kernels[0].kernel} is listed under "supported", but {backward_kernels[0].kernel} is listed under "autograd".'
return ParsedExternalYaml(backend_key, autograd_key, cpp_namespace, backend_indices)
def main() -> None:
parser = argparse.ArgumentParser(description='Generate backend stub files')
parser.add_argument(
'-s',
'--source_yaml',
help='path to source yaml file containing operator external definitions')
parser.add_argument(
'-o', '--output_dir', help='output directory')
parser.add_argument(
'--dry_run', type=bool, default=False, help='output directory')
options = parser.parse_args()
run(options.source_yaml, options.output_dir, options.dry_run)
def run(source_yaml: str, output_dir: str, dry_run: bool) -> None:
# Assumes that this file lives at PYTORCH_ROOT/tools/codegen/gen_backend_stubs.py
pytorch_root = pathlib.Path(__file__).parent.parent.parent.absolute()
template_dir = os.path.join(pytorch_root, "aten/src/ATen/templates")
def make_file_manager(install_dir: str) -> FileManager:
return FileManager(install_dir=install_dir, template_dir=template_dir, dry_run=dry_run)
fm = make_file_manager(output_dir)
native_yaml_path = os.path.join(pytorch_root, 'aten/src/ATen/native/native_functions.yaml')
parsed_yaml = parse_native_yaml(native_yaml_path)
native_functions, backend_indices = parsed_yaml.native_functions, parsed_yaml.backend_indices
grouped_native_functions = get_grouped_native_functions(native_functions)
parsed_backend_yaml = parse_backend_yaml(source_yaml, grouped_native_functions, backend_indices)
backend_key = parsed_backend_yaml.backend_key
autograd_key = parsed_backend_yaml.autograd_key
cpp_namespace = parsed_backend_yaml.cpp_namespace
backend_indices = parsed_backend_yaml.backend_indices
selector = SelectiveBuilder.get_nop_selector()
# TODO: handle cases when yaml contains zero ops properly in a later PR.
if backend_key is not None and autograd_key is not None:
backend_dispatch_key: DispatchKey = backend_key
autograd_dispatch_key: DispatchKey = autograd_key
generated_comment = 'Autogenerated file by gen_backend_stubs.py. Do not edit directly!'
fm.write('aten_xla_type.h', lambda: {
'generated_comment': generated_comment,
'cpp_namespace': cpp_namespace,
# Convert to a set first to remove duplicate kernel names.
# Backends are allowed to repeat kernel names; only generate the declaration once!
'dispatch_xla_declarations': list(set(concatMap(
lambda f: dest.compute_native_function_declaration(f, backend_indices[backend_dispatch_key]),
grouped_native_functions
))) + list(set(concatMap(
lambda f: dest.compute_native_function_declaration(f, backend_indices[autograd_dispatch_key]),
grouped_native_functions
))),
})
fm.write('aten_xla_type_default.h', lambda: {
'generated_comment': generated_comment,
'cpp_namespace': cpp_namespace,
'dispatch_aten_fallback_declarations': list(concatMap(
dest.GenExternalAtenFallback(Target.NAMESPACED_DECLARATION, backend_indices[backend_dispatch_key]),
[g for g in grouped_native_functions if not backend_indices[autograd_dispatch_key].has_kernel(g)]
)) + list(concatMap(
dest.GenExternalAtenFallback(Target.NAMESPACED_DECLARATION, backend_indices[autograd_dispatch_key]),
[g for g in grouped_native_functions if backend_indices[autograd_dispatch_key].has_kernel(g)]
)),
})
fm.write('aten_xla_type_default.cpp', lambda: {
'generated_comment': generated_comment,
'cpp_namespace': cpp_namespace,
# TODO: after cpu fallbacks are moved to a boxed kernel,
# merge registrations / definitions into RegisterDispatchKey
'dispatch_aten_fallback_definitions': list(concatMap(
dest.GenExternalAtenFallback(Target.NAMESPACED_DEFINITION, backend_indices[backend_dispatch_key]),
[g for g in grouped_native_functions if not backend_indices[autograd_dispatch_key].has_kernel(g)]
)) + list(concatMap(
dest.GenExternalAtenFallback(Target.NAMESPACED_DEFINITION, backend_indices[autograd_dispatch_key]),
[g for g in grouped_native_functions if backend_indices[autograd_dispatch_key].has_kernel(g)]
)),
'dispatch_registrations': list(concatMap(
dest.GenExternalAtenFallback(Target.REGISTRATION, backend_indices[backend_dispatch_key]),
[g for g in grouped_native_functions if not backend_indices[autograd_dispatch_key].has_kernel(g)]
)),
'dispatch_autograd_registrations': list(concatMap(
dest.GenExternalAtenFallback(Target.REGISTRATION, backend_indices[autograd_dispatch_key]),
[g for g in grouped_native_functions if backend_indices[autograd_dispatch_key].has_kernel(g)]
)),
})
if __name__ == '__main__':
main()