mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[torchgen] Refactor types (#90589)
A retry of #89487. Accidentally closed. ## Split `torchgen.api.types` into `types_base`, `types` and `signatures`. In `types_base`: * Created base class `CType`. `BaseCType` and `ConstRefCType` etc are inheriting `CType`. * Only keep abstract type model definitions, such as `BaseCppType`. In `types`: * Define `BaseCppType` with `at` and `c10` namespaces. * All the signatures using these types. In `signatures`: * Define all the signatures. In `__init__`: * `from ... import *`, suppress flake8 error. Differential Revision: [D41455634](https://our.internmc.facebook.com/intern/diff/D41455634/) **NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D41455634/)! Pull Request resolved: https://github.com/pytorch/pytorch/pull/90589 Approved by: https://github.com/iseeyuan
This commit is contained in:
parent
0457020d2c
commit
453ff96029
2
.flake8
2
.flake8
|
|
@ -12,7 +12,7 @@ ignore =
|
|||
B007,B008,
|
||||
# these ignores are from flake8-comprehensions; please fix!
|
||||
C400,C401,C402,C403,C404,C405,C407,C411,C413,C414,C415
|
||||
per-file-ignores = __init__.py: F401 torch/utils/cpp_extension.py: B950
|
||||
per-file-ignores = __init__.py: F401 torch/utils/cpp_extension.py: B950 torchgen/api/types/__init__.py: F401,F403
|
||||
optional-ascii-coding = True
|
||||
exclude =
|
||||
./.git,
|
||||
|
|
|
|||
3
torchgen/api/types/__init__.py
Normal file
3
torchgen/api/types/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
from .types import *
|
||||
from .types_base import *
|
||||
from .signatures import * # isort:skip
|
||||
|
|
@ -1,419 +1,27 @@
|
|||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Dict, Iterator, List, Optional, Sequence, Set, Tuple, TypeVar, Union
|
||||
|
||||
from typing import Iterator, List, Optional, Sequence, Set, Tuple, Union
|
||||
|
||||
from torchgen.model import (
|
||||
Argument,
|
||||
BackendIndex,
|
||||
BaseTy,
|
||||
FunctionSchema,
|
||||
NativeFunction,
|
||||
NativeFunctionsGroup,
|
||||
NativeFunctionsViewGroup,
|
||||
ScalarType,
|
||||
SelfArgument,
|
||||
TensorOptionsArguments,
|
||||
)
|
||||
|
||||
_T = TypeVar("_T")
|
||||
from .types_base import Binding, CType, Expr
|
||||
|
||||
TENSOR_LIST_LIKE_CTYPES = [
|
||||
"at::TensorList",
|
||||
"const c10::List<c10::optional<at::Tensor>> &",
|
||||
"const at::ITensorListRef &",
|
||||
]
|
||||
|
||||
# An ArgName is just the str name of the argument in schema;
|
||||
# but in some special circumstances, we may add a little extra
|
||||
# context. The Enum SpecialArgName covers all of these cases;
|
||||
# grep for their construction sites to see when they can occr.
|
||||
|
||||
SpecialArgName = Enum("SpecialArgName", ("possibly_redundant_memory_format",))
|
||||
ArgName = Union[str, SpecialArgName]
|
||||
|
||||
# This class shouldn't be created directly; instead, use/create one of the singletons below.
|
||||
@dataclass(frozen=True)
|
||||
class BaseCppType:
|
||||
ns: Optional[str]
|
||||
name: str
|
||||
|
||||
def __str__(self) -> str:
|
||||
if self.ns is None or self.ns == "":
|
||||
return self.name
|
||||
return f"{self.ns}::{self.name}"
|
||||
|
||||
|
||||
# The set of all non-templated, valid, fully-qualified names of C++ types that are used in the codegen.
|
||||
# Templated types get their own dataclass, mainly to make namespace parsing easier.
|
||||
byteT = BaseCppType("", "uint8_t")
|
||||
charT = BaseCppType("", "int8_t")
|
||||
shortT = BaseCppType("", "int16_t")
|
||||
# It would be more symmetric for this to be called intT, but it easy to mix
|
||||
# this up with JIT int (which is int64_t in C++), so we intentionally don't
|
||||
# define intT to make it obvious when you've stuffed it up
|
||||
int32T = BaseCppType("", "int32_t")
|
||||
longT = BaseCppType("", "int64_t")
|
||||
halfT = BaseCppType("at", "Half")
|
||||
doubleT = BaseCppType("", "double")
|
||||
floatT = BaseCppType("", "float")
|
||||
complexHalfT = BaseCppType(
|
||||
"c10", "complex<c10::Half>"
|
||||
) # stuffing template param here is an abuse
|
||||
complexFloatT = BaseCppType("c10", "complex<float>")
|
||||
complexDoubleT = BaseCppType("c10", "complex<double>")
|
||||
boolT = BaseCppType("", "bool")
|
||||
bfloat16T = BaseCppType("at", "BFloat16")
|
||||
voidT = BaseCppType("", "void")
|
||||
stringT = BaseCppType("c10", "string_view")
|
||||
generatorT = BaseCppType("at", "Generator")
|
||||
scalarTypeT = BaseCppType("at", "ScalarType")
|
||||
tensorT = BaseCppType("at", "Tensor")
|
||||
optionalTensorRefT = BaseCppType("at", "OptionalTensorRef")
|
||||
tensorListT = BaseCppType("at", "TensorList")
|
||||
iTensorListRefT = BaseCppType("at", "ITensorListRef")
|
||||
iOptTensorListRefT = BaseCppType("at", "IOptTensorListRef")
|
||||
dimnameT = BaseCppType("at", "Dimname")
|
||||
dimnameListT = BaseCppType("at", "DimnameList")
|
||||
dimVectorT = BaseCppType("at", "DimVector")
|
||||
layoutT = BaseCppType("at", "Layout")
|
||||
deviceT = BaseCppType("at", "Device")
|
||||
scalarT = BaseCppType("at", "Scalar")
|
||||
optionalScalarRefT = BaseCppType("at", "OptionalScalarRef")
|
||||
memoryFormatT = BaseCppType("at", "MemoryFormat")
|
||||
qschemeT = BaseCppType("at", "QScheme")
|
||||
storageT = BaseCppType("at", "Storage")
|
||||
streamT = BaseCppType("at", "Stream")
|
||||
intArrayRefT = BaseCppType("at", "IntArrayRef")
|
||||
optionalIntArrayRefT = BaseCppType("at", "OptionalIntArrayRef")
|
||||
optionalSymIntArrayRefT = BaseCppType("at", "OptionalSymIntArrayRef")
|
||||
tensorOptionsT = BaseCppType("at", "TensorOptions")
|
||||
typeAndSizeT = BaseCppType("torch::autograd::generated", "TypeAndSize")
|
||||
tensorGeometryT = BaseCppType("at", "TensorGeometry")
|
||||
SymIntT = BaseCppType("c10", "SymInt")
|
||||
symIntArrayRefT = BaseCppType("c10", "SymIntArrayRef")
|
||||
|
||||
# Types representing template parameters. Technically, we probably shouldn't
|
||||
# represent them this way in codegen, but it was pretty convenient.
|
||||
scalar_t = BaseCppType("", "scalar_t")
|
||||
opmath_t = BaseCppType("", "opmath_t")
|
||||
|
||||
ScalarTypeToCppMapping: Dict[ScalarType, BaseCppType] = {
|
||||
ScalarType.Byte: byteT,
|
||||
ScalarType.Char: charT,
|
||||
ScalarType.Short: shortT,
|
||||
ScalarType.Int: int32T,
|
||||
ScalarType.Long: longT,
|
||||
ScalarType.Half: halfT,
|
||||
ScalarType.Float: floatT,
|
||||
ScalarType.Double: doubleT,
|
||||
ScalarType.ComplexHalf: complexHalfT,
|
||||
ScalarType.ComplexFloat: complexFloatT,
|
||||
ScalarType.ComplexDouble: complexDoubleT,
|
||||
ScalarType.Bool: boolT,
|
||||
ScalarType.BFloat16: bfloat16T,
|
||||
}
|
||||
|
||||
BaseTypeToCppMapping: Dict[BaseTy, BaseCppType] = {
|
||||
BaseTy.int: longT,
|
||||
BaseTy.float: doubleT,
|
||||
BaseTy.bool: boolT,
|
||||
BaseTy.str: stringT,
|
||||
BaseTy.Generator: generatorT,
|
||||
BaseTy.ScalarType: scalarTypeT,
|
||||
BaseTy.Tensor: tensorT,
|
||||
BaseTy.Dimname: dimnameT,
|
||||
BaseTy.DimVector: dimVectorT,
|
||||
BaseTy.Layout: layoutT,
|
||||
BaseTy.Device: deviceT,
|
||||
BaseTy.Scalar: scalarT,
|
||||
BaseTy.MemoryFormat: memoryFormatT,
|
||||
BaseTy.QScheme: qschemeT,
|
||||
BaseTy.Storage: storageT,
|
||||
BaseTy.Stream: streamT,
|
||||
BaseTy.SymInt: SymIntT,
|
||||
}
|
||||
|
||||
# CTypes encode C++ type structure as needed for translation.
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BaseCType:
|
||||
type: BaseCppType
|
||||
|
||||
def cpp_type(self, *, strip_ref: bool = False) -> str:
|
||||
return str(self.type)
|
||||
|
||||
# For BC reasons, we don't want to introduce at:: namespaces to RegistrationDeclarations.yaml
|
||||
# TODO: Kill this when we eventually remove it!
|
||||
def cpp_type_registration_declarations(self) -> str:
|
||||
return str(self.type).replace("at::", "")
|
||||
|
||||
def remove_const_ref(self) -> "CType":
|
||||
return self
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ConstRefCType:
|
||||
elem: "CType"
|
||||
|
||||
def cpp_type(self, *, strip_ref: bool = False) -> str:
|
||||
if strip_ref:
|
||||
return self.elem.cpp_type(strip_ref=strip_ref)
|
||||
return f"const {self.elem.cpp_type()} &"
|
||||
|
||||
def cpp_type_registration_declarations(self) -> str:
|
||||
return f"const {self.elem.cpp_type_registration_declarations()} &"
|
||||
|
||||
def remove_const_ref(self) -> "CType":
|
||||
return self.elem.remove_const_ref()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MutRefCType:
|
||||
elem: "CType"
|
||||
|
||||
def cpp_type(self, *, strip_ref: bool = False) -> str:
|
||||
if strip_ref:
|
||||
return self.elem.cpp_type(strip_ref=strip_ref)
|
||||
return f"{self.elem.cpp_type()} &"
|
||||
|
||||
def cpp_type_registration_declarations(self) -> str:
|
||||
return f"{self.elem.cpp_type_registration_declarations()} &"
|
||||
|
||||
def remove_const_ref(self) -> "CType":
|
||||
return self.elem.remove_const_ref()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class OptionalCType:
|
||||
elem: "CType"
|
||||
|
||||
def cpp_type(self, *, strip_ref: bool = False) -> str:
|
||||
# Do not pass `strip_ref` recursively.
|
||||
return f"c10::optional<{self.elem.cpp_type()}>"
|
||||
|
||||
def cpp_type_registration_declarations(self) -> str:
|
||||
return f"c10::optional<{self.elem.cpp_type_registration_declarations()}>"
|
||||
|
||||
def remove_const_ref(self) -> "CType":
|
||||
return OptionalCType(self.elem.remove_const_ref())
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ListCType:
|
||||
elem: "CType"
|
||||
|
||||
def cpp_type(self, *, strip_ref: bool = False) -> str:
|
||||
# Do not pass `strip_ref` recursively.
|
||||
return f"c10::List<{self.elem.cpp_type()}>"
|
||||
|
||||
def cpp_type_registration_declarations(self) -> str:
|
||||
return f"c10::List<{self.elem.cpp_type_registration_declarations()}>"
|
||||
|
||||
def remove_const_ref(self) -> "CType":
|
||||
return ListCType(self.elem.remove_const_ref())
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ArrayRefCType:
|
||||
elem: "CType"
|
||||
|
||||
def cpp_type(self, *, strip_ref: bool = False) -> str:
|
||||
# Do not pass `strip_ref` recursively.
|
||||
return f"at::ArrayRef<{self.elem.cpp_type()}>"
|
||||
|
||||
def cpp_type_registration_declarations(self) -> str:
|
||||
return f"ArrayRef<{self.elem.cpp_type_registration_declarations()}>"
|
||||
|
||||
def remove_const_ref(self) -> "CType":
|
||||
return ArrayRefCType(self.elem.remove_const_ref())
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class VectorCType:
|
||||
elem: "CType"
|
||||
|
||||
def cpp_type(self, *, strip_ref: bool = False) -> str:
|
||||
# Do not pass `strip_ref` recursively.
|
||||
return f"::std::vector<{self.elem.cpp_type()}>"
|
||||
|
||||
def cpp_type_registration_declarations(self) -> str:
|
||||
return f"::std::vector<{self.elem.cpp_type_registration_declarations()}>"
|
||||
|
||||
def remove_const_ref(self) -> "CType":
|
||||
return VectorCType(self.elem.remove_const_ref())
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ArrayCType:
|
||||
elem: "CType"
|
||||
size: int
|
||||
|
||||
def cpp_type(self, *, strip_ref: bool = False) -> str:
|
||||
# Do not pass `strip_ref` recursively.
|
||||
return f"::std::array<{self.elem.cpp_type()},{self.size}>"
|
||||
|
||||
def cpp_type_registration_declarations(self) -> str:
|
||||
return f"::std::array<{self.elem.cpp_type_registration_declarations()},{self.size}>"
|
||||
|
||||
def remove_const_ref(self) -> "CType":
|
||||
return ArrayCType(self.elem.remove_const_ref(), self.size)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TupleCType:
|
||||
elems: List["CType"]
|
||||
|
||||
def cpp_type(self, *, strip_ref: bool = False) -> str:
|
||||
# Do not pass `strip_ref` recursively.
|
||||
return f'::std::tuple<{",".join([e.cpp_type() for e in self.elems])}>'
|
||||
|
||||
def cpp_type_registration_declarations(self) -> str:
|
||||
return f'::std::tuple<{",".join([e.cpp_type_registration_declarations() for e in self.elems])}>'
|
||||
|
||||
def remove_const_ref(self) -> "CType":
|
||||
return TupleCType([e.remove_const_ref() for e in self.elems])
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class VectorizedCType:
|
||||
# This template is explicitly specialized, so the only valid
|
||||
# elems are those we have specializations for (e.g., float, double, ...)
|
||||
# scalar_t is also a common argument here (when we are codegen in
|
||||
# a templated context)
|
||||
elem: BaseCType
|
||||
|
||||
def cpp_type(self, *, strip_ref: bool = False) -> str:
|
||||
return f"at::vec::Vectorized<{self.elem.cpp_type()}>"
|
||||
|
||||
def cpp_type_registration_declarations(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def remove_const_ref(self) -> "CType":
|
||||
return self
|
||||
|
||||
|
||||
CType = Union[
|
||||
BaseCType,
|
||||
OptionalCType,
|
||||
ConstRefCType,
|
||||
MutRefCType,
|
||||
ListCType,
|
||||
ArrayRefCType,
|
||||
ArrayCType,
|
||||
VectorCType,
|
||||
TupleCType,
|
||||
VectorizedCType,
|
||||
]
|
||||
|
||||
# A NamedCType is short for Named C++ semantic type. A NamedCType represents a C++ type, plus
|
||||
# semantic information about what it represents. For example, consider the
|
||||
# argument "bool pin_memory"; its normal C++ type is "bool", but its C++
|
||||
# semantic type also keeps track that this represents a "pin_memory"; you can't
|
||||
# just use a random other boolean in a context where you need a "pin_memory"!
|
||||
#
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class NamedCType:
|
||||
name: ArgName
|
||||
type: CType
|
||||
|
||||
def cpp_type(self, *, strip_ref: bool = False) -> str:
|
||||
return self.type.cpp_type(strip_ref=strip_ref)
|
||||
|
||||
# For BC reasons, we don't want to introduce at:: namespaces to RegistrationDeclarations.yaml
|
||||
# TODO: Kill this when we eventually remove it!
|
||||
def cpp_type_registration_declarations(self) -> str:
|
||||
return self.type.cpp_type_registration_declarations()
|
||||
|
||||
def remove_const_ref(self) -> "NamedCType":
|
||||
return NamedCType(self.name, self.type.remove_const_ref())
|
||||
|
||||
def with_name(self, name: str) -> "NamedCType":
|
||||
return NamedCType(name, self.type)
|
||||
|
||||
|
||||
# A binding represents any C++ binding site for a formal parameter.
|
||||
# We don't distinguish between binding sites for different APIs;
|
||||
# instead, all of the important distinctions are encoded in CType,
|
||||
# which you can use to figure out if a given Binding is appropriate
|
||||
# for use in another context. (See torchgen.api.translate)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Binding:
|
||||
name: str
|
||||
nctype: NamedCType
|
||||
argument: Union[Argument, TensorOptionsArguments, SelfArgument]
|
||||
# TODO: maybe don't represent default here
|
||||
default: Optional[str] = None
|
||||
|
||||
def rename(self, name: str) -> "Binding":
|
||||
return Binding(
|
||||
name=name,
|
||||
nctype=self.nctype,
|
||||
argument=self.argument,
|
||||
default=self.default,
|
||||
)
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return self.nctype.cpp_type()
|
||||
|
||||
def no_default(self) -> "Binding":
|
||||
return Binding(
|
||||
name=self.name,
|
||||
nctype=self.nctype,
|
||||
default=None,
|
||||
argument=self.argument,
|
||||
)
|
||||
|
||||
def decl(self, *, func_ptr_cast: bool = False) -> str:
|
||||
mb_default = ""
|
||||
if self.default is not None:
|
||||
mb_default = f"={self.default}"
|
||||
|
||||
# casting only needs to know the type
|
||||
if func_ptr_cast:
|
||||
return f"{self.type}"
|
||||
else:
|
||||
return f"{self.type} {self.name}{mb_default}"
|
||||
|
||||
# For BC reasons, we don't want to introduce at:: namespaces to RegistrationDeclarations.yaml
|
||||
# TODO: Kill this when we eventually remove it!
|
||||
def decl_registration_declarations(self) -> str:
|
||||
type_s = self.nctype.cpp_type_registration_declarations()
|
||||
mb_default = ""
|
||||
if self.default is not None:
|
||||
mb_default = f"={self.default}"
|
||||
return f"{type_s} {self.name}{mb_default}"
|
||||
|
||||
def defn(self) -> str:
|
||||
return f"{self.type} {self.name}"
|
||||
|
||||
def with_name(self, name: str) -> "Binding":
|
||||
return Binding(
|
||||
name=name, nctype=self.nctype, argument=self.argument, default=self.default
|
||||
)
|
||||
|
||||
|
||||
# An Expr is a C++ expression. It has a C++ string representing its syntax,
|
||||
# as well as a CType saying what it provides.
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Expr:
|
||||
expr: str
|
||||
type: NamedCType
|
||||
|
||||
|
||||
# A CppSignature represents a single overload in the C++ API. For
|
||||
# any given function schema, there may be multiple CppSignatures
|
||||
# corresponding to it, based on how we desugar to C++. See also
|
||||
# CppSignatureGroup.
|
||||
@dataclass(frozen=True)
|
||||
class CppSignature:
|
||||
"""
|
||||
A CppSignature represents a single overload in the C++ API. For
|
||||
any given function schema, there may be multiple CppSignatures
|
||||
corresponding to it, based on how we desugar to C++. See also
|
||||
CppSignatureGroup.
|
||||
"""
|
||||
|
||||
# The schema this signature is derived from
|
||||
func: FunctionSchema
|
||||
|
||||
182
torchgen/api/types/types.py
Normal file
182
torchgen/api/types/types.py
Normal file
|
|
@ -0,0 +1,182 @@
|
|||
"""
|
||||
Where should I add a new type? `types_base.py` vs `types.py`
|
||||
|
||||
This file defines data model classes for torchgen typing system, as well as some base types such as int32_t.
|
||||
|
||||
`types.py` defines ATen Tensor type and some c10 types, along with signatures that use these types.
|
||||
|
||||
The difference between these two files, is `types_base.py` should be implementation-agnostic, meaning it shouldn't
|
||||
contain any type definition that is tight to a specific C++ library (e.g., ATen), so that it can be easily reused
|
||||
if we want to generate code for another C++ library.
|
||||
|
||||
Add new types to `types.py` if these types are ATen/c10 related.
|
||||
Add new types to `types_base.py` if they are basic and not attached to ATen/c10.
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, TypeVar
|
||||
|
||||
from torchgen.model import BaseTy, ScalarType
|
||||
|
||||
from .types_base import (
|
||||
BaseCppType,
|
||||
BaseCType,
|
||||
boolT,
|
||||
byteT,
|
||||
charT,
|
||||
CType,
|
||||
doubleT,
|
||||
floatT,
|
||||
int32T,
|
||||
longT,
|
||||
shortT,
|
||||
)
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
TENSOR_LIST_LIKE_CTYPES = [
|
||||
"at::TensorList",
|
||||
"const c10::List<c10::optional<at::Tensor>> &",
|
||||
"const at::ITensorListRef &",
|
||||
]
|
||||
|
||||
|
||||
halfT = BaseCppType("at", "Half")
|
||||
complexHalfT = BaseCppType(
|
||||
"c10", "complex<c10::Half>"
|
||||
) # stuffing template param here is an abuse
|
||||
complexFloatT = BaseCppType("c10", "complex<float>")
|
||||
complexDoubleT = BaseCppType("c10", "complex<double>")
|
||||
bfloat16T = BaseCppType("at", "BFloat16")
|
||||
stringT = BaseCppType("c10", "string_view")
|
||||
generatorT = BaseCppType("at", "Generator")
|
||||
scalarTypeT = BaseCppType("at", "ScalarType")
|
||||
tensorT = BaseCppType("at", "Tensor")
|
||||
optionalTensorRefT = BaseCppType("at", "OptionalTensorRef")
|
||||
tensorListT = BaseCppType("at", "TensorList")
|
||||
iTensorListRefT = BaseCppType("at", "ITensorListRef")
|
||||
iOptTensorListRefT = BaseCppType("at", "IOptTensorListRef")
|
||||
dimnameT = BaseCppType("at", "Dimname")
|
||||
dimnameListT = BaseCppType("at", "DimnameList")
|
||||
dimVectorT = BaseCppType("at", "DimVector")
|
||||
layoutT = BaseCppType("at", "Layout")
|
||||
deviceT = BaseCppType("at", "Device")
|
||||
scalarT = BaseCppType("at", "Scalar")
|
||||
optionalScalarRefT = BaseCppType("at", "OptionalScalarRef")
|
||||
memoryFormatT = BaseCppType("at", "MemoryFormat")
|
||||
qschemeT = BaseCppType("at", "QScheme")
|
||||
storageT = BaseCppType("at", "Storage")
|
||||
streamT = BaseCppType("at", "Stream")
|
||||
intArrayRefT = BaseCppType("at", "IntArrayRef")
|
||||
optionalIntArrayRefT = BaseCppType("at", "OptionalIntArrayRef")
|
||||
optionalSymIntArrayRefT = BaseCppType("at", "OptionalSymIntArrayRef")
|
||||
tensorOptionsT = BaseCppType("at", "TensorOptions")
|
||||
typeAndSizeT = BaseCppType("torch::autograd::generated", "TypeAndSize")
|
||||
tensorGeometryT = BaseCppType("at", "TensorGeometry")
|
||||
SymIntT = BaseCppType("c10", "SymInt")
|
||||
symIntArrayRefT = BaseCppType("c10", "SymIntArrayRef")
|
||||
|
||||
# Types representing template parameters. Technically, we probably shouldn't
|
||||
# represent them this way in codegen, but it was pretty convenient.
|
||||
scalar_t = BaseCppType("", "scalar_t")
|
||||
opmath_t = BaseCppType("", "opmath_t")
|
||||
|
||||
ScalarTypeToCppMapping: Dict[ScalarType, BaseCppType] = {
|
||||
ScalarType.Byte: byteT,
|
||||
ScalarType.Char: charT,
|
||||
ScalarType.Short: shortT,
|
||||
ScalarType.Int: int32T,
|
||||
ScalarType.Long: longT,
|
||||
ScalarType.Half: halfT,
|
||||
ScalarType.Float: floatT,
|
||||
ScalarType.Double: doubleT,
|
||||
ScalarType.ComplexHalf: complexHalfT,
|
||||
ScalarType.ComplexFloat: complexFloatT,
|
||||
ScalarType.ComplexDouble: complexDoubleT,
|
||||
ScalarType.Bool: boolT,
|
||||
ScalarType.BFloat16: bfloat16T,
|
||||
}
|
||||
|
||||
BaseTypeToCppMapping: Dict[BaseTy, BaseCppType] = {
|
||||
BaseTy.int: longT,
|
||||
BaseTy.float: doubleT,
|
||||
BaseTy.bool: boolT,
|
||||
BaseTy.str: stringT,
|
||||
BaseTy.Generator: generatorT,
|
||||
BaseTy.ScalarType: scalarTypeT,
|
||||
BaseTy.Tensor: tensorT,
|
||||
BaseTy.Dimname: dimnameT,
|
||||
BaseTy.DimVector: dimVectorT,
|
||||
BaseTy.Layout: layoutT,
|
||||
BaseTy.Device: deviceT,
|
||||
BaseTy.Scalar: scalarT,
|
||||
BaseTy.MemoryFormat: memoryFormatT,
|
||||
BaseTy.QScheme: qschemeT,
|
||||
BaseTy.Storage: storageT,
|
||||
BaseTy.Stream: streamT,
|
||||
BaseTy.SymInt: SymIntT,
|
||||
}
|
||||
|
||||
# CTypes encode C++ type structure as needed for translation.
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class OptionalCType(CType):
|
||||
elem: "CType"
|
||||
|
||||
def cpp_type(self, *, strip_ref: bool = False) -> str:
|
||||
# Do not pass `strip_ref` recursively.
|
||||
return f"c10::optional<{self.elem.cpp_type()}>"
|
||||
|
||||
def cpp_type_registration_declarations(self) -> str:
|
||||
return f"c10::optional<{self.elem.cpp_type_registration_declarations()}>"
|
||||
|
||||
def remove_const_ref(self) -> "CType":
|
||||
return OptionalCType(self.elem.remove_const_ref())
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ListCType(CType):
|
||||
elem: "CType"
|
||||
|
||||
def cpp_type(self, *, strip_ref: bool = False) -> str:
|
||||
# Do not pass `strip_ref` recursively.
|
||||
return f"c10::List<{self.elem.cpp_type()}>"
|
||||
|
||||
def cpp_type_registration_declarations(self) -> str:
|
||||
return f"c10::List<{self.elem.cpp_type_registration_declarations()}>"
|
||||
|
||||
def remove_const_ref(self) -> "CType":
|
||||
return ListCType(self.elem.remove_const_ref())
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ArrayRefCType(CType):
|
||||
elem: "CType"
|
||||
|
||||
def cpp_type(self, *, strip_ref: bool = False) -> str:
|
||||
# Do not pass `strip_ref` recursively.
|
||||
return f"at::ArrayRef<{self.elem.cpp_type()}>"
|
||||
|
||||
def cpp_type_registration_declarations(self) -> str:
|
||||
return f"ArrayRef<{self.elem.cpp_type_registration_declarations()}>"
|
||||
|
||||
def remove_const_ref(self) -> "CType":
|
||||
return ArrayRefCType(self.elem.remove_const_ref())
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class VectorizedCType(CType):
|
||||
# This template is explicitly specialized, so the only valid
|
||||
# elems are those we have specializations for (e.g., float, double, ...)
|
||||
# scalar_t is also a common argument here (when we are codegen in
|
||||
# a templated context)
|
||||
elem: BaseCType
|
||||
|
||||
def cpp_type(self, *, strip_ref: bool = False) -> str:
|
||||
return f"at::vec::Vectorized<{self.elem.cpp_type()}>"
|
||||
|
||||
def cpp_type_registration_declarations(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def remove_const_ref(self) -> "CType":
|
||||
return self
|
||||
263
torchgen/api/types/types_base.py
Normal file
263
torchgen/api/types/types_base.py
Normal file
|
|
@ -0,0 +1,263 @@
|
|||
"""
|
||||
Where should I add a new type? `types_base.py` vs `types.py`
|
||||
|
||||
This file defines data model classes for torchgen typing system, as well as some base types such as int32_t.
|
||||
|
||||
`types.py` defines ATen Tensor type and some c10 types, along with signatures that use these types.
|
||||
|
||||
The difference between these two files, is `types_base.py` should be implementation-agnostic, meaning it shouldn't
|
||||
contain any type definition that is tight to a specific C++ library (e.g., ATen), so that it can be easily reused
|
||||
if we want to generate code for another C++ library.
|
||||
|
||||
Add new types to `types.py` if these types are ATen/c10 related.
|
||||
Add new types to `types_base.py` if they are basic and not attached to ATen/c10.
|
||||
"""
|
||||
from abc import ABC
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from torchgen.model import Argument, SelfArgument, TensorOptionsArguments
|
||||
|
||||
# An ArgName is just the str name of the argument in schema;
|
||||
# but in some special circumstances, we may add a little extra
|
||||
# context. The Enum SpecialArgName covers all of these cases;
|
||||
# grep for their construction sites to see when they can occr.
|
||||
|
||||
SpecialArgName = Enum("SpecialArgName", ("possibly_redundant_memory_format",))
|
||||
ArgName = Union[str, SpecialArgName]
|
||||
|
||||
|
||||
# This class shouldn't be created directly; instead, use/create one of the singletons below.
|
||||
@dataclass(frozen=True)
|
||||
class BaseCppType:
|
||||
ns: Optional[str]
|
||||
name: str
|
||||
|
||||
def __str__(self) -> str:
|
||||
if self.ns is None or self.ns == "":
|
||||
return self.name
|
||||
return f"{self.ns}::{self.name}"
|
||||
|
||||
|
||||
# The set of all non-templated, valid, fully-qualified names of C++ types that are used in the codegen.
|
||||
# Templated types get their own dataclass, mainly to make namespace parsing easier.
|
||||
byteT = BaseCppType("", "uint8_t")
|
||||
charT = BaseCppType("", "int8_t")
|
||||
shortT = BaseCppType("", "int16_t")
|
||||
# It would be more symmetric for this to be called intT, but it easy to mix
|
||||
# this up with JIT int (which is int64_t in C++), so we intentionally don't
|
||||
# define intT to make it obvious when you've stuffed it up
|
||||
int32T = BaseCppType("", "int32_t")
|
||||
longT = BaseCppType("", "int64_t")
|
||||
doubleT = BaseCppType("", "double")
|
||||
floatT = BaseCppType("", "float")
|
||||
boolT = BaseCppType("", "bool")
|
||||
voidT = BaseCppType("", "void")
|
||||
|
||||
|
||||
class CType(ABC):
|
||||
def cpp_type(self, *, strip_ref: bool = False) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def cpp_type_registration_declarations(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def remove_const_ref(self) -> "CType":
|
||||
return self
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BaseCType(CType):
|
||||
type: BaseCppType
|
||||
|
||||
def cpp_type(self, *, strip_ref: bool = False) -> str:
|
||||
return str(self.type)
|
||||
|
||||
# For BC reasons, we don't want to introduce at:: namespaces to RegistrationDeclarations.yaml
|
||||
# TODO: Kill this when we eventually remove it!
|
||||
def cpp_type_registration_declarations(self) -> str:
|
||||
return str(self.type).replace("at::", "")
|
||||
|
||||
def remove_const_ref(self) -> "CType":
|
||||
return self
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ConstRefCType(CType):
|
||||
elem: "CType"
|
||||
|
||||
def cpp_type(self, *, strip_ref: bool = False) -> str:
|
||||
if strip_ref:
|
||||
return self.elem.cpp_type(strip_ref=strip_ref)
|
||||
return f"const {self.elem.cpp_type()} &"
|
||||
|
||||
def cpp_type_registration_declarations(self) -> str:
|
||||
return f"const {self.elem.cpp_type_registration_declarations()} &"
|
||||
|
||||
def remove_const_ref(self) -> "CType":
|
||||
return self.elem.remove_const_ref()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class VectorCType(CType):
|
||||
elem: "CType"
|
||||
|
||||
def cpp_type(self, *, strip_ref: bool = False) -> str:
|
||||
# Do not pass `strip_ref` recursively.
|
||||
return f"::std::vector<{self.elem.cpp_type()}>"
|
||||
|
||||
def cpp_type_registration_declarations(self) -> str:
|
||||
return f"::std::vector<{self.elem.cpp_type_registration_declarations()}>"
|
||||
|
||||
def remove_const_ref(self) -> "CType":
|
||||
return VectorCType(self.elem.remove_const_ref())
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ArrayCType(CType):
|
||||
elem: "CType"
|
||||
size: int
|
||||
|
||||
def cpp_type(self, *, strip_ref: bool = False) -> str:
|
||||
# Do not pass `strip_ref` recursively.
|
||||
return f"::std::array<{self.elem.cpp_type()},{self.size}>"
|
||||
|
||||
def cpp_type_registration_declarations(self) -> str:
|
||||
return f"::std::array<{self.elem.cpp_type_registration_declarations()},{self.size}>"
|
||||
|
||||
def remove_const_ref(self) -> "CType":
|
||||
return ArrayCType(self.elem.remove_const_ref(), self.size)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TupleCType(CType):
|
||||
elems: List["CType"]
|
||||
|
||||
def cpp_type(self, *, strip_ref: bool = False) -> str:
|
||||
# Do not pass `strip_ref` recursively.
|
||||
return f'::std::tuple<{",".join([e.cpp_type() for e in self.elems])}>'
|
||||
|
||||
def cpp_type_registration_declarations(self) -> str:
|
||||
return f'::std::tuple<{",".join([e.cpp_type_registration_declarations() for e in self.elems])}>'
|
||||
|
||||
def remove_const_ref(self) -> "CType":
|
||||
return TupleCType([e.remove_const_ref() for e in self.elems])
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MutRefCType(CType):
|
||||
elem: "CType"
|
||||
|
||||
def cpp_type(self, *, strip_ref: bool = False) -> str:
|
||||
if strip_ref:
|
||||
return self.elem.cpp_type(strip_ref=strip_ref)
|
||||
return f"{self.elem.cpp_type()} &"
|
||||
|
||||
def cpp_type_registration_declarations(self) -> str:
|
||||
return f"{self.elem.cpp_type_registration_declarations()} &"
|
||||
|
||||
def remove_const_ref(self) -> "CType":
|
||||
return self.elem.remove_const_ref()
|
||||
|
||||
|
||||
# A NamedCType is short for Named C++ semantic type. A NamedCType represents a C++ type, plus
|
||||
# semantic information about what it represents. For example, consider the
|
||||
# argument "bool pin_memory"; its normal C++ type is "bool", but its C++
|
||||
# semantic type also keeps track that this represents a "pin_memory"; you can't
|
||||
# just use a random other boolean in a context where you need a "pin_memory"!
|
||||
#
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class NamedCType:
|
||||
name: ArgName
|
||||
type: CType
|
||||
|
||||
def cpp_type(self, *, strip_ref: bool = False) -> str:
|
||||
return self.type.cpp_type(strip_ref=strip_ref)
|
||||
|
||||
# For BC reasons, we don't want to introduce at:: namespaces to RegistrationDeclarations.yaml
|
||||
# TODO: Kill this when we eventually remove it!
|
||||
def cpp_type_registration_declarations(self) -> str:
|
||||
return self.type.cpp_type_registration_declarations()
|
||||
|
||||
def remove_const_ref(self) -> "NamedCType":
|
||||
return NamedCType(self.name, self.type.remove_const_ref())
|
||||
|
||||
def with_name(self, name: str) -> "NamedCType":
|
||||
return NamedCType(name, self.type)
|
||||
|
||||
|
||||
# A binding represents any C++ binding site for a formal parameter.
|
||||
# We don't distinguish between binding sites for different APIs;
|
||||
# instead, all of the important distinctions are encoded in CType,
|
||||
# which you can use to figure out if a given Binding is appropriate
|
||||
# for use in another context. (See torchgen.api.translate)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Binding:
|
||||
name: str
|
||||
nctype: NamedCType
|
||||
argument: Union[Argument, TensorOptionsArguments, SelfArgument]
|
||||
# TODO: maybe don't represent default here
|
||||
default: Optional[str] = None
|
||||
|
||||
def rename(self, name: str) -> "Binding":
|
||||
return Binding(
|
||||
name=name,
|
||||
nctype=self.nctype,
|
||||
argument=self.argument,
|
||||
default=self.default,
|
||||
)
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return self.nctype.cpp_type()
|
||||
|
||||
def no_default(self) -> "Binding":
|
||||
return Binding(
|
||||
name=self.name,
|
||||
nctype=self.nctype,
|
||||
default=None,
|
||||
argument=self.argument,
|
||||
)
|
||||
|
||||
def decl(self, *, func_ptr_cast: bool = False) -> str:
|
||||
mb_default = ""
|
||||
if self.default is not None:
|
||||
mb_default = f"={self.default}"
|
||||
|
||||
# casting only needs to know the type
|
||||
if func_ptr_cast:
|
||||
return f"{self.type}"
|
||||
else:
|
||||
return f"{self.type} {self.name}{mb_default}"
|
||||
|
||||
# For BC reasons, we don't want to introduce at:: namespaces to RegistrationDeclarations.yaml
|
||||
# TODO: Kill this when we eventually remove it!
|
||||
def decl_registration_declarations(self) -> str:
|
||||
type_s = self.nctype.cpp_type_registration_declarations()
|
||||
mb_default = ""
|
||||
if self.default is not None:
|
||||
mb_default = f"={self.default}"
|
||||
return f"{type_s} {self.name}{mb_default}"
|
||||
|
||||
def defn(self) -> str:
|
||||
return f"{self.type} {self.name}"
|
||||
|
||||
def with_name(self, name: str) -> "Binding":
|
||||
return Binding(
|
||||
name=name, nctype=self.nctype, argument=self.argument, default=self.default
|
||||
)
|
||||
|
||||
|
||||
# An Expr is a C++ expression. It has a C++ string representing its syntax,
|
||||
# as well as a CType saying what it provides.
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Expr:
|
||||
expr: str
|
||||
type: NamedCType
|
||||
Loading…
Reference in New Issue
Block a user