pytorch/tools/codegen/api/autograd.py
Edward Yang 3efd5d8f01 Introduce tools.codegen.api.translate (#49122)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/49122

cpparguments_exprs has induced a lot of head scratching in many recent PRs for how to structure the code in a good way.  This PR eliminates the old algorithm for an entirely new algorithm inspired by logic programming.  The net result is shorter, cleaner and should be more robust to future changes.

This PR is a bit of a whopper.  Here is the order to review it.

- tools/codegen/api/types.py
  - Deleted CppArgument, CppArgumentPackIface (and subclasses), CppExpr, DispatcherExpr, DispatcherArgument, NativeExpr, NativeArgument, MetaArgument. All things previously called XArgument are now Binding. All things previously called XExpr are now Expr. I deleted the `__str__` implementation on Binding and fixed all call sites not to use it. On Binding, I renamed `str_no_default` and `str_default` to `defn` and `decl` for better symmetry with the corresponding signature concepts, although I'm open to naming them back to their original versions.
  - Obviously, things are less type safe without the class distinctions. So I introduce a new ADT called CType. CType represents the *semantic C++ type* of a binding: it is both the C++ type (e.g., `const Tensor&`) as well as the argument name that specifies what the  binding denotes (e.g., `other`). Every binding now records its CType. The key observation here is that you don't actually care if a given expression is from the cpp or dispatcher or native API; what you care is having enough information to know what the expression means, so you can use it appropriately. CType has this information. For the most part, ArgNames are just the string names of the arguments as you see them in JIT schema, but there is one case (`possibly_redundant_memory_format`) where we encode a little extra information. Unlike the plain strings we previously used to represent C++ types, CType have a little bit of structure around optional and references, because the translation code needs to work around these concepts.
  - I took the opportunity to kill all of the private fields like `_arguments` and `_returns_type` (since the argument types don't make sense anymore). Everything is computed for you on the fly. If this is a perf problem in codegen we can start using `cached_property` decorator.
  - All of the heavy lifting in CppSignature.argument_packs has been moved to the cpp module. We'll head over there next. Similarly, all of the exprs methods are now calling translate, the new functionality which we haven't gotten to yet
- tools/codegen/api/cpp.py
   - We refactor all of the type computation functions to return CType instead of str. Because CTypes need to know the denotation, there is a new `binds: ArgName` argument to most functions that provides the denotation, so we can slot it in. (An alternative would have been to construct CTypes without denotations and then fill them in post-facto, but I didn't do it this way. One downside is there are some places where I need a CType without denotation, so I fill these in with `__placeholder__` whenever this happens).
  - `argument` and `arguments` are now extremely simple. There is no more Pack business, just produce one or more Bindings. The one thing of note is that when both a `memory_format` and `options` are in scope, we label the memory format as `possibly_redundant_memory_format`. This will be used in translation
- tools/codegen/api/dispatcher.py and tools/codegen/api/native.py - same deal as cpp.py. One thing is that `cpparguments_exprs` is deleted; that is in the translator
- tools/codegen/api/translate.py - the translator! It uses a very simple backwards deduction engine to work out how to fill in the arguments of functions. There are comments in the file that explain how it works.
- Everything else: just some small call site tweaks for places when I changed API.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Test Plan: Imported from OSS

Reviewed By: ljk53

Differential Revision: D25455887

Pulled By: ezyang

fbshipit-source-id: 90dc58d420d4cc49281aa8647987c69f3ed42fa6
2020-12-16 16:18:40 -08:00

90 lines
3.3 KiB
Python

from dataclasses import dataclass
from typing import Optional, Sequence, List, Tuple
from tools.codegen.api.types import *
from tools.codegen.model import *
# Represents a saved attribute involved in backward calculation.
# Note that it can be a derived property of an input argument, e.g.:
# we could save `other.scalar_type()` instead of the entire `other` tensor.
@dataclass(frozen=True)
class SavedAttribute:
# Name of the saved attribute.
# Suffix is appended if it's derived property, e.g.: `other_scalar_type`
name: str
# The cpp type string.
# TODO: change from raw string to model.Type
type: str
# The expression to read the derived property at save time, e.g.:
# `other.scalar_type()`.
expr: str
# Represents a backward formula that calculates derivatives for one
# or more tensors.
@dataclass(frozen=True)
class Derivative:
# The formula string (legit C++ expression).
# Note that expressions against input arguments have been replaced with the
# corresponding saved attributes.
# E.g.:
# raw formula: `mul_tensor_backward(grad, self, other.scalar_type())`
# here: `mul_tensor_backward(grad, self, other_scalar_type)`
formula: str
# Names of the arguments for which this formula calculates derivatives.
var_names: Tuple[str, ...]
# Saved inputs that are referenced by the formula.
saved_inputs: Tuple[SavedAttribute, ...]
# Saved outputs that are referenced by the formula.
saved_outputs: Tuple[SavedAttribute, ...]
# Represents differentiability info for a NativeFunction.
@dataclass(frozen=True)
class DifferentiabilityInfo:
# The base name read from derivatives.yaml.
name: str
# The matching native function.
#
# There can be multiple NativeFunction having the same base name:
# - different overloads with different types of input arguments;
# - in-place/out/functional variants of the same function;
#
# We first use the schema string (under the 'name' key) in derivatives.yaml
# to find the NativeFunction having the same schema string.
# Then we find the in-place/out/functional variants of the matching function.
# Among these variants, we choose the one having the same name as the
# derivatives.yaml entry. If there is no exact match, then we choose the
# in-place variant.
# TODO: maybe the logic to search for all variants is no longer necessary?
func: NativeFunction
# The name of the generated autograd function.
# It's set only if we will calculate a derivative, i.e.
# 'args_with_derivatives' is not empty.
op: Optional[str]
# The derivatives formulae for this function.
derivatives: Sequence[Derivative]
# The union of 'saved_inputs' of all 'derivatives'.
all_saved_inputs: Sequence[SavedAttribute]
# The union of 'saved_outputs' of all 'derivatives'.
all_saved_outputs: Sequence[SavedAttribute]
# The function's input arguments for which it calculates derivatives.
# It's the union of 'var_names' of all 'derivatives', sorted by the
# argument order in the function schema.
args_with_derivatives: Sequence[Binding]
# Names of arguments whose derivative formula is 'non_differentiable'.
non_differentiable_arg_names: Sequence[str]
# Raw data read from derivatives.yaml.
output_differentiability: Optional[List[bool]]