mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/61443 For more information, see #55070. This PR also adds a new type, `OptionalTensorRef` as a replacement for `c10::optional<Tensor>&` in order to avoid the reference count manipulations that are inevitable with the latter. I have confirmed using Godbolt/Compiler Explorer that this class does indeed avoid manipulating the reference count of the `intrusive_ptr` inside the `Tensor` it refers to: 1. [P429709479](https://www.internalfb.com/phabricator/paste/view/P429709479) - Given a `const Tensor&` in scope, an `OptionalTensorRef` can be constructed without bumping refcount. 2. [P429709883](https://www.internalfb.com/phabricator/paste/view/P429709883) - Given an `OptionalTensorRef`, a `const Tensor&` can be produced without bumping refcount. 3. [P429710335](https://www.internalfb.com/phabricator/paste/view/P429710335) - When `OptionalTensorRef` is destructed, the refcount should not be decremented. 4. [P429769525](https://www.internalfb.com/phabricator/paste/view/P429769525) - `OptionalTensorRef` can be assigned without refcount manipulation. 5. [P429769882](https://www.internalfb.com/phabricator/paste/view/P429769882) - `OptionalTensorRef` can be move assigned without refcount manipulation. Test Plan: Imported from OSS Reviewed By: jamesr66a Differential Revision: D29780666 Pulled By: SplitInfinity fbshipit-source-id: 7af157215300e9254d635433cbd583f7329fe064
195 lines
8.7 KiB
Python
195 lines
8.7 KiB
Python
from typing import Dict, Sequence, List, NoReturn, Union
|
|
from tools.codegen.api.types import (BaseCType, Binding, ConstRefCType,
|
|
Expr, MutRefCType, OptionalCType,
|
|
NamedCType, SpecialArgName, tensorT,
|
|
memoryFormatT, tensorOptionsT, scalarTypeT,
|
|
boolT, deviceT, layoutT, optionalTensorRefT)
|
|
|
|
# This file implements a small program synthesis engine that implements
|
|
# conversions between one API to another.
|
|
#
|
|
# The key data type in this file in NamedCType, short for Named C++ semantic type. A NamedCType
|
|
# represents a C++ type, plus semantic information about what it represents.
|
|
# For example, consider the argument "bool pin_memory"; its normal C++ type is
|
|
# "bool", but its C++ semantic type also keeps track that this represents a
|
|
# "pin_memory"; you can't just use a random other boolean in a context where you
|
|
# need a "pin_memory"!
|
|
#
|
|
# The translator takes a list of needed NamedCTypes, and then figures out how
|
|
# to construct expressions with these NamedCTypes from the given bindings. Many
|
|
# of these expressions are trivial (I need a Tensor other; there's a Tensor
|
|
# other scope); others are more nontrivial and may require packing/unpacking.
|
|
# Some examples of non-trivial action:
|
|
#
|
|
# - Need the "dtype" binding? Well, maybe "dtype" isn't available
|
|
# in the context, instead, "options" is, and you need to extract
|
|
# it from there. (Gather)
|
|
#
|
|
# - Need the "context" binding? Well, maybe "context" isn't available
|
|
# in the context, and you need to construct it from "dtype", "device",
|
|
# etc. (Scatter)
|
|
#
|
|
# - Need the "memory_format" binding? Well, actually, it's available
|
|
# from both "memory_format" and "options", so you had better make sure
|
|
# they are consistent. (Join)
|
|
|
|
options_ctype = NamedCType("options", ConstRefCType(BaseCType(tensorOptionsT)))
|
|
|
|
class UnsatError(RuntimeError):
|
|
pass
|
|
|
|
# Given a set of in-scope bindings and a set of target bindings, synthesize
|
|
# a list of expressions that uses only the in-scope bindings (bindings) that
|
|
# have all of the types of goals. You may want to use this function if
|
|
# you're generating code for a function like:
|
|
#
|
|
# void f({args}) {
|
|
# g({exprs}); // g is a different API
|
|
# }
|
|
#
|
|
# and you need to generate "exprs".
|
|
#
|
|
# Typically, a list of Bindings is convenient to get (you usually call something
|
|
# like arguments() to get them); but technically you only need less information:
|
|
# for 'bindings' an (un-ordered) list of Exprs is sufficient; similarly, for
|
|
# 'goals', an (ordered) list of NamedCType goals is sufficient. If you are doing
|
|
# something more complicated, e.g., tracking the set of bindings in a context,
|
|
# you may find using these smaller types more convenient.
|
|
def translate(
|
|
bindings: Sequence[Union[Expr, Binding]],
|
|
goals: Sequence[Union[NamedCType, Binding]],
|
|
*, method: bool = False
|
|
) -> List[Expr]:
|
|
|
|
binding_exprs: List[Expr] = []
|
|
for b in bindings:
|
|
if isinstance(b, Binding):
|
|
binding_exprs.append(Expr(
|
|
expr=b.name,
|
|
type=b.nctype,
|
|
))
|
|
else:
|
|
binding_exprs.append(b)
|
|
|
|
goal_ctypes: List[NamedCType] = []
|
|
for g in goals:
|
|
if isinstance(g, Binding):
|
|
goal_ctypes.append(g.nctype)
|
|
else:
|
|
goal_ctypes.append(g)
|
|
|
|
# Add all the bindings to the context
|
|
ctx: Dict[NamedCType, str] = {}
|
|
for b in binding_exprs:
|
|
ctx[b.type] = b.expr
|
|
|
|
# While we're at it, do some simple forward inference, looking through
|
|
# constructors.
|
|
# TODO: My kingdom for a pattern matcher
|
|
# https://www.python.org/dev/peps/pep-0634/
|
|
# TODO: This could get us in recomputation trouble if b.expr is nontrivial
|
|
t = b.type
|
|
if isinstance(t, ConstRefCType) and isinstance(t.elem, OptionalCType) and \
|
|
isinstance(t.elem.elem, BaseCType) and str(t.elem.elem.type) == 'at::Tensor':
|
|
ctx[NamedCType(t.elem.elem.name, ConstRefCType(BaseCType(tensorT)))] = \
|
|
f'({b.expr}.has_value() ? *{b.expr} : at::Tensor())'
|
|
|
|
if t.type == ConstRefCType(OptionalCType(BaseCType(tensorT))):
|
|
ctx[NamedCType(t.name, BaseCType(optionalTensorRefT))] = \
|
|
f'(({b.expr}.has_value() && (*{b.expr}).defined()) ? at::OptionalTensorRef(*{b.expr}) : at::OptionalTensorRef())'
|
|
|
|
# Add implicit bindings if the generated code is inside a Tensor method
|
|
if method:
|
|
ctx[NamedCType("self", MutRefCType(BaseCType(tensorT)))] = "const_cast<Tensor&>(*this)"
|
|
ctx[NamedCType("self", ConstRefCType(BaseCType(tensorT)))] = "const_cast<Tensor&>(*this)"
|
|
# This is better! Byte-for-byte compat
|
|
# ctx[NamedCType("self", ConstRefCType(BaseCType(tensorT)))] = "*this"
|
|
|
|
def unsat(goal: NamedCType) -> NoReturn:
|
|
ctx_desc = '\n'.join(f" {t.cpp_type()} {t.name}; // {e}" for t, e in ctx.items())
|
|
raise UnsatError(f'''
|
|
Failed to synthesize the expression "{goal.cpp_type()} {goal.name}".
|
|
When I failed, the following bindings were available in the context:
|
|
|
|
{ctx_desc}
|
|
|
|
This probably means there is a missing rule in the rules of tools.codegen.api.translate.
|
|
Check this module for more information.
|
|
''')
|
|
|
|
# A shitty backtracking search implementation. It's shitty because it
|
|
# doesn't actually do backtracing or search. In particular, if
|
|
# direct=True, we won't try to do any fancy synthesis, just trivial
|
|
# conversions (e.g., "T a" is OK for "const T& a"). So all of the
|
|
# existing rules in this function simply try to solve immediately,
|
|
# and bail if things don't work out.
|
|
def solve(goal: NamedCType, *, direct: bool) -> str:
|
|
def direct_solve(goal: NamedCType) -> str:
|
|
return solve(goal, direct=True)
|
|
|
|
if goal in ctx:
|
|
# Trivial
|
|
return ctx[goal]
|
|
|
|
# const & is satisfied with mutable &
|
|
if isinstance(goal.type, ConstRefCType):
|
|
try:
|
|
# WARNING: not strictly decreasing; be careful not
|
|
# to add a direct conversion that goes satisfies
|
|
# mutable& with const&
|
|
return solve(NamedCType(goal.name, MutRefCType(goal.type.elem)), direct=direct)
|
|
except UnsatError:
|
|
pass
|
|
|
|
# mutable & is satisfied with value
|
|
if isinstance(goal.type, MutRefCType):
|
|
try:
|
|
return solve(NamedCType(goal.name, goal.type.elem), direct=direct)
|
|
except UnsatError:
|
|
pass
|
|
|
|
if direct:
|
|
unsat(goal)
|
|
|
|
# For now, all of these rules are mutually exclusive.
|
|
if goal == NamedCType("memory_format", OptionalCType(BaseCType(memoryFormatT))):
|
|
memory_format = direct_solve(
|
|
NamedCType(SpecialArgName.possibly_redundant_memory_format, OptionalCType(BaseCType(memoryFormatT)))
|
|
)
|
|
# No need to join "memory_format" and "options" if the target API takes "options" directly.
|
|
# Otherwise it will cause the redundant memory_format error.
|
|
if options_ctype in goal_ctypes:
|
|
return memory_format
|
|
try:
|
|
options = direct_solve(options_ctype)
|
|
return f"c10::impl::check_tensor_options_and_extract_memory_format({options}, {memory_format})"
|
|
except UnsatError:
|
|
return memory_format
|
|
|
|
elif goal == NamedCType("options", BaseCType(tensorOptionsT)):
|
|
dtype = direct_solve(NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))))
|
|
pin_memory = direct_solve(NamedCType("pin_memory", OptionalCType(BaseCType(boolT))))
|
|
device = direct_solve(NamedCType("device", OptionalCType(BaseCType(deviceT))))
|
|
layout = direct_solve(NamedCType("layout", OptionalCType(BaseCType(layoutT))))
|
|
return f'TensorOptions().dtype({dtype}).layout({layout}).device({device}).pinned_memory({pin_memory})'
|
|
|
|
elif goal == NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))):
|
|
options = direct_solve(options_ctype)
|
|
return f'optTypeMetaToScalarType({options}.dtype_opt())'
|
|
|
|
elif goal == NamedCType("layout", OptionalCType(BaseCType(layoutT))):
|
|
options = direct_solve(options_ctype)
|
|
return f'{options}.layout_opt()'
|
|
|
|
elif goal == NamedCType("device", OptionalCType(BaseCType(deviceT))):
|
|
options = direct_solve(options_ctype)
|
|
return f'{options}.device_opt()'
|
|
|
|
elif goal == NamedCType("pin_memory", OptionalCType(BaseCType(boolT))):
|
|
options = direct_solve(options_ctype)
|
|
return f'{options}.pinned_memory_opt()'
|
|
|
|
unsat(goal)
|
|
|
|
return [Expr(solve(g, direct=False), g) for g in goal_ctypes]
|