mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/55334 The goal of this PR is to clean up some of the autograd codegen to compare C++ types using `CType` objects instead of raw strings. My last PR in the stack made that string comparison a little more fragile, since the raw C++ strings needed to be namespace-aware. I confirmed byte-for-byte no codegen changes vs. the last PR (which added namespaces to the codegen) by running `diff -qr ../pytorch-common_test/torch/csrc/autograd/generated/ ../pytorch-callgrind_test_after2/torch/csrc/autograd/generated/` and `diff -qr ../pytorch-common_test/build/aten/src/ATen/ ../pytorch-callgrind_test_after2/build/aten/src/ATen/` Note that a better end-state for the autograd codegen would be to do all of its type pattern matching directly off of JIT types, instead of off of CType’s (which are really just generated from JIT types, incorporating C++ specific semantics). That looks like it’ll require a pretty substantial change though, so I’m not doing it in this PR. As part of this change (and after talking with ezyang), I split off the `CType` data class into a separate `NamedCType` class, which holds a name and a `CType`. This way, `CType` only knows about actual C++ types, making it easier to compare CType’s to each other in the codegen when we only care about the type. The core change is in `types.py`, but it required a bunch of downstream changes to update all of the places where we create `CType`s to create `NamedCType`s instead. The main change in the autograd codegen was that I updated `SavedAttribute` to store a `NamedCType`. The other autograd changes all pretty much came from that change. Test Plan: Imported from OSS Reviewed By: bhosmer Differential Revision: D27708347 Pulled By: bdhirsh fbshipit-source-id: 3e07c80569c7b229c638f389e76e319bff6315f9
191 lines
8.4 KiB
Python
191 lines
8.4 KiB
Python
from typing import Dict, Sequence, List, NoReturn, Union
|
|
from tools.codegen.api.types import (BaseCType, Binding, ConstRefCType,
|
|
Expr, MutRefCType, OptionalCType,
|
|
NamedCType, SpecialArgName, tensorT,
|
|
memoryFormatT, tensorOptionsT, scalarTypeT,
|
|
boolT, deviceT, layoutT)
|
|
|
|
# This file implements a small program synthesis engine that implements
|
|
# conversions between one API to another.
|
|
#
|
|
# The key data type in this file in NamedCType, short for Named C++ semantic type. A NamedCType
|
|
# represents a C++ type, plus semantic information about what it represents.
|
|
# For example, consider the argument "bool pin_memory"; its normal C++ type is
|
|
# "bool", but its C++ semantic type also keeps track that this represents a
|
|
# "pin_memory"; you can't just use a random other boolean in a context where you
|
|
# need a "pin_memory"!
|
|
#
|
|
# The translator takes a list of needed NamedCTypes, and then figures out how
|
|
# to construct expressions with these NamedCTypes from the given bindings. Many
|
|
# of these expressions are trivial (I need a Tensor other; there's a Tensor
|
|
# other scope); others are more nontrivial and may require packing/unpacking.
|
|
# Some examples of non-trivial action:
|
|
#
|
|
# - Need the "dtype" binding? Well, maybe "dtype" isn't available
|
|
# in the context, instead, "options" is, and you need to extract
|
|
# it from there. (Gather)
|
|
#
|
|
# - Need the "context" binding? Well, maybe "context" isn't available
|
|
# in the context, and you need to construct it from "dtype", "device",
|
|
# etc. (Scatter)
|
|
#
|
|
# - Need the "memory_format" binding? Well, actually, it's available
|
|
# from both "memory_format" and "options", so you had better make sure
|
|
# they are consistent. (Join)
|
|
|
|
options_ctype = NamedCType("options", ConstRefCType(BaseCType(tensorOptionsT)))
|
|
|
|
class UnsatError(RuntimeError):
|
|
pass
|
|
|
|
# Given a set of in-scope bindings and a set of target bindings, synthesize
|
|
# a list of expressions that uses only the in-scope bindings (bindings) that
|
|
# have all of the types of goals. You may want to use this function if
|
|
# you're generating code for a function like:
|
|
#
|
|
# void f({args}) {
|
|
# g({exprs}); // g is a different API
|
|
# }
|
|
#
|
|
# and you need to generate "exprs".
|
|
#
|
|
# Typically, a list of Bindings is convenient to get (you usually call something
|
|
# like arguments() to get them); but technically you only need less information:
|
|
# for 'bindings' an (un-ordered) list of Exprs is sufficient; similarly, for
|
|
# 'goals', an (ordered) list of NamedCType goals is sufficient. If you are doing
|
|
# something more complicated, e.g., tracking the set of bindings in a context,
|
|
# you may find using these smaller types more convenient.
|
|
def translate(
|
|
bindings: Sequence[Union[Expr, Binding]],
|
|
goals: Sequence[Union[NamedCType, Binding]],
|
|
*, method: bool = False
|
|
) -> List[Expr]:
|
|
|
|
binding_exprs: List[Expr] = []
|
|
for b in bindings:
|
|
if isinstance(b, Binding):
|
|
binding_exprs.append(Expr(
|
|
expr=b.name,
|
|
type=b.nctype,
|
|
))
|
|
else:
|
|
binding_exprs.append(b)
|
|
|
|
goal_ctypes: List[NamedCType] = []
|
|
for g in goals:
|
|
if isinstance(g, Binding):
|
|
goal_ctypes.append(g.nctype)
|
|
else:
|
|
goal_ctypes.append(g)
|
|
|
|
# Add all the bindings to the context
|
|
ctx: Dict[NamedCType, str] = {}
|
|
for b in binding_exprs:
|
|
ctx[b.type] = b.expr
|
|
|
|
# While we're at it, do some simple forward inference, looking through
|
|
# constructors.
|
|
# TODO: My kingdom for a pattern matcher
|
|
# https://www.python.org/dev/peps/pep-0634/
|
|
# TODO: This could get us in recomputation trouble if b.expr is nontrivial
|
|
t = b.type
|
|
if isinstance(t, ConstRefCType) and isinstance(t.elem, OptionalCType) and \
|
|
isinstance(t.elem.elem, BaseCType) and str(t.elem.elem.type) == 'at::Tensor':
|
|
ctx[NamedCType(t.elem.elem.name, ConstRefCType(BaseCType(tensorT)))] = \
|
|
f'({b.expr}.has_value() ? *{b.expr} : at::Tensor())'
|
|
|
|
# Add implicit bindings if the generated code is inside a Tensor method
|
|
if method:
|
|
ctx[NamedCType("self", MutRefCType(BaseCType(tensorT)))] = "const_cast<Tensor&>(*this)"
|
|
ctx[NamedCType("self", ConstRefCType(BaseCType(tensorT)))] = "const_cast<Tensor&>(*this)"
|
|
# This is better! Byte-for-byte compat
|
|
# ctx[NamedCType("self", ConstRefCType(BaseCType(tensorT)))] = "*this"
|
|
|
|
def unsat(goal: NamedCType) -> NoReturn:
|
|
ctx_desc = '\n'.join(f" {t.cpp_type()} {t.name}; // {e}" for t, e in ctx.items())
|
|
raise UnsatError(f'''
|
|
Failed to synthesize the expression "{goal.cpp_type()} {goal.name}".
|
|
When I failed, the following bindings were available in the context:
|
|
|
|
{ctx_desc}
|
|
|
|
This probably means there is a missing rule in the rules of tools.codegen.api.translate.
|
|
Check this module for more information.
|
|
''')
|
|
|
|
# A shitty backtracking search implementation. It's shitty because it
|
|
# doesn't actually do backtracing or search. In particular, if
|
|
# direct=True, we won't try to do any fancy synthesis, just trivial
|
|
# conversions (e.g., "T a" is OK for "const T& a"). So all of the
|
|
# existing rules in this function simply try to solve immediately,
|
|
# and bail if things don't work out.
|
|
def solve(goal: NamedCType, *, direct: bool) -> str:
|
|
def direct_solve(goal: NamedCType) -> str:
|
|
return solve(goal, direct=True)
|
|
|
|
if goal in ctx:
|
|
# Trivial
|
|
return ctx[goal]
|
|
|
|
# const & is satisfied with mutable &
|
|
if isinstance(goal.type, ConstRefCType):
|
|
try:
|
|
# WARNING: not strictly decreasing; be careful not
|
|
# to add a direct conversion that goes satisfies
|
|
# mutable& with const&
|
|
return solve(NamedCType(goal.name, MutRefCType(goal.type.elem)), direct=direct)
|
|
except UnsatError:
|
|
pass
|
|
|
|
# mutable & is satisfied with value
|
|
if isinstance(goal.type, MutRefCType):
|
|
try:
|
|
return solve(NamedCType(goal.name, goal.type.elem), direct=direct)
|
|
except UnsatError:
|
|
pass
|
|
|
|
if direct:
|
|
unsat(goal)
|
|
|
|
# For now, all of these rules are mutually exclusive.
|
|
if goal == NamedCType("memory_format", OptionalCType(BaseCType(memoryFormatT))):
|
|
memory_format = direct_solve(
|
|
NamedCType(SpecialArgName.possibly_redundant_memory_format, OptionalCType(BaseCType(memoryFormatT)))
|
|
)
|
|
# No need to join "memory_format" and "options" if the target API takes "options" directly.
|
|
# Otherwise it will cause the redundant memory_format error.
|
|
if options_ctype in goal_ctypes:
|
|
return memory_format
|
|
try:
|
|
options = direct_solve(options_ctype)
|
|
return f"c10::impl::check_tensor_options_and_extract_memory_format({options}, {memory_format})"
|
|
except UnsatError:
|
|
return memory_format
|
|
|
|
elif goal == NamedCType("options", BaseCType(tensorOptionsT)):
|
|
dtype = direct_solve(NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))))
|
|
pin_memory = direct_solve(NamedCType("pin_memory", OptionalCType(BaseCType(boolT))))
|
|
device = direct_solve(NamedCType("device", OptionalCType(BaseCType(deviceT))))
|
|
layout = direct_solve(NamedCType("layout", OptionalCType(BaseCType(layoutT))))
|
|
return f'TensorOptions().dtype({dtype}).layout({layout}).device({device}).pinned_memory({pin_memory})'
|
|
|
|
elif goal == NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))):
|
|
options = direct_solve(options_ctype)
|
|
return f'optTypeMetaToScalarType({options}.dtype_opt())'
|
|
|
|
elif goal == NamedCType("layout", OptionalCType(BaseCType(layoutT))):
|
|
options = direct_solve(options_ctype)
|
|
return f'{options}.layout_opt()'
|
|
|
|
elif goal == NamedCType("device", OptionalCType(BaseCType(deviceT))):
|
|
options = direct_solve(options_ctype)
|
|
return f'{options}.device_opt()'
|
|
|
|
elif goal == NamedCType("pin_memory", OptionalCType(BaseCType(boolT))):
|
|
options = direct_solve(options_ctype)
|
|
return f'{options}.pinned_memory_opt()'
|
|
|
|
unsat(goal)
|
|
|
|
return [Expr(solve(g, direct=False), g) for g in goal_ctypes]
|