mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47284 Test Plan: Imported from OSS Reviewed By: nikithamalgifb Differential Revision: D24706763 Pulled By: bdhirsh fbshipit-source-id: 60fbe81a0dff7e07aa8c169235d15b84151d3ed7
925 lines
36 KiB
Python
925 lines
36 KiB
Python
import re
|
|
|
|
from dataclasses import dataclass
|
|
from typing import List, Dict, Optional, Iterator, Tuple, Set, NoReturn
|
|
from enum import Enum
|
|
import itertools
|
|
|
|
# A little trick from https://github.com/python/mypy/issues/6366
|
|
# for getting mypy to do exhaustiveness checking
|
|
# TODO: put this somewhere else, maybe
|
|
def assert_never(x: NoReturn) -> NoReturn:
|
|
raise AssertionError("Unhandled type: {}".format(type(x).__name__))
|
|
|
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
|
#
|
|
# DATA MODEL
|
|
#
|
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
|
#
|
|
# Some general principles for our data model.
|
|
#
|
|
# - Stop using C++ data types as the internal data representation
|
|
# format. Instead, the internal data structures are centered
|
|
# around JIT schema representation. This avoid a big problem
|
|
# with the old codegen where we read in all the types from
|
|
# native_functions.yaml and then immediately had to retranslate
|
|
# them into C++ types.
|
|
#
|
|
# - More semantic data representation. Instead of representing
|
|
# everything as dicts and strings, we define dataclasses for
|
|
# every interesting entity the code generation has to deal with.
|
|
# These dataclasses have strong semantic invariants: for example,
|
|
# we generally require them to roundtrip losslessly into the
|
|
# form they were parsed from. These structures are immutable
|
|
# and you're expected to populate information once during
|
|
# construction.
|
|
|
|
# Represent a source location; used for better error reporting
|
|
@dataclass(frozen=True)
|
|
class Location:
|
|
file: str
|
|
line: int
|
|
|
|
def __str__(self) -> str:
|
|
return "{}:{}".format(self.file, self.line)
|
|
|
|
# Valid values of the 'variants' field in native_functions.yaml
|
|
Variant = Enum('Variant', ('function', 'method'))
|
|
|
|
class UseC10Dispatcher(Enum):
|
|
full = 0
|
|
with_codegenerated_unboxing_wrapper = 1
|
|
hacky_wrapper_for_legacy_signatures = 2
|
|
|
|
def dispatcher_uses_new_style(self) -> bool:
|
|
return self in [UseC10Dispatcher.full, UseC10Dispatcher.hacky_wrapper_for_legacy_signatures]
|
|
|
|
# The basic input to the code generation is native_functions.yaml.
|
|
# The name "native", BTW, comes from the distinction between native
|
|
# functions and legacy TH functions. The legacy TH functions are gone,
|
|
# but the "native" descriptor has stuck.
|
|
#
|
|
# NativeFunction models a single entry in native_functions.yaml. Its
|
|
# fields roughly correspond to what you would see in the YAML itself,
|
|
# but after canonicalization and parsing has occurred.
|
|
#
|
|
# You can see some of the overall design patterns for how we setup
|
|
# dataclasses in this class, but we will defer a complete discussion
|
|
# of this at FunctionSchema.
|
|
@dataclass(frozen=True)
|
|
class NativeFunction:
|
|
# The function schema of the operator in question. This schema
|
|
# has been parsed; see FunctionSchema for more about its structure.
|
|
# (This type is quoted as we are forward referencing a type
|
|
# defined later in the file. I opted for this ordering of the
|
|
# classes for expository clarity.)
|
|
func: 'FunctionSchema'
|
|
|
|
# Corresponds to the 'use_c10_dispatcher' field. The default
|
|
# is 'with_codegenerated_unboxing_wrapper'
|
|
use_c10_dispatcher: UseC10Dispatcher
|
|
|
|
# Whether or not to omit automatic generation of a DeviceGuard
|
|
device_guard: bool
|
|
|
|
# What python module to put the function in
|
|
python_module: Optional[str]
|
|
|
|
# TODO: figure out what this does
|
|
category_override: Optional[str]
|
|
|
|
# If no variants are specified in native_functions.yaml, this is
|
|
# assumed to be {'function'}.
|
|
variants: Set[Variant]
|
|
|
|
# Whether or not we should skip generating registrations for
|
|
# this kernel. This is a bit of a double-edged sword, as manual
|
|
# registrations don't participate in codegen-based selective build!
|
|
manual_kernel_registration: bool
|
|
|
|
# A mapping of dispatch keys to names of functions implementing
|
|
# them. In native_functions.yaml, the dispatch entry is optional; in that
|
|
# case, that is equivalent to having written:
|
|
#
|
|
# dispatch:
|
|
# Math: $operator_name
|
|
#
|
|
# TODO: str key could be replaced with more explicit enum
|
|
dispatch: Dict[str, str]
|
|
|
|
# The location in the YAML file were this native function entry was
|
|
# defined. This is for conveniently reporting error messages!
|
|
loc: 'Location'
|
|
|
|
# NB: The benefit of defining a dataclass is that we automatically get
|
|
# a constructor defined for all the fields we specify. No need
|
|
# to explicitly write it out.
|
|
|
|
@staticmethod
|
|
def from_yaml(ei: Dict[str, object], loc: 'Location') -> 'NativeFunction':
|
|
"""
|
|
Parse a NativeFunction from a dictionary as directly parsed
|
|
from native_functions.yaml
|
|
"""
|
|
e = ei.copy()
|
|
|
|
funcs = e.pop('func')
|
|
assert isinstance(funcs, str), f'not a str: {funcs}'
|
|
func = FunctionSchema.parse(funcs)
|
|
|
|
use_c10_dispatcher_s = e.pop('use_c10_dispatcher', None)
|
|
if use_c10_dispatcher_s is None:
|
|
use_c10_dispatcher = UseC10Dispatcher.with_codegenerated_unboxing_wrapper
|
|
elif use_c10_dispatcher_s == 'full':
|
|
use_c10_dispatcher = UseC10Dispatcher.full
|
|
elif use_c10_dispatcher_s == 'hacky_wrapper_for_legacy_signatures':
|
|
use_c10_dispatcher = UseC10Dispatcher.hacky_wrapper_for_legacy_signatures
|
|
else:
|
|
raise AssertionError(
|
|
f'use_c10_dispatcher must be unset or set to full, got {use_c10_dispatcher}')
|
|
|
|
variants_s = e.pop('variants', 'function')
|
|
assert isinstance(variants_s, str)
|
|
variants: Set[Variant] = set()
|
|
for v in variants_s.split(', '):
|
|
if v == 'function':
|
|
variants.add(Variant.function)
|
|
elif v == 'method':
|
|
variants.add(Variant.method)
|
|
else:
|
|
raise AssertionError(f'illegal variant {v}')
|
|
|
|
manual_kernel_registration = e.pop('manual_kernel_registration', False)
|
|
assert isinstance(manual_kernel_registration, bool), f'not a bool: {manual_kernel_registration}'
|
|
|
|
device_guard = e.pop('device_guard', True)
|
|
assert isinstance(device_guard, bool), f'not a bool: {device_guard}'
|
|
|
|
python_module = e.pop('python_module', None)
|
|
assert python_module is None or isinstance(python_module, str), f'not a str: {python_module}'
|
|
|
|
category_override = e.pop('category_override', None)
|
|
assert category_override is None or isinstance(category_override, str), f'not a str: {category_override}'
|
|
|
|
raw_dispatch = e.pop('dispatch', None)
|
|
assert raw_dispatch is None or isinstance(raw_dispatch, dict), e
|
|
dispatch: Dict[str, str] = {}
|
|
if raw_dispatch is not None:
|
|
for ks, v in raw_dispatch.items():
|
|
if ks == '__line__':
|
|
continue # not worth tracking line numbers for dispatch entries
|
|
assert isinstance(ks, str), e
|
|
assert isinstance(v, str), e
|
|
for k in ks.split(","):
|
|
dispatch[k.strip()] = v
|
|
else:
|
|
from tools.codegen.api import cpp
|
|
dispatch['Math'] = cpp.name(func)
|
|
|
|
assert not ('DefaultBackend' in dispatch and 'Math' in dispatch), \
|
|
"cannot specify both DefaultBackend and Math on a single kernel; each " \
|
|
"strictly subsumes the other. If you wanted to provide an explicit autograd " \
|
|
"implementation, specify DefaultBackend; otherwise specify Math only"
|
|
|
|
e.pop('__line__')
|
|
assert not e, f"leftover entries: {e}"
|
|
|
|
return NativeFunction(
|
|
func=func,
|
|
use_c10_dispatcher=use_c10_dispatcher,
|
|
variants=variants,
|
|
manual_kernel_registration=manual_kernel_registration,
|
|
python_module=python_module,
|
|
category_override=category_override,
|
|
dispatch=dispatch,
|
|
device_guard=device_guard,
|
|
loc=loc,
|
|
)
|
|
|
|
# __post_init__ functions in dataclasses can be used to do extra
|
|
# validation after construction.
|
|
#
|
|
# Notice that we don't do any type validation here. In fact, we
|
|
# rely exclusively on mypy to check if you've done types correctly!
|
|
# Validation is for nontrivial invariants that cannot be (conveniently)
|
|
# encoded in the type system.
|
|
def __post_init__(self) -> None:
|
|
if self.func.out_arguments:
|
|
assert self.variants == {Variant.function}, "Native functions with out arguments MUST " \
|
|
"be declared with only function variant; e.g., variants: function; " \
|
|
"otherwise you will tickle a Python argument binding bug " \
|
|
"(which usually manifests itself as the result variable being undefined.)"
|
|
|
|
SchemaKind = Enum('SchemaKind', ('functional', 'inplace', 'out'))
|
|
|
|
# Represents a bundle of native functions that are semantically related.
|
|
@dataclass(frozen=True)
|
|
class NativeFunctionGroup:
|
|
functional: Optional[NativeFunction]
|
|
inplace: Optional[NativeFunction]
|
|
out: Optional[NativeFunction]
|
|
|
|
def __post_init__(self) -> None:
|
|
test_sig: Optional[FunctionSchema] = None
|
|
for f in self.functions():
|
|
if test_sig is None:
|
|
test_sig = f.func.signature()
|
|
else:
|
|
if test_sig != f.func.signature():
|
|
raise AssertionError(
|
|
"NativeFunctionGroup constructed from two NativeFunctions "
|
|
f"that don't have matching signatures: {test_sig} != {f.func.signature()}"
|
|
)
|
|
|
|
def signature(self) -> 'FunctionSchema':
|
|
if self.out is not None:
|
|
return self.out.func.signature()
|
|
elif self.functional is not None:
|
|
return self.functional.func.signature()
|
|
elif self.inplace is not None:
|
|
return self.inplace.func.signature()
|
|
else:
|
|
raise AssertionError("invalid NativeFunctionGroup has no NativeFunctions")
|
|
|
|
def functions(self) -> Iterator[NativeFunction]:
|
|
if self.out is not None:
|
|
yield self.out
|
|
if self.functional is not None:
|
|
yield self.functional
|
|
if self.inplace is not None:
|
|
yield self.inplace
|
|
|
|
@staticmethod
|
|
def from_dict(d: Dict[SchemaKind, NativeFunction]) -> 'NativeFunctionGroup':
|
|
functional = d.get(SchemaKind.functional)
|
|
inplace = d.get(SchemaKind.inplace)
|
|
out = d.get(SchemaKind.out)
|
|
return NativeFunctionGroup(
|
|
functional=functional,
|
|
inplace=inplace,
|
|
out=out,
|
|
)
|
|
|
|
# The function schema is undoubtedly the most important data structure
|
|
# in all of the codegen, as it defines the type signature for operators,
|
|
# and most of the code generation we do is type directed (e.g., look at
|
|
# the types, decide what to do. Think about how we code generate
|
|
# C++ function stubs!)
|
|
#
|
|
# We will also see in this class the general structure for how we model
|
|
# data in this code generation. A few notable properties to point out
|
|
# ahead of time:
|
|
#
|
|
# - These dataclasses are a *lossless* representation of the strings
|
|
# they are parsed from. In fact, we assert that given the
|
|
# information stored in the dataclass, we can exactly reconstruct
|
|
# the string we parsed from (and assert this inside the parse
|
|
# definition). There are a few reasons for this:
|
|
#
|
|
# - If you find that it is difficult to reconstruct the string
|
|
# given a dataclass, that is a clue that you are data
|
|
# representation is wrong.
|
|
#
|
|
# - It helps ensure that all relevant information is present
|
|
# in the dataclass, so that downstream users aren't tempted
|
|
# to reparse the original string to get some information
|
|
# that was omitted.
|
|
#
|
|
# - It forces you to represent the data in-memory in the same way
|
|
# it is recorded textually, which makes the dataclasses easier
|
|
# to understand for someone who is familiar with the
|
|
# textual format. (As a tradeoff, it means you have to model
|
|
# the syntax, even when it is inconvenient. But maybe that means
|
|
# the syntax is bad!) If you don't understand the internal
|
|
# representation, go look at the printing code to see how
|
|
# it maps onto the surface syntax!
|
|
#
|
|
# - It makes it easy to test the parsing code, as parsing code
|
|
# that is inconsistent with the string code will fail early
|
|
# and loudly. (As a tradeoff, it makes the parsing code a bit
|
|
# brittle (in particular, with trivial whitespace changes you
|
|
# are likely to trigger an assert error).
|
|
#
|
|
# In general, try to make the __str__ code as simple as possible
|
|
# (even at the cost of more complex parsing logic.) Additionally,
|
|
# try to minimize redundancy in data representation. (Precomputed
|
|
# fields are OK though: they are defined as a simple function on
|
|
# the canonical representation in question.)
|
|
#
|
|
# - These dataclasses are all frozen; once constructed their
|
|
# values never change. This makes it easy to tell where any
|
|
# given data came from: just look to the constructor. As a
|
|
# tradeoff, you can't easily "decorate" a schema with extra
|
|
# information from a post-facto analysis. We impose this
|
|
# restriction to make these structures more understandable.
|
|
#
|
|
@dataclass(frozen=True)
|
|
class FunctionSchema:
|
|
# The name of the operator this function schema describes.
|
|
name: 'OperatorName'
|
|
|
|
arguments: Tuple['Argument', ...]
|
|
kwarg_only_arguments: Tuple['Argument', ...] # but not including out args
|
|
# Unlike in the previous codegen, we have factored out 'out' arguments
|
|
# in the canonical representation, removing them from kwarg
|
|
# arguments. This choice is justified by numerous downstream
|
|
# transformations which treat out arguments specially; additionally,
|
|
# you can see that canonicity is not violated!
|
|
out_arguments: Tuple['Argument', ...] # these are also kwarg-only
|
|
|
|
# TODO: Need to handle collisions with argument names at some point
|
|
returns: Tuple['Return', ...]
|
|
|
|
def schema_order_arguments(self) -> Iterator['Argument']:
|
|
return itertools.chain(self.arguments, self.kwarg_only_arguments, self.out_arguments)
|
|
|
|
@staticmethod
|
|
def parse(func: str) -> 'FunctionSchema':
|
|
# We should probably get a proper parser here
|
|
assert ' -> ' in func, "function schema missing return type (spaces are mandatory)"
|
|
func_decl, return_decl = [x.strip() for x in func.split(' -> ')]
|
|
ops, args = func_decl.split('(', 1)
|
|
assert args[-1] == ")", "Expecting closing )"
|
|
args = args[:-1]
|
|
name = OperatorName.parse(ops)
|
|
arguments, kwarg_only_arguments, out_arguments = parse_arguments(args)
|
|
returns = parse_returns(return_decl)
|
|
r = FunctionSchema(
|
|
name=name,
|
|
arguments=arguments,
|
|
kwarg_only_arguments=kwarg_only_arguments,
|
|
out_arguments=out_arguments,
|
|
returns=returns
|
|
)
|
|
assert str(r) == func, f'{str(r)} != {func}'
|
|
return r
|
|
|
|
def __post_init__(self) -> None:
|
|
for arg, ret in zip(self.out_arguments, self.returns):
|
|
assert arg.annotation == ret.annotation, \
|
|
"Out arguments must have matching return Tensor; furthermore, " \
|
|
"the ith-argument needs to correspond to the ith return"
|
|
# Invariant: we expect out arguments to appear as keyword arguments in the schema.
|
|
# This means that all mutable returns should be aliased to a keyword argument
|
|
# (except for "self", which we explicitly don't treat as an out argument because of its use in methods)
|
|
# See Note [is_out_fn]
|
|
out_and_self = list(self.out_arguments) + [arg for arg in self.arguments if arg.name == "self"]
|
|
mutable_returns = [ret for ret in self.returns if ret.annotation is not None and ret.annotation.is_write]
|
|
for ret in mutable_returns:
|
|
assert any([ret.annotation == arg.annotation for arg in out_and_self]), \
|
|
"All mutable returns must be aliased either to a keyword argument, or to \"self\". " \
|
|
"Did you forget to mark an out argument as keyword-only?"
|
|
if self.out_arguments:
|
|
assert len(self.out_arguments) == len(self.returns), \
|
|
"Must return as many arguments as there are out arguments"
|
|
if self.name.name.inplace:
|
|
# TODO: fixme
|
|
if str(self.name) not in [
|
|
'_amp_foreach_non_finite_check_and_unscale_',
|
|
'_foreach_add_.ScalarList',
|
|
'_foreach_sub_.ScalarList',
|
|
'_foreach_mul_.ScalarList',
|
|
'_foreach_div_.ScalarList',
|
|
'_foreach_add_.Scalar',
|
|
'_foreach_sub_.Scalar',
|
|
'_foreach_mul_.Scalar',
|
|
'_foreach_div_.Scalar',
|
|
'_foreach_add_.List',
|
|
'_foreach_sub_.List',
|
|
'_foreach_mul_.List',
|
|
'_foreach_div_.List',
|
|
'_foreach_exp_',
|
|
'_foreach_sqrt_',
|
|
'_foreach_addcmul_.Scalar',
|
|
'_foreach_addcdiv_.Scalar',
|
|
'_foreach_addcmul_.ScalarList',
|
|
'_foreach_addcdiv_.ScalarList']:
|
|
assert len(self.returns) == 1
|
|
|
|
def is_out_fn(self) -> bool:
|
|
# Note [is_out_fn]
|
|
#
|
|
# out functions are the variants which take an explicit out= argument
|
|
# to populate into. We need to know if a schema corresponds to an
|
|
# out function for several reasons:
|
|
#
|
|
# - They codegen differently in C++ API
|
|
# - codegen to at::add_out rather than at::add
|
|
# - out argument is moved to front of C++ argument list
|
|
#
|
|
# out functions are DEFINED to be any function with a keyword-only
|
|
# argument that is mutable. In principle, this could lead to a
|
|
# false positive if you define a function that mutates a
|
|
# kwarg only argument, but this isn't the "true" output of this
|
|
# function. A more robust definition that would work in this
|
|
# case would also look at:
|
|
#
|
|
# - The output types. Out functions take in the arguments
|
|
# they mutate and then return them again; this is sort
|
|
# of "definitionally" what makes something an out function.
|
|
# Historically, we DO check this for consistency.
|
|
# - Correspondence with pure variant. An out function
|
|
# should have a signature equivalent to its pure variant,
|
|
# but just with extra kwargs for the output elements. This
|
|
# is difficult to actually check for and historically
|
|
# we only do this check in tools/
|
|
return bool(self.out_arguments)
|
|
|
|
def kind(self) -> SchemaKind:
|
|
"""
|
|
What kind of schema is this? A functional schema is one
|
|
that returns a newly allocated output; an inplace schema
|
|
modifies the self argument inplace; an out schema writes
|
|
the result into an explicitly provided out argument.
|
|
"""
|
|
is_inplace = self.name.name.inplace
|
|
is_out = bool(self.out_arguments)
|
|
assert not (is_inplace and is_out)
|
|
if is_inplace:
|
|
return SchemaKind.inplace
|
|
elif is_out:
|
|
return SchemaKind.out
|
|
else:
|
|
return SchemaKind.functional
|
|
|
|
# WARNING: This method is not currently tested in any meaningful way
|
|
def signature(self) -> 'FunctionSchema':
|
|
"""
|
|
Certain schemas are 'related', in that they are simply
|
|
inplace/out/functional versions of the same function. This method
|
|
factors these schemas into the "core" functional signature which
|
|
is equal across all versions.
|
|
|
|
Here is what normalization happens to the schema to convert
|
|
it to a signature:
|
|
- The overload name is stripped (name is retained, since
|
|
it expresses semantic content about what the function does)
|
|
- Inplace is set False
|
|
- Out arguments are stripped
|
|
- Mutability annotations are stripped (this is sound
|
|
because you cannot overload on mutability annotation)
|
|
|
|
This function is based off of get_signature in
|
|
tools.autograd.load_derivatives
|
|
"""
|
|
|
|
# dataclasses.replace could be used here, but it is less
|
|
# type safe so for now I've opted to type everything out
|
|
def strip_arg_annotation(a: Argument) -> Argument:
|
|
return Argument(
|
|
name=a.name,
|
|
type=a.type,
|
|
default=a.default, # hmmm
|
|
annotation=None,
|
|
)
|
|
|
|
def strip_ret_annotation(r: Return) -> Return:
|
|
return Return(
|
|
name=r.name,
|
|
type=r.type,
|
|
annotation=None,
|
|
)
|
|
|
|
return FunctionSchema(
|
|
name=OperatorName(
|
|
name=BaseOperatorName(
|
|
base=self.name.name.base,
|
|
inplace=False,
|
|
dunder_method=self.name.name.dunder_method,
|
|
),
|
|
overload_name="", # stripped
|
|
),
|
|
arguments=tuple(map(strip_arg_annotation, self.arguments)),
|
|
kwarg_only_arguments=tuple(map(strip_arg_annotation, self.kwarg_only_arguments)),
|
|
out_arguments=(), # stripped
|
|
returns=tuple(map(strip_ret_annotation, self.returns)),
|
|
)
|
|
|
|
def __str__(self) -> str:
|
|
all_arguments: List[str] = []
|
|
all_arguments.extend(map(str, self.arguments))
|
|
if self.kwarg_only_arguments or self.out_arguments:
|
|
all_arguments.append('*')
|
|
all_arguments.extend(map(str, self.kwarg_only_arguments))
|
|
all_arguments.extend(map(str, self.out_arguments))
|
|
all_arguments_str = ', '.join(all_arguments)
|
|
if len(self.returns) == 1:
|
|
returns = str(self.returns[0]) # omit parentheses
|
|
else:
|
|
returns = '(' + ', '.join(map(str, self.returns)) + ')'
|
|
return f'{self.name}({all_arguments_str}) -> {returns}'
|
|
|
|
# Here is the rest of the data model, described more briefly.
|
|
|
|
# Simplified version for what actually shows up in built-ins.
|
|
# Look at alias_info.h for expanded syntax. If you need the structure,
|
|
# you also need to make this structure recursive so it can be lined
|
|
# up with the type components too. For primitives this isn't really
|
|
# necessary
|
|
@dataclass(frozen=True)
|
|
class Annotation:
|
|
# Typically only has one element. Not actually a set so
|
|
# we can conveniently assume it is canonically ordered
|
|
alias_set: Tuple[str, ...]
|
|
is_write: bool
|
|
|
|
@staticmethod
|
|
def parse(ann: str) -> 'Annotation':
|
|
m = re.match(r'^([a-z])(!?)$', ann)
|
|
assert m is not None, f'unrecognized alias annotation {ann}'
|
|
alias_set = (m.group(1),)
|
|
is_write = m.group(2) == '!'
|
|
r = Annotation(alias_set=alias_set, is_write=is_write)
|
|
assert str(r) == ann, f'{r} != {ann}'
|
|
return r
|
|
|
|
def __str__(self) -> str:
|
|
alias_set = '|'.join(self.alias_set)
|
|
is_write = '!' if self.is_write else ''
|
|
return f'{alias_set}{is_write}'
|
|
|
|
# The base class for the type system. This is also loosely modeled
|
|
# off of jit_type.h, but we've simplified the hierarchy to focus
|
|
# in on the aspects of the type system that matter for code generation
|
|
# (for example, there's no SingleElementType subclass anymore).
|
|
# You never actually construct a Type; usually it's going to be one
|
|
# of the subclasses. If Python had ADTs this would be one!
|
|
@dataclass(frozen=True)
|
|
class Type:
|
|
@staticmethod
|
|
def parse(t: str) -> 'Type':
|
|
r = Type._parse(t)
|
|
assert str(r) == t, f'{r} != {t}'
|
|
return r
|
|
|
|
@staticmethod
|
|
def _parse(t: str) -> 'Type':
|
|
m = re.match(r'^(.+)\?$', t)
|
|
if m is not None:
|
|
return OptionalType(Type.parse(m.group(1)))
|
|
m = re.match(r'^(.+)\[([0-9]+)?\]$', t)
|
|
if m is not None:
|
|
size = int(m.group(2)) if m.group(2) is not None else None
|
|
return ListType(elem=Type.parse(m.group(1)), size=size)
|
|
try:
|
|
return BaseType(BaseTy[t])
|
|
except KeyError:
|
|
raise RuntimeError(f"unrecognized type {t}")
|
|
|
|
def __str__(self) -> str:
|
|
raise NotImplementedError
|
|
|
|
# WARNING: These concepts are not very well-defined. For example,
|
|
# is "int?" nullable? How about "int?[]". They are defined
|
|
# so we can conveniently generate legacy Declarations.yaml but
|
|
# really we should probably just remove these at some point
|
|
|
|
def is_tensor_like(self) -> bool:
|
|
raise NotImplementedError
|
|
|
|
def is_nullable(self) -> bool:
|
|
raise NotImplementedError
|
|
|
|
def is_list_like(self) -> Optional['ListType']:
|
|
raise NotImplementedError
|
|
|
|
# Base types are simple, atomic types with no further structure
|
|
BaseTy = Enum('BaseTy', (
|
|
'Generator',
|
|
'ScalarType',
|
|
'Tensor',
|
|
'int',
|
|
'Dimname',
|
|
'float',
|
|
'str',
|
|
'bool',
|
|
'Layout',
|
|
'Device',
|
|
'Scalar',
|
|
'MemoryFormat',
|
|
'QScheme',
|
|
'Storage',
|
|
'Stream',
|
|
'ConstQuantizerPtr', # TODO: rename
|
|
))
|
|
|
|
@dataclass(frozen=True)
|
|
class BaseType(Type):
|
|
name: BaseTy
|
|
|
|
def __str__(self) -> str:
|
|
return f'{self.name.name}'
|
|
|
|
def is_tensor_like(self) -> bool:
|
|
return self.name == BaseTy.Tensor
|
|
|
|
def is_nullable(self) -> bool:
|
|
return False
|
|
|
|
def is_list_like(self) -> Optional['ListType']:
|
|
return None
|
|
|
|
# Optional types may be specified, or may also be validly given None
|
|
@dataclass(frozen=True)
|
|
class OptionalType(Type):
|
|
elem: Type
|
|
|
|
def __str__(self) -> str:
|
|
return f'{self.elem}?'
|
|
|
|
def is_tensor_like(self) -> bool:
|
|
return self.elem.is_tensor_like()
|
|
|
|
def is_nullable(self) -> bool:
|
|
return True
|
|
|
|
def is_list_like(self) -> Optional['ListType']:
|
|
return self.elem.is_list_like()
|
|
|
|
# List types specify that we may have multiples of an element. We
|
|
# also support explicit sizes on list types, but these have
|
|
# some nontrivial semantics! (However, for C++ API purposes, explicit
|
|
# sizes are mostly erased from the type system.)
|
|
#
|
|
# DANGER WILL ROBINSON: C++ elaboration depends on elem type; e.g.,
|
|
# int[] elaborates differently than bool[3]!
|
|
@dataclass(frozen=True)
|
|
class ListType(Type):
|
|
elem: Type
|
|
size: Optional[int]
|
|
|
|
def __str__(self) -> str:
|
|
size = f'{self.size}' if self.size else ''
|
|
return f'{self.elem}[{size}]'
|
|
|
|
def is_tensor_like(self) -> bool:
|
|
return self.elem.is_tensor_like()
|
|
|
|
def is_nullable(self) -> bool:
|
|
return self.elem.is_nullable()
|
|
|
|
def is_list_like(self) -> Optional['ListType']:
|
|
return self
|
|
|
|
@dataclass(frozen=True)
|
|
class Argument:
|
|
# NB: I didn't put kwarg_only as a boolean field here, unlike
|
|
# c10::Argument, so that printing works correctly
|
|
|
|
name: str
|
|
type: Type
|
|
default: Optional[str]
|
|
|
|
# The semantics of the annotation field are a little strange.
|
|
#
|
|
# Alias annotations parametrize Tensors (since Tensors are the only things
|
|
# that can alias.) This motivates why I write Tensor(a!)? (and not, for
|
|
# example, Tensor?(a!)), because the (a!) describes aliasing on the tensor,
|
|
# which may be optional (i.e., the alias annotation should bind first to
|
|
# Tensor, before the optional postfix annotation).
|
|
#
|
|
# However, despite being a property of Tensor, we (and c10::Argument)
|
|
# store the annotation at the top level of the Argument, rather than
|
|
# inside the embedded Tensor type. In the C++ version of this
|
|
# class, we then go through great lengths to mimic the type
|
|
# structure in the annotation structure so we can correlate
|
|
# annotations with types.
|
|
#
|
|
# Now, it turns out, in all applications in code generation, the
|
|
# structure of annotated types is very simple. So we just hard
|
|
# code it here. But if we ever do get anything more complex, this
|
|
# model will have to change!
|
|
annotation: Optional[Annotation]
|
|
|
|
@staticmethod
|
|
def parse(arg: str) -> 'Argument':
|
|
name: str
|
|
default: Optional[str]
|
|
type_and_annot, name_and_default = arg.rsplit(' ', 1)
|
|
if '=' in name_and_default:
|
|
name, default = name_and_default.split('=')
|
|
else:
|
|
name = name_and_default
|
|
default = None
|
|
# TODO: deduplicate annotation matching with Return
|
|
match = re.match(r'Tensor\((.+)\)(.*)', type_and_annot)
|
|
annotation: Optional[Annotation]
|
|
if match:
|
|
# If you update this, make sure the __str__ still works too
|
|
assert match.group(2) in ['', '?', '[]'], 'unrecognized alias analysis form with Tensor'
|
|
type_s = 'Tensor' + match.group(2)
|
|
annotation = Annotation.parse(match.group(1))
|
|
else:
|
|
type_s = type_and_annot
|
|
annotation = None
|
|
type = Type.parse(type_s)
|
|
r = Argument(
|
|
name=name,
|
|
type=type,
|
|
default=default,
|
|
annotation=annotation,
|
|
)
|
|
assert str(r) == arg, f'{str(r)} != {arg}'
|
|
return r
|
|
|
|
@property
|
|
def is_write(self) -> bool:
|
|
return self.annotation is not None and self.annotation.is_write
|
|
|
|
def __str__(self) -> str:
|
|
type = f'{self.type}'
|
|
if self.annotation:
|
|
assert type in ['Tensor', 'Tensor?', 'Tensor[]']
|
|
type = type.replace('Tensor', f'Tensor({self.annotation})')
|
|
if self.name is None:
|
|
return type
|
|
else:
|
|
mb_default = ''
|
|
if self.default:
|
|
mb_default = f'={self.default}'
|
|
return f"{type} {self.name}{mb_default}"
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class Return:
|
|
name: Optional[str]
|
|
type: Type
|
|
annotation: Optional[Annotation]
|
|
|
|
@staticmethod
|
|
def parse(arg: str) -> 'Return':
|
|
name: Optional[str]
|
|
if ' ' in arg:
|
|
type_and_annot, name = arg.rsplit(' ', 1)
|
|
else:
|
|
type_and_annot = arg
|
|
name = None
|
|
match = re.match(r'Tensor\((.+)\)(.*)', type_and_annot)
|
|
annotation: Optional[Annotation]
|
|
if match:
|
|
# If you update this, make sure the __str__ still works too
|
|
assert match.group(2) in ['', '?', '[]'], 'unrecognized alias analysis form with Tensor'
|
|
type_s = 'Tensor' + match.group(2)
|
|
annotation = Annotation.parse(match.group(1))
|
|
else:
|
|
type_s = type_and_annot
|
|
annotation = None
|
|
type = Type.parse(type_s)
|
|
r = Return(
|
|
name=name,
|
|
type=type,
|
|
annotation=annotation,
|
|
)
|
|
assert str(r) == arg, f'{str(r)} != {arg}'
|
|
return r
|
|
|
|
@property
|
|
def is_write(self) -> bool:
|
|
return self.annotation is not None and self.annotation.is_write
|
|
|
|
def __str__(self) -> str:
|
|
type = f'{self.type}'
|
|
if self.annotation:
|
|
assert type in ['Tensor', 'Tensor?', 'Tensor[]']
|
|
type = type.replace('Tensor', f'Tensor({self.annotation})')
|
|
if self.name is None:
|
|
return type
|
|
else:
|
|
return f"{type} {self.name}"
|
|
|
|
|
|
# Names that validly are __iXXX__ indicating inplace operations.
|
|
# Taken from https://www.python.org/dev/peps/pep-0203/#new-methods
|
|
# NB: PyTorch hasn't actually implemented all of these
|
|
AUGMENTED_ASSIGNMENT_NAMES = ['add', 'sub', 'mul', 'div', 'mod', 'pow', 'lshift', 'rshift', 'and', 'xor', 'or']
|
|
|
|
# A BaseOperatorName is what we think of the operator name, without
|
|
# the overload name. Unusually, we don't represent this as just a
|
|
# string; instead, we directly represent a few important semantic
|
|
# bits of information we derive from the string: namely whether
|
|
# or not it's inplace (add_) and whether or not it's a double-underscore
|
|
# method (__add__)
|
|
@dataclass(frozen=True)
|
|
class BaseOperatorName:
|
|
base: str
|
|
inplace: bool
|
|
dunder_method: bool
|
|
|
|
@staticmethod
|
|
def parse(op: str) -> 'BaseOperatorName':
|
|
assert op != ''
|
|
assert not op.endswith('_out'), \
|
|
"_out suffix is reserved and not permitted for operator names; " \
|
|
"did you mean to specify an out overload name instead?"
|
|
m = re.match(r'^__([^_]+)__$', op)
|
|
if m is not None:
|
|
dunder_method = True
|
|
base = m.group(1)
|
|
if any(base == f'i{n}' for n in AUGMENTED_ASSIGNMENT_NAMES):
|
|
inplace = True
|
|
base = base[1:]
|
|
else:
|
|
inplace = False
|
|
# temporary, this is not intrinsically true but
|
|
# has been historically true for dunder methods
|
|
# we support (but, if we ever got, say, __int__, this would
|
|
# be wrong!)
|
|
assert base[0] != 'i'
|
|
else:
|
|
dunder_method = False
|
|
base = op
|
|
if base[-1] == '_':
|
|
inplace = True
|
|
base = base[:-1]
|
|
else:
|
|
inplace = False
|
|
r = BaseOperatorName(base=base, inplace=inplace, dunder_method=dunder_method)
|
|
assert str(r) == op, f'{str(r)} != {op}'
|
|
return r
|
|
|
|
def __str__(self) -> str:
|
|
if self.dunder_method:
|
|
i = 'i' if self.inplace else ''
|
|
return f'__{i}{self.base}__'
|
|
else:
|
|
i = '_' if self.inplace else ''
|
|
return f'{self.base}{i}'
|
|
|
|
# Operator name is the base operator name along with the (typically not
|
|
# user visible) overload string.
|
|
@dataclass(frozen=True)
|
|
class OperatorName:
|
|
name: BaseOperatorName
|
|
overload_name: str
|
|
|
|
@staticmethod
|
|
def parse(op_name: str) -> 'OperatorName':
|
|
if '.' in op_name:
|
|
name, overload_name = op_name.split('.', 1)
|
|
else:
|
|
name = op_name
|
|
overload_name = ''
|
|
r = OperatorName(
|
|
name=BaseOperatorName.parse(name),
|
|
overload_name=overload_name
|
|
)
|
|
assert str(r) == op_name, f'{str(r)} != {op_name}'
|
|
return r
|
|
|
|
def __str__(self) -> str:
|
|
if self.overload_name:
|
|
return f"{self.name}.{self.overload_name}"
|
|
else:
|
|
return f"{self.name}"
|
|
|
|
# Helper functions for parsing argument lists (both inputs and returns)
|
|
|
|
def parse_returns(return_decl: str) -> Tuple[Return, ...]:
|
|
"""
|
|
Input: '()'
|
|
Output: []
|
|
"""
|
|
if return_decl == '()':
|
|
return ()
|
|
if return_decl[0] == '(' and return_decl[-1] == ')':
|
|
return_decl = return_decl[1:-1]
|
|
return tuple(Return.parse(arg) for arg in return_decl.split(', '))
|
|
|
|
def parse_arguments(args: str) -> Tuple[Tuple[Argument, ...], Tuple[Argument, ...], Tuple[Argument, ...]]:
|
|
"""
|
|
Input: 'int x, int y, int z'
|
|
Output: positional args, kwarg only args
|
|
"""
|
|
arguments: List[Argument] = []
|
|
kwarg_only_arguments: List[Argument] = []
|
|
out_arguments: List[Argument] = []
|
|
arguments_acc = arguments
|
|
|
|
# TODO: Use a real parser here; this will get bamboozled
|
|
# by signatures that contain things like std::array<bool, 2> (note the space)
|
|
for arg in args.split(', '):
|
|
if not arg:
|
|
continue
|
|
if arg == '*':
|
|
assert arguments_acc is arguments, "invalid syntax: kwarg-only specifier * can only occur once"
|
|
arguments_acc = kwarg_only_arguments
|
|
continue
|
|
parg = Argument.parse(arg)
|
|
# Currently, we rely directly on the invariant that there are NO
|
|
# kwarg-only mutating arguments. If you want to relax this,
|
|
# we will need a more semantic way of matching that takes
|
|
# into account return arguments. In that case, you will have
|
|
# to manage out_arguments computation a level up, in
|
|
# FunctionSchema. See Note [is_out_fn]
|
|
if parg.annotation is not None and parg.annotation.is_write:
|
|
if arguments_acc is arguments:
|
|
pass # do nothing
|
|
elif arguments_acc is kwarg_only_arguments:
|
|
arguments_acc = out_arguments
|
|
else:
|
|
assert arguments_acc is not out_arguments
|
|
arguments_acc.append(parg)
|
|
|
|
return tuple(arguments), tuple(kwarg_only_arguments), tuple(out_arguments)
|