mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Use presence of _symint in kernel name to generate symint sig or not (#84579)
Something people found confusing was that whether or not a native:: signature would get SymInt or not in its type was based on the dispatch key. This changes it so that SymInt or not in type is based on whether or not you have _symint in the name of the kernel or not. This means that even when we make operators support SymInt, you no longer have to go and update all the preexisting definitions; instead, you now selectively write _symint to opt individual kernels into SymInt support. I then go and update a bunch of kernels that don't have proper SymInt support to make use of this convention. There is some hacking around for view generation code. I also add support for external backends to specify 'symint' operators, for which we generate SymInt signatures instead of regular signatures. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Differential Revision: [D39310060](https://our.internmc.facebook.com/intern/diff/D39310060) Pull Request resolved: https://github.com/pytorch/pytorch/pull/84579 Approved by: https://github.com/wconstab
This commit is contained in:
parent
18a31cc044
commit
93aef3a010
2
.github/ci_commit_pins/xla.txt
vendored
2
.github/ci_commit_pins/xla.txt
vendored
|
|
@ -1 +1 @@
|
|||
2ba7616e9070bd14ea34a5ef5459bac571198926
|
||||
f00dd2f35ecf6455d97237d63c70c9c8ec190940
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@
|
|||
namespace at {
|
||||
namespace native {
|
||||
|
||||
Tensor empty_meta(
|
||||
Tensor empty_meta_symint(
|
||||
SymIntArrayRef size,
|
||||
c10::optional<ScalarType> dtype_opt,
|
||||
c10::optional<Layout> layout_opt,
|
||||
|
|
|
|||
|
|
@ -214,12 +214,9 @@ Tensor empty_strided_cpu(IntArrayRef size, IntArrayRef stride, c10::optional<Sca
|
|||
return at::detail::empty_strided_cpu(size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt);
|
||||
}
|
||||
|
||||
Tensor& empty_out(SymIntArrayRef sym_size,
|
||||
Tensor& empty_out(IntArrayRef size,
|
||||
c10::optional<c10::MemoryFormat> optional_memory_format,
|
||||
Tensor& result) {
|
||||
// TODO: support empty_out properly (I was forced to change this immediately
|
||||
// with empty so that empty/empty.out had the same type signature)
|
||||
auto size = c10::asIntArrayRefSlow(sym_size);
|
||||
// Preferably, this argument would not be accepted by _out, but the code
|
||||
// generator requires the out and non-out overloads to match exactly
|
||||
TORCH_CHECK(
|
||||
|
|
@ -386,7 +383,7 @@ Tensor empty_like_quantized(
|
|||
}
|
||||
}
|
||||
|
||||
Tensor new_empty(
|
||||
Tensor new_empty_symint(
|
||||
const Tensor& self,
|
||||
SymIntArrayRef size,
|
||||
c10::optional<ScalarType> dtype_opt,
|
||||
|
|
@ -1077,7 +1074,7 @@ Tensor triu_indices_cpu(
|
|||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ zeros ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Tensor zeros(SymIntArrayRef size,
|
||||
Tensor zeros_symint(SymIntArrayRef size,
|
||||
c10::optional<ScalarType> dtype,
|
||||
c10::optional<Layout> layout,
|
||||
c10::optional<Device> device,
|
||||
|
|
@ -1107,8 +1104,7 @@ Tensor& zeros_sparse_out(IntArrayRef size, Tensor& result) {
|
|||
return result;
|
||||
}
|
||||
|
||||
Tensor& zeros_out(SymIntArrayRef sym_size, Tensor& result) {
|
||||
auto size = c10::asIntArrayRefSlow(sym_size);
|
||||
Tensor& zeros_out(IntArrayRef size, Tensor& result) {
|
||||
if (result.is_sparse()) {
|
||||
// TODO: I think this branch should be dead, but we don't have an easy
|
||||
// way to cover all sparse kernels with zeros_sparse_out, so retain this
|
||||
|
|
|
|||
|
|
@ -844,9 +844,7 @@ Tensor diag_embed(const Tensor& self, int64_t offset, int64_t dim1_, int64_t dim
|
|||
return result;
|
||||
}
|
||||
|
||||
Tensor expand(const Tensor& self, c10::SymIntArrayRef sym_size, bool /*unused*/) {
|
||||
// TODO: properly support SymInt expand
|
||||
auto size = asIntArrayRefSlow(sym_size);
|
||||
Tensor expand(const Tensor& self, c10::IntArrayRef size, bool /*unused*/) {
|
||||
TORCH_CHECK(size.size() >= (size_t)self.dim(),
|
||||
"expand(", self.toString(), "{", self.sizes(), "}, size=", size,
|
||||
"): the number of sizes provided (", size.size(), ") ",
|
||||
|
|
@ -925,9 +923,8 @@ const Tensor &as_strided_(const Tensor& self, IntArrayRef size, IntArrayRef stri
|
|||
return self;
|
||||
}
|
||||
|
||||
Tensor narrow_copy_dense(const Tensor& self, int64_t dim, SymInt start, SymInt length) {
|
||||
// TODO: properly support SymInt narrow_copy
|
||||
return self.narrow(dim, start.expect_int(), length.expect_int()).clone(at::MemoryFormat::Contiguous);
|
||||
Tensor narrow_copy_dense(const Tensor& self, int64_t dim, int64_t start, int64_t length) {
|
||||
return self.narrow(dim, start, length).clone(at::MemoryFormat::Contiguous);
|
||||
}
|
||||
|
||||
Tensor narrow_copy_dense_cpu(const Tensor& self, int64_t dim, int64_t start, int64_t length){
|
||||
|
|
@ -3204,12 +3201,6 @@ Tensor adjoint(const Tensor &self) {
|
|||
return _adjoint(self, /*transpose=*/false, "adjoint()");
|
||||
}
|
||||
|
||||
Tensor view_meta(const Tensor& self,
|
||||
at::SymIntArrayRef size) {
|
||||
// TODO: Properly support SymInt view
|
||||
return view_impl(self, c10::asIntArrayRefSlow(size));
|
||||
}
|
||||
|
||||
Tensor view(const Tensor& self,
|
||||
at::IntArrayRef size) {
|
||||
return view_impl(self, size);
|
||||
|
|
@ -3592,7 +3583,7 @@ at::Tensor& expand_copy_SymInt_out(const at::Tensor & self, c10::SymIntArrayRef
|
|||
}
|
||||
|
||||
|
||||
at::Tensor& expand_copy_out(const at::Tensor & self, at::SymIntArrayRef size, bool implicit, at::Tensor & out) {
|
||||
at::Tensor& expand_copy_out_symint(const at::Tensor & self, at::SymIntArrayRef size, bool implicit, at::Tensor & out) {
|
||||
auto tmp = self.expand_symint(size, implicit);
|
||||
out.copy_(tmp);
|
||||
return out;
|
||||
|
|
@ -3748,7 +3739,7 @@ void unbind_copy_int_out(const at::Tensor & self, int64_t dim, at::TensorList o
|
|||
}
|
||||
|
||||
|
||||
at::Tensor& view_copy_out(const at::Tensor & self, at::SymIntArrayRef size, at::Tensor & out) {
|
||||
at::Tensor& view_copy_out_symint(const at::Tensor & self, at::SymIntArrayRef size, at::Tensor & out) {
|
||||
auto tmp = self.view_symint(size);
|
||||
out.copy_(tmp);
|
||||
return out;
|
||||
|
|
|
|||
|
|
@ -2054,7 +2054,7 @@
|
|||
CPU: empty_cpu
|
||||
CUDA: empty_cuda
|
||||
MPS: empty_mps
|
||||
Meta: empty_meta
|
||||
Meta: empty_meta_symint
|
||||
MkldnnCPU: empty_mkldnn
|
||||
SparseCPU, SparseCUDA, SparseMeta: empty_sparse
|
||||
SparseCsrCPU, SparseCsrCUDA: empty_sparse_compressed
|
||||
|
|
@ -2065,7 +2065,7 @@
|
|||
- func: new_empty(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
|
||||
variants: method
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: new_empty
|
||||
CompositeExplicitAutograd: new_empty_symint
|
||||
autogen: new_empty.out
|
||||
|
||||
- func: new_empty_strided(Tensor self, int[] size, int[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
|
||||
|
|
@ -5548,7 +5548,7 @@
|
|||
|
||||
- func: zeros(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: zeros
|
||||
CompositeExplicitAutograd: zeros_symint
|
||||
|
||||
- func: zeros.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
|
||||
dispatch:
|
||||
|
|
@ -6889,8 +6889,7 @@
|
|||
device_check: NoCheck
|
||||
device_guard: False
|
||||
dispatch:
|
||||
Meta: view_meta
|
||||
ZeroTensor, CPU, CUDA, QuantizedCPU, QuantizedCUDA, MPS: view
|
||||
ZeroTensor, Meta, CPU, CUDA, QuantizedCPU, QuantizedCUDA, MPS: view
|
||||
MkldnnCPU: mkldnn_view
|
||||
NestedTensorCPU, NestedTensorCUDA: view_nested
|
||||
|
||||
|
|
@ -12938,7 +12937,7 @@
|
|||
- func: expand_copy.out(Tensor self, SymInt[] size, *, bool implicit=False, Tensor(a!) out) -> Tensor(a!)
|
||||
variants: function
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: expand_copy_out
|
||||
CompositeExplicitAutograd: expand_copy_out_symint
|
||||
|
||||
|
||||
- func: permute_copy.out(Tensor self, int[] dims, *, Tensor(a!) out) -> Tensor(a!)
|
||||
|
|
@ -13058,7 +13057,7 @@
|
|||
- func: view_copy.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
|
||||
variants: function
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: view_copy_out
|
||||
CompositeExplicitAutograd: view_copy_out_symint
|
||||
|
||||
|
||||
- func: view_copy.dtype_out(Tensor self, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
|
||||
|
|
|
|||
|
|
@ -199,6 +199,13 @@ supported:
|
|||
- _trilinear
|
||||
- linalg_pinv.atol_rtol_tensor
|
||||
- logsumexp.out
|
||||
symint:
|
||||
- empty.memory_format
|
||||
- expand
|
||||
- expand_copy
|
||||
- narrow_copy
|
||||
- view
|
||||
- view_copy
|
||||
autograd:
|
||||
- max_pool3d
|
||||
- native_group_norm
|
||||
|
|
|
|||
|
|
@ -119,7 +119,7 @@ TEST(MathKernelTest, NarrowCopy) {
|
|||
for (const auto dim : c10::irange(3)) {
|
||||
const int64_t start = 1, length = 4;
|
||||
auto y_ref = x.narrow(dim, start, length);
|
||||
auto y_test = at::native::narrow_copy_dense(x, dim, c10::SymInt(start), c10::SymInt(length));
|
||||
auto y_test = at::native::narrow_copy_dense(x, dim, start, length);
|
||||
ASSERT_ALLCLOSE_TOLERANCES(y_ref, y_test, 0, 0);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import distutils.command.clean
|
||||
import sys
|
||||
import shutil
|
||||
import glob
|
||||
import os
|
||||
|
|
@ -129,21 +130,25 @@ class BuildExtension_(BuildExtension):
|
|||
if __name__ == '__main__':
|
||||
print("Building wheel {}-{}".format(package_name, version))
|
||||
write_version_file()
|
||||
setup(
|
||||
# Metadata
|
||||
name=package_name,
|
||||
version=version,
|
||||
author='PyTorch Core Team',
|
||||
url="https://github.com/pytorch/functorch",
|
||||
description='JAX-like composable function transforms for PyTorch',
|
||||
license='BSD',
|
||||
try:
|
||||
setup(
|
||||
# Metadata
|
||||
name=package_name,
|
||||
version=version,
|
||||
author='PyTorch Core Team',
|
||||
url="https://github.com/pytorch/functorch",
|
||||
description='JAX-like composable function transforms for PyTorch',
|
||||
license='BSD',
|
||||
|
||||
# Package info
|
||||
packages=find_packages(),
|
||||
install_requires=requirements,
|
||||
extras_require=extras,
|
||||
ext_modules=get_extensions(),
|
||||
cmdclass={
|
||||
"build_ext": BuildExtension_.with_options(no_python_abi_suffix=True),
|
||||
'clean': clean,
|
||||
})
|
||||
# Package info
|
||||
packages=find_packages(),
|
||||
install_requires=requirements,
|
||||
extras_require=extras,
|
||||
ext_modules=get_extensions(),
|
||||
cmdclass={
|
||||
"build_ext": BuildExtension_.with_options(no_python_abi_suffix=True),
|
||||
'clean': clean,
|
||||
})
|
||||
except Exception as e:
|
||||
print(e, file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
|
|
|||
|
|
@ -584,7 +584,8 @@ def gen_inplace_or_view_type(
|
|||
[fn for fn in fns_with_infos if use_derived(fn)],
|
||||
key_fn=lambda fn: fn.func.root_name,
|
||||
base_env={
|
||||
"generated_comment": f"@generated from {template_path}/ADInplaceOrViewType.cpp",
|
||||
"generated_comment": "@"
|
||||
+ f"generated from {template_path}/ADInplaceOrViewType.cpp",
|
||||
},
|
||||
env_callable=gen_inplace_or_view_type_env,
|
||||
num_shards=2,
|
||||
|
|
|
|||
|
|
@ -535,7 +535,7 @@ def gen_trace_type(
|
|||
[fn for fn in native_functions if cpp.name(fn.func) not in MANUAL_TRACER],
|
||||
key_fn=lambda fn: fn.root_name,
|
||||
base_env={
|
||||
"generated_comment": f"@generated from {template_path}/TraceType.cpp",
|
||||
"generated_comment": "@" + f"generated from {template_path}/TraceType.cpp",
|
||||
},
|
||||
env_callable=gen_trace_type_func,
|
||||
num_shards=5,
|
||||
|
|
|
|||
|
|
@ -314,7 +314,6 @@ class TestGenNativeFunctionDeclaration(unittest.TestCase):
|
|||
dispatch_key=k,
|
||||
use_out_as_primary=True,
|
||||
external=False,
|
||||
symint=False,
|
||||
device_guard=False,
|
||||
index=backend_indices[k],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -238,7 +238,7 @@ invalid_key: invalid_val"""
|
|||
output_error = self.get_errors_from_gen_backend_stubs(yaml_str)
|
||||
self.assertExpectedInline(
|
||||
output_error,
|
||||
""" contains unexpected keys: invalid_key. Only the following keys are supported: backend, class_name, cpp_namespace, extra_headers, supported, autograd, full_codegen, non_native, ir_gen""", # noqa: B950
|
||||
""" contains unexpected keys: invalid_key. Only the following keys are supported: backend, class_name, cpp_namespace, extra_headers, supported, autograd, full_codegen, non_native, ir_gen, symint""", # noqa: B950
|
||||
)
|
||||
|
||||
# if use_out_as_primary is provided, it must be a bool
|
||||
|
|
|
|||
|
|
@ -269,7 +269,7 @@ at::Tensor LazyNativeFunctions::_to_copy(
|
|||
}
|
||||
};
|
||||
|
||||
at::Tensor LazyNativeFunctions::empty(
|
||||
at::Tensor LazyNativeFunctions::empty_symint(
|
||||
at::SymIntArrayRef sym_size,
|
||||
c10::optional<at::ScalarType> dtype,
|
||||
c10::optional<at::Layout> layout,
|
||||
|
|
@ -307,7 +307,7 @@ at::Tensor LazyNativeFunctions::empty_strided(
|
|||
c10::optional<at::Device> device,
|
||||
c10::optional<bool> pin_memory) {
|
||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||
at::Tensor t = empty(
|
||||
at::Tensor t = empty_symint(
|
||||
c10::SymIntArrayRef::fromIntArrayRef(size),
|
||||
dtype,
|
||||
layout,
|
||||
|
|
@ -409,7 +409,7 @@ at::Tensor LazyNativeFunctions::_unsafe_view(
|
|||
const at::Tensor& self,
|
||||
at::IntArrayRef size) {
|
||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||
return LazyNativeFunctions::view_copy(
|
||||
return LazyNativeFunctions::view_copy_symint(
|
||||
self, c10::SymIntArrayRef::fromIntArrayRef(size));
|
||||
}
|
||||
|
||||
|
|
@ -449,7 +449,7 @@ at::Tensor LazyNativeFunctions::new_empty_strided(
|
|||
self, size, stride, dtype, layout, device, pin_memory);
|
||||
}
|
||||
|
||||
at::Tensor LazyNativeFunctions::narrow_copy(
|
||||
at::Tensor LazyNativeFunctions::narrow_copy_symint(
|
||||
const at::Tensor& self,
|
||||
int64_t dim,
|
||||
c10::SymInt start,
|
||||
|
|
|
|||
|
|
@ -35,7 +35,12 @@ def name(func: FunctionSchema) -> str:
|
|||
|
||||
|
||||
def argumenttype_type(
|
||||
t: Type, *, mutable: bool, binds: ArgName, remove_non_owning_ref_types: bool = False
|
||||
t: Type,
|
||||
*,
|
||||
mutable: bool,
|
||||
binds: ArgName,
|
||||
remove_non_owning_ref_types: bool = False,
|
||||
symint: bool = True,
|
||||
) -> NamedCType:
|
||||
# This is a faux amis. If it makes sense in the future to add
|
||||
# more special cases here, or invert things so cpp.argument_type
|
||||
|
|
@ -45,25 +50,30 @@ def argumenttype_type(
|
|||
t,
|
||||
mutable=mutable,
|
||||
binds=binds,
|
||||
symint=True,
|
||||
symint=symint,
|
||||
remove_non_owning_ref_types=remove_non_owning_ref_types,
|
||||
)
|
||||
|
||||
|
||||
def argument_type(
|
||||
a: Argument, *, binds: ArgName, remove_non_owning_ref_types: bool = False
|
||||
a: Argument,
|
||||
*,
|
||||
binds: ArgName,
|
||||
remove_non_owning_ref_types: bool = False,
|
||||
symint: bool = True,
|
||||
) -> NamedCType:
|
||||
return argumenttype_type(
|
||||
a.type,
|
||||
mutable=a.is_write,
|
||||
binds=binds,
|
||||
remove_non_owning_ref_types=remove_non_owning_ref_types,
|
||||
symint=symint,
|
||||
)
|
||||
|
||||
|
||||
def returns_type(rs: Sequence[Return]) -> CType:
|
||||
def returns_type(rs: Sequence[Return], *, symint: bool = True) -> CType:
|
||||
# At present, there is no difference. But there could be!
|
||||
return cpp.returns_type(rs, symint=True)
|
||||
return cpp.returns_type(rs, symint=symint)
|
||||
|
||||
|
||||
def jit_arguments(func: FunctionSchema) -> List[Argument]:
|
||||
|
|
@ -89,15 +99,20 @@ def jit_arguments(func: FunctionSchema) -> List[Argument]:
|
|||
)
|
||||
|
||||
|
||||
def argument(a: Argument, *, remove_non_owning_ref_types: bool = False) -> Binding:
|
||||
def argument(
|
||||
a: Argument, *, remove_non_owning_ref_types: bool = False, symint: bool = True
|
||||
) -> Binding:
|
||||
return Binding(
|
||||
nctype=argument_type(
|
||||
a, binds=a.name, remove_non_owning_ref_types=remove_non_owning_ref_types
|
||||
a,
|
||||
binds=a.name,
|
||||
remove_non_owning_ref_types=remove_non_owning_ref_types,
|
||||
symint=symint,
|
||||
),
|
||||
name=a.name,
|
||||
argument=a,
|
||||
)
|
||||
|
||||
|
||||
def arguments(func: FunctionSchema) -> List[Binding]:
|
||||
return [argument(a) for a in jit_arguments(func)]
|
||||
def arguments(func: FunctionSchema, *, symint: bool = True) -> List[Binding]:
|
||||
return [argument(a, symint=symint) for a in jit_arguments(func)]
|
||||
|
|
|
|||
|
|
@ -577,8 +577,10 @@ class DispatcherSignature:
|
|||
# and need to avoid naming collisions.
|
||||
prefix: str = ""
|
||||
|
||||
symint: bool = True
|
||||
|
||||
def arguments(self) -> List[Binding]:
|
||||
return dispatcher.arguments(self.func)
|
||||
return dispatcher.arguments(self.func, symint=self.symint)
|
||||
|
||||
def name(self) -> str:
|
||||
return self.prefix + dispatcher.name(self.func)
|
||||
|
|
@ -604,7 +606,7 @@ class DispatcherSignature:
|
|||
return [Expr(a.name, a.nctype) for a in self.arguments()]
|
||||
|
||||
def returns_type(self) -> CType:
|
||||
return dispatcher.returns_type(self.func.returns)
|
||||
return dispatcher.returns_type(self.func.returns, symint=self.symint)
|
||||
|
||||
def ptr_type(self) -> str:
|
||||
dispatcher_args_types_str = ", ".join(a.type for a in self.arguments())
|
||||
|
|
@ -616,8 +618,10 @@ class DispatcherSignature:
|
|||
return f"{self.returns_type().cpp_type()} ({dispatcher_args_types_str})"
|
||||
|
||||
@staticmethod
|
||||
def from_schema(func: FunctionSchema, *, prefix: str = "") -> "DispatcherSignature":
|
||||
return DispatcherSignature(func, prefix)
|
||||
def from_schema(
|
||||
func: FunctionSchema, *, prefix: str = "", symint: bool = True
|
||||
) -> "DispatcherSignature":
|
||||
return DispatcherSignature(func, prefix, symint)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
|
@ -778,15 +782,16 @@ def kernel_signature(
|
|||
# so we'd like to keep the differences as small as possible.
|
||||
# With external backends, we'd like to enforce that they write their kernels with schemas
|
||||
# that match the Dispatcher API directly, if they can.
|
||||
meta = backend_index.get_kernel(f)
|
||||
symint = meta is not None and meta.supports_symint()
|
||||
if symint:
|
||||
assert (
|
||||
f.func.has_symint()
|
||||
), f"attempted to define symint kernel for {backend_index.dispatch_key} without SymInt in schema"
|
||||
if backend_index.external:
|
||||
# Dispatcher signature faithfully does SymInt, which is good for XLA,
|
||||
# not so good for more conventional backends but we don't have any of
|
||||
# those. If we do, that's time to add a new Signature that is a cross
|
||||
# between DispatcherSignature and NativeSignature
|
||||
assert backend_index.symint
|
||||
return DispatcherSignature.from_schema(f.func, prefix=prefix)
|
||||
return DispatcherSignature.from_schema(f.func, prefix=prefix, symint=symint)
|
||||
else:
|
||||
return NativeSignature(f.func, prefix=prefix, symint=backend_index.symint)
|
||||
return NativeSignature(f.func, prefix=prefix, symint=symint)
|
||||
|
||||
|
||||
# Functions only, no types
|
||||
|
|
|
|||
|
|
@ -751,8 +751,11 @@ resize_out(out, sizes, strides, options);
|
|||
)
|
||||
|
||||
# Signature of the wrapper function we'll register to the dispatcher
|
||||
kern = self.backend_index.get_kernel(f)
|
||||
sig = NativeSignature(
|
||||
f.func, prefix="wrapper_", symint=self.backend_index.symint
|
||||
f.func,
|
||||
prefix="wrapper_",
|
||||
symint=kern is not None and kern.supports_symint(),
|
||||
)
|
||||
|
||||
if self.target is Target.NAMESPACED_DECLARATION:
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ from torchgen.api import cpp
|
|||
from torchgen.api.translate import translate
|
||||
from torchgen.api.types import (
|
||||
Binding,
|
||||
CppSignature,
|
||||
CppSignatureGroup,
|
||||
DispatcherSignature,
|
||||
NamedCType,
|
||||
|
|
@ -161,7 +162,6 @@ def parse_native_yaml_struct(
|
|||
device_guard=False,
|
||||
# I'm actually not sure about this; undefined could be hit on
|
||||
# empty TensorList, hypothetically that could have sizes in it
|
||||
symint=False,
|
||||
index={},
|
||||
)
|
||||
)
|
||||
|
|
@ -176,16 +176,6 @@ def parse_native_yaml_struct(
|
|||
# Only cuda-like devices in tree require device guards
|
||||
device_guard=is_cuda_dispatch_key(k),
|
||||
index=v,
|
||||
# Which dispatch keys natively support symint
|
||||
# Note: DispatchKey.CompositeExplicitAutograd has to match out
|
||||
# composites; I think there's some factoring problem here
|
||||
symint=k
|
||||
in [
|
||||
DispatchKey.Meta,
|
||||
DispatchKey.CompositeImplicitAutograd,
|
||||
DispatchKey.CompositeExplicitAutograd,
|
||||
DispatchKey.CompositeExplicitAutogradNonFunctional,
|
||||
],
|
||||
)
|
||||
return ParsedYaml(rs, indices)
|
||||
|
||||
|
|
@ -363,6 +353,8 @@ def static_dispatch_extra_headers(backends: List[BackendIndex]) -> List[str]:
|
|||
# supporting usecases even when there is a memory_format argument along with tensor_option arguments.
|
||||
# This usecase is not covered by tools.codegen.api.translate() yet as its application is limited to static dispatch
|
||||
def translate_args_dispatcher_to_cpp(
|
||||
sig: DispatcherSignature,
|
||||
cpp_sig: CppSignature,
|
||||
f: NativeFunction,
|
||||
) -> str:
|
||||
|
||||
|
|
@ -385,10 +377,7 @@ def translate_args_dispatcher_to_cpp(
|
|||
output_bindings.append(binding)
|
||||
return output_bindings
|
||||
|
||||
disp_sig = DispatcherSignature.from_schema(f.func)
|
||||
cpp_sig = CppSignatureGroup.from_native_function(
|
||||
f, method=False, fallback_binding=False
|
||||
).signature
|
||||
disp_sig = sig
|
||||
disp_bindings = disp_sig.arguments()
|
||||
# When last argument of CPP signature has SpecialArgName.possibly_redundant_memory_format NCType,
|
||||
# get memory_format bindings of dispatcher signature to have the same NCType as well
|
||||
|
|
@ -401,11 +390,20 @@ def translate_args_dispatcher_to_cpp(
|
|||
|
||||
|
||||
def generate_static_dispatch_backend_call(
|
||||
sig: DispatcherSignature,
|
||||
f: NativeFunction,
|
||||
backend_index: BackendIndex,
|
||||
) -> str:
|
||||
name = DispatcherSignature.from_schema(f.func).name()
|
||||
exprs = translate_args_dispatcher_to_cpp(f)
|
||||
cpp_sigs = CppSignatureGroup.from_native_function(
|
||||
f, method=False, fallback_binding=False
|
||||
)
|
||||
if sig.symint and f.func.has_symint():
|
||||
cpp_sig = cpp_sigs.symint_signature
|
||||
else:
|
||||
cpp_sig = cpp_sigs.signature
|
||||
assert cpp_sig is not None
|
||||
name = cpp_sig.name()
|
||||
exprs = translate_args_dispatcher_to_cpp(sig, cpp_sig, f)
|
||||
backend_metadata = backend_index.get_kernel(f)
|
||||
kernel_ns = (
|
||||
backend_metadata.cpp_namespace
|
||||
|
|
@ -417,11 +415,20 @@ def generate_static_dispatch_backend_call(
|
|||
|
||||
|
||||
def generate_static_dispatch_fallback_call(
|
||||
sig: DispatcherSignature,
|
||||
f: NativeFunction,
|
||||
backend_indices: List[BackendIndex],
|
||||
) -> str:
|
||||
name = DispatcherSignature.from_schema(f.func).name()
|
||||
exprs = translate_args_dispatcher_to_cpp(f)
|
||||
cpp_sigs = CppSignatureGroup.from_native_function(
|
||||
f, method=False, fallback_binding=False
|
||||
)
|
||||
if sig.symint and f.func.has_symint():
|
||||
cpp_sig = cpp_sigs.symint_signature
|
||||
else:
|
||||
cpp_sig = cpp_sigs.signature
|
||||
assert cpp_sig is not None
|
||||
name = cpp_sig.name()
|
||||
exprs = translate_args_dispatcher_to_cpp(sig, cpp_sig, f)
|
||||
ns = DEFAULT_KERNEL_NAMESPACE.replace("::native", "")
|
||||
if f.has_composite_explicit_autograd_kernel:
|
||||
return f"return {ns}::{DispatchKey.CompositeExplicitAutograd.lower()}::{name}({exprs});"
|
||||
|
|
@ -437,6 +444,7 @@ def generate_static_dispatch_fallback_call(
|
|||
|
||||
|
||||
def static_dispatch(
|
||||
sig: DispatcherSignature,
|
||||
f: NativeFunction,
|
||||
backend_indices: List[BackendIndex],
|
||||
) -> str:
|
||||
|
|
@ -453,11 +461,10 @@ def static_dispatch(
|
|||
)
|
||||
]
|
||||
if len(keys) == 1:
|
||||
return generate_static_dispatch_backend_call(f, keys[0])
|
||||
return generate_static_dispatch_backend_call(sig, f, keys[0])
|
||||
elif len(keys) == 0:
|
||||
return generate_static_dispatch_fallback_call(f, backend_indices)
|
||||
return generate_static_dispatch_fallback_call(sig, f, backend_indices)
|
||||
|
||||
sig = DispatcherSignature.from_schema(f.func)
|
||||
native_tensor_args = [
|
||||
a.name
|
||||
for a in sig.arguments()
|
||||
|
|
@ -483,10 +490,10 @@ def static_dispatch(
|
|||
for index in keys:
|
||||
dispatch_code.append(f"""case DispatchKey::{index.dispatch_key}:""")
|
||||
dispatch_code.append(
|
||||
f"""\t{generate_static_dispatch_backend_call(f, index)};"""
|
||||
f"""\t{generate_static_dispatch_backend_call(sig, f, index)};"""
|
||||
)
|
||||
|
||||
fallback = generate_static_dispatch_fallback_call(f, backend_indices)
|
||||
fallback = generate_static_dispatch_fallback_call(sig, f, backend_indices)
|
||||
connector = "\n\t\t"
|
||||
|
||||
return f"""
|
||||
|
|
@ -528,8 +535,6 @@ class ComputeOperators:
|
|||
def __call__(self, f: NativeFunction) -> str:
|
||||
sig = DispatcherSignature.from_schema(f.func)
|
||||
name = f.func.name.unambiguous_name()
|
||||
call_method_name = "call"
|
||||
redispatch_method_name = "redispatch"
|
||||
|
||||
if self.target is Target.DECLARATION:
|
||||
# Note [The ATen Operators API]
|
||||
|
|
@ -563,8 +568,8 @@ struct TORCH_API {name} {{
|
|||
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::{f.func.name.name}")
|
||||
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "{f.func.name.overload_name}")
|
||||
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, {cpp_string(str(f.func))})
|
||||
static {sig.defn(name=call_method_name, is_redispatching_fn=False)};
|
||||
static {sig.defn(name=redispatch_method_name, is_redispatching_fn=True)};
|
||||
static {sig.defn(name="call", is_redispatching_fn=False)};
|
||||
static {sig.defn(name="redispatch", is_redispatching_fn=True)};
|
||||
}};"""
|
||||
|
||||
elif self.target is Target.DEFINITION:
|
||||
|
|
@ -585,12 +590,13 @@ static C10_NOINLINE c10::TypedOperatorHandle<{name}::schema> create_{name}_typed
|
|||
dispatcher_exprs_str = ", ".join(
|
||||
["dispatchKeySet"] + [a.name for a in sig.arguments()]
|
||||
)
|
||||
dispatcher_call = "redispatch"
|
||||
method_name = f"{name}::{redispatch_method_name}"
|
||||
method_base = "redispatch"
|
||||
else:
|
||||
method_name = f"{name}::{call_method_name}"
|
||||
dispatcher_exprs_str = ", ".join([a.name for a in sig.arguments()])
|
||||
dispatcher_call = "call"
|
||||
method_base = "call"
|
||||
|
||||
dispatcher_call = method_base
|
||||
method_name = f"{name}::{method_base}"
|
||||
|
||||
fn_body = f"""
|
||||
static auto op = create_{name}_typed_handle();
|
||||
|
|
@ -602,7 +608,7 @@ static C10_NOINLINE c10::TypedOperatorHandle<{name}::schema> create_{name}_typed
|
|||
):
|
||||
# call() should go through static dispatch
|
||||
fn_body = static_dispatch(
|
||||
f, backend_indices=self.static_dispatch_backend_indices
|
||||
sig, f, backend_indices=self.static_dispatch_backend_indices
|
||||
)
|
||||
defns += f"""
|
||||
// aten::{f.func}
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import os
|
|||
import pathlib
|
||||
import re
|
||||
from collections import Counter, defaultdict, namedtuple
|
||||
from typing import Dict, List, Optional, Sequence, Union
|
||||
from typing import Dict, List, Optional, Sequence, Set, Union
|
||||
|
||||
import yaml
|
||||
|
||||
|
|
@ -68,6 +68,7 @@ def parse_backend_yaml(
|
|||
"full_codegen",
|
||||
"non_native",
|
||||
"ir_gen",
|
||||
"symint",
|
||||
]
|
||||
|
||||
backend = yaml_values.pop("backend", None)
|
||||
|
|
@ -96,6 +97,14 @@ def parse_backend_yaml(
|
|||
supported, list
|
||||
), f'expected "supported" to be a list, but got: {supported} (of type {type(supported)})'
|
||||
|
||||
symint = yaml_values.pop("symint", [])
|
||||
if symint is None:
|
||||
symint = [] # Allow an empty list of symint ops
|
||||
assert isinstance(
|
||||
symint, list
|
||||
), f'expected "symint" to be a list, but got: {supported} (of type {type(supported)})'
|
||||
symint_set = set(symint)
|
||||
|
||||
supported_autograd = yaml_values.pop("autograd", [])
|
||||
assert isinstance(
|
||||
supported_autograd, list
|
||||
|
|
@ -118,6 +127,7 @@ Only the following keys are supported: {", ".join(valid_keys)}'
|
|||
|
||||
def create_backend_index(
|
||||
backend_ops: List[str],
|
||||
symint_ops: Set[str],
|
||||
dispatch_key: DispatchKey,
|
||||
*,
|
||||
use_out_as_primary: bool,
|
||||
|
|
@ -131,6 +141,8 @@ Only the following keys are supported: {", ".join(valid_keys)}'
|
|||
), f"Found an invalid operator name: {op_name}"
|
||||
# See Note [External Backends Follow Dispatcher API]
|
||||
kernel_name = dispatcher.name(native_functions_map[op_name].func)
|
||||
if op in symint_ops:
|
||||
kernel_name += "_symint"
|
||||
# TODO: allow structured external backends later.
|
||||
m = BackendMetadata(
|
||||
kernel=kernel_name, structured=False, cpp_namespace=cpp_namespace
|
||||
|
|
@ -140,7 +152,6 @@ Only the following keys are supported: {", ".join(valid_keys)}'
|
|||
dispatch_key=dispatch_key,
|
||||
use_out_as_primary=use_out_as_primary,
|
||||
external=True,
|
||||
symint=True, # TODO: make this configurable
|
||||
device_guard=use_device_guard,
|
||||
index=metadata,
|
||||
)
|
||||
|
|
@ -154,6 +165,7 @@ Only the following keys are supported: {", ".join(valid_keys)}'
|
|||
|
||||
backend_idx = create_backend_index(
|
||||
supported,
|
||||
symint_set,
|
||||
backend_key,
|
||||
use_out_as_primary=use_out_as_primary,
|
||||
use_device_guard=use_device_guard,
|
||||
|
|
@ -171,6 +183,7 @@ the behavior of autograd for some operators on your backend. However "Autograd{b
|
|||
|
||||
autograd_idx = create_backend_index(
|
||||
supported_autograd,
|
||||
symint_set,
|
||||
autograd_key,
|
||||
use_out_as_primary=use_out_as_primary,
|
||||
use_device_guard=use_device_guard,
|
||||
|
|
|
|||
|
|
@ -84,21 +84,21 @@ def gen_composite_view_copy_kernel(g: NativeFunctionsViewGroup) -> Optional[str]
|
|||
# clone() calls in their graph (which is normally needed by reshape).
|
||||
if str(g.view_copy.func.name) == "view_copy":
|
||||
return """\
|
||||
at::Tensor view_copy(const at::Tensor & self, at::SymIntArrayRef size) {
|
||||
// TODO: don't cast to int array ref
|
||||
auto int_size = c10::asIntArrayRefSlow(size);
|
||||
DimVector shape = infer_size_dv(int_size, self.numel());
|
||||
at::Tensor view_copy(const at::Tensor & self, at::IntArrayRef size) {
|
||||
DimVector shape = infer_size_dv(size, self.numel());
|
||||
if (!at::detail::computeStride(self.sizes(), self.strides(), shape).has_value()) {
|
||||
return self.reshape(int_size);
|
||||
return self.reshape(size);
|
||||
} else {
|
||||
auto output = at::_ops::view::call(self, size);
|
||||
auto output = at::_ops::view::call(self, c10::SymIntArrayRef::fromIntArrayRef(size));
|
||||
return output.clone();
|
||||
}
|
||||
}
|
||||
"""
|
||||
# view_copy is a native signature, since we're generating an at::native:: kernel
|
||||
# Functionalization always operates on symints though
|
||||
view_copy_sig = NativeSignature(g.view_copy.func, symint=True)
|
||||
view_copy_sig = NativeSignature(
|
||||
g.view_copy.func, symint=False
|
||||
) # TODO: flag day this True
|
||||
|
||||
# view is a dispatcher signature, since we're calling into the at::_ops API
|
||||
view_sig = DispatcherSignature(g.view.func)
|
||||
|
|
@ -641,7 +641,7 @@ def gen_functionalization_registration(
|
|||
metadata = composite_implicit_autograd_index.get_kernel(f)
|
||||
assert metadata is not None
|
||||
native_api_name = metadata.kernel
|
||||
sig = DispatcherSignature.from_schema(f.func)
|
||||
sig = NativeSignature(f.func, symint=metadata.supports_symint())
|
||||
# Note [Composite view ops in the functionalization pass]
|
||||
# We don't need to worry about implemententing functionalization kernels for views with
|
||||
# CompositeImplicitAutograd kernels, because we can just decompose them into their base operators.
|
||||
|
|
|
|||
|
|
@ -79,7 +79,7 @@ def gen_unwraps(
|
|||
|
||||
|
||||
def gen_case_where_all_bdims_are_none(
|
||||
schema: FunctionSchema, cur_level_var: str
|
||||
outer_sig: DispatcherSignature, schema: FunctionSchema, cur_level_var: str
|
||||
) -> str:
|
||||
conditions = []
|
||||
flat_args = schema.arguments.flat_all
|
||||
|
|
@ -90,7 +90,7 @@ def gen_case_where_all_bdims_are_none(
|
|||
|
||||
sig = DispatcherSignature.from_schema(schema)
|
||||
translated_args = ", ".join(
|
||||
e.expr for e in translate(sig.arguments(), sig.arguments())
|
||||
e.expr for e in translate(outer_sig.arguments(), sig.arguments())
|
||||
)
|
||||
return f"""\
|
||||
if ({' && '.join(conditions)}) {{
|
||||
|
|
@ -160,7 +160,7 @@ def gen_vmap_inplace_plumbing(native_function: NativeFunction) -> Optional[str]:
|
|||
cur_level_var = "cur_level"
|
||||
|
||||
unwraps, unwrapped_arg_list = gen_unwraps(schema.arguments.flat_all, cur_level_var)
|
||||
bdims_all_none_case = gen_case_where_all_bdims_are_none(schema, cur_level_var)
|
||||
bdims_all_none_case = gen_case_where_all_bdims_are_none(sig, schema, cur_level_var)
|
||||
|
||||
return f"""\
|
||||
template <typename batch_rule_t, batch_rule_t batch_rule>
|
||||
|
|
@ -182,7 +182,7 @@ def gen_vmap_plumbing_no_returns(native_function: NativeFunction) -> str:
|
|||
cur_level_var = "cur_level"
|
||||
|
||||
unwraps, unwrapped_arg_list = gen_unwraps(schema.arguments.flat_all, cur_level_var)
|
||||
bdims_all_none_case = gen_case_where_all_bdims_are_none(schema, cur_level_var)
|
||||
bdims_all_none_case = gen_case_where_all_bdims_are_none(sig, schema, cur_level_var)
|
||||
|
||||
return f"""\
|
||||
template <typename batch_rule_t, batch_rule_t batch_rule>
|
||||
|
|
@ -224,7 +224,7 @@ def gen_vmap_plumbing(native_function: NativeFunction) -> Optional[str]:
|
|||
cur_level_var = "cur_level"
|
||||
|
||||
unwraps, unwrapped_arg_list = gen_unwraps(schema.arguments.flat_all, cur_level_var)
|
||||
bdims_all_none_case = gen_case_where_all_bdims_are_none(schema, cur_level_var)
|
||||
bdims_all_none_case = gen_case_where_all_bdims_are_none(sig, schema, cur_level_var)
|
||||
|
||||
wrapped_returns = gen_returns(returns, cur_level_var, results_var)
|
||||
return f"""\
|
||||
|
|
|
|||
|
|
@ -1098,6 +1098,9 @@ class BackendMetadata:
|
|||
# The namespace for kernels, default value: DEFAULT_KERNEL_NAMESPACE
|
||||
cpp_namespace: str
|
||||
|
||||
def supports_symint(self) -> bool:
|
||||
return "_symint" in self.kernel
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class UfuncInnerLoop:
|
||||
|
|
@ -1141,8 +1144,6 @@ class BackendIndex:
|
|||
external: bool
|
||||
# Other backend-specific information that is on a per-operator basis
|
||||
index: Dict["OperatorName", BackendMetadata]
|
||||
# Whether or not this backend handles symbolic ints or not
|
||||
symint: bool
|
||||
|
||||
@staticmethod
|
||||
def grow_index(
|
||||
|
|
|
|||
|
|
@ -304,6 +304,8 @@ def generate_function(
|
|||
if func.kind() == SchemaKind.out
|
||||
else cpp.name(func)
|
||||
)
|
||||
if f.func.has_symint():
|
||||
kernel_name += "_symint"
|
||||
backend_metadata = {
|
||||
DispatchKey.CompositeExplicitAutograd: {
|
||||
func.name: BackendMetadata(
|
||||
|
|
@ -555,7 +557,7 @@ def gen_composite_functional_kernel(g: NativeFunctionsGroup) -> Optional[str]:
|
|||
|
||||
clone_mutable_inputs_str = "\n".join(clone_mutable_inputs)
|
||||
return f"""
|
||||
{sig.defn()} {{
|
||||
{sig.defn(name=sig.name() + ("_symint" if g.out.func.has_symint() else ""))} {{
|
||||
{clone_mutable_inputs_str}
|
||||
{maybe_assign}at::_ops::{target_f.func.name.unambiguous_name()}::call({exprs});
|
||||
{ret_str}
|
||||
|
|
@ -615,7 +617,7 @@ def gen_composite_out_kernel(g: NativeFunctionsGroup) -> Optional[str]:
|
|||
|
||||
# Kernel name needs to follow the naming convention defined in `generate_function()`
|
||||
return f"""
|
||||
{sig.defn(name=g.out.func.name.unambiguous_name())} {{
|
||||
{sig.defn(name=g.out.func.name.unambiguous_name() + ("_symint" if g.out.func.has_symint() else ""))} {{
|
||||
auto {out_name} = at::_ops::{g.functional.func.name.unambiguous_name()}::call({exprs});
|
||||
{copy_outs_str}
|
||||
{return_str(g.out.func.returns, rets)}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user