mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64432 Original PR description + feedback here: https://github.com/pytorch/pytorch/pull/63048 I've addressed all of the feedback in the original PR and made some pretty large changes, listed below. **Table of Contents** - Starting points - List of the main changes from the original PR - Next Steps - Example codegen output (for a view, mutation, and view+mutation op) **Starting Points** A good place to start when looking through the PR: * Alban mentioned that this is a useful mental model (thanks Ed for originally making this clear to me). Semantically, the pass currently does THREE things, which are all needed by functorch - all fused together into one big pass. * (a) alias removal, which replaces {view} calls with {view}_copy calls, and manually tracks aliasing information, so that when one tensor is mutated, we re-apply the same mutation to all of the aliases. This is the bulk of the work - once this is done, the next 2 things are trivial to implement. * (b) mutation removal, which is easy to do once we know that there are no aliases. Every mutation `a.add_(b)` becomes `a.replace_(a.add(b))` * (c) reapplying views: all of the `{view}_copy` calls are replaced with `{view}` calls again. This is an optimization that we can make specifically for functorch (and strided backends), that only care about mutation removal and not alias removal * XLA and Vulkan only want (a), or (a) + (b). Later, we'll want to split this out so that you can actually opt into different versions of this logic. * There is currently no {view}_copy replacement, because the pass just <replace views with copies> and <replace copies with views> steps have been combined. Later, we'll want to actually implement {view}_copy variants of each view operator, probably with codegen. * documentation breadcrumb 1, in `FunctionalTensorWrapper.cpp`: https://github.com/pytorch/pytorch/pull/64432/files#diff-a0bac99bf205dba5b94cb64fc2466d3d55d991887572f9cd6a02e27b3a91dd60R59 (you might have to expand the `FunctionalTensorWrapper.cpp` file, which GitHub closes by default because it's large) * documentation breadcrumb 2, in `FunctionalTensorWrapper.h`: https://github.com/pytorch/pytorch/pull/64432/files#diff-c945c71a4ccac65871f24a912e8904f9a5088b24a32e636727ea9c8fe920708aR12 * Reading through the codegen output at the bottom of this description. **Main changes from the original PR** (1) I use lambdas instead of a giant enum to handle all of the different views. This results in less boilerplate per view op (and more stuff that can be codegen'd). Every `ViewMeta` object now contains a `forward` and `reverse` lambda, that knows how to replay the view and its inverse. This makes the actual code that executes the replaying logic a lot less boilerplate-y (see `Alias::sync_update_operations` and `FunctionalTensorWrapper::sync_`) (2) Every tensor during the functionalization pass is always wrapped in a `FunctionalTensorWrapper`. This is potentially unnecessary for Vulkan/XLA, and will have a mild perf impact, but for now this PR just targets the functorch use case. I previously had a complicated design a (`FunctionalTensorImplBase` class) to avoid needing the wrapper for XLA, but it had some subtleties that are gonna require more thought to fix, so I'm pushing that off for now. (3) `FunctionalTensorWrapper` objects accurately report stride information. It's a little annoying to do this though, because the logic that calculates stride info for each view isn't easily separated from the actual view kernels in core, `at::native::{view}`. I do this by adding logic in each `at::functionalization::{view}` kernel to call the reference implementation `at::native::{view}`. I don't do anything with the output aside from taking it's size/stride/storage_offset to set the actual output tensor's size/stride/storage_offset correctly. There's another annoying part to this: I'm pretty sure that we want to pass in the actual *wrapper* tensors directly into the native kernels, not their inner unwrapped values. But there are some `at::native::{view}` kernels that call other tensor methods, which re-invokes the dispatcher, calling functionalization/functorch kernels that try do the unwrapping. To do this, right now I have an `AutoDispatchDirectlyToNative` guard that basically ensures that any tensor methods called inside of the at::native::{view} op always redispatch straight to the CPU kernel (which will be another at::native:: kernel). This feels kind of heavy handed, but I'm not sure of a better way to do it. (4) `FunctionalTensorWrapper` objects accurately report aliasing information. There's a new `FunctionalStorageImpl` class (subclass of `StorageImpl`) that allows tensors in the functionalization pass to accurately alias storage. If two tensors `a` and `b` in a functionalized program are views of one another, then `a.storage.is_alias_of(b.storage)` should return true. I added this in a pretty similar way to how meta tensors allocate storage, although I don't pass in an actual allocator (I think this is fine because you should never resize a functional tensor's storage). One thing I'm not sure about - should `FunctionalTensorWrapper` set `storage_access_should_throw_`: (a) always, (b) never, (c) only if its wrapped tensor has it set. Right now I have it not set, mostly because calling the reference view functions (`at::native::{view}`) requires looking at the storage. But that means that if you try to access storage from python in a functionalized program, you'll get silent garbage instead of an error. Related question: are we planning on exposing meta tensor storage to python in the future (even though it contains garbage)? (5) better docs :) **View operator coverage** (6) The functionalization pass now gets math-composite view ops for free. I didn't add the `Functionalize` dispatch key to the composite set, because I don't want composite ops like `torch.ones` to get decomposed before hitting the functionalization pass. Instead, I added codegen to manually register the `at::native::` kernels of composite view ops. This is a little hairy, because the names of the `at::native::` kernels aren't easily accessible. They're stored in a `Dict[DispatchKey, BackendIndex]`. I made a best-effort attempt to get each view kernel's name, basically by assuming that every view op has either a composite or cpu implementation. There's also a hardcoded list of composite view ops in `gen_inplace_or_view_type.py`, but it looks like it's wrong. This is probably worth rationalizing later, but instead I created a new list of the "complete" set of composite view ops, and preserved the old set by hardcoding the delta between the two sets. (7) I've added codegen for ops that are both views AND mutations, like `transpose_()` (why do we even have these {emoji:1f622}). From some light testing, it looks like they work correctly with one caveat: I had a hard time ensuring that functorch programs that mutate their inputs using ops like `transpose_()` preserve the input mutations after the program finishes running. For (in my corresponding functorch branch) I emit a warning when this happens, and just don't preserve the mutation (8) I added `{view}_inverse` implementations for every view op, in `FunctionalInverses.cpp`. These are needed to take mutations made to views and replay them back onto the base. To reduce boilerplate, the codegen generates function declarations for each `{view}_inverse` function, so you get a nice compiler error when someone eventually adds a new view op. The only view ops currently not supported are (a) as_strided, and (b) the sparse view ops (values()/indices()). I can add support for as_strided, but it needs an `as_strided_inverse()` function. That will look really similar to the `as_strided_backward()` function in FunctionsManual.cpp, but it has some noticeable differences: we basically want an `as_strided_embed` for autograd and `as_strided_scatter` for functionalization. We also will probably need them to be primitives w.r.t to autograd, since the currently implementation for autograd uses view().copy_() calls that XLA won't be able to handle. I'm wondering if anyone has any objections, but otherwise I can make those change (which will require writing backward formulas for `as_strided_embed` and `as_strided_scatter`). I did a bunch of manual testing that all looks pretty good, but it's definitely not fully tested. Ed pointed out that once XLA uses this pass (or at least once there's a POC), we can just run the existing xla view test suite. Hopefully that delay is okay - if it's not, maybe we can think about using OpInfos similar to how functorch uses them for testing. Note: there's some duplication with autograd's view code. Every `{view}_inverse` implementation is really similar to the implementation for that view listed in `derivatives.yaml`. There are some major differences though: * the autograd implementations over those backwards functions (like `permute_backwards()`, in `FunctionsManual.cpp`) internally call other view ops. For functoinalization, we want them to (eventually call `{view}_copy` operators). * For view ops that take a subset of the original storage, like `slice/select/diagonal/as_strided()`, the autograd backward functions fill the "spaces" in the inverse call with zeroes. For functionalizations, we want to fill them with the value of `base` at those positions. It looks like this currently applies to 6 total ops (since we can ignore composites): * select * slice * diagonal * as_stridied * split * split_with_sizes A nice end state would probably be for the autograd + functoinalization codegen to both look at the same yaml (either `derivatives.yaml`, or something else), and automatically generate the right thing. I didn't leave that in scope for this PR though. **Current State + Next Steps** There are a bunch of followups after this PR eventually lands. Roughly in order: * Use the current pass to register problematic composite ops in functorch. Also, nested `functionalize()` calls aren't supported yet (I mostly just need to remove some debug asserts and test it). * Work on freeing up dispatch key space in the by deduplicating the `{backend}`/`Autograd{backend}`/`Sparse{backend}`/`Quantized{backend}` keys * Once we have more dispatch keys, split up this pass into 3 pieces - it's currently fused, and doesn't do the right thing for vulkan/XLA. Specifically, all of the `{view}` calls in the current pass's view-replay logic should turn into `{view}_copy` calls that vulkan/XLA know how to implement, and there will be separate passes for (a) removing mutations, and (b) turning `{view}_copy` calls back into `{view}` calls. For Vulkan, we eventually want a pass that ONLY removes aliasing and view calls, and doesn't remove mutations. We can also probably make the 2 new passes user dispatch keys to save dispatch key space, if they'll only be used by functorch anyway. * Do more of a dive on perf for the vulkan/xla use cases. There are several areas to improve perf with varying levels of effort required. The simplest one that I'll probably do regardless is to codegen the out-of-place kernels instead of using a boxed fallback. Getting a POC working for xla will also be useful to test the view operator coverage. **Example Codegen Output** View Op: ``` ::std::vector<at::Tensor> split_Tensor(c10::DispatchKeySet ks, const at::Tensor & self, int64_t split_size, int64_t dim) { auto self_ = at::functionalization::impl::unwrapFunctionalTensor(self); ::std::vector<at::Tensor> out; { at::AutoDispatchBelowFunctionalize guard; auto tmp_output = at::redispatch::split(ks & c10::after_func_keyset, self_, split_size, dim); out = at::functionalization::impl::wrapFunctionalTensor(tmp_output); // I'm fusing the [alias removal], [mutation removal], [add views back] passes together. // Later, we'll want to turn them into separate passes (since e.g. vulkan only cares about alias removal). } at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta( [split_size, dim](const at::Tensor& base, int64_t mutated_view_idx) -> at::Tensor { return base.split(split_size, dim)[mutated_view_idx]; }, [split_size, dim](const at::Tensor& base, const at::Tensor& mutated_view, int64_t mutated_view_idx) -> at::Tensor { return at::functionalization::impl::split_inverse(base, mutated_view, mutated_view_idx, split_size, dim); } ); at::functionalization::impl::set_view_meta(out, self, view_meta); at::AutoDispatchDirectlyToNative native_guard; ::std::vector<at::Tensor> reference_tensor_output = at::native::split(self, split_size, dim); at::functionalization::impl::set_strides(out, reference_tensor_output); return out; } ``` Mutation Op: ``` at::Tensor & add__Tensor(c10::DispatchKeySet ks, at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) { at::functionalization::impl::sync(self); at::functionalization::impl::sync(other); auto self_ = at::functionalization::impl::unwrapFunctionalTensor(self); auto other_ = at::functionalization::impl::unwrapFunctionalTensor(other); at::Tensor tmp_output; { at::AutoDispatchBelowFunctionalize guard; // The functionalization pass explicitly doesn't pass out= parameters to the redispatch tmp_output = at::redispatch::add( ks & c10::after_func_keyset, self_, other_, alpha); } self.replace_(tmp_output); at::functionalization::impl::maybe_add_update(self); return self; } ``` View + Mutation Op: ``` at::Tensor & transpose_(c10::DispatchKeySet ks, at::Tensor & self, int64_t dim0, int64_t dim1) { at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta( [dim0, dim1](const at::Tensor& base, int64_t mutated_view_idx) -> at::Tensor { return base.transpose(dim0, dim1); }, [dim0, dim1](const at::Tensor& base, const at::Tensor& mutated_view, int64_t mutated_view_idx) -> at::Tensor { return at::functionalization::impl::transpose_inverse(base, mutated_view, dim0, dim1); } ); at::functionalization::impl::mutate_view_meta(self, view_meta); // See Note [Propagating strides in the functionalization pass] // Directly update the sizes/strides/storage_offset fields on self using the inplace call. // I need the guard because I don't want the at::native kernel to end up calling more functionalization/functorch kernels. // Its only job is to directly compute the output size/stride/storage_offset metadata. at::AutoDispatchDirectlyToNative native_guard; at::native::transpose_(self, dim0, dim1); return self; } ``` Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D31942093 Pulled By: bdhirsh fbshipit-source-id: b95598dae35dd1842fa8b1d8d1448332f3afaadf
935 lines
36 KiB
Python
935 lines
36 KiB
Python
# Generates Python bindings for ATen functions
|
|
#
|
|
# The bindings are generated as methods on python_variable or functions on the
|
|
# torch._C._nn. torch._C._fft, torch._C._linalg or torch._C._special objects.
|
|
#
|
|
|
|
# Code tries to stick to the following rules:
|
|
#
|
|
# - templates should be colocated with the functions that use them.
|
|
# no templates are currently shared between functions, but if that
|
|
# happens, maybe put the template with the first one
|
|
#
|
|
# - don't use environment dictionaries when calling template.substitute().
|
|
# pass named arguments directly for everything, otherwise it's much too
|
|
# hard to track what's actually being used and by who
|
|
#
|
|
# - colocate any new hacks/adjustments with existing ones of the same kind.
|
|
# ideally in a data structure rather than code if possible. See e.g.
|
|
# SCHEMA_DEFAULT_CONVERSION_HACKS, etc.
|
|
#
|
|
# - similarly, conversions from one format to another should ideally happen
|
|
# all at once in a single place.
|
|
#
|
|
# - no nontrivial nested functions. couple-liners are ok but please no more.
|
|
# especially avoid functions that read/write outer variables defined far away.
|
|
#
|
|
# - raise RuntimeError instead of asserting, and put as much
|
|
# information as is available into the message. I.e. no need to
|
|
# plumb in new params whose only purpose is to fill out an error
|
|
# message, but use what's there
|
|
#
|
|
|
|
from collections import defaultdict
|
|
import itertools
|
|
import re
|
|
import yaml
|
|
|
|
from .gen_trace_type import should_trace
|
|
|
|
from tools.codegen.code_template import CodeTemplate
|
|
from tools.codegen.api import cpp
|
|
from tools.codegen.api.types import CppSignatureGroup
|
|
from tools.codegen.api.python import (PythonArgument, PythonSignature,
|
|
PythonSignatureDeprecated,
|
|
PythonSignatureGroup,
|
|
PythonSignatureNativeFunctionPair,
|
|
arg_parser_output_exprs,
|
|
argument_type_str, cpp_dispatch_exprs,
|
|
cpp_dispatch_target,
|
|
dispatch_lambda_args,
|
|
dispatch_lambda_exprs,
|
|
dispatch_lambda_return_str,
|
|
has_tensor_options,
|
|
namedtuple_fieldnames, signature)
|
|
from tools.codegen.gen import cpp_string, parse_native_yaml
|
|
from tools.codegen.context import with_native_function
|
|
from tools.codegen.model import (Argument, BaseOperatorName, NativeFunction,
|
|
Type, Variant)
|
|
from tools.codegen.utils import split_name_params, YamlLoader, FileManager
|
|
|
|
from typing import Dict, Optional, List, Tuple, Set, Sequence, Callable
|
|
|
|
#
|
|
# declarations blocklist
|
|
# We skip codegen for these functions, for various reasons.
|
|
# Future PRs will categorize this list and eliminate or hoist
|
|
# them out of eager-only codegen.
|
|
# See https://github.com/pytorch/pytorch/issues/30788
|
|
#
|
|
|
|
# These functions require manual Python bindings or are not exposed to Python
|
|
_SKIP_PYTHON_BINDINGS = [
|
|
'alias', 'contiguous', 'is_cuda', 'is_sparse', 'is_sparse_csr', 'size', 'stride',
|
|
'.*_backward', '.*_backward_(out|input|weight|bias)', '.*_forward',
|
|
'.*_forward_out', '_unsafe_view', 'tensor', '_?sparse_coo_tensor.*',
|
|
'_?sparse_csr_tensor.*',
|
|
'_arange.*', '_range.*', 'linspace.*', 'logspace.*',
|
|
'_sparse_add_out', '_sparse_div.*', '_sparse_mul.*', '_sparse_sub.*', '_sparse_dense_add_out',
|
|
'index', 'unique_dim_consecutive',
|
|
'_cumsum.*', '_cumprod.*', '_sum.*', '_prod.*',
|
|
'_th_.*', '_thnn_.*',
|
|
'arange.*', 'range.*', '_solve.*', '_inverse.*',
|
|
'full(_out)?',
|
|
'_cholesky.*', '_triangular_solve.*', '_qr.*', '_symeig.*', '_svd.*',
|
|
'slice', 'randint(_out)?',
|
|
'item', '_local_scalar_dense', 'to',
|
|
'_to_copy',
|
|
'copy_sparse_to_sparse_', 'copy_',
|
|
'numpy_T', 'matrix_H', 'mT', 'mH', # these need to be an attributes in Python, not functions
|
|
'nonzero(_(out|numpy))?',
|
|
'set_data',
|
|
'.*_overrideable', # overrideable functions for backend extension
|
|
'data', 'is_leaf', 'output_nr', '_version', 'requires_grad_', 'retains_grad', 'set_',
|
|
'_fw_primal', 'fake_quantize_per_tensor_affine_cachemask',
|
|
'fake_quantize_per_channel_affine_cachemask',
|
|
'_reshape_alias',
|
|
'replace_', # only used by the functionalization pass, doesn't need to be exposed to python
|
|
]
|
|
|
|
SKIP_PYTHON_BINDINGS = list(map(lambda pattern: re.compile(rf'^{pattern}$'), _SKIP_PYTHON_BINDINGS))
|
|
|
|
# These function signatures are not exposed to Python. Note that this signature
|
|
# list does not support regex.
|
|
SKIP_PYTHON_BINDINGS_SIGNATURES = [
|
|
'add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor',
|
|
'add_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)',
|
|
'sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor',
|
|
'sub_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)',
|
|
'mul.Scalar(Tensor self, Scalar other) -> Tensor',
|
|
'mul_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)',
|
|
'div.Scalar(Tensor self, Scalar other) -> Tensor',
|
|
'div_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)',
|
|
]
|
|
|
|
@with_native_function
|
|
def should_generate_py_binding(f: NativeFunction) -> bool:
|
|
name = cpp.name(f.func)
|
|
for skip_regex in SKIP_PYTHON_BINDINGS:
|
|
if skip_regex.match(name):
|
|
return False
|
|
|
|
signature = str(f.func)
|
|
for pattern in SKIP_PYTHON_BINDINGS_SIGNATURES:
|
|
if pattern == signature:
|
|
return False
|
|
|
|
return True
|
|
|
|
def get_pycname(name: BaseOperatorName) -> str:
|
|
return f'THPVariable_{name}'
|
|
|
|
def is_noarg(overloads: Sequence[PythonSignatureNativeFunctionPair]) -> bool:
|
|
return len(overloads) == 1 and overloads[0].signature.arguments_count() == 0
|
|
|
|
def is_py_variable_method(f: NativeFunction) -> bool:
|
|
return f.python_module is None and Variant.method in f.variants
|
|
|
|
def is_py_torch_function(f: NativeFunction) -> bool:
|
|
return f.python_module is None and Variant.function in f.variants
|
|
|
|
def is_py_nn_function(f: NativeFunction) -> bool:
|
|
return f.python_module == 'nn'
|
|
|
|
def is_py_fft_function(f: NativeFunction) -> bool:
|
|
return f.python_module == 'fft'
|
|
|
|
def is_py_linalg_function(f: NativeFunction) -> bool:
|
|
return f.python_module == 'linalg'
|
|
|
|
def is_py_special_function(f: NativeFunction) -> bool:
|
|
return f.python_module == 'special'
|
|
|
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
|
#
|
|
# Main Function
|
|
#
|
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
|
|
|
def gen(out: str, native_yaml_path: str, deprecated_yaml_path: str, template_path: str) -> None:
|
|
fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
|
|
native_functions = parse_native_yaml(native_yaml_path).native_functions
|
|
native_functions = list(filter(should_generate_py_binding, native_functions))
|
|
|
|
methods = load_signatures(native_functions, deprecated_yaml_path, method=True)
|
|
create_python_bindings(
|
|
fm, methods, is_py_variable_method, None, 'python_variable_methods.cpp', method=True)
|
|
|
|
# NOTE: num_shards here must be synced with gatherTorchFunctions in
|
|
# torch/csrc/autograd/python_torch_functions_manual.cpp
|
|
functions = load_signatures(native_functions, deprecated_yaml_path, method=False)
|
|
create_python_bindings_sharded(
|
|
fm, functions, is_py_torch_function, 'torch', 'python_torch_functions.cpp',
|
|
method=False, num_shards=3)
|
|
|
|
create_python_bindings(
|
|
fm, functions, is_py_nn_function, 'torch.nn', 'python_nn_functions.cpp', method=False)
|
|
|
|
create_python_bindings(
|
|
fm, functions, is_py_fft_function, 'torch.fft', 'python_fft_functions.cpp', method=False)
|
|
|
|
create_python_bindings(
|
|
fm, functions, is_py_linalg_function, 'torch.linalg', 'python_linalg_functions.cpp', method=False)
|
|
|
|
create_python_bindings(
|
|
fm, functions, is_py_special_function, 'torch.special', 'python_special_functions.cpp', method=False)
|
|
|
|
def group_filter_overloads(
|
|
pairs: Sequence[PythonSignatureNativeFunctionPair],
|
|
pred: Callable[[NativeFunction], bool]
|
|
) -> Dict[BaseOperatorName, List[PythonSignatureNativeFunctionPair]]:
|
|
grouped: Dict[BaseOperatorName, List[PythonSignatureNativeFunctionPair]] = defaultdict(list)
|
|
for pair in pairs:
|
|
if pred(pair.function):
|
|
grouped[pair.function.func.name.name].append(pair)
|
|
return grouped
|
|
|
|
def create_python_bindings(
|
|
fm: FileManager,
|
|
pairs: Sequence[PythonSignatureNativeFunctionPair],
|
|
pred: Callable[[NativeFunction], bool],
|
|
module: Optional[str],
|
|
filename: str,
|
|
*,
|
|
method: bool,
|
|
) -> None:
|
|
"""Generates Python bindings to ATen functions"""
|
|
py_methods: List[str] = []
|
|
py_method_defs: List[str] = []
|
|
py_forwards: List[str] = []
|
|
|
|
grouped = group_filter_overloads(pairs, pred)
|
|
|
|
for name in sorted(grouped.keys(), key=lambda x: str(x)):
|
|
overloads = grouped[name]
|
|
py_methods.append(method_impl(name, module, overloads, method=method))
|
|
py_method_defs.append(method_def(name, module, overloads, method=method))
|
|
py_forwards.extend(forward_decls(name, overloads, method=method))
|
|
|
|
fm.write_with_template(filename, filename, lambda: {
|
|
'generated_comment': '@' + f'generated from {fm.template_dir}/{filename}',
|
|
'py_forwards': py_forwards,
|
|
'py_methods': py_methods,
|
|
'py_method_defs': py_method_defs,
|
|
})
|
|
|
|
def create_python_bindings_sharded(
|
|
fm: FileManager,
|
|
pairs: Sequence[PythonSignatureNativeFunctionPair],
|
|
pred: Callable[[NativeFunction], bool],
|
|
module: Optional[str],
|
|
filename: str,
|
|
*,
|
|
method: bool,
|
|
num_shards: int
|
|
) -> None:
|
|
"""Generates Python bindings to ATen functions"""
|
|
grouped = group_filter_overloads(pairs, pred)
|
|
|
|
def key_func(kv: Tuple[BaseOperatorName, List[PythonSignatureNativeFunctionPair]]) -> str:
|
|
return str(kv[0])
|
|
|
|
def env_func(
|
|
kv: Tuple[BaseOperatorName, List[PythonSignatureNativeFunctionPair]]
|
|
) -> Dict[str, List[str]]:
|
|
return {
|
|
'py_forwards': list(forward_decls(kv[0], kv[1], method=method)),
|
|
'py_methods': [method_impl(kv[0], module, kv[1], method=method)],
|
|
'py_method_defs': [method_def(kv[0], module, kv[1], method=method)],
|
|
}
|
|
|
|
fm.write_sharded(
|
|
filename,
|
|
grouped.items(),
|
|
base_env={
|
|
'generated_comment':
|
|
'@' + f'generated from {fm.template_dir}/{filename}',
|
|
},
|
|
key_fn=key_func,
|
|
env_callable=env_func,
|
|
num_shards=num_shards,
|
|
sharded_keys={'py_forwards', 'py_methods', 'py_method_defs'}
|
|
)
|
|
|
|
def load_signatures(
|
|
native_functions: List[NativeFunction],
|
|
deprecated_yaml_path: str,
|
|
*,
|
|
method: bool,
|
|
skip_deprecated: bool = False,
|
|
pyi: bool = False,
|
|
) -> Sequence[PythonSignatureNativeFunctionPair]:
|
|
|
|
@with_native_function
|
|
def gen_signature_pairs(f: NativeFunction) -> PythonSignatureNativeFunctionPair:
|
|
return PythonSignatureNativeFunctionPair(
|
|
signature=signature(f, method=method, pyi=pyi),
|
|
function=f,
|
|
)
|
|
|
|
pairs = list(map(gen_signature_pairs, native_functions))
|
|
deprecated = load_deprecated_signatures(pairs, deprecated_yaml_path, method=method, pyi=pyi)
|
|
return pairs if skip_deprecated else pairs + deprecated
|
|
|
|
def load_deprecated_signatures(
|
|
pairs: Sequence[PythonSignatureNativeFunctionPair],
|
|
deprecated_yaml_path: str,
|
|
*,
|
|
method: bool,
|
|
pyi: bool,
|
|
) -> List[PythonSignatureNativeFunctionPair]:
|
|
# The deprecated.yaml doesn't have complete type information, we need
|
|
# find and leverage the original ATen signature (to which it delegates
|
|
# the call) to generate the full python signature.
|
|
# We join the deprecated and the original signatures using type-only form.
|
|
|
|
# native function -> type-only signature
|
|
@with_native_function
|
|
def signature_original(f: NativeFunction) -> str:
|
|
# remove inplace suffix but keep outplace suffix
|
|
opname = str(f.func.name.name.base)
|
|
if f.func.is_out_fn():
|
|
opname += '_out'
|
|
if f.func.name.name.inplace and pyi:
|
|
opname += '_'
|
|
args = CppSignatureGroup.from_native_function(f, method=False).signature.arguments()
|
|
# Simply ignore TensorOptionsArguments as it does not exist in deprecated.yaml.
|
|
types = ', '.join(argument_type_str(a.argument.type)
|
|
for a in args if isinstance(a.argument, Argument))
|
|
return f'{opname}({types})'
|
|
|
|
# deprecated -> type-only native signature (according to the call order)
|
|
def signature_deprecated(opname: str, params: List[str], call_args: List[str]) -> str:
|
|
# create a mapping of parameter name to parameter type
|
|
types: Dict[str, str] = {}
|
|
for param in params:
|
|
if param == '*':
|
|
continue
|
|
type, name = param.split(' ')
|
|
types[name] = type
|
|
# if the name in the call is not in the parameter list, assume it's
|
|
# a literal Scalar
|
|
rearranged_types = ', '.join(types.get(arg, 'Scalar') for arg in call_args)
|
|
return f'{opname}({rearranged_types})'
|
|
|
|
# group the original ATen signatures by type-only signature
|
|
grouped: Dict[str, List[PythonSignatureNativeFunctionPair]] = defaultdict(list)
|
|
for pair in pairs:
|
|
grouped[signature_original(pair.function)].append(pair)
|
|
|
|
# find matching original signatures for each deprecated signature
|
|
results: List[PythonSignatureNativeFunctionPair] = []
|
|
|
|
with open(deprecated_yaml_path, 'r') as f:
|
|
deprecated_defs = yaml.load(f, Loader=YamlLoader)
|
|
|
|
for deprecated in deprecated_defs:
|
|
_, params = split_name_params(deprecated['name'])
|
|
aten_name, call_args = split_name_params(deprecated['aten'])
|
|
|
|
for pair in grouped[signature_deprecated(aten_name, params, call_args)]:
|
|
# It uses the types from the original ATen declaration, but the
|
|
# ordering and parameter names from the deprecated overload. Any
|
|
# default parameter values from the original ATen declaration are
|
|
# ignored.
|
|
# Deprecated signature might reorder input_args and input_kwargs,
|
|
# but never changes output_args nor TensorOptions (if any?),
|
|
# so here we only look into these two types of args.
|
|
python_sig = pair.signature
|
|
src_args: Dict[str, PythonArgument] = {a.name: PythonArgument(
|
|
name=a.name,
|
|
type=a.type,
|
|
default=None,
|
|
default_init=None,
|
|
) for a in itertools.chain(python_sig.input_args, python_sig.input_kwargs)}
|
|
|
|
args: List[str] = []
|
|
input_args: List[PythonArgument] = []
|
|
input_kwargs: List[PythonArgument] = []
|
|
|
|
kwarg_only = False
|
|
for param in params:
|
|
if param == '*':
|
|
kwarg_only = True
|
|
continue
|
|
_, param_name = param.split(' ')
|
|
args.append(param_name)
|
|
|
|
if param_name not in src_args:
|
|
# output argument
|
|
continue
|
|
|
|
if not kwarg_only:
|
|
if not method or param_name != 'self':
|
|
input_args.append(src_args[param_name])
|
|
else:
|
|
input_kwargs.append(src_args[param_name])
|
|
|
|
results.append(PythonSignatureNativeFunctionPair(
|
|
signature=PythonSignatureDeprecated(
|
|
name=python_sig.name,
|
|
input_args=tuple(input_args),
|
|
input_kwargs=tuple(input_kwargs),
|
|
output_args=python_sig.output_args,
|
|
tensor_options_args=python_sig.tensor_options_args,
|
|
method=python_sig.method,
|
|
deprecated_args_names=tuple(args),
|
|
deprecated_args_exprs=tuple(call_args),
|
|
returns=python_sig.returns,
|
|
),
|
|
function=pair.function,
|
|
))
|
|
|
|
return results
|
|
|
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
|
#
|
|
# Named Tuple Codegen
|
|
#
|
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
|
|
|
@with_native_function
|
|
def gen_namedtuple_typename_key(f: NativeFunction) -> str:
|
|
name = cpp.name(f.func)
|
|
fieldnames = namedtuple_fieldnames(f.func.returns)
|
|
return '_'.join([name] + fieldnames)
|
|
|
|
def emit_namedtuple_typedefs(
|
|
overloads: Sequence[PythonSignatureNativeFunctionPair]
|
|
) -> Tuple[List[str], Dict[str, str]]:
|
|
"""
|
|
Generate block of named tuple type def inits, and add typeref snippets
|
|
to declarations that use them
|
|
"""
|
|
flddefnames: Dict[str, str] = {} # map from unique field name lists to field def name
|
|
flddefs: List[str] = [] # field def declarations
|
|
typenames: Dict[str, str] = {} # map from unique name + field name lists to typedef name
|
|
typedefs: List[str] = [] # typedef declarations and init code
|
|
|
|
for overload in overloads:
|
|
fieldnames = namedtuple_fieldnames(overload.function.func.returns)
|
|
if not fieldnames:
|
|
continue
|
|
|
|
fn_key = '_'.join(fieldnames)
|
|
fieldsname = flddefnames.get(fn_key)
|
|
if fieldsname is None:
|
|
fieldsname = f'NamedTuple_fields{"" if not flddefs else len(flddefs)}'
|
|
flddefnames[fn_key] = fieldsname
|
|
fields = ', '.join(f'{{"{fn}", ""}}' for fn in fieldnames)
|
|
flddefs.append(f"""\
|
|
static PyStructSequence_Field {fieldsname}[] = {{ {fields}, {{nullptr}} }};
|
|
""")
|
|
|
|
name = cpp.name(overload.function.func) # use @with_native_function?
|
|
tn_key = gen_namedtuple_typename_key(overload.function)
|
|
typename = typenames.get(tn_key)
|
|
if typename is None:
|
|
typename = f'NamedTuple{"" if not typedefs else len(typedefs)}'
|
|
typenames[tn_key] = typename
|
|
typedefs.append(f"""\
|
|
static PyTypeObject {typename};
|
|
static bool {typename}_initialized = false;
|
|
if (!{typename}_initialized) {{
|
|
{typename}_initialized = true;
|
|
static PyStructSequence_Desc desc = {{ "torch.return_types.{name}", nullptr, {fieldsname}, {len(fieldnames)} }};
|
|
PyStructSequence_InitType(&{typename}, &desc);
|
|
{typename}.tp_repr = (reprfunc)torch::utils::returned_structseq_repr;
|
|
}}
|
|
""")
|
|
|
|
return flddefs + typedefs, typenames
|
|
|
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
|
#
|
|
# Method Impl Codegen
|
|
#
|
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
|
|
|
# python binding for all overloads of a particular function/method
|
|
PY_VARIABLE_METHOD_VARARGS = CodeTemplate(r"""\
|
|
// ${name}
|
|
static PyObject * ${pycname}(PyObject* self_, PyObject* args, PyObject* kwargs)
|
|
{
|
|
${method_header}
|
|
static PythonArgParser parser({
|
|
${signatures}
|
|
}, /*traceable=*/${traceable});
|
|
|
|
ParsedArgs<${max_args}> parsed_args;
|
|
auto _r = parser.parse(${self_}, args, kwargs, parsed_args);
|
|
${check_has_torch_function}
|
|
switch (_r.idx) {
|
|
${dispatch}
|
|
}
|
|
${method_footer}
|
|
}
|
|
|
|
""")
|
|
|
|
# handler for a single parsed signature - may be a single overload or
|
|
# a pair of overloads that whose signatures only differ in output params
|
|
# (plugged into PY_VARIABLE_METHOD_VARARGS as an item in ${dispatch})
|
|
PY_VARIABLE_CASE = CodeTemplate("""\
|
|
case ${overload_index}: {
|
|
${body}
|
|
}
|
|
""")
|
|
|
|
# python binding for single-overload function/method
|
|
PY_VARIABLE_METHOD_VARARGS_SINGLETON = CodeTemplate("""\
|
|
// ${name}
|
|
static PyObject * ${pycname}(PyObject* self_, PyObject* args, PyObject* kwargs)
|
|
{
|
|
${method_header}
|
|
static PythonArgParser parser({
|
|
${signatures}
|
|
}, /*traceable=*/${traceable});
|
|
|
|
ParsedArgs<${max_args}> parsed_args;
|
|
auto _r = parser.parse(${self_}, args, kwargs, parsed_args);
|
|
${check_has_torch_function}
|
|
${dispatch}
|
|
${method_footer}
|
|
}
|
|
|
|
""")
|
|
|
|
# python binding for a method with no args, shortcuts parsing
|
|
PY_VARIABLE_METHOD_NOARGS = CodeTemplate("""\
|
|
// ${name}
|
|
static PyObject * ${pycname}(PyObject* self_, PyObject* args)
|
|
{
|
|
${method_header}
|
|
${check_has_torch_function}
|
|
${dispatch}
|
|
${method_footer}
|
|
}
|
|
|
|
""")
|
|
|
|
def method_impl(
|
|
name: BaseOperatorName,
|
|
module: Optional[str],
|
|
overloads: Sequence[PythonSignatureNativeFunctionPair],
|
|
*,
|
|
method: bool
|
|
) -> str:
|
|
"""
|
|
Generate a python binding for all overloads of an op.
|
|
"""
|
|
pycname = get_pycname(name)
|
|
noarg = is_noarg(overloads)
|
|
namedtuple_inits, namedtuple_typenames = emit_namedtuple_typedefs(overloads)
|
|
|
|
method_header = ['HANDLE_TH_ERRORS']
|
|
method_header += namedtuple_inits
|
|
method_header += [
|
|
"const Tensor& self = THPVariable_Unpack(self_);"
|
|
] if method else []
|
|
|
|
method_footer = ([] if noarg else ['Py_RETURN_NONE;']) + ['END_HANDLE_TH_ERRORS']
|
|
|
|
traceable = 'true' if all(should_trace(o.function) for o in overloads) else 'false'
|
|
|
|
grouped_overloads: Sequence[PythonSignatureGroup] = group_overloads(overloads)
|
|
is_singleton = len(grouped_overloads) == 1
|
|
signatures: List[str] = []
|
|
dispatch: List[str] = []
|
|
for overload_index, overload in enumerate(grouped_overloads):
|
|
signature = overload.signature.signature_str()
|
|
signatures.append(f'{cpp_string(str(signature))},')
|
|
dispatch_body = emit_dispatch_case(overload, namedtuple_typenames)
|
|
dispatch.append(
|
|
PY_VARIABLE_CASE.substitute(overload_index=overload_index, body=dispatch_body)
|
|
if not is_singleton else dispatch_body)
|
|
|
|
if noarg:
|
|
template = PY_VARIABLE_METHOD_NOARGS
|
|
elif is_singleton:
|
|
template = PY_VARIABLE_METHOD_VARARGS_SINGLETON
|
|
else:
|
|
template = PY_VARIABLE_METHOD_VARARGS
|
|
|
|
return template.substitute(
|
|
name=name,
|
|
pycname=pycname,
|
|
method_header=method_header,
|
|
max_args=max(map(lambda o: o.signature.arguments_count(), overloads)),
|
|
signatures=signatures,
|
|
traceable=traceable,
|
|
check_has_torch_function=gen_has_torch_function_check(
|
|
name=name,
|
|
module=module,
|
|
noarg=noarg,
|
|
method=method,
|
|
),
|
|
dispatch=dispatch,
|
|
method_footer=method_footer,
|
|
self_="self_" if method else "nullptr",
|
|
)
|
|
|
|
def gen_has_torch_function_check(
|
|
name: BaseOperatorName, module: Optional[str], *, noarg: bool, method: bool
|
|
) -> str:
|
|
if noarg:
|
|
if method:
|
|
return f"""\
|
|
if(check_has_torch_function(self_)) {{
|
|
return handle_torch_function(self_, "{name}");
|
|
}}
|
|
"""
|
|
else:
|
|
return ''
|
|
|
|
self_ = "self_" if method else "nullptr"
|
|
namespace = {
|
|
"torch": "THPVariableFunctionsModule",
|
|
"torch.nn": "THPNNVariableFunctionsModule",
|
|
"torch.fft": "THPFFTVariableFunctionsModule",
|
|
"torch.linalg": "THPLinalgVariableFunctionsModule",
|
|
"torch.special": "THPSpecialVariableFunctionsModule",
|
|
}[module] if module else "THPVariableClass"
|
|
|
|
return f"""\
|
|
if(_r.has_torch_function()) {{
|
|
return handle_torch_function(_r, {self_}, args, kwargs, {namespace}, "{module or "torch.Tensor"}");
|
|
}}
|
|
"""
|
|
|
|
# handler for output/no-output overload pair
|
|
PY_VARIABLE_OUT = CodeTemplate("""\
|
|
if (_r.isNone(${out_idx})) {
|
|
${call_dispatch}
|
|
} else {
|
|
${call_dispatch_out}
|
|
}
|
|
""")
|
|
|
|
def emit_dispatch_case(
|
|
overload: PythonSignatureGroup,
|
|
namedtuple_typenames: Dict[str, str],
|
|
) -> str:
|
|
"""
|
|
Emit dispatch code for a single parsed signature. This corresponds to either
|
|
a single native function, or a pair that differ only in output params. In the
|
|
latter case, a single python signature is used for both and dispatching
|
|
switches on the presence/absence of passed output args.
|
|
"""
|
|
if overload.outplace is not None:
|
|
# dispatch output and no-output variants, branch on _r.isNone(<out_idx>)
|
|
return PY_VARIABLE_OUT.substitute(
|
|
out_idx=overload.signature.output_idx(),
|
|
call_dispatch=emit_single_dispatch(
|
|
overload.signature, overload.base, namedtuple_typenames),
|
|
call_dispatch_out=emit_single_dispatch(
|
|
overload.signature, overload.outplace, namedtuple_typenames),
|
|
)
|
|
else:
|
|
# no-output version only
|
|
return emit_single_dispatch(
|
|
overload.signature, overload.base, namedtuple_typenames)
|
|
|
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
|
#
|
|
# Forward Declarations Codegen
|
|
#
|
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
|
|
|
def forward_decls(
|
|
name: BaseOperatorName,
|
|
overloads: Sequence[PythonSignatureNativeFunctionPair],
|
|
*,
|
|
method: bool
|
|
) -> Tuple[str, ...]:
|
|
if method:
|
|
return ()
|
|
|
|
pycname = get_pycname(name)
|
|
if is_noarg(overloads):
|
|
return (f"""\
|
|
static PyObject * {pycname}(PyObject* self_, PyObject* args);
|
|
""",)
|
|
else:
|
|
return (f"""\
|
|
static PyObject * {pycname}(PyObject* self_, PyObject* args, PyObject* kwargs);
|
|
""",)
|
|
|
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
|
#
|
|
# Method Def (Binding Table Entry) Codegen
|
|
#
|
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
|
|
|
def method_def(
|
|
name: BaseOperatorName,
|
|
module: Optional[str],
|
|
overloads: Sequence[PythonSignatureNativeFunctionPair],
|
|
*,
|
|
method: bool
|
|
) -> str:
|
|
"""
|
|
Generate method def entry.
|
|
"""
|
|
pycname = get_pycname(name)
|
|
|
|
if is_noarg(overloads):
|
|
pyfunc_cast = ''
|
|
flags = 'METH_NOARGS' if method else 'METH_VARARGS | METH_KEYWORDS'
|
|
else:
|
|
pyfunc_cast = 'castPyCFunctionWithKeywords'
|
|
flags = 'METH_VARARGS | METH_KEYWORDS'
|
|
|
|
if module == "torch":
|
|
flags += ' | METH_STATIC'
|
|
|
|
if name.dunder_method:
|
|
# PyMethodDef entry for binary op, throws not implemented error
|
|
return f"""\
|
|
{{"{name}", {pyfunc_cast}(TypeError_to_NotImplemented_<{pycname}>), {flags}, NULL}},"""
|
|
else:
|
|
# PyMethodDef entry
|
|
return f"""\
|
|
{{"{name}", {pyfunc_cast}({pycname}), {flags}, NULL}},"""
|
|
|
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
|
#
|
|
# Overload Sorting and Grouping
|
|
#
|
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
|
|
|
def group_overloads(
|
|
overloads: Sequence[PythonSignatureNativeFunctionPair],
|
|
) -> Sequence[PythonSignatureGroup]:
|
|
bases: Dict[str, PythonSignatureNativeFunctionPair] = {}
|
|
outplaces: Dict[str, PythonSignatureNativeFunctionPair] = {}
|
|
|
|
# first group by signature ignoring out arguments
|
|
for overload in overloads:
|
|
sig = overload.signature.signature_str(skip_outputs=True)
|
|
if overload.function.func.is_out_fn():
|
|
if sig in outplaces:
|
|
raise RuntimeError(
|
|
f'Found duplicated function definition:\n- {overload.function.func}.\n'
|
|
f'Existing definition:\n- {outplaces[sig].function.func}.'
|
|
)
|
|
outplaces[sig] = overload
|
|
else:
|
|
if sig in bases:
|
|
raise RuntimeError(
|
|
f'Found duplicated function definition:\n- {overload.function.func}.\n'
|
|
f'Existing definition:\n- {bases[sig].function.func}.'
|
|
)
|
|
bases[sig] = overload
|
|
|
|
for sig, out in outplaces.items():
|
|
if sig not in bases:
|
|
candidates: List[str] = []
|
|
for overload in overloads:
|
|
if str(overload.function.func.name.name) == str(out.function.func.name.name) \
|
|
and not overload.function.func.is_out_fn() \
|
|
and not overload.signature.deprecated:
|
|
candidates.append(overload.signature.signature_str(skip_outputs=True))
|
|
out_sig = out.signature.signature_str()
|
|
raise RuntimeError(
|
|
f'While identifying overloads, we found an out schema {out_sig} without a corresponding non-out variant. '
|
|
f'We expected the non-out variant to have schema: \n- {sig}\nPlease check that you spelled the schema '
|
|
'correctly in native_functions.yaml. We discovered the following candidate(s): \n'
|
|
+ '\n'.join(f'- {candidate}' for candidate in candidates))
|
|
|
|
grouped: List[PythonSignatureGroup] = []
|
|
for sig, base in bases.items():
|
|
outplace = outplaces.get(sig)
|
|
grouped.append(PythonSignatureGroup(
|
|
# prefer the signature with optional out=... arguments because it's the
|
|
# superset that can be used to parse input for both base and outplace.
|
|
signature=outplace.signature if outplace is not None else base.signature,
|
|
base=base.function,
|
|
outplace=outplace.function if outplace is not None else None,
|
|
))
|
|
|
|
return sort_overloads(grouped)
|
|
|
|
# This function declares a partial order on declarations, and sorts them according
|
|
# to its linear extension. This is necessary, because there's some ambiguity in the
|
|
# choice of overload, and we want a different order.
|
|
#
|
|
# See Note[Order of overloads matters]
|
|
#
|
|
# A few examples of ambiguous python signature pairs.
|
|
#
|
|
# All parameters have the same type, except one taking Tensor the other taking
|
|
# Scalar. A numeric PyObject can be casted into Tensor, and a zero-dim Tensor
|
|
# object can be accepted as Scalar type parameter (see python_arg_parser.cpp).
|
|
# Therefore, same input arguments might be accepted by either python signature.
|
|
# We want to always parse the one taking Tensor first.
|
|
#
|
|
# bitwise_and(Tensor input, Tensor other, *, Tensor out=None)
|
|
# bitwise_and(Tensor input, Scalar other, *, Tensor out=None)
|
|
#
|
|
# If they have different number of parameters then they are not ambiguous - but
|
|
# the difference on output param can be ignored as it's optional.
|
|
#
|
|
# multiply(Tensor input, Tensor other, *, Tensor out=None)
|
|
# multiply(Tensor input, Scalar other)
|
|
#
|
|
# Both positional args and keyword-only args are considered together.
|
|
#
|
|
# subtract(Tensor other, *, Scalar alpha=1)
|
|
# subtract(Scalar other, Scalar alpha=1)
|
|
#
|
|
# A few ambiguous cases which it does NOT handle yet.
|
|
#
|
|
# If there is any difference in other parameters besides the Tensor/Scalar
|
|
# difference, then they are not considered ambiguous by this method anymore.
|
|
# However, the difference could be too trivial to disambiguate.
|
|
#
|
|
# foo(Tensor input, Scalar other, Scalar bar)
|
|
# foo(Tensor input, Tensor other, double bar)
|
|
#
|
|
# If they are taking different number of parameters then they are not considered
|
|
# ambiguous anymore, even if the difference is only on optional kwargs.
|
|
#
|
|
# foo(Scalar other, Scalar alpha=1)
|
|
# foo(Tensor other, *, Scalar alpha=1, Scalar beta=1)
|
|
#
|
|
|
|
def sort_overloads(
|
|
grouped_overloads: Sequence[PythonSignatureGroup]
|
|
) -> Sequence[PythonSignatureGroup]:
|
|
|
|
def is_arg_smaller(t1: Type, t2: Type) -> bool:
|
|
return (str(t1) == 'Scalar' and str(t2) == 'Tensor' or
|
|
'Dimname' in str(t1) and 'Dimname' not in str(t2) or
|
|
# In the discussion https://github.com/pytorch/pytorch/issues/54555 it has been
|
|
# discussed why it is important to prioritize int/int? over int[]
|
|
str(t1) == 'int[]' and (str(t2) == 'int' or str(t2) == 'int?') or
|
|
# TensorList currently throws an error during argument parsing, that's why it needs to be
|
|
# last in signature ordering. See discussion: https://github.com/pytorch/pytorch/issues/58087
|
|
str(t1) == 'Tensor[]' and str(t2).find("[]") != -1)
|
|
|
|
|
|
def is_smaller(s1: PythonSignature, s2: PythonSignature) -> bool:
|
|
"""Returns True if s1 < s2 in the partial order."""
|
|
args1, args2 = s1.arguments(skip_outputs=True), s2.arguments(skip_outputs=True)
|
|
if len(args1) != len(args2):
|
|
return False
|
|
# TODO: should use some canonical form instead of 'str(arg.type)' - see comments
|
|
# above. The old codegen used the deprecated 'dynamic_type(arg.type)', which
|
|
# ignores the optional annotation, i.e. 'Scalar' and 'Scalar?'.
|
|
equal = all(arg1.type == arg2.type for arg1, arg2 in zip(args1, args2))
|
|
smaller_or_equal = all(str(arg1.type) == str(arg2.type)
|
|
or is_arg_smaller(arg1.type, arg2.type)
|
|
for arg1, arg2 in zip(args1, args2))
|
|
return smaller_or_equal and not equal
|
|
|
|
# First sort by signature
|
|
grouped_overloads = sorted(grouped_overloads, key=lambda x: x.signature.signature_str())
|
|
|
|
# Construct the relation graph
|
|
larger_than: Dict[int, Set[int]] = defaultdict(set)
|
|
for i1, overload1 in enumerate(grouped_overloads):
|
|
for i2, overload2 in enumerate(grouped_overloads):
|
|
if is_smaller(overload1.signature, overload2.signature):
|
|
larger_than[i1].add(i2)
|
|
|
|
if not larger_than:
|
|
return list(grouped_overloads)
|
|
|
|
# Use a topological sort to sort overloads according to the partial order.
|
|
N = len(grouped_overloads)
|
|
sorted_ids: List[int] = list(filter(lambda x: x not in larger_than, range(N)))
|
|
|
|
for idx in range(N):
|
|
# The size of sorted_ids will grow to N eventually.
|
|
i = sorted_ids[idx]
|
|
for j in sorted(larger_than.keys()):
|
|
larger = larger_than[j]
|
|
larger.discard(i)
|
|
if not larger:
|
|
del larger_than[j]
|
|
sorted_ids.append(j)
|
|
|
|
return list(map(lambda x: grouped_overloads[x], sorted_ids))
|
|
|
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
|
#
|
|
# Codegen API Integration
|
|
#
|
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
|
|
|
def emit_single_dispatch(
|
|
ps: PythonSignature, f: NativeFunction, namedtuple_typenames: Dict[str, str]
|
|
) -> str:
|
|
"""
|
|
Emit dispatch code for a single native function.
|
|
"""
|
|
@with_native_function
|
|
def go(f: NativeFunction) -> str:
|
|
# header comments
|
|
deprecated = '[deprecated] ' if ps.deprecated else ''
|
|
schema_comment = f'// {deprecated}aten::{f.func}'
|
|
|
|
# dispatch lambda signature
|
|
name = cpp.name(f.func)
|
|
lambda_formals = ', '.join(map(lambda a: f"{a.type_str} {a.name}",
|
|
dispatch_lambda_args(ps, f)))
|
|
lambda_return = dispatch_lambda_return_str(f)
|
|
|
|
# dispatch lambda body
|
|
dispatch_callee = cpp_dispatch_target(f)
|
|
dispatch_args = ', '.join(cpp_dispatch_exprs(f, python_signature=ps))
|
|
|
|
# from arg parser outputs to dispatch lambda arguments
|
|
parser_outputs = arg_parser_output_exprs(ps, f)
|
|
lambda_arg_exprs = dispatch_lambda_exprs(ps, f)
|
|
inits = '\n'.join(lambda_arg_exprs.inits)
|
|
lambda_args = ', '.join(lambda_arg_exprs.exprs)
|
|
|
|
# scatter fields
|
|
# TODO: Checking `ps.method and ('requires_grad' in parser_outputs)` is a hacky
|
|
# solution for enabling the 'requires_grad' argument for tensor methods
|
|
# new_full, new_empty, and new_zeros. A much better but more difficult to
|
|
# implement solution involves refactoring according to Ed's description here:
|
|
# https://github.com/pytorch/pytorch/issues/36455#issuecomment-614767589
|
|
need_set_requires_grad = ps.tensor_options_args and (not has_tensor_options(f) or (
|
|
ps.method and ('requires_grad' in parser_outputs)))
|
|
set_requires_grad = f'.set_requires_grad({parser_outputs["requires_grad"].expr})' \
|
|
if need_set_requires_grad else ''
|
|
|
|
if lambda_return == 'void':
|
|
return f"""\
|
|
{schema_comment}
|
|
{inits}
|
|
auto dispatch_{name} = []({lambda_formals}) -> {lambda_return} {{
|
|
pybind11::gil_scoped_release no_gil;
|
|
{dispatch_callee}({dispatch_args});
|
|
}};
|
|
dispatch_{name}({lambda_args}){set_requires_grad};
|
|
Py_RETURN_NONE;
|
|
"""
|
|
else:
|
|
typename = namedtuple_typenames.get(gen_namedtuple_typename_key(f))
|
|
namedtuple_typeref = f'&{typename}, ' if typename is not None else ''
|
|
return f"""\
|
|
{schema_comment}
|
|
{inits}
|
|
auto dispatch_{name} = []({lambda_formals}) -> {lambda_return} {{
|
|
pybind11::gil_scoped_release no_gil;
|
|
return {dispatch_callee}({dispatch_args});
|
|
}};
|
|
return wrap({namedtuple_typeref}dispatch_{name}({lambda_args}){set_requires_grad});
|
|
"""
|
|
|
|
return go(f)
|