mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Rename int to long, add more C++ types. (#66108)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/66108 BC-breaking change: intT is now longT (which aligns it more accurately with how the types are referred to in C++). The benefit for this is we can idiomatically express all C++ dtypes (with intT now mapping to int32_t). These types are needed for ufunc codegen in a latter patch. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D31385761 Pulled By: ezyang fbshipit-source-id: ec6f3a0953794313470dbe14911f23ac116be425
This commit is contained in:
parent
11bc435622
commit
ece0221854
|
|
@ -11,7 +11,7 @@ from typing import List, Sequence, Tuple
|
|||
from tools.codegen.api.autograd import (Derivative, DifferentiabilityInfo,
|
||||
SavedAttribute, uses_retain_variables,
|
||||
uses_single_grad)
|
||||
from tools.codegen.api.types import (Binding, BaseCType, OptionalCType, tensorT, intT,
|
||||
from tools.codegen.api.types import (Binding, BaseCType, OptionalCType, tensorT, longT,
|
||||
doubleT, scalarT, stringT, boolT, intArrayRefT,
|
||||
tensorListT, MutRefCType, ListCType, ArrayRefCType)
|
||||
from tools.codegen.code_template import CodeTemplate
|
||||
|
|
@ -276,7 +276,7 @@ if (prop.isComplex()) {
|
|||
"""
|
||||
|
||||
MISC_GETTER_DEFS = {
|
||||
OptionalCType(BaseCType(intT)): (GETTER_DEFINITION_OPT, GETTER_BODY_INT64_T),
|
||||
OptionalCType(BaseCType(longT)): (GETTER_DEFINITION_OPT, GETTER_BODY_INT64_T),
|
||||
BaseCType(doubleT): (GETTER_DEFINITION, GETTER_BODY_DOUBLE),
|
||||
OptionalCType(BaseCType(doubleT)): (GETTER_DEFINITION_OPT, GETTER_BODY_DOUBLE),
|
||||
BaseCType(boolT): (GETTER_DEFINITION, GETTER_BODY_BOOL),
|
||||
|
|
@ -430,7 +430,7 @@ def process_function(info: DifferentiabilityInfo, template: CodeTemplate) -> str
|
|||
saved_variables.append(f'c10::OptionalArray<double> {name};')
|
||||
getter_definitions.append(GETTER_DEFINITION_OPT_ARRAYREF.substitute(
|
||||
op=info.op, name=name, body=GETTER_BODY_ARRAYREF_DOUBLE))
|
||||
elif type == BaseCType(intT):
|
||||
elif type == BaseCType(longT):
|
||||
saved_variables.append(f'{type.cpp_type()} {name} = 0;')
|
||||
getter_definitions.append(GETTER_DEFINITION.substitute(
|
||||
op=info.op, name=name, body=GETTER_BODY_INT64_T))
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from tools.codegen.api.autograd import (
|
|||
dispatch_strategy,
|
||||
)
|
||||
from tools.codegen.api.types import (Binding, DispatcherSignature, CppSignatureGroup, CType,
|
||||
BaseCType, OptionalCType, intT, boolT, intArrayRefT)
|
||||
BaseCType, OptionalCType, longT, boolT, intArrayRefT)
|
||||
from tools.codegen.code_template import CodeTemplate
|
||||
from tools.codegen.context import with_native_function
|
||||
from tools.codegen.model import (
|
||||
|
|
@ -244,8 +244,8 @@ def emit_view_lambda(f: NativeFunction, unpacked_bindings: List[Binding]) -> str
|
|||
replay_view_func = ''
|
||||
updated_unpacked_args: List[str] = []
|
||||
known_view_arg_simple_types: List[CType] = [
|
||||
BaseCType(intT),
|
||||
OptionalCType(BaseCType(intT)),
|
||||
BaseCType(longT),
|
||||
OptionalCType(BaseCType(longT)),
|
||||
BaseCType(boolT),
|
||||
BaseCType(intArrayRefT)]
|
||||
for unpacked_binding in unpacked_bindings:
|
||||
|
|
@ -266,7 +266,7 @@ def emit_view_lambda(f: NativeFunction, unpacked_bindings: List[Binding]) -> str
|
|||
arg_vec = arg + '_vec'
|
||||
replay_view_func += ARRAYREF_TO_VEC.substitute(arg=arg, vec=arg_vec)
|
||||
updated_unpacked_args.append(arg_vec)
|
||||
elif arg_type == OptionalCType(BaseCType(intT)):
|
||||
elif arg_type == OptionalCType(BaseCType(longT)):
|
||||
# Materialize int64_t? to int64_t
|
||||
arg_value = arg + '_val'
|
||||
replay_view_func += OPTIONAL_TO_VAL.substitute(arg=arg, val=arg_value, default='0')
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ import yaml
|
|||
from tools.codegen.api.autograd import (Derivative, DifferentiabilityInfo,
|
||||
SavedAttribute, ForwardDerivative)
|
||||
from tools.codegen.api.types import (Binding, CppSignatureGroup, NamedCType, BaseCType, VectorCType,
|
||||
intArrayRefT, tensorOptionsT, typeAndSizeT, intT, boolT,
|
||||
intArrayRefT, tensorOptionsT, typeAndSizeT, longT, boolT,
|
||||
tensorGeometryT, scalarTypeT, SpecialArgName,
|
||||
OptionalCType, stringT)
|
||||
from tools.codegen.api import cpp
|
||||
|
|
@ -518,17 +518,17 @@ def saved_variables(
|
|||
# replace self.size(2) with self_size_2
|
||||
(r'{}.size\((\w+)\)', {
|
||||
'suffix': lambda m: '_argsize_{}'.format(*m.groups()),
|
||||
'nctype': lambda name: NamedCType(name, BaseCType(intT)),
|
||||
'nctype': lambda name: NamedCType(name, BaseCType(longT)),
|
||||
}),
|
||||
# replace self.numel() with self_numel
|
||||
(r'{}.numel\(\)', {
|
||||
'suffix': '_numel',
|
||||
'nctype': lambda name: NamedCType(name, BaseCType(intT)),
|
||||
'nctype': lambda name: NamedCType(name, BaseCType(longT)),
|
||||
}),
|
||||
# replace to_args_sizes(self) with self_args_sizes
|
||||
(r'to_args_sizes\({}\)', {
|
||||
'suffix': '_args_sizes',
|
||||
'nctype': lambda name: NamedCType(name, VectorCType(VectorCType(BaseCType(intT)))),
|
||||
'nctype': lambda name: NamedCType(name, VectorCType(VectorCType(BaseCType(longT)))),
|
||||
}),
|
||||
# replace to_args_scalartypes(self) with self_args_scalartypes
|
||||
(r'to_args_scalartypes\({}\)', {
|
||||
|
|
@ -547,7 +547,7 @@ def saved_variables(
|
|||
# replace self.dim() with self_dim
|
||||
(r'{}.dim\(\)', {
|
||||
'suffix': '_dim',
|
||||
'nctype': lambda name: NamedCType(name, BaseCType(intT)),
|
||||
'nctype': lambda name: NamedCType(name, BaseCType(longT)),
|
||||
}),
|
||||
# replace self.strides() with self_strides
|
||||
(r'{}.strides\(\)', {
|
||||
|
|
|
|||
|
|
@ -30,9 +30,22 @@ class BaseCppType:
|
|||
|
||||
# The set of all non-templated, valid, fully-qualified names of C++ types that are used in the codegen.
|
||||
# Templated types get their own dataclass, mainly to make namespace parsing easier.
|
||||
intT = BaseCppType('', 'int64_t')
|
||||
byteT = BaseCppType('', 'uint8_t')
|
||||
charT = BaseCppType('', 'int8_t')
|
||||
shortT = BaseCppType('', 'int16_t')
|
||||
# It would be more symmetric for this to be called intT, but it easy to mix
|
||||
# this up with JIT int (which is int64_t in C++), so we intentionally don't
|
||||
# define intT to make it obvious when you've stuffed it up
|
||||
int32T = BaseCppType('', 'int32_t')
|
||||
longT = BaseCppType('', 'int64_t')
|
||||
halfT = BaseCppType('at', 'Half')
|
||||
doubleT = BaseCppType('', 'double')
|
||||
floatT = BaseCppType('', 'float')
|
||||
complexHalfT = BaseCppType('c10', 'complex<c10::Half>') # stuffing template param here is an abuse
|
||||
complexFloatT = BaseCppType('c10', 'complex<float>')
|
||||
complexDoubleT = BaseCppType('c10', 'complex<double>')
|
||||
boolT = BaseCppType('', 'bool')
|
||||
bfloat16T = BaseCppType('at', 'BFloat16')
|
||||
voidT = BaseCppType('', 'void')
|
||||
stringT = BaseCppType('c10', 'string_view')
|
||||
generatorT = BaseCppType('at', 'Generator')
|
||||
|
|
@ -56,7 +69,7 @@ typeAndSizeT = BaseCppType('torch::autograd::generated', 'TypeAndSize')
|
|||
tensorGeometryT = BaseCppType('at', 'TensorGeometry')
|
||||
|
||||
BaseTypeToCppMapping: Dict[BaseTy, BaseCppType] = {
|
||||
BaseTy.int: intT,
|
||||
BaseTy.int: longT,
|
||||
BaseTy.float: doubleT,
|
||||
BaseTy.bool: boolT,
|
||||
BaseTy.str: stringT,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user