[torchgen] Refactor types (#90589)

A retry of #89487. Accidentally closed.

## Split `torchgen.api.types` into `types_base`, `types` and `signatures`.

In `types_base`:
* Created base class `CType`. `BaseCType` and `ConstRefCType` etc are inheriting `CType`.
* Only keep abstract type model definitions, such as `BaseCppType`.

In `types`:
* Define `BaseCppType` with `at` and `c10` namespaces.
* All the signatures using these types.

In `signatures`:
* Define all the signatures.

In `__init__`:
* `from ... import *`, suppress flake8 error.

Differential Revision: [D41455634](https://our.internmc.facebook.com/intern/diff/D41455634/)

**NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D41455634/)!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90589
Approved by: https://github.com/iseeyuan
This commit is contained in:
Larry Liu 2022-12-09 16:33:01 -08:00 committed by PyTorch MergeBot
parent 0457020d2c
commit 453ff96029
5 changed files with 459 additions and 403 deletions

View File

@ -12,7 +12,7 @@ ignore =
B007,B008,
# these ignores are from flake8-comprehensions; please fix!
C400,C401,C402,C403,C404,C405,C407,C411,C413,C414,C415
per-file-ignores = __init__.py: F401 torch/utils/cpp_extension.py: B950
per-file-ignores = __init__.py: F401 torch/utils/cpp_extension.py: B950 torchgen/api/types/__init__.py: F401,F403
optional-ascii-coding = True
exclude =
./.git,

View File

@ -0,0 +1,3 @@
from .types import *
from .types_base import *
from .signatures import * # isort:skip

View File

@ -1,419 +1,27 @@
from dataclasses import dataclass
from enum import Enum
from typing import Dict, Iterator, List, Optional, Sequence, Set, Tuple, TypeVar, Union
from typing import Iterator, List, Optional, Sequence, Set, Tuple, Union
from torchgen.model import (
Argument,
BackendIndex,
BaseTy,
FunctionSchema,
NativeFunction,
NativeFunctionsGroup,
NativeFunctionsViewGroup,
ScalarType,
SelfArgument,
TensorOptionsArguments,
)
_T = TypeVar("_T")
from .types_base import Binding, CType, Expr
TENSOR_LIST_LIKE_CTYPES = [
"at::TensorList",
"const c10::List<c10::optional<at::Tensor>> &",
"const at::ITensorListRef &",
]
# An ArgName is just the str name of the argument in schema;
# but in some special circumstances, we may add a little extra
# context. The Enum SpecialArgName covers all of these cases;
# grep for their construction sites to see when they can occr.
SpecialArgName = Enum("SpecialArgName", ("possibly_redundant_memory_format",))
ArgName = Union[str, SpecialArgName]
# This class shouldn't be created directly; instead, use/create one of the singletons below.
@dataclass(frozen=True)
class BaseCppType:
ns: Optional[str]
name: str
def __str__(self) -> str:
if self.ns is None or self.ns == "":
return self.name
return f"{self.ns}::{self.name}"
# The set of all non-templated, valid, fully-qualified names of C++ types that are used in the codegen.
# Templated types get their own dataclass, mainly to make namespace parsing easier.
byteT = BaseCppType("", "uint8_t")
charT = BaseCppType("", "int8_t")
shortT = BaseCppType("", "int16_t")
# It would be more symmetric for this to be called intT, but it easy to mix
# this up with JIT int (which is int64_t in C++), so we intentionally don't
# define intT to make it obvious when you've stuffed it up
int32T = BaseCppType("", "int32_t")
longT = BaseCppType("", "int64_t")
halfT = BaseCppType("at", "Half")
doubleT = BaseCppType("", "double")
floatT = BaseCppType("", "float")
complexHalfT = BaseCppType(
"c10", "complex<c10::Half>"
) # stuffing template param here is an abuse
complexFloatT = BaseCppType("c10", "complex<float>")
complexDoubleT = BaseCppType("c10", "complex<double>")
boolT = BaseCppType("", "bool")
bfloat16T = BaseCppType("at", "BFloat16")
voidT = BaseCppType("", "void")
stringT = BaseCppType("c10", "string_view")
generatorT = BaseCppType("at", "Generator")
scalarTypeT = BaseCppType("at", "ScalarType")
tensorT = BaseCppType("at", "Tensor")
optionalTensorRefT = BaseCppType("at", "OptionalTensorRef")
tensorListT = BaseCppType("at", "TensorList")
iTensorListRefT = BaseCppType("at", "ITensorListRef")
iOptTensorListRefT = BaseCppType("at", "IOptTensorListRef")
dimnameT = BaseCppType("at", "Dimname")
dimnameListT = BaseCppType("at", "DimnameList")
dimVectorT = BaseCppType("at", "DimVector")
layoutT = BaseCppType("at", "Layout")
deviceT = BaseCppType("at", "Device")
scalarT = BaseCppType("at", "Scalar")
optionalScalarRefT = BaseCppType("at", "OptionalScalarRef")
memoryFormatT = BaseCppType("at", "MemoryFormat")
qschemeT = BaseCppType("at", "QScheme")
storageT = BaseCppType("at", "Storage")
streamT = BaseCppType("at", "Stream")
intArrayRefT = BaseCppType("at", "IntArrayRef")
optionalIntArrayRefT = BaseCppType("at", "OptionalIntArrayRef")
optionalSymIntArrayRefT = BaseCppType("at", "OptionalSymIntArrayRef")
tensorOptionsT = BaseCppType("at", "TensorOptions")
typeAndSizeT = BaseCppType("torch::autograd::generated", "TypeAndSize")
tensorGeometryT = BaseCppType("at", "TensorGeometry")
SymIntT = BaseCppType("c10", "SymInt")
symIntArrayRefT = BaseCppType("c10", "SymIntArrayRef")
# Types representing template parameters. Technically, we probably shouldn't
# represent them this way in codegen, but it was pretty convenient.
scalar_t = BaseCppType("", "scalar_t")
opmath_t = BaseCppType("", "opmath_t")
ScalarTypeToCppMapping: Dict[ScalarType, BaseCppType] = {
ScalarType.Byte: byteT,
ScalarType.Char: charT,
ScalarType.Short: shortT,
ScalarType.Int: int32T,
ScalarType.Long: longT,
ScalarType.Half: halfT,
ScalarType.Float: floatT,
ScalarType.Double: doubleT,
ScalarType.ComplexHalf: complexHalfT,
ScalarType.ComplexFloat: complexFloatT,
ScalarType.ComplexDouble: complexDoubleT,
ScalarType.Bool: boolT,
ScalarType.BFloat16: bfloat16T,
}
BaseTypeToCppMapping: Dict[BaseTy, BaseCppType] = {
BaseTy.int: longT,
BaseTy.float: doubleT,
BaseTy.bool: boolT,
BaseTy.str: stringT,
BaseTy.Generator: generatorT,
BaseTy.ScalarType: scalarTypeT,
BaseTy.Tensor: tensorT,
BaseTy.Dimname: dimnameT,
BaseTy.DimVector: dimVectorT,
BaseTy.Layout: layoutT,
BaseTy.Device: deviceT,
BaseTy.Scalar: scalarT,
BaseTy.MemoryFormat: memoryFormatT,
BaseTy.QScheme: qschemeT,
BaseTy.Storage: storageT,
BaseTy.Stream: streamT,
BaseTy.SymInt: SymIntT,
}
# CTypes encode C++ type structure as needed for translation.
@dataclass(frozen=True)
class BaseCType:
type: BaseCppType
def cpp_type(self, *, strip_ref: bool = False) -> str:
return str(self.type)
# For BC reasons, we don't want to introduce at:: namespaces to RegistrationDeclarations.yaml
# TODO: Kill this when we eventually remove it!
def cpp_type_registration_declarations(self) -> str:
return str(self.type).replace("at::", "")
def remove_const_ref(self) -> "CType":
return self
@dataclass(frozen=True)
class ConstRefCType:
elem: "CType"
def cpp_type(self, *, strip_ref: bool = False) -> str:
if strip_ref:
return self.elem.cpp_type(strip_ref=strip_ref)
return f"const {self.elem.cpp_type()} &"
def cpp_type_registration_declarations(self) -> str:
return f"const {self.elem.cpp_type_registration_declarations()} &"
def remove_const_ref(self) -> "CType":
return self.elem.remove_const_ref()
@dataclass(frozen=True)
class MutRefCType:
elem: "CType"
def cpp_type(self, *, strip_ref: bool = False) -> str:
if strip_ref:
return self.elem.cpp_type(strip_ref=strip_ref)
return f"{self.elem.cpp_type()} &"
def cpp_type_registration_declarations(self) -> str:
return f"{self.elem.cpp_type_registration_declarations()} &"
def remove_const_ref(self) -> "CType":
return self.elem.remove_const_ref()
@dataclass(frozen=True)
class OptionalCType:
elem: "CType"
def cpp_type(self, *, strip_ref: bool = False) -> str:
# Do not pass `strip_ref` recursively.
return f"c10::optional<{self.elem.cpp_type()}>"
def cpp_type_registration_declarations(self) -> str:
return f"c10::optional<{self.elem.cpp_type_registration_declarations()}>"
def remove_const_ref(self) -> "CType":
return OptionalCType(self.elem.remove_const_ref())
@dataclass(frozen=True)
class ListCType:
elem: "CType"
def cpp_type(self, *, strip_ref: bool = False) -> str:
# Do not pass `strip_ref` recursively.
return f"c10::List<{self.elem.cpp_type()}>"
def cpp_type_registration_declarations(self) -> str:
return f"c10::List<{self.elem.cpp_type_registration_declarations()}>"
def remove_const_ref(self) -> "CType":
return ListCType(self.elem.remove_const_ref())
@dataclass(frozen=True)
class ArrayRefCType:
elem: "CType"
def cpp_type(self, *, strip_ref: bool = False) -> str:
# Do not pass `strip_ref` recursively.
return f"at::ArrayRef<{self.elem.cpp_type()}>"
def cpp_type_registration_declarations(self) -> str:
return f"ArrayRef<{self.elem.cpp_type_registration_declarations()}>"
def remove_const_ref(self) -> "CType":
return ArrayRefCType(self.elem.remove_const_ref())
@dataclass(frozen=True)
class VectorCType:
elem: "CType"
def cpp_type(self, *, strip_ref: bool = False) -> str:
# Do not pass `strip_ref` recursively.
return f"::std::vector<{self.elem.cpp_type()}>"
def cpp_type_registration_declarations(self) -> str:
return f"::std::vector<{self.elem.cpp_type_registration_declarations()}>"
def remove_const_ref(self) -> "CType":
return VectorCType(self.elem.remove_const_ref())
@dataclass(frozen=True)
class ArrayCType:
elem: "CType"
size: int
def cpp_type(self, *, strip_ref: bool = False) -> str:
# Do not pass `strip_ref` recursively.
return f"::std::array<{self.elem.cpp_type()},{self.size}>"
def cpp_type_registration_declarations(self) -> str:
return f"::std::array<{self.elem.cpp_type_registration_declarations()},{self.size}>"
def remove_const_ref(self) -> "CType":
return ArrayCType(self.elem.remove_const_ref(), self.size)
@dataclass(frozen=True)
class TupleCType:
elems: List["CType"]
def cpp_type(self, *, strip_ref: bool = False) -> str:
# Do not pass `strip_ref` recursively.
return f'::std::tuple<{",".join([e.cpp_type() for e in self.elems])}>'
def cpp_type_registration_declarations(self) -> str:
return f'::std::tuple<{",".join([e.cpp_type_registration_declarations() for e in self.elems])}>'
def remove_const_ref(self) -> "CType":
return TupleCType([e.remove_const_ref() for e in self.elems])
@dataclass(frozen=True)
class VectorizedCType:
# This template is explicitly specialized, so the only valid
# elems are those we have specializations for (e.g., float, double, ...)
# scalar_t is also a common argument here (when we are codegen in
# a templated context)
elem: BaseCType
def cpp_type(self, *, strip_ref: bool = False) -> str:
return f"at::vec::Vectorized<{self.elem.cpp_type()}>"
def cpp_type_registration_declarations(self) -> str:
raise NotImplementedError
def remove_const_ref(self) -> "CType":
return self
CType = Union[
BaseCType,
OptionalCType,
ConstRefCType,
MutRefCType,
ListCType,
ArrayRefCType,
ArrayCType,
VectorCType,
TupleCType,
VectorizedCType,
]
# A NamedCType is short for Named C++ semantic type. A NamedCType represents a C++ type, plus
# semantic information about what it represents. For example, consider the
# argument "bool pin_memory"; its normal C++ type is "bool", but its C++
# semantic type also keeps track that this represents a "pin_memory"; you can't
# just use a random other boolean in a context where you need a "pin_memory"!
#
@dataclass(frozen=True)
class NamedCType:
name: ArgName
type: CType
def cpp_type(self, *, strip_ref: bool = False) -> str:
return self.type.cpp_type(strip_ref=strip_ref)
# For BC reasons, we don't want to introduce at:: namespaces to RegistrationDeclarations.yaml
# TODO: Kill this when we eventually remove it!
def cpp_type_registration_declarations(self) -> str:
return self.type.cpp_type_registration_declarations()
def remove_const_ref(self) -> "NamedCType":
return NamedCType(self.name, self.type.remove_const_ref())
def with_name(self, name: str) -> "NamedCType":
return NamedCType(name, self.type)
# A binding represents any C++ binding site for a formal parameter.
# We don't distinguish between binding sites for different APIs;
# instead, all of the important distinctions are encoded in CType,
# which you can use to figure out if a given Binding is appropriate
# for use in another context. (See torchgen.api.translate)
@dataclass(frozen=True)
class Binding:
name: str
nctype: NamedCType
argument: Union[Argument, TensorOptionsArguments, SelfArgument]
# TODO: maybe don't represent default here
default: Optional[str] = None
def rename(self, name: str) -> "Binding":
return Binding(
name=name,
nctype=self.nctype,
argument=self.argument,
default=self.default,
)
@property
def type(self) -> str:
return self.nctype.cpp_type()
def no_default(self) -> "Binding":
return Binding(
name=self.name,
nctype=self.nctype,
default=None,
argument=self.argument,
)
def decl(self, *, func_ptr_cast: bool = False) -> str:
mb_default = ""
if self.default is not None:
mb_default = f"={self.default}"
# casting only needs to know the type
if func_ptr_cast:
return f"{self.type}"
else:
return f"{self.type} {self.name}{mb_default}"
# For BC reasons, we don't want to introduce at:: namespaces to RegistrationDeclarations.yaml
# TODO: Kill this when we eventually remove it!
def decl_registration_declarations(self) -> str:
type_s = self.nctype.cpp_type_registration_declarations()
mb_default = ""
if self.default is not None:
mb_default = f"={self.default}"
return f"{type_s} {self.name}{mb_default}"
def defn(self) -> str:
return f"{self.type} {self.name}"
def with_name(self, name: str) -> "Binding":
return Binding(
name=name, nctype=self.nctype, argument=self.argument, default=self.default
)
# An Expr is a C++ expression. It has a C++ string representing its syntax,
# as well as a CType saying what it provides.
@dataclass(frozen=True)
class Expr:
expr: str
type: NamedCType
# A CppSignature represents a single overload in the C++ API. For
# any given function schema, there may be multiple CppSignatures
# corresponding to it, based on how we desugar to C++. See also
# CppSignatureGroup.
@dataclass(frozen=True)
class CppSignature:
"""
A CppSignature represents a single overload in the C++ API. For
any given function schema, there may be multiple CppSignatures
corresponding to it, based on how we desugar to C++. See also
CppSignatureGroup.
"""
# The schema this signature is derived from
func: FunctionSchema

182
torchgen/api/types/types.py Normal file
View File

@ -0,0 +1,182 @@
"""
Where should I add a new type? `types_base.py` vs `types.py`
This file defines data model classes for torchgen typing system, as well as some base types such as int32_t.
`types.py` defines ATen Tensor type and some c10 types, along with signatures that use these types.
The difference between these two files, is `types_base.py` should be implementation-agnostic, meaning it shouldn't
contain any type definition that is tight to a specific C++ library (e.g., ATen), so that it can be easily reused
if we want to generate code for another C++ library.
Add new types to `types.py` if these types are ATen/c10 related.
Add new types to `types_base.py` if they are basic and not attached to ATen/c10.
"""
from dataclasses import dataclass
from typing import Dict, TypeVar
from torchgen.model import BaseTy, ScalarType
from .types_base import (
BaseCppType,
BaseCType,
boolT,
byteT,
charT,
CType,
doubleT,
floatT,
int32T,
longT,
shortT,
)
_T = TypeVar("_T")
TENSOR_LIST_LIKE_CTYPES = [
"at::TensorList",
"const c10::List<c10::optional<at::Tensor>> &",
"const at::ITensorListRef &",
]
halfT = BaseCppType("at", "Half")
complexHalfT = BaseCppType(
"c10", "complex<c10::Half>"
) # stuffing template param here is an abuse
complexFloatT = BaseCppType("c10", "complex<float>")
complexDoubleT = BaseCppType("c10", "complex<double>")
bfloat16T = BaseCppType("at", "BFloat16")
stringT = BaseCppType("c10", "string_view")
generatorT = BaseCppType("at", "Generator")
scalarTypeT = BaseCppType("at", "ScalarType")
tensorT = BaseCppType("at", "Tensor")
optionalTensorRefT = BaseCppType("at", "OptionalTensorRef")
tensorListT = BaseCppType("at", "TensorList")
iTensorListRefT = BaseCppType("at", "ITensorListRef")
iOptTensorListRefT = BaseCppType("at", "IOptTensorListRef")
dimnameT = BaseCppType("at", "Dimname")
dimnameListT = BaseCppType("at", "DimnameList")
dimVectorT = BaseCppType("at", "DimVector")
layoutT = BaseCppType("at", "Layout")
deviceT = BaseCppType("at", "Device")
scalarT = BaseCppType("at", "Scalar")
optionalScalarRefT = BaseCppType("at", "OptionalScalarRef")
memoryFormatT = BaseCppType("at", "MemoryFormat")
qschemeT = BaseCppType("at", "QScheme")
storageT = BaseCppType("at", "Storage")
streamT = BaseCppType("at", "Stream")
intArrayRefT = BaseCppType("at", "IntArrayRef")
optionalIntArrayRefT = BaseCppType("at", "OptionalIntArrayRef")
optionalSymIntArrayRefT = BaseCppType("at", "OptionalSymIntArrayRef")
tensorOptionsT = BaseCppType("at", "TensorOptions")
typeAndSizeT = BaseCppType("torch::autograd::generated", "TypeAndSize")
tensorGeometryT = BaseCppType("at", "TensorGeometry")
SymIntT = BaseCppType("c10", "SymInt")
symIntArrayRefT = BaseCppType("c10", "SymIntArrayRef")
# Types representing template parameters. Technically, we probably shouldn't
# represent them this way in codegen, but it was pretty convenient.
scalar_t = BaseCppType("", "scalar_t")
opmath_t = BaseCppType("", "opmath_t")
ScalarTypeToCppMapping: Dict[ScalarType, BaseCppType] = {
ScalarType.Byte: byteT,
ScalarType.Char: charT,
ScalarType.Short: shortT,
ScalarType.Int: int32T,
ScalarType.Long: longT,
ScalarType.Half: halfT,
ScalarType.Float: floatT,
ScalarType.Double: doubleT,
ScalarType.ComplexHalf: complexHalfT,
ScalarType.ComplexFloat: complexFloatT,
ScalarType.ComplexDouble: complexDoubleT,
ScalarType.Bool: boolT,
ScalarType.BFloat16: bfloat16T,
}
BaseTypeToCppMapping: Dict[BaseTy, BaseCppType] = {
BaseTy.int: longT,
BaseTy.float: doubleT,
BaseTy.bool: boolT,
BaseTy.str: stringT,
BaseTy.Generator: generatorT,
BaseTy.ScalarType: scalarTypeT,
BaseTy.Tensor: tensorT,
BaseTy.Dimname: dimnameT,
BaseTy.DimVector: dimVectorT,
BaseTy.Layout: layoutT,
BaseTy.Device: deviceT,
BaseTy.Scalar: scalarT,
BaseTy.MemoryFormat: memoryFormatT,
BaseTy.QScheme: qschemeT,
BaseTy.Storage: storageT,
BaseTy.Stream: streamT,
BaseTy.SymInt: SymIntT,
}
# CTypes encode C++ type structure as needed for translation.
@dataclass(frozen=True)
class OptionalCType(CType):
elem: "CType"
def cpp_type(self, *, strip_ref: bool = False) -> str:
# Do not pass `strip_ref` recursively.
return f"c10::optional<{self.elem.cpp_type()}>"
def cpp_type_registration_declarations(self) -> str:
return f"c10::optional<{self.elem.cpp_type_registration_declarations()}>"
def remove_const_ref(self) -> "CType":
return OptionalCType(self.elem.remove_const_ref())
@dataclass(frozen=True)
class ListCType(CType):
elem: "CType"
def cpp_type(self, *, strip_ref: bool = False) -> str:
# Do not pass `strip_ref` recursively.
return f"c10::List<{self.elem.cpp_type()}>"
def cpp_type_registration_declarations(self) -> str:
return f"c10::List<{self.elem.cpp_type_registration_declarations()}>"
def remove_const_ref(self) -> "CType":
return ListCType(self.elem.remove_const_ref())
@dataclass(frozen=True)
class ArrayRefCType(CType):
elem: "CType"
def cpp_type(self, *, strip_ref: bool = False) -> str:
# Do not pass `strip_ref` recursively.
return f"at::ArrayRef<{self.elem.cpp_type()}>"
def cpp_type_registration_declarations(self) -> str:
return f"ArrayRef<{self.elem.cpp_type_registration_declarations()}>"
def remove_const_ref(self) -> "CType":
return ArrayRefCType(self.elem.remove_const_ref())
@dataclass(frozen=True)
class VectorizedCType(CType):
# This template is explicitly specialized, so the only valid
# elems are those we have specializations for (e.g., float, double, ...)
# scalar_t is also a common argument here (when we are codegen in
# a templated context)
elem: BaseCType
def cpp_type(self, *, strip_ref: bool = False) -> str:
return f"at::vec::Vectorized<{self.elem.cpp_type()}>"
def cpp_type_registration_declarations(self) -> str:
raise NotImplementedError
def remove_const_ref(self) -> "CType":
return self

View File

@ -0,0 +1,263 @@
"""
Where should I add a new type? `types_base.py` vs `types.py`
This file defines data model classes for torchgen typing system, as well as some base types such as int32_t.
`types.py` defines ATen Tensor type and some c10 types, along with signatures that use these types.
The difference between these two files, is `types_base.py` should be implementation-agnostic, meaning it shouldn't
contain any type definition that is tight to a specific C++ library (e.g., ATen), so that it can be easily reused
if we want to generate code for another C++ library.
Add new types to `types.py` if these types are ATen/c10 related.
Add new types to `types_base.py` if they are basic and not attached to ATen/c10.
"""
from abc import ABC
from dataclasses import dataclass
from enum import Enum
from typing import List, Optional, Union
from torchgen.model import Argument, SelfArgument, TensorOptionsArguments
# An ArgName is just the str name of the argument in schema;
# but in some special circumstances, we may add a little extra
# context. The Enum SpecialArgName covers all of these cases;
# grep for their construction sites to see when they can occr.
SpecialArgName = Enum("SpecialArgName", ("possibly_redundant_memory_format",))
ArgName = Union[str, SpecialArgName]
# This class shouldn't be created directly; instead, use/create one of the singletons below.
@dataclass(frozen=True)
class BaseCppType:
ns: Optional[str]
name: str
def __str__(self) -> str:
if self.ns is None or self.ns == "":
return self.name
return f"{self.ns}::{self.name}"
# The set of all non-templated, valid, fully-qualified names of C++ types that are used in the codegen.
# Templated types get their own dataclass, mainly to make namespace parsing easier.
byteT = BaseCppType("", "uint8_t")
charT = BaseCppType("", "int8_t")
shortT = BaseCppType("", "int16_t")
# It would be more symmetric for this to be called intT, but it easy to mix
# this up with JIT int (which is int64_t in C++), so we intentionally don't
# define intT to make it obvious when you've stuffed it up
int32T = BaseCppType("", "int32_t")
longT = BaseCppType("", "int64_t")
doubleT = BaseCppType("", "double")
floatT = BaseCppType("", "float")
boolT = BaseCppType("", "bool")
voidT = BaseCppType("", "void")
class CType(ABC):
def cpp_type(self, *, strip_ref: bool = False) -> str:
raise NotImplementedError
def cpp_type_registration_declarations(self) -> str:
raise NotImplementedError
def remove_const_ref(self) -> "CType":
return self
@dataclass(frozen=True)
class BaseCType(CType):
type: BaseCppType
def cpp_type(self, *, strip_ref: bool = False) -> str:
return str(self.type)
# For BC reasons, we don't want to introduce at:: namespaces to RegistrationDeclarations.yaml
# TODO: Kill this when we eventually remove it!
def cpp_type_registration_declarations(self) -> str:
return str(self.type).replace("at::", "")
def remove_const_ref(self) -> "CType":
return self
@dataclass(frozen=True)
class ConstRefCType(CType):
elem: "CType"
def cpp_type(self, *, strip_ref: bool = False) -> str:
if strip_ref:
return self.elem.cpp_type(strip_ref=strip_ref)
return f"const {self.elem.cpp_type()} &"
def cpp_type_registration_declarations(self) -> str:
return f"const {self.elem.cpp_type_registration_declarations()} &"
def remove_const_ref(self) -> "CType":
return self.elem.remove_const_ref()
@dataclass(frozen=True)
class VectorCType(CType):
elem: "CType"
def cpp_type(self, *, strip_ref: bool = False) -> str:
# Do not pass `strip_ref` recursively.
return f"::std::vector<{self.elem.cpp_type()}>"
def cpp_type_registration_declarations(self) -> str:
return f"::std::vector<{self.elem.cpp_type_registration_declarations()}>"
def remove_const_ref(self) -> "CType":
return VectorCType(self.elem.remove_const_ref())
@dataclass(frozen=True)
class ArrayCType(CType):
elem: "CType"
size: int
def cpp_type(self, *, strip_ref: bool = False) -> str:
# Do not pass `strip_ref` recursively.
return f"::std::array<{self.elem.cpp_type()},{self.size}>"
def cpp_type_registration_declarations(self) -> str:
return f"::std::array<{self.elem.cpp_type_registration_declarations()},{self.size}>"
def remove_const_ref(self) -> "CType":
return ArrayCType(self.elem.remove_const_ref(), self.size)
@dataclass(frozen=True)
class TupleCType(CType):
elems: List["CType"]
def cpp_type(self, *, strip_ref: bool = False) -> str:
# Do not pass `strip_ref` recursively.
return f'::std::tuple<{",".join([e.cpp_type() for e in self.elems])}>'
def cpp_type_registration_declarations(self) -> str:
return f'::std::tuple<{",".join([e.cpp_type_registration_declarations() for e in self.elems])}>'
def remove_const_ref(self) -> "CType":
return TupleCType([e.remove_const_ref() for e in self.elems])
@dataclass(frozen=True)
class MutRefCType(CType):
elem: "CType"
def cpp_type(self, *, strip_ref: bool = False) -> str:
if strip_ref:
return self.elem.cpp_type(strip_ref=strip_ref)
return f"{self.elem.cpp_type()} &"
def cpp_type_registration_declarations(self) -> str:
return f"{self.elem.cpp_type_registration_declarations()} &"
def remove_const_ref(self) -> "CType":
return self.elem.remove_const_ref()
# A NamedCType is short for Named C++ semantic type. A NamedCType represents a C++ type, plus
# semantic information about what it represents. For example, consider the
# argument "bool pin_memory"; its normal C++ type is "bool", but its C++
# semantic type also keeps track that this represents a "pin_memory"; you can't
# just use a random other boolean in a context where you need a "pin_memory"!
#
@dataclass(frozen=True)
class NamedCType:
name: ArgName
type: CType
def cpp_type(self, *, strip_ref: bool = False) -> str:
return self.type.cpp_type(strip_ref=strip_ref)
# For BC reasons, we don't want to introduce at:: namespaces to RegistrationDeclarations.yaml
# TODO: Kill this when we eventually remove it!
def cpp_type_registration_declarations(self) -> str:
return self.type.cpp_type_registration_declarations()
def remove_const_ref(self) -> "NamedCType":
return NamedCType(self.name, self.type.remove_const_ref())
def with_name(self, name: str) -> "NamedCType":
return NamedCType(name, self.type)
# A binding represents any C++ binding site for a formal parameter.
# We don't distinguish between binding sites for different APIs;
# instead, all of the important distinctions are encoded in CType,
# which you can use to figure out if a given Binding is appropriate
# for use in another context. (See torchgen.api.translate)
@dataclass(frozen=True)
class Binding:
name: str
nctype: NamedCType
argument: Union[Argument, TensorOptionsArguments, SelfArgument]
# TODO: maybe don't represent default here
default: Optional[str] = None
def rename(self, name: str) -> "Binding":
return Binding(
name=name,
nctype=self.nctype,
argument=self.argument,
default=self.default,
)
@property
def type(self) -> str:
return self.nctype.cpp_type()
def no_default(self) -> "Binding":
return Binding(
name=self.name,
nctype=self.nctype,
default=None,
argument=self.argument,
)
def decl(self, *, func_ptr_cast: bool = False) -> str:
mb_default = ""
if self.default is not None:
mb_default = f"={self.default}"
# casting only needs to know the type
if func_ptr_cast:
return f"{self.type}"
else:
return f"{self.type} {self.name}{mb_default}"
# For BC reasons, we don't want to introduce at:: namespaces to RegistrationDeclarations.yaml
# TODO: Kill this when we eventually remove it!
def decl_registration_declarations(self) -> str:
type_s = self.nctype.cpp_type_registration_declarations()
mb_default = ""
if self.default is not None:
mb_default = f"={self.default}"
return f"{type_s} {self.name}{mb_default}"
def defn(self) -> str:
return f"{self.type} {self.name}"
def with_name(self, name: str) -> "Binding":
return Binding(
name=name, nctype=self.nctype, argument=self.argument, default=self.default
)
# An Expr is a C++ expression. It has a C++ string representing its syntax,
# as well as a CType saying what it provides.
@dataclass(frozen=True)
class Expr:
expr: str
type: NamedCType