Revert "Don't introduce new overload for SymInt (#83628)"

This reverts commit 9790d90e4b.

Reverted https://github.com/pytorch/pytorch/pull/83628 on behalf of https://github.com/malfet due to Breaks internal builds, see D39076487
This commit is contained in:
PyTorch MergeBot 2022-08-27 01:23:17 +00:00
parent 38e5e4a85f
commit c7edcd6968
81 changed files with 729 additions and 766 deletions

View File

@ -1 +1 @@
a668569f7f9b7ecd946cf2551d30d482799d597d
9b2f7929c2dae841888a836449c25b04c8cf4045

View File

@ -186,8 +186,7 @@ Tensor expand_batching_rule(const Tensor& self, IntArrayRef size, bool implicit)
}
Tensor expand_symint_batching_rule(const Tensor& self, SymIntArrayRef psize, bool implicit) {
// TODO: properly support this
return expand_batching_rule(self, asIntArrayRefSlow(psize), implicit);
return self.expand(asIntArrayRefSlow(psize), implicit);
}
std::vector<Tensor> chunk_batching_rule(const Tensor& self, int64_t chunks, int64_t dim) {
@ -470,8 +469,7 @@ Tensor view_batching_rule(const Tensor& self, IntArrayRef size) {
}
Tensor view_symint_batching_rule(const Tensor& self, c10::SymIntArrayRef size) {
// TODO: properly support this
return view_batching_rule(self, asIntArrayRefSlow(size));
return self.view(asIntArrayRefSlow(size));
}
Tensor view_as_complex_batching_rule(const Tensor& self) {
@ -1011,7 +1009,6 @@ Tensor new_empty_symint_batching_rule(
c10::optional<Layout> layout,
c10::optional<Device> device,
c10::optional<bool> pin_memory) {
// TODO: properly support this
return new_empty_batching_rule(self, asIntArrayRefSlow(size), dtype, layout, device, pin_memory);
}
@ -1112,7 +1109,8 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
m.impl("tensor_split.sections", tensor_split_sections_batching_rule);
m.impl("tensor_split.indices", tensor_split_indices_batching_rule);
m.impl("diagonal", diagonal_batching_rule);
m.impl("expand", expand_symint_batching_rule);
m.impl("expand", expand_batching_rule);
m.impl("expand.SymInt", expand_symint_batching_rule);
m.impl("expand_as", native::expand_as); // composite wrt autograd
m.impl("movedim.intlist", movedim_batching_rule);
m.impl("movedim.int", static_cast<Tensor(*)(const Tensor&,int64_t,int64_t)>(native::movedim)); // composite wrt autograd
@ -1140,7 +1138,8 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
m.impl("unbind.int", unbind_batching_rule);
m.impl("unfold", unfold_batching_rule);
m.impl("unsqueeze", unsqueeze_batching_rule);
m.impl("view", view_symint_batching_rule);
m.impl("view", view_batching_rule);
m.impl("view.SymInt", view_symint_batching_rule);
m.impl("view_as", native::view_as); // composite wrt autograd
// clamp operations
@ -1278,7 +1277,8 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
m.impl("diagonal_backward", diagonal_backward_batching_rule);
// Tensor.new_* operators
m.impl("new_empty", new_empty_symint_batching_rule);
m.impl("new_empty", new_empty_batching_rule);
m.impl("new_empty.SymInt", new_empty_symint_batching_rule);
m.impl("new_empty_strided", new_empty_strided_batching_rule);
m.impl("new_zeros", new_zeros_batching_rule);

View File

@ -137,8 +137,12 @@ Tensor FunctionalInverses::diagonal_copy_inverse(const Tensor& base, const Tenso
return base.diagonal_scatter(mutated_view, offset, dim1, dim2);
}
Tensor FunctionalInverses::expand_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, at::SymIntArrayRef size, bool implicit) {
return at::sum_to(mutated_view, base.sym_sizes(),/*always_return_non_view=*/!reapply_views);
Tensor FunctionalInverses::expand_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, at::IntArrayRef size, bool implicit) {
return at::sum_to(mutated_view, base.sizes(),/*always_return_non_view=*/!reapply_views);
}
Tensor FunctionalInverses::expand_copy_SymInt_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, c10::SymIntArrayRef size, bool implicit) {
return at::sum_to(mutated_view, c10::asIntArrayRefSlow(base.sym_sizes()),/*always_return_non_view=*/!reapply_views);
}
Tensor FunctionalInverses::permute_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, at::IntArrayRef dims) {
@ -287,7 +291,15 @@ Tensor FunctionalInverses::unbind_copy_int_inverse(const Tensor& base, const Ten
return base.select_scatter(mutated_view, dim, mutated_view_idx);
}
Tensor FunctionalInverses::view_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, at::SymIntArrayRef size) {
Tensor FunctionalInverses::view_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, at::IntArrayRef size) {
if (reapply_views) {
return mutated_view.view(base.sizes());
} else {
return at::view_copy(mutated_view, base.sizes());
}
}
Tensor FunctionalInverses::view_copy_SymInt_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, c10::SymIntArrayRef size) {
if (reapply_views) {
return mutated_view.view_symint(base.sym_sizes());
} else {
@ -295,7 +307,6 @@ Tensor FunctionalInverses::view_copy_inverse(const Tensor& base, const Tensor& m
}
}
Tensor FunctionalInverses::view_copy_dtype_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, at::ScalarType dtype) {
if (reapply_views) {
return mutated_view.view(base.scalar_type());

View File

@ -179,6 +179,7 @@ TORCH_LIBRARY_IMPL(aten, Named, m) {
m.impl("exp.out", CppFunction::makeFallthrough());
m.impl("exp_", CppFunction::makeFallthrough());
m.impl("expand", CppFunction::makeFallthrough());
m.impl("expand.SymInt", CppFunction::makeFallthrough());
m.impl("expm1", CppFunction::makeFallthrough());
m.impl("expm1.out", CppFunction::makeFallthrough());
m.impl("expm1_", CppFunction::makeFallthrough());

View File

@ -353,14 +353,7 @@ namespace impl {
template<bool AllowDeprecatedTypes>
struct ivalue_to_arg<c10::SymIntArrayRef, AllowDeprecatedTypes> final {
static std::vector<c10::SymInt> call(IValue& v) {
if (v.isIntList()) {
std::vector<c10::SymInt> r;
auto src = v.toIntList();
std::transform(src.begin(), src.end(), std::back_inserter(r), [](int64_t i) { return c10::SymInt(i); });
return r;
} else {
return ivalue_to_arg<std::vector<c10::SymInt>, AllowDeprecatedTypes>::call(v);
}
return ivalue_to_arg<std::vector<c10::SymInt>, AllowDeprecatedTypes>::call(v);
}
};
template<class T, bool AllowDeprecatedTypes>

View File

@ -35,8 +35,7 @@ OperatorEntry::OperatorEntry(OperatorName&& operator_name)
namespace {
void checkSchema(const OperatorName& name, const FunctionSchema& from_def, const std::string& from_def_debug, const FunctionSchema& inferred, const std::string& inferred_debug) {
// TODO: figure out if we can just directly save real schema at def time
c10::optional<std::string> schema_difference = findSchemaDifferences(from_def.cloneWithRealTypes(), inferred);
c10::optional<std::string> schema_difference = findSchemaDifferences(from_def, inferred);
if (schema_difference.has_value()) {
TORCH_CHECK(false,
"Inferred operator schema for a C++ kernel function doesn't match the expected function schema.\n"

View File

@ -231,6 +231,8 @@ TypePtr DynamicType::fallback() const {
return BoolType::get();
case Tag::Int:
return IntType::get();
case Tag::SymInt:
return SymIntType::get();
case Tag::Float:
return FloatType::get();
case Tag::Complex:
@ -324,6 +326,8 @@ DynamicType::Ptr IValue::TagType<c10::DynamicType>::get(const c10::IValue& v) {
return DynamicTypeTrait<ComplexType>::getBaseType();
case Tag::Int:
return DynamicTypeTrait<IntType>::getBaseType();
case Tag::SymInt:
return DynamicTypeTrait<SymIntType>::getBaseType();
case Tag::Bool:
return DynamicTypeTrait<BoolType>::getBaseType();
case Tag::String:

View File

@ -16,6 +16,7 @@ constexpr DynamicTypeBits kDynamicAnyTypeBit = DYNAMIC_TYPE_BIT(30);
constexpr DynamicTypeBits kDynamicNoneTypeBit = DYNAMIC_TYPE_BIT(1);
constexpr DynamicTypeBits kDynamicIntTypeBit = DYNAMIC_TYPE_BIT(3);
constexpr DynamicTypeBits kDynamicSymIntTypeBit = DYNAMIC_TYPE_BIT(23);
constexpr DynamicTypeBits kDynamicFloatTypeBit = DYNAMIC_TYPE_BIT(4);
constexpr DynamicTypeBits kDynamicComplexTypeBit = DYNAMIC_TYPE_BIT(5);
constexpr DynamicTypeBits kDynamicListTypeBit = DYNAMIC_TYPE_BIT(7);
@ -28,6 +29,7 @@ constexpr DynamicTypeBits kDynamicClassTypeBit = DYNAMIC_TYPE_BIT(10);
_(Bool, DYNAMIC_TYPE_BIT(2), 1) \
_(Int, kDynamicIntTypeBit, 1) \
_(Float, kDynamicFloatTypeBit, 1) \
_(SymInt, kDynamicSymIntTypeBit, 1) \
_(Complex, kDynamicComplexTypeBit, 1) \
_(Number, \
(kDynamicIntTypeBit | kDynamicFloatTypeBit | kDynamicComplexTypeBit), \
@ -61,7 +63,6 @@ constexpr DynamicTypeBits kDynamicClassTypeBit = DYNAMIC_TYPE_BIT(10);
#define FORALL_DYNAMIC_TYPES_FAKE(_) \
_(ScalarType, kDynamicIntTypeBit, 1) \
_(Layout, kDynamicIntTypeBit, 1) \
_(SymInt, kDynamicIntTypeBit, 1) \
_(MemoryFormat, kDynamicIntTypeBit, 1)
#define FORWARD_DECL_TYPE(NAME, _, __) struct NAME ## Type;

View File

@ -17,22 +17,6 @@ const std::vector<Argument>& FunctionSchema::getCorrectList(SchemaArgType type)
}
}
FunctionSchema FunctionSchema::cloneWithRealTypes() const {
auto cloneWithRealTypes = [](const Argument& a) {
return a.cloneWithType(a.real_type());
};
std::vector<Argument> new_arguments, new_returns;
std::transform(arguments().begin(), arguments().end(), std::back_inserter(new_arguments), cloneWithRealTypes);
std::transform(returns().begin(), returns().end(), std::back_inserter(new_returns), cloneWithRealTypes);
return FunctionSchema(
name(),
overload_name(),
std::move(new_arguments),
std::move(new_returns),
is_vararg(),
is_varret());
}
bool FunctionSchema::canAliasTypeSetsAlias(const c10::optional<AliasTypeSet> &lhs, const c10::optional<AliasTypeSet> &rhs) const {
if (!lhs || !rhs) {
return false;

View File

@ -44,7 +44,7 @@ struct Argument {
c10::optional<AliasInfo> alias_info = c10::nullopt)
: name_(std::move(name)),
type_(fake_type ? std::move(fake_type) : TensorType::get()),
real_type_(real_type ? std::move(real_type) : type_),
real_type_(real_type ? std::move(real_type) : TensorType::get()),
N_(std::move(N)),
default_value_(std::move(default_value)),
alias_info_(alias_info ? std::make_unique<AliasInfo>(std::move(*alias_info)) : nullptr),
@ -88,8 +88,6 @@ struct Argument {
const TypePtr& type() const {
return type_;
}
// if type() is non-null, this is guaranteed to be non-null (if no real
// type was provided, this takes on type()'s value)
const TypePtr& real_type() const {
return real_type_;
}
@ -474,8 +472,6 @@ struct TORCH_API FunctionSchema {
FunctionSchema cloneWithRemappedTypes(
const std::function<TypePtr(TypePtr)> type_map) const;
FunctionSchema cloneWithRealTypes() const;
// Check that inputs have the correct types and appends any missing default
// values.
template <typename T = c10::PlatformType>

View File

@ -1789,12 +1789,30 @@ struct getTypePtr_<SymInt> final {
}
};
template <>
struct getTypePtr_<c10::ScalarType> final {
static decltype(auto) call() {
return IntType::get();
}
};
template <>
struct getTypePtr_<c10::Device> final {
static decltype(auto) call() {
return DeviceObjType::get();
}
};
template <>
struct getTypePtr_<c10::Layout> final {
static decltype(auto) call() {
return IntType::get();
}
};
template <>
struct getTypePtr_<c10::MemoryFormat> final {
static decltype(auto) call() {
return IntType::get();
}
};
template <>
struct getTypePtr_<bool> final {
static decltype(auto) call() {
return BoolType::get();
@ -2115,27 +2133,6 @@ private:
LayoutType() : EnumerationType() {}
};
namespace detail {
template <>
struct getTypePtr_<c10::ScalarType> final {
static decltype(auto) call() {
return ScalarTypeType::get();
}
};
template <>
struct getTypePtr_<c10::Layout> final {
static decltype(auto) call() {
return LayoutType::get();
}
};
template <>
struct getTypePtr_<c10::MemoryFormat> final {
static decltype(auto) call() {
return MemoryFormatType::get();
}
};
} // namespace detail
// the common supertype of all lists,
// List[T] <: AnyList for all T
struct AnyListType;

View File

@ -54,6 +54,7 @@ namespace at {
#define TENSOR_UTILITIES_AND_CONSTRUCTORS(m) \
m.impl("empty_like", torch::CppFunction::makeFallthrough()); \
m.impl("empty.memory_format", torch::CppFunction::makeFallthrough()); \
m.impl("empty.SymInt", torch::CppFunction::makeFallthrough()); \
m.impl("empty.out", torch::CppFunction::makeFallthrough()); \
m.impl("empty_strided", torch::CppFunction::makeFallthrough()); \
m.impl("full_like", torch::CppFunction::makeFallthrough()); \

View File

@ -13,6 +13,18 @@ namespace at {
namespace native {
Tensor empty_meta(
IntArrayRef size,
c10::optional<ScalarType> dtype_opt,
c10::optional<Layout> layout_opt,
c10::optional<Device> device_opt,
c10::optional<bool> pin_memory_opt,
c10::optional<c10::MemoryFormat> memory_format_opt
) {
return at::detail::empty_meta(
size, dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt);
}
Tensor empty_symint_meta(
SymIntArrayRef size,
c10::optional<ScalarType> dtype_opt,
c10::optional<Layout> layout_opt,

View File

@ -20,7 +20,7 @@ Tensor _bincount_cpu_template(
AT_ERROR("minlength should be >= 0");
}
if (self.dim() == 1 && self.numel() == 0) {
return at::zeros({minlength}, kLong);
return native::zeros({minlength}, kLong);
}
if (self.dim() != 1 || *self.min().data_ptr<input_t>() < 0) {
AT_ERROR("bincount only supports 1-d non-negative integral inputs.");
@ -38,7 +38,7 @@ Tensor _bincount_cpu_template(
const input_t* self_p = self.data_ptr<input_t>();
if (has_weights) {
output = at::zeros(
output = native::zeros(
{nbins},
optTypeMetaToScalarType(weights.options().dtype_opt()),
weights.options().layout_opt(),
@ -50,7 +50,7 @@ Tensor _bincount_cpu_template(
output_p[self_p[i]] += weights_p[i];
}
} else {
output = at::zeros({nbins}, kLong);
output = native::zeros({nbins}, kLong);
int64_t* output_p = output.data_ptr<int64_t>();
for (const auto i : c10::irange(self_size)) {
output_p[self_p[i]] += 1L;

View File

@ -186,7 +186,12 @@ Tensor empty_cpu(IntArrayRef size, c10::optional<ScalarType> dtype_opt, c10::opt
return at::detail::empty_cpu(size, dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt);
}
Tensor empty_names(
Tensor empty_symint_cpu(c10::SymIntArrayRef size, c10::optional<ScalarType> dtype_opt, c10::optional<Layout> layout_opt,
c10::optional<Device> device_opt, c10::optional<bool> pin_memory_opt, c10::optional<c10::MemoryFormat> memory_format_opt) {
return at::native::empty_cpu(c10::asIntArrayRefSlow(size), dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt);
}
Tensor empty(
IntArrayRef size,
c10::optional<DimnameList> names,
c10::optional<ScalarType> dtype,
@ -214,12 +219,9 @@ Tensor empty_strided_cpu(IntArrayRef size, IntArrayRef stride, c10::optional<Sca
return at::detail::empty_strided_cpu(size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt);
}
Tensor& empty_out(SymIntArrayRef sym_size,
Tensor& empty_out(IntArrayRef size,
c10::optional<c10::MemoryFormat> optional_memory_format,
Tensor& result) {
// TODO: support empty_out properly (I was forced to change this immediately
// with empty so that empty/empty.out had the same type signature)
auto size = c10::asIntArrayRefSlow(sym_size);
// Preferably, this argument would not be accepted by _out, but the code
// generator requires the out and non-out overloads to match exactly
TORCH_CHECK(
@ -387,6 +389,17 @@ Tensor empty_like_quantized(
}
Tensor new_empty(
const Tensor& self,
IntArrayRef size,
c10::optional<ScalarType> dtype_opt,
c10::optional<Layout> layout_opt,
c10::optional<Device> device_opt,
c10::optional<bool> pin_memory_opt
) {
return self.new_empty_symint(c10::SymIntArrayRef::fromIntArrayRef(size), dtype_opt, layout_opt, device_opt, pin_memory_opt);
}
Tensor new_empty_symint(
const Tensor& self,
SymIntArrayRef size,
c10::optional<ScalarType> dtype_opt,
@ -1077,7 +1090,15 @@ Tensor triu_indices_cpu(
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ zeros ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Tensor zeros(SymIntArrayRef size,
Tensor zeros(IntArrayRef size,
c10::optional<ScalarType> dtype,
c10::optional<Layout> layout,
c10::optional<Device> device,
c10::optional<bool> pin_memory) {
return at::zeros_symint(c10::SymIntArrayRef::fromIntArrayRef(size), dtype, layout, device, pin_memory);
}
Tensor zeros_symint(SymIntArrayRef size,
c10::optional<ScalarType> dtype,
c10::optional<Layout> layout,
c10::optional<Device> device,
@ -1102,17 +1123,8 @@ Tensor _efficientzerotensor(IntArrayRef size,
return out;
}
Tensor& zeros_sparse_out(IntArrayRef size, Tensor& result) {
result.sparse_resize_and_clear_(size, size.size(), 0.);
return result;
}
Tensor& zeros_out(SymIntArrayRef sym_size, Tensor& result) {
auto size = c10::asIntArrayRefSlow(sym_size);
Tensor& zeros_out(IntArrayRef size, Tensor& result) {
if (result.is_sparse()) {
// TODO: I think this branch should be dead, but we don't have an easy
// way to cover all sparse kernels with zeros_sparse_out, so retain this
// for now
result.sparse_resize_and_clear_(size, size.size(), 0.);
return result;
} else {

View File

@ -300,7 +300,7 @@ Tensor sparse_broadcast_to(const Tensor& self, IntArrayRef size) {
new_values_size[0] = new_indices_size[1];
Tensor new_values = values.expand(broadcast_dense_sizes).repeat_interleave(nnz_factor, 0);
Tensor new_indices = indices.new_empty(new_indices_size);
Tensor new_indices = at::native::new_empty(indices, new_indices_size);
if (broadcast_sizes.size()>0) {
// ones(broadcast_sizes).nonzero() is equivalent to
// product(map(arange, broadcast_sizes)) but avoids creating
@ -542,14 +542,14 @@ static Tensor cat_sparse_impl(TensorList tensors, int64_t dim) {
zeros_sizes[0] = t._values().size(0);
zeros_sizes[values_dim] = cumulative_size;
cumulative_size += t._values().size(values_dim);
auto z1 = at::zeros(
auto z1 = native::zeros(
zeros_sizes,
optTypeMetaToScalarType(t._values().options().dtype_opt()),
t._values().options().layout_opt(),
t._values().options().device_opt(),
t._values().options().pinned_memory_opt());
zeros_sizes[values_dim] = total_size - cumulative_size;
auto z2 = at::zeros(
auto z2 = native::zeros(
zeros_sizes,
optTypeMetaToScalarType(t._values().options().dtype_opt()),
t._values().options().layout_opt(),
@ -843,9 +843,12 @@ Tensor diag_embed(const Tensor& self, int64_t offset, int64_t dim1_, int64_t dim
return result;
}
Tensor expand(const Tensor& self, c10::SymIntArrayRef sym_size, bool /*unused*/) {
// TODO: properly support SymInt expand
auto size = asIntArrayRefSlow(sym_size);
Tensor expand_symint(const Tensor& self, c10::SymIntArrayRef packed_size, bool implicit) {
auto size = asIntArrayRefSlow(packed_size);
return self.expand(size, implicit);
}
Tensor expand(const Tensor& self, IntArrayRef size, bool /*unused*/) {
TORCH_CHECK(size.size() >= (size_t)self.dim(),
"expand(", self.toString(), "{", self.sizes(), "}, size=", size,
"): the number of sizes provided (", size.size(), ") ",
@ -924,9 +927,12 @@ const Tensor &as_strided_(const Tensor& self, IntArrayRef size, IntArrayRef stri
return self;
}
Tensor narrow_copy_dense(const Tensor& self, int64_t dim, SymInt start, SymInt length) {
// TODO: properly support SymInt narrow_copy
return self.narrow(dim, start.expect_int(), length.expect_int()).clone(at::MemoryFormat::Contiguous);
Tensor narrow_copy_symint(const Tensor& self, int64_t dim, int64_t start, SymInt sym_length) {
return self.narrow_copy(dim, start, sym_length.expect_int());
}
Tensor narrow_copy_dense(const Tensor& self, int64_t dim, int64_t start, int64_t length) {
return self.narrow(dim, start, length).clone(at::MemoryFormat::Contiguous);
}
Tensor narrow_copy_dense_cpu(const Tensor& self, int64_t dim, int64_t start, int64_t length){
@ -2776,7 +2782,7 @@ Tensor unsqueeze_sparse(Tensor const &self, int64_t dim) {
if (dim <= sparse_dim) {
auto new_indices = at::cat(
{indices.narrow(0, 0, dim),
at::zeros(
native::zeros(
{1, indices.size(1)},
kLong,
indices.options().layout_opt(),
@ -3112,15 +3118,14 @@ Tensor adjoint(const Tensor &self) {
return _adjoint(self, /*transpose=*/false, "adjoint()");
}
Tensor view_meta(const Tensor& self,
at::SymIntArrayRef size) {
// TODO: Properly support SymInt view
return view_impl(self, c10::asIntArrayRefSlow(size));
Tensor view(const Tensor& self,
IntArrayRef size) {
return view_impl(self, size);
}
Tensor view(const Tensor& self,
at::IntArrayRef size) {
return view_impl(self, size);
Tensor view_symint(const Tensor& self,
c10::SymIntArrayRef size) {
return self.view(c10::asIntArrayRefSlow(size));
}
Tensor alias(const Tensor& self) {
@ -3500,8 +3505,8 @@ at::Tensor& expand_copy_SymInt_out(const at::Tensor & self, c10::SymIntArrayRef
}
at::Tensor& expand_copy_out(const at::Tensor & self, at::SymIntArrayRef size, bool implicit, at::Tensor & out) {
auto tmp = self.expand_symint(size, implicit);
at::Tensor& expand_copy_out(const at::Tensor & self, at::IntArrayRef size, bool implicit, at::Tensor & out) {
auto tmp = self.expand(size, implicit);
out.copy_(tmp);
return out;
}
@ -3656,8 +3661,8 @@ void unbind_copy_int_out(const at::Tensor & self, int64_t dim, at::TensorList o
}
at::Tensor& view_copy_out(const at::Tensor & self, at::SymIntArrayRef size, at::Tensor & out) {
auto tmp = self.view_symint(size);
at::Tensor& view_copy_out(const at::Tensor & self, at::IntArrayRef size, at::Tensor & out) {
auto tmp = self.view(size);
out.copy_(tmp);
return out;
}

View File

@ -15,7 +15,7 @@
#include <ATen/ops/bincount_native.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/histc_native.h>
#include <ATen/ops/zeros.h>
#include <ATen/ops/zeros_native.h>
#endif
namespace at {
@ -271,7 +271,7 @@ bool CUDA_tensor_histogram(
detail::TensorInfo<output_t, IndexType> pInfo(nullptr, 0, {}, {});
Tensor partial_output;
if (memType == CUDAHistogramMemoryType::MULTI_BLOCK) {
partial_output = at::zeros(
partial_output = native::zeros(
{grid.x, nbins},
optTypeMetaToScalarType(a.options().dtype_opt()),
a.options().layout_opt(),
@ -313,7 +313,7 @@ Tensor _bincount_cuda_template(
AT_ERROR("minlength should be >= 0");
}
if (self.dim() == 1 && self.numel() == 0) {
return at::zeros(
return native::zeros(
{minlength},
kLong,
c10::nullopt /* layout */,
@ -342,7 +342,7 @@ Tensor _bincount_cuda_template(
// alloc output counter on GPU
Tensor output;
if (has_weights) {
output = at::zeros(
output = native::zeros(
{nbins},
optTypeMetaToScalarType(weights.options().dtype_opt()),
weights.options().layout_opt(),
@ -351,7 +351,7 @@ Tensor _bincount_cuda_template(
cuda::CUDA_tensor_histogram<weights_t, input_t, true>(
output, self, weights, nbins, minvalue, maxvalue);
} else {
output = at::zeros(
output = native::zeros(
{nbins},
kLong,
c10::nullopt /* layout */,
@ -373,7 +373,7 @@ Tensor _histc_cuda_template(
if (nbins <= 0) {
AT_ERROR("bins must be > 0");
}
Tensor output = at::zeros(
Tensor output = native::zeros(
{nbins},
self.scalar_type(),
c10::nullopt /* layout */,

View File

@ -55,6 +55,10 @@ Tensor empty_cuda(IntArrayRef size, c10::optional<ScalarType> dtype_opt, c10::op
return at::detail::empty_cuda(size, dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt);
}
Tensor empty_symint_cuda(c10::SymIntArrayRef size, c10::optional<ScalarType> dtype_opt, c10::optional<Layout> layout_opt, c10::optional<Device> device_opt, c10::optional<bool> pin_memory_opt, c10::optional<c10::MemoryFormat> memory_format_opt) {
return at::native::empty_cuda(asIntArrayRefSlow(size), dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt);
}
Tensor _efficientzerotensor_cuda(IntArrayRef size,
c10::optional<ScalarType> dtype,
c10::optional<Layout> layout,

View File

@ -436,7 +436,7 @@ Tensor cudnn_convolution_relu(
bool allow_tf32 = ctx.allowTF32CuDNN();
auto _bias = bias_t.has_value()
? bias_t.value()
: at::zeros(
: at::native::zeros(
{output_t.size(1)},
optTypeMetaToScalarType(output_t.options().dtype_opt()),
output_t.options().layout_opt(),
@ -514,7 +514,7 @@ Tensor cudnn_convolution_add_relu(
auto _alpha = alpha.has_value() ? alpha.value().to<float>() : 1.0;
auto _bias = bias_t.has_value()
? bias_t.value()
: at::zeros(
: at::native::zeros(
{output_t.size(1)},
optTypeMetaToScalarType(output_t.options().dtype_opt()),
output_t.options().layout_opt(),

View File

@ -1570,7 +1570,7 @@ Tensor miopen_convolution_add_relu(
auto _alpha = alpha.has_value() ? alpha.value().to<float>() : 1.0;
auto _bias = bias.has_value()
? bias.value()
: at::zeros(
: at::native::zeros(
{contig_output.size(1)},
optTypeMetaToScalarType(contig_output.options().dtype_opt()),
contig_output.options().layout_opt(),
@ -1614,7 +1614,7 @@ Tensor miopen_convolution_relu(
auto _bias = bias.has_value()
? bias.value()
: at::zeros(
: at::native::zeros(
{output_t.size(1)},
optTypeMetaToScalarType(output_t.options().dtype_opt()),
output_t.options().layout_opt(),
@ -1661,7 +1661,7 @@ Tensor miopen_convolution_relu(
auto _bias = bias.has_value()
? bias.value()
: at::zeros(
: at::native::zeros(
{contig_output.size(1)},
optTypeMetaToScalarType(contig_output.options().dtype_opt()),
contig_output.options().layout_opt(),

View File

@ -2,6 +2,10 @@
namespace at { namespace native {
Tensor empty_symint_mkldnn(c10::SymIntArrayRef sizes, c10::optional<ScalarType> dtype, c10::optional<Layout> layout, c10::optional<Device> device, c10::optional<bool> pin_memory, c10::optional<c10::MemoryFormat> optional_memory_format) {
return at::native::empty_mkldnn(c10::asIntArrayRefSlow(sizes), dtype, layout, device, pin_memory, optional_memory_format);
}
#if AT_MKLDNN_ENABLED()
Tensor empty_mkldnn(IntArrayRef sizes, c10::optional<ScalarType> dtype, c10::optional<Layout> layout, c10::optional<Device> device, c10::optional<bool> pin_memory, c10::optional<c10::MemoryFormat> optional_memory_format) {

View File

@ -71,6 +71,17 @@ Tensor empty_mps(
return at::detail::empty_mps(size, dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt);
}
Tensor empty_symint_mps(
c10::SymIntArrayRef size,
c10::optional<ScalarType> dtype_opt,
c10::optional<Layout> layout_opt,
c10::optional<Device> device_opt,
c10::optional<bool> pin_memory_opt,
c10::optional<c10::MemoryFormat> memory_format_opt) {
return at::native::empty_mps(c10::asIntArrayRefSlow(size), dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt);
}
Tensor empty_strided_mps(
IntArrayRef size,
IntArrayRef stride,

View File

@ -2046,10 +2046,10 @@
device_check: NoCheck
device_guard: False
dispatch:
CompositeExplicitAutograd: empty_names
CompositeExplicitAutograd: empty
autogen: empty.names_out
- func: empty.memory_format(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
- func: empty.memory_format(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
dispatch:
CPU: empty_cpu
CUDA: empty_cuda
@ -2060,14 +2060,39 @@
SparseCsrCPU, SparseCsrCUDA: empty_sparse_compressed
QuantizedCPU, QuantizedCUDA, QuantizedMeta: empty_unknown_quantized
# all calls to empty() in python used to go through the symint overload
# even if all arguments were concerete integers.
# adding symint overloads of kernels for every dispatch key allowed us
# to skip redispatching to `empty.memory_format` and hit backend kernels directly
# we recently updated signature parsing to dispath `empty()` calls in python
# to `empty.SymInt` iff there's is a symint node argument
# hopefully, we could simplify this entry soon
- func: empty.SymInt(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
dispatch:
CPU: empty_symint_cpu
CUDA: empty_symint_cuda
MPS: empty_symint_mps
Meta: empty_symint_meta
MkldnnCPU: empty_symint_mkldnn
SparseCPU, SparseCUDA, SparseMeta: empty_symint_sparse
SparseCsrCPU, SparseCsrCUDA: empty_symint_sparse_compressed
QuantizedCPU, QuantizedCUDA: empty_symint_unknown_quantized
autogen: empty.SymInt_out
# We do not make new_empty a composite that calls into new_empty_strided, as the strided version
# is significantly more difficult to implement by different backends
- func: new_empty(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
- func: new_empty(Tensor self, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
variants: method
dispatch:
CompositeExplicitAutograd: new_empty
autogen: new_empty.out
- func: new_empty.SymInt(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
variants: method
dispatch:
CompositeExplicitAutograd: new_empty_symint
autogen: new_empty.SymInt_out
- func: new_empty_strided(Tensor self, int[] size, int[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
variants: method
dispatch:
@ -2145,7 +2170,7 @@
QuantizedCPU, QuantizedCUDA: empty_quantized
autogen: empty_quantized.out
- func: empty.out(SymInt[] size, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)
- func: empty.out(int[] size, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck
device_guard: False
@ -2269,7 +2294,14 @@
SparseCPU, SparseCUDA: expm1_sparse_out
SparseCsrCPU, SparseCsrCUDA: expm1_sparse_csr_out
- func: expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a)
- func: expand.SymInt(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a)
variants: method # This is method-only to match the previous tensor API. In the future we could make this a function too.
device_check: NoCheck
device_guard: False
dispatch:
CompositeExplicitAutograd: expand_symint
- func: expand(Tensor(a) self, int[] size, *, bool implicit=False) -> Tensor(a)
variants: method # This is method-only to match the previous tensor API. In the future we could make this a function too.
device_check: NoCheck
device_guard: False
@ -3665,7 +3697,7 @@
dispatch:
CompositeExplicitAutograd: mvlgamma_
- func: narrow_copy(Tensor self, int dim, SymInt start, SymInt length) -> Tensor
- func: narrow_copy(Tensor self, int dim, int start, int length) -> Tensor
variants: function, method
dispatch:
CPU: narrow_copy_dense_cpu
@ -3673,7 +3705,13 @@
CompositeExplicitAutogradNonFunctional: narrow_copy_dense
tags: view_copy
- func: narrow_copy.out(Tensor self, int dim, SymInt start, SymInt length, *, Tensor(a!) out) -> Tensor(a!)
- func: narrow_copy.SymInt(Tensor self, int dim, int start, SymInt length) -> Tensor
variants: function, method
dispatch:
CompositeExplicitAutograd: narrow_copy_symint
autogen: narrow_copy.SymInt_out
- func: narrow_copy.out(Tensor self, int dim, int start, int length, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU: narrow_copy_dense_cpu_out
@ -5542,14 +5580,19 @@
CUDA: _efficientzerotensor_cuda
autogen: _efficientzerotensor.out
- func: zeros(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
- func: zeros(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
dispatch:
CompositeExplicitAutograd: zeros
- func: zeros.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
- func: zeros.SymInt(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
dispatch:
CompositeExplicitAutograd: zeros_symint
autogen: zeros.SymInt_out
- func: zeros.out(int[] size, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
CompositeExplicitAutograd: zeros_out
SparseCPU, SparseCUDA, SparseMeta: zeros_sparse_out
SparseCPU, SparseCUDA, SparseMeta: zeros_out
- func: zeros_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
dispatch:
@ -6876,13 +6919,20 @@
CPU: masked_softmax_backward_cpu
autogen: _masked_softmax_backward.out
- func: view(Tensor(a) self, SymInt[] size) -> Tensor(a)
- func: view.SymInt(Tensor(a) self, SymInt[] size) -> Tensor(a)
variants: method
device_check: NoCheck
device_guard: False
dispatch:
Meta: view_meta
ZeroTensor, CPU, CUDA, QuantizedCPU, QuantizedCUDA, MPS: view
CompositeExplicitAutograd: view_symint
MkldnnCPU: mkldnn_view_symint
- func: view(Tensor(a) self, int[] size) -> Tensor(a)
variants: method
device_check: NoCheck
device_guard: False
dispatch:
ZeroTensor, CPU, CUDA, Meta, QuantizedCPU, QuantizedCUDA, MPS: view
MkldnnCPU: mkldnn_view
# Warning: If you want to change the name or overload name of this
@ -12714,12 +12764,18 @@
CompositeExplicitAutogradNonFunctional: diagonal_copy
tags: view_copy
- func: expand_copy(Tensor self, SymInt[] size, *, bool implicit=False) -> Tensor
- func: expand_copy(Tensor self, int[] size, *, bool implicit=False) -> Tensor
variants: function
dispatch:
CompositeExplicitAutogradNonFunctional: expand_copy
tags: view_copy
- func: expand_copy.SymInt(Tensor self, SymInt[] size, *, bool implicit=False) -> Tensor
variants: function
dispatch:
CompositeExplicitAutograd: expand_copy_SymInt
tags: view_copy
- func: permute_copy(Tensor self, int[] dims) -> Tensor
variants: function
dispatch:
@ -12848,7 +12904,7 @@
CompositeExplicitAutogradNonFunctional: unbind_copy_int
tags: view_copy
- func: view_copy(Tensor self, SymInt[] size) -> Tensor
- func: view_copy(Tensor self, int[] size) -> Tensor
variants: function
dispatch:
CompositeExplicitAutogradNonFunctional: view_copy
@ -12908,6 +12964,14 @@
CompositeExplicitAutograd: _neg_view_copy_out
- func: view_copy.SymInt(Tensor self, SymInt[] size) -> Tensor
variants: function
dispatch:
CompositeExplicitAutograd: view_copy_SymInt
tags: view_copy
autogen: view_copy.SymInt_out
- func: as_strided_copy.out(Tensor self, int[] size, int[] stride, int? storage_offset=None, *, Tensor(a!) out) -> Tensor(a!)
variants: function
dispatch:
@ -12926,7 +12990,13 @@
CompositeExplicitAutograd: diagonal_copy_out
- func: expand_copy.out(Tensor self, SymInt[] size, *, bool implicit=False, Tensor(a!) out) -> Tensor(a!)
- func: expand_copy.SymInt_out(Tensor self, SymInt[] size, *, bool implicit=False, Tensor(a!) out) -> Tensor(a!)
variants: function
dispatch:
CompositeExplicitAutograd: expand_copy_SymInt_out
- func: expand_copy.out(Tensor self, int[] size, *, bool implicit=False, Tensor(a!) out) -> Tensor(a!)
variants: function
dispatch:
CompositeExplicitAutograd: expand_copy_out
@ -13046,7 +13116,7 @@
CompositeExplicitAutograd: unbind_copy_int_out
- func: view_copy.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
- func: view_copy.out(Tensor self, int[] size, *, Tensor(a!) out) -> Tensor(a!)
variants: function
dispatch:
CompositeExplicitAutograd: view_copy_out

View File

@ -66,6 +66,16 @@ Tensor empty_per_channel_affine_quantized(
quantizer);
}
Tensor empty_symint_unknown_quantized(
c10::SymIntArrayRef size,
c10::optional<ScalarType> dtype,
c10::optional<Layout> layout,
c10::optional<Device> device,
c10::optional<bool> pin_memory,
c10::optional<c10::MemoryFormat> optional_memory_format) {
return at::native::empty_unknown_quantized(c10::asIntArrayRefSlow(size), dtype, layout, device, pin_memory, optional_memory_format);
}
Tensor empty_unknown_quantized(
IntArrayRef size,
c10::optional<ScalarType> dtype,

View File

@ -488,6 +488,16 @@ SPARSE_COMPRESSED_TENSOR(csc, kSparseCsc)
SPARSE_COMPRESSED_TENSOR(bsr, kSparseBsr)
SPARSE_COMPRESSED_TENSOR(bsc, kSparseBsc)
Tensor empty_symint_sparse_compressed(
c10::SymIntArrayRef size,
c10::optional<ScalarType> dtype,
c10::optional<Layout> layout,
c10::optional<Device> device,
c10::optional<bool> pin_memory,
c10::optional<MemoryFormat> optional_memory_format) {
return at::native::empty_sparse_compressed(c10::asIntArrayRefSlow(size), dtype, layout, device, pin_memory, optional_memory_format);
}
Tensor empty_sparse_compressed(
IntArrayRef size,
c10::optional<ScalarType> dtype,

View File

@ -207,6 +207,17 @@ Tensor empty_sparse(
size.size(), 0, size, dtype, layout, device, pin_memory);
}
/** Empty init **/
Tensor empty_symint_sparse(
c10::SymIntArrayRef size,
c10::optional<ScalarType> dtype,
c10::optional<Layout> layout,
c10::optional<Device> device,
c10::optional<bool> pin_memory,
c10::optional<MemoryFormat> optional_memory_format) {
return at::native::empty_sparse(c10::asIntArrayRefSlow(size), dtype, layout, device, pin_memory, optional_memory_format);
}
/* Shape init */
Tensor sparse_coo_tensor(IntArrayRef size,
c10::optional<ScalarType> dtype,

View File

@ -1524,7 +1524,7 @@ SparseTensor& _sspaddmm_out_cpu(
int64_t t_nnz = t._nnz();
int64_t r_nnz = nnz * dim_k + t_nnz;
Tensor newi = at::empty({2, r_nnz}, kLong);
Tensor newv = at::zeros(
Tensor newv = native::zeros(
{r_nnz},
optTypeMetaToScalarType(values.options().dtype_opt()),
values.options().layout_opt(),

View File

@ -141,6 +141,7 @@ full_codegen:
- upsample_nearest2d
- upsample_nearest2d_backward
- zero
- narrow_copy.SymInt
- alias_copy
- as_strided_copy
- diagonal_copy
@ -174,6 +175,7 @@ supported:
- _copy_from
- _copy_from_and_resize
- empty.memory_format
- empty.SymInt
- empty_strided
- fill_.Scalar
- normal_

View File

@ -29,13 +29,12 @@ Tensor _empty_affine_quantized(
}
Tensor empty_memory_format(
const SymIntArrayRef sym_sizes,
const IntArrayRef sizes,
const c10::optional<ScalarType> dtype,
const c10::optional<c10::Layout> layout,
const c10::optional<Device> device,
const c10::optional<bool> pin_memory,
const optional<MemoryFormat> memory_format) {
auto sizes = c10::asIntArrayRefSlow(sym_sizes);
return convert(vTensor{
api::context(),
sizes,
@ -56,12 +55,7 @@ Tensor empty_strided(
const optional<Device> device,
const optional<bool> pin_memory) {
return empty_memory_format(
c10::SymIntArrayRef::fromIntArrayRef(sizes),
dtype,
layout,
device,
pin_memory,
c10::MemoryFormat::Contiguous);
sizes, dtype, layout, device, pin_memory, c10::MemoryFormat::Contiguous);
}
#ifdef USE_VULKAN_API

View File

@ -42,8 +42,7 @@ Tensor view_internal(const Tensor& self_arg, const IntArrayRef shape) {
return convert(v_output);
}
inline Tensor view(const Tensor& self_arg, const SymIntArrayRef sym_shape) {
auto shape = c10::asIntArrayRefSlow(sym_shape);
inline Tensor view(const Tensor& self_arg, const IntArrayRef shape) {
return view_internal(self_arg, shape);
}

View File

@ -55,6 +55,8 @@ void resize_out_helper(const at::TensorList& dst, const at::TensorList& src) {
${CompositeViewCopyKernel_Definitions}
${SymIntViewCopyKernel_Definitions}
${GeneratedCompositeFunctional_Definitions}
${GeneratedCompositeOut_Definitions}

View File

@ -28,7 +28,7 @@ T getSampleValue();
template <>
at::Tensor getSampleValue() {
return at::zeros({2, 2}).to(at::kCPU);
return at::native::zeros({2, 2}).to(at::kCPU);
}
template <>

View File

@ -105,7 +105,7 @@ void assertOwn(
template<>
Tensor getSampleValue() {
return at::zeros({2, 2}).to(at::kCPU);
return at::native::zeros({2, 2}).to(at::kCPU);
}
template<>

View File

@ -15,7 +15,7 @@ using namespace at;
static int test_int;
Tensor empty_override(SymIntArrayRef size, c10::optional<ScalarType> dtype, c10::optional<Layout> layout,
Tensor empty_override(IntArrayRef size, c10::optional<ScalarType> dtype, c10::optional<Layout> layout,
c10::optional<Device> device, c10::optional<bool> pin_memory, c10::optional<MemoryFormat> optional_memory_format) {
test_int = 1;
auto tensor_impl = c10::make_intrusive<TensorImpl, UndefinedTensorImpl>(
@ -44,7 +44,7 @@ Tensor empty_strided_override(
c10::optional<c10::Device> device,
c10::optional<bool> pin_memory) {
return empty_override(SymIntArrayRef::fromIntArrayRef(size), dtype, layout, device, pin_memory, c10::nullopt);
return empty_override(size, dtype, layout, device, pin_memory, c10::nullopt);
}
TORCH_LIBRARY_IMPL(aten, ORT, m) {

View File

@ -119,7 +119,7 @@ TEST(MathKernelTest, NarrowCopy) {
for (const auto dim : c10::irange(3)) {
const int64_t start = 1, length = 4;
auto y_ref = x.narrow(dim, start, length);
auto y_test = at::native::narrow_copy_dense(x, dim, c10::SymInt(start), c10::SymInt(length));
auto y_test = at::native::narrow_copy_dense(x, dim, start, length);
ASSERT_ALLCLOSE_TOLERANCES(y_ref, y_test, 0, 0);
}
}

View File

@ -9,25 +9,6 @@
namespace at { namespace functorch {
template <typename A, A a, typename C>
struct NewBlahBatchRuleHelperSymInt;
template <typename F, F Func, typename A, typename B, typename... T>
struct NewBlahBatchRuleHelperSymInt<F, Func, typelist<A, B, T...>> {
static std::tuple<Tensor,optional<int64_t>> apply(
const Tensor& tensor,
optional<int64_t> batch_dim,
SymIntArrayRef shape,
T... extra_args) {
const auto bdim_size = tensor.sym_size(batch_dim.value());
c10::SmallVector<c10::SymInt> new_shape;
new_shape.reserve(shape.size() + 1);
new_shape.emplace_back(bdim_size);
new_shape.insert(new_shape.end(), shape.begin(), shape.end());
return std::make_tuple(Func(tensor, new_shape, std::forward<T>(extra_args)...), 0);
}
};
template <typename A, A a, typename C>
struct NewBlahBatchRuleHelper;
@ -56,12 +37,6 @@ struct NewBlahBatchRuleHelper<F, Func, typelist<A, B, T...>> {
&fn,\
c10::guts::function_traits<decltype(fn)>::parameter_types>::apply)
#define NEW_BLAH_BATCH_RULE_SYMINT(fn) SINGLE_ARG(\
NewBlahBatchRuleHelperSymInt<\
decltype(&fn),\
&fn,\
c10::guts::function_traits<decltype(fn)>::parameter_types>::apply)
std::tuple<Tensor,optional<int64_t>> _new_zeros_with_same_feature_meta_batch_rule(
const Tensor& self, optional<int64_t> self_bdim,
const Tensor& other, optional<int64_t> other_bdim,
@ -107,6 +82,17 @@ bool _has_same_storage_numel_batch_rule(const Tensor& a, const Tensor& b) {
return true;
}
Tensor new_empty_symint_decomp(
const Tensor& self,
SymIntArrayRef size,
c10::optional<ScalarType> dtype_opt,
c10::optional<Layout> layout_opt,
c10::optional<Device> device_opt,
c10::optional<bool> pin_memory_opt
) {
return self.new_empty(c10::asIntArrayRefSlow(size), dtype_opt, layout_opt, device_opt, pin_memory_opt);
}
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
m.impl("_has_same_storage_numel", _has_same_storage_numel_batch_rule);
VMAP_SUPPORT(ones_like, BASIC_UNARY_BATCH_RULE(ATEN_FN(ones_like)));
@ -115,7 +101,8 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
VMAP_SUPPORT(randn_like, BASIC_UNARY_BATCH_RULE(ATEN_FN(randn_like)));
VMAP_SUPPORT(rand_like, BASIC_UNARY_BATCH_RULE(ATEN_FN(rand_like)));
VMAP_SUPPORT(full_like, BASIC_UNARY_BATCH_RULE(ATEN_FN(full_like)));
VMAP_SUPPORT(new_empty, NEW_BLAH_BATCH_RULE_SYMINT(ATEN_FN(new_empty)));
VMAP_SUPPORT(new_empty, NEW_BLAH_BATCH_RULE(ATEN_FN(new_empty)));
m.impl("new_empty.SymInt", new_empty_symint_decomp);
VMAP_SUPPORT(new_zeros, NEW_BLAH_BATCH_RULE(ATEN_FN(new_zeros)));
VMAP_SUPPORT(new_ones, NEW_BLAH_BATCH_RULE(ATEN_FN(new_ones)));
VMAP_SUPPORT(new_full, NEW_BLAH_BATCH_RULE(ATEN_FN(new_full)));

View File

@ -427,15 +427,15 @@ std::tuple<Tensor,optional<int64_t>> slice_backward_batch_rule(
}
std::tuple<Tensor, optional<int64_t>> view_batching_rule(
const Tensor &self, optional<int64_t> self_bdim, SymIntArrayRef sym_size)
const Tensor &self, optional<int64_t> self_bdim, IntArrayRef size)
{
TORCH_INTERNAL_ASSERT(self_bdim.has_value());
auto self_ = moveBatchDimToFront(self, self_bdim);
c10::SmallVector<c10::SymInt> size_(sym_size.size() + 1);
VmapDimVector size_(size.size() + 1);
// copy batch size
size_[0] = self_.size(0);
std::copy(sym_size.cbegin(), sym_size.cend(), size_.begin() + 1);
return std::make_tuple(self_.view_symint(size_), 0);
std::copy(size.cbegin(), size.cend(), size_.begin() + 1);
return std::make_tuple(self_.view(size_), 0);
}
Tensor view_symint_decomposition(const Tensor& self,
@ -446,7 +446,7 @@ Tensor view_symint_decomposition(const Tensor& self,
template <typename F, F Func>
std::tuple<Tensor, optional<int64_t>> expand_batch_rule(
const Tensor &self, optional<int64_t> self_bdim, SymIntArrayRef size, bool implicit)
const Tensor &self, optional<int64_t> self_bdim, IntArrayRef size, bool implicit)
{
auto self_dim = self.dim();
TORCH_CHECK(static_cast<uint64_t>(self_dim - 1) <= size.size(),
@ -457,7 +457,7 @@ std::tuple<Tensor, optional<int64_t>> expand_batch_rule(
auto self_sizes = self_.sizes();
auto batch_size = self_sizes[0];
c10::SmallVector<c10::SymInt> size_(size.size() + 1);
c10::SmallBuffer<int64_t, 5> size_(size.size() + 1);
size_[0] = batch_size;
std::copy(size.cbegin(), size.cend(), size_.begin() + 1);
@ -471,12 +471,12 @@ std::tuple<Tensor, optional<int64_t>> expand_batch_rule(
// so the strategy here is to view it first as a tensor of size [B0, 1, 3] and
// then expand.
auto extra_dims = size.size() - (self_dim - 1);
c10::SmallVector<c10::SymInt> view_shape(size_.size(), /*init_value*/1);
VmapDimVector view_shape(size_.size(), /*init_value*/1);
view_shape[0] = batch_size;
std::copy(self_sizes.cbegin() + 1, self_sizes.cend(),
view_shape.begin() + 1 + extra_dims);
return std::make_tuple(Func(self_.view_symint(view_shape), size_, implicit), 0);
return std::make_tuple(Func(self_.view(view_shape), size_, implicit), 0);
}
std::tuple<Tensor, optional<int64_t>> unfold_batch_rule(
@ -549,6 +549,8 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
VMAP_SUPPORT2(slice, Tensor, slice_batch_rule);
VMAP_SUPPORT2(transpose, int, transpose_int_batch_rule);
VMAP_SUPPORT(diag_embed, diag_embed_batch_rule);
m.impl("expand.SymInt", expand_symint_decomp_hack);
m.impl("view.SymInt", view_symint_decomposition);
}
}}

View File

@ -84,6 +84,35 @@ static inline at::DeviceType DefaultDevice() {
} // namespace
#ifndef C10_MOBILE
TEST(LazyDynamicOpsTest, NarrowCopy) {
auto x = torch::rand({5, 10, 10}).to(kLazy);
const size_t Y_DIM = 3;
const size_t X_DIM_INDEX = 2;
auto y = torch::rand({Y_DIM}).to(kLazy);
auto ly = torch::lazy::TryGetLtcTensor(y);
auto dim_node = MakeNode<SizeNode>(ly->GetIrValue(), 0);
auto lmn = c10::make_intrusive<torch::lazy::SymIntNodeImpl>(dim_node);
auto z = x.narrow_copy_symint(X_DIM_INDEX, 0, lmn->toSymInt());
AllClose(z.cpu(), x.cpu().narrow_copy(X_DIM_INDEX, 0, Y_DIM));
}
TEST(LazyDynamicOpsTest, NarrowCopyViaSymSizes) {
FLAGS_ltc_enable_symbolic_shapes = true;
auto xc = torch::rand({10});
auto x = xc.to(kLazy);
const size_t Y_DIM = 3;
const size_t X_DIM_INDEX = 0;
auto y = torch::rand({Y_DIM}).to(kLazy);
auto z = x.narrow_copy_symint(X_DIM_INDEX, 0, y.sym_sizes()[0]);
auto zc = xc.narrow_copy(X_DIM_INDEX, 0, Y_DIM);
ASSERT_EQ(z.sizes()[0], xc.sizes()[0]); // note, xc not zc
// shape inference assumes narrow_copy can copy the whole tensor
AllClose(z.cpu(), zc);
FLAGS_ltc_enable_symbolic_shapes = false;
}
#endif
TEST_F(LazyOpsTest, TestScalarTensor) {
torch::Tensor scalar_tensor = torch::scalar_tensor(
1., torch::TensorOptions(torch::kFloat).device(DefaultDevice()));

View File

@ -85,7 +85,8 @@ at::Tensor custom__copy_from(const at::Tensor& self, const at::Tensor& dst, bool
// More details on the dispatcher can be found at http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/.
TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
m.impl("add.Tensor", &custom_add_Tensor);
m.impl("empty.memory_format", &custom_empty_symint);
m.impl("empty.memory_format", &custom_empty_memory_format);
m.impl("empty.SymInt", &custom_empty_symint);
m.impl("fill_.Scalar", &custom_fill__scalar);
m.impl("_copy_from", &custom__copy_from);
}

View File

@ -20,10 +20,15 @@ Tensor get_tensor(caffe2::TypeMeta dtype, IntArrayRef size) {
return Tensor(std::move(tensor_impl));
}
Tensor empty_override(SymIntArrayRef size, c10::optional<ScalarType> dtype, c10::optional<Layout> layout, c10::optional<Device> device,
Tensor empty_override(IntArrayRef size, c10::optional<ScalarType> dtype, c10::optional<Layout> layout, c10::optional<Device> device,
c10::optional<bool> pin_memory, c10::optional<c10::MemoryFormat> optional_memory_format) {
test_int = 0;
return get_tensor(scalarTypeToTypeMeta(dtype_or_default(dtype)), c10::asIntArrayRefSlow(size));
return get_tensor(scalarTypeToTypeMeta(dtype_or_default(dtype)), size);
}
Tensor empty_symint_override(c10::SymIntArrayRef size, c10::optional<ScalarType> dtype_opt, c10::optional<Layout> layout_opt,
c10::optional<Device> device_opt, c10::optional<bool> pin_memory_opt, c10::optional<c10::MemoryFormat> memory_format_opt) {
return empty_override(c10::asIntArrayRefSlow(size), dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt);
}
Tensor& add_out_override(const Tensor & a, const Tensor & b , const Scalar& c, Tensor & out) {
@ -53,6 +58,7 @@ std::tuple<Tensor,Tensor,Tensor> fake_convolution_backward(
}
TORCH_LIBRARY_IMPL(aten, ORT, m) {
m.impl("empty.SymInt", empty_symint_override);
m.impl("empty.memory_format", empty_override);
m.impl("add.out", add_out_override);
m.impl("convolution_overrideable", fake_convolution);

View File

@ -131,18 +131,6 @@ ALLOW_LIST = [
("aten::sum.SymInt", datetime.date(2022, 11, 30)),
("aten::mps_linear", datetime.date(9999, 1, 1)),
("aten::_mps_linear", datetime.date(9999, 1, 1)),
("aten::view_copy.SymInt", datetime.date(2022, 11, 30)),
("aten::view_copy.SymInt_out", datetime.date(2022, 11, 30)),
("aten::expand_copy.SymInt", datetime.date(2022, 11, 30)),
("aten::expand_copy.SymInt_out", datetime.date(2022, 11, 30)),
("aten::expand.SymInt", datetime.date(2022, 11, 30)),
("aten::narrow_copy.SymInt", datetime.date(2022, 11, 30)),
("aten::narrow_copy.SymInt_out", datetime.date(2022, 11, 30)),
("aten::view.SymInt", datetime.date(2022, 11, 30)),
("aten::new_empty.SymInt", datetime.date(2022, 11, 30)),
("aten::new_empty.SymInt_out", datetime.date(2022, 11, 30)),
("aten::zeros.SymInt", datetime.date(2022, 11, 30)),
("aten::zeros.SymInt_out", datetime.date(2022, 11, 30)),
# TODO: FIXME: prims shouldn't be checked
("prims::.*", datetime.date(9999, 1, 1)),
("aten::_amp_foreach_non_finite_check_and_unscale.out", datetime.date(2022, 9, 1)),

View File

@ -391,7 +391,8 @@ class TestDecomp(TestCase):
if func not in decomposition_table or func in [
torch.ops.aten.detach.default,
# non-deterministic ops
torch.ops.aten.new_empty.default
torch.ops.aten.new_empty.default,
torch.ops.aten.new_empty.SymInt
] or any_unsupported(args, kwargs):
return func(*args, **kwargs)

View File

@ -54,7 +54,7 @@ def cat_meta(tensors, dim=0):
return tensors[0].new_empty(new_shape)
@register_meta([aten.narrow_copy.default])
@register_meta([aten.narrow_copy.SymInt])
def narrow_copy_symint_meta(a, dim, start, length, **kwargs):
shape = []
for i, x in enumerate(a.shape):
@ -65,7 +65,7 @@ def narrow_copy_symint_meta(a, dim, start, length, **kwargs):
return a.new_empty(tuple(shape))
@register_meta([aten.expand.default])
@register_meta([aten.expand.SymInt])
def expand_symint_meta(a, size, implicit=False):
return a.new_empty(size)
@ -293,11 +293,11 @@ class TestPySymInt(TestCase):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
torch.ops.aten.narrow_copy.default(x, 0, 0, x.shape[0])
torch.ops.aten.narrow_copy.SymInt(x, 0, 0, x.shape[0])
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
torch.ops.aten.expand.default(x, [x.shape[0], x.shape[1], x.shape[2]])
torch.ops.aten.expand.SymInt(x, [x.shape[0], x.shape[1], x.shape[2]])
def test_fx_trace_intlist(self):
class CustomModule(torch.nn.Module):

View File

@ -727,6 +727,7 @@ meta_dispatch_skips = {
aten.linalg_pinv.atol_rtol_tensor: {f32, f64},
aten.linalg_pinv.atol_rtol_tensor_out: {f32, f64},
aten.empty.memory_format: {b8, bf16, c128, c64, c32, f16, f32, f64, i16, i32, i64, i8, u8},
aten.empty.SymInt: {b8, bf16, c128, c64, c32, f16, f32, f64, i16, i32, i64, i8, u8},
}
meta_dispatch_device_expected_failures = defaultdict(dict)

View File

@ -15949,7 +15949,6 @@ class TestNNDeviceType(NNTestCase):
torch.cuda.synchronize()
issue_24823_2()
@skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/945")
@dtypes(torch.float, torch.double)
@largeTensorTest(lambda self, device, dtype:
# Compute sum of the large tensor sizes:

View File

@ -317,21 +317,24 @@ class TestProfilerTree(TestCase):
ProfilerTree.format(p.profiler, 12),
"""\
aten::zeros
aten::empty
aten::zero_
Top level Annotation
aten::empty
aten::zeros
aten::empty
aten::zero_
Top level Annotation
aten::empty
aten::zeros
aten::zeros
aten::empty
aten::zero_
First Annotation
aten::empty
aten::ones
aten::empty
aten::fill_
aten::zeros
aten::empty
aten::zero_
aten::zeros
aten::empty
aten::zero_
Second Annotation
aten::empty
aten::add
@ -340,8 +343,9 @@ class TestProfilerTree(TestCase):
aten::empty_strided
aten::copy_
aten::zeros
aten::empty
aten::zero_
aten::zeros
aten::empty
aten::zero_
Third Annotation
aten::empty
aten::ones_like
@ -712,7 +716,6 @@ class TestProfilerTree(TestCase):
torch/profiler/profiler.py(...): stop
...""")
@unittest.skip("https://github.com/pytorch/pytorch/issues/83606")
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
@ProfilerTree.test
def test_profiler_experimental_tree_cuda(self):
@ -810,7 +813,6 @@ class TestProfilerTree(TestCase):
allow_failure=ALLOW_CUDA_FAILURE,
)
@unittest.skip("https://github.com/pytorch/pytorch/issues/83606")
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
@ProfilerTree.test
def test_profiler_experimental_tree_cuda_with_stream(self):

View File

@ -762,7 +762,7 @@ class TestSymbolicTracing(TestCase):
def forward(self, a_1):
sym_size = torch.ops.aten.sym_size(a_1, 0); a_1 = None
mul = sym_size * 2; sym_size = None
empty = torch.ops.aten.empty.memory_format([mul], device = device(type='cpu'), pin_memory = False); mul = None
empty = torch.ops.aten.empty.SymInt([mul], device = device(type='cpu'), pin_memory = False); mul = None
sym_size_1 = torch.ops.aten.sym_size(empty, 0)
return empty""")

View File

@ -622,9 +622,12 @@
self: grad * (result + 1)
result: auto_element_wise
# TODO: this derivative is not SymInt safe, need sum_to support
- name: expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a)
self: at::sum_to(grad, self.sym_sizes())
- name: expand(Tensor(a) self, int[] size, *, bool implicit=False) -> Tensor(a)
self: at::sum_to(grad, self.sizes())
result: auto_linear
- name: expand.SymInt(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a)
self: at::sum_to(grad, c10::asIntArrayRefSlow(self.sym_sizes()))
result: auto_linear
- name: exponential_(Tensor(a!) self, float lambd=1, *, Generator? generator=None) -> Tensor(a!)
@ -1732,11 +1735,16 @@
# linear
result1: mean(self_t, dim.value_or(IntArrayRef({})), keepdim)
# TODO: this derivative is not SymInt safe, need reshape_symint
- name: view(Tensor(a) self, SymInt[] size) -> Tensor(a)
- name: view(Tensor(a) self, int[] size) -> Tensor(a)
self: grad.reshape(self.sizes())
result: auto_linear
- name: view.SymInt(Tensor(a) self, SymInt[] size) -> Tensor(a)
# TODO: add proper double backward for view.SymInt
# by SymIntizing `reshape`
self: grad.reshape(c10::asIntArrayRefSlow(self.sym_sizes()))
result: auto_linear
- name: view.dtype(Tensor(a) self, ScalarType dtype) -> Tensor(a)
output_differentiability: [False]

View File

@ -249,7 +249,6 @@ def unpack_args(f: NativeFunction) -> Tuple[List[str], List[Binding]]:
for r in cpp.argument(
a,
method=False,
symint=True,
cpp_no_default_args=set(),
faithful=False,
has_tensor_options=False,
@ -495,7 +494,7 @@ def gen_formals(f: NativeFunction) -> str:
# See Note [Plumbing Keys Through The Dispatcher] for details.
["c10::DispatchKeySet ks"]
+ [
f'{cpp.argument_type(a, binds="__placeholder__", symint=True).cpp_type()} {a.name}'
f'{cpp.argument_type(a, binds="__placeholder__").cpp_type()} {a.name}'
for a in f.func.schema_order_arguments()
]
)
@ -515,7 +514,7 @@ def inplace_or_view_method_definition(
):
return None
return METHOD_DEFINITION.substitute(
return_type=cpp.returns_type(f.func.returns, symint=True).cpp_type(),
return_type=cpp.returns_type(f.func.returns).cpp_type(),
type_wrapper_name=type_wrapper_name(f),
formals=gen_formals(f),
type_definition_body=emit_inplace_or_view_body(fn),

View File

@ -238,8 +238,6 @@ def gen(
tags_yaml_path: str,
deprecated_yaml_path: str,
template_path: str,
*,
symint: bool = True,
) -> None:
fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
native_functions = parse_native_yaml(
@ -255,7 +253,6 @@ def gen(
None,
"python_variable_methods.cpp",
method=True,
symint=symint,
)
# NOTE: num_shards here must be synced with gatherTorchFunctions in
@ -269,7 +266,6 @@ def gen(
"python_torch_functions.cpp",
method=False,
num_shards=3,
symint=symint,
)
create_python_bindings(
@ -279,7 +275,6 @@ def gen(
"torch.nn",
"python_nn_functions.cpp",
method=False,
symint=symint,
)
create_python_bindings(
@ -289,7 +284,6 @@ def gen(
"torch.fft",
"python_fft_functions.cpp",
method=False,
symint=symint,
)
create_python_bindings(
@ -299,7 +293,6 @@ def gen(
"torch.linalg",
"python_linalg_functions.cpp",
method=False,
symint=symint,
)
create_python_bindings(
@ -309,7 +302,6 @@ def gen(
"torch.sparse",
"python_sparse_functions.cpp",
method=False,
symint=symint,
)
create_python_bindings(
@ -319,7 +311,6 @@ def gen(
"torch.special",
"python_special_functions.cpp",
method=False,
symint=symint,
)
# Currently, we only use `functions` to generate `return_types` bindings.
@ -363,7 +354,6 @@ def create_python_bindings(
filename: str,
*,
method: bool,
symint: bool = True,
) -> None:
"""Generates Python bindings to ATen functions"""
py_methods: List[str] = []
@ -375,9 +365,7 @@ def create_python_bindings(
for name in sorted(grouped.keys(), key=lambda x: str(x)):
overloads = grouped[name]
py_methods.append(
method_impl(name, module, overloads, method=method, symint=symint)
)
py_methods.append(method_impl(name, module, overloads, method=method))
py_method_defs.append(method_def(name, module, overloads, method=method))
py_forwards.extend(forward_decls(name, overloads, method=method))
ops_headers.append(f"#include <ATen/ops/{name.base}.h>")
@ -440,7 +428,6 @@ def create_python_bindings_sharded(
*,
method: bool,
num_shards: int,
symint: bool = True,
) -> None:
"""Generates Python bindings to ATen functions"""
grouped = group_filter_overloads(pairs, pred)
@ -457,9 +444,7 @@ def create_python_bindings_sharded(
return {
"ops_headers": [f"#include <ATen/ops/{name.base}.h>"],
"py_forwards": list(forward_decls(name, fn_pairs, method=method)),
"py_methods": [
method_impl(name, module, fn_pairs, method=method, symint=symint)
],
"py_methods": [method_impl(name, module, fn_pairs, method=method)],
"py_method_defs": [method_def(name, module, fn_pairs, method=method)],
}
@ -788,7 +773,6 @@ def method_impl(
overloads: Sequence[PythonSignatureNativeFunctionPair],
*,
method: bool,
symint: bool = True,
) -> str:
"""
Generate a python binding for all overloads of an op.
@ -807,18 +791,14 @@ def method_impl(
traceable = "true" if all(should_trace(o.function) for o in overloads) else "false"
grouped_overloads: Sequence[PythonSignatureGroup] = group_overloads(
overloads, symint=symint
)
grouped_overloads: Sequence[PythonSignatureGroup] = group_overloads(overloads)
is_singleton = len(grouped_overloads) == 1
signatures: List[str] = []
dispatch: List[str] = []
for overload_index, overload in enumerate(grouped_overloads):
signature = overload.signature.signature_str(symint=symint)
signature = overload.signature.signature_str()
signatures.append(f"{cpp_string(str(signature))},")
dispatch_body = emit_dispatch_case(
overload, namedtuple_typenames, symint=symint
)
dispatch_body = emit_dispatch_case(overload, namedtuple_typenames)
dispatch.append(
PY_VARIABLE_CASE.substitute(
overload_index=overload_index, body=dispatch_body
@ -902,8 +882,6 @@ if (_r.isNone(${out_idx})) {
def emit_dispatch_case(
overload: PythonSignatureGroup,
namedtuple_typenames: Dict[str, str],
*,
symint: bool = True,
) -> str:
"""
Emit dispatch code for a single parsed signature. This corresponds to either
@ -916,19 +894,18 @@ def emit_dispatch_case(
return PY_VARIABLE_OUT.substitute(
out_idx=overload.signature.output_idx(),
call_dispatch=emit_single_dispatch(
overload.signature, overload.base, namedtuple_typenames, symint=symint
overload.signature, overload.base, namedtuple_typenames
),
call_dispatch_out=emit_single_dispatch(
overload.signature,
overload.outplace,
namedtuple_typenames,
symint=symint,
),
)
else:
# no-output version only
return emit_single_dispatch(
overload.signature, overload.base, namedtuple_typenames, symint=symint
overload.signature, overload.base, namedtuple_typenames
)
@ -1010,14 +987,14 @@ def method_def(
def group_overloads(
overloads: Sequence[PythonSignatureNativeFunctionPair], *, symint: bool = True
overloads: Sequence[PythonSignatureNativeFunctionPair],
) -> Sequence[PythonSignatureGroup]:
bases: Dict[str, PythonSignatureNativeFunctionPair] = {}
outplaces: Dict[str, PythonSignatureNativeFunctionPair] = {}
# first group by signature ignoring out arguments
for overload in overloads:
sig = overload.signature.signature_str(skip_outputs=True, symint=symint)
sig = overload.signature.signature_str(skip_outputs=True)
if overload.function.func.is_out_fn():
if sig in outplaces:
raise RuntimeError(
@ -1044,11 +1021,9 @@ def group_overloads(
and not overload.signature.deprecated
):
candidates.append(
overload.signature.signature_str(
skip_outputs=True, symint=symint
)
overload.signature.signature_str(skip_outputs=True)
)
out_sig = out.signature.signature_str(symint=symint)
out_sig = out.signature.signature_str()
raise RuntimeError(
f"While identifying overloads, we found an out schema {out_sig} without a corresponding non-out variant. "
f"We expected the non-out variant to have schema: \n- {sig}\nPlease check that you spelled the schema "
@ -1063,7 +1038,7 @@ def group_overloads(
)
for sig, base in bases.items()
]
return sort_overloads(grouped, symint=symint)
return sort_overloads(grouped)
# This function declares a partial order on declarations, and sorts them according
@ -1112,7 +1087,7 @@ def group_overloads(
def sort_overloads(
grouped_overloads: Sequence[PythonSignatureGroup], *, symint: bool = True
grouped_overloads: Sequence[PythonSignatureGroup],
) -> Sequence[PythonSignatureGroup]:
# NB: Smaller here means lower priority
@ -1157,7 +1132,7 @@ def sort_overloads(
# First sort by signature
grouped_overloads = sorted(
grouped_overloads, key=lambda x: x.signature.signature_str(symint=symint)
grouped_overloads, key=lambda x: x.signature.signature_str()
)
# Construct the relation graph
@ -1195,11 +1170,7 @@ def sort_overloads(
def emit_single_dispatch(
ps: PythonSignature,
f: NativeFunction,
namedtuple_typenames: Dict[str, str],
*,
symint: bool = True,
ps: PythonSignature, f: NativeFunction, namedtuple_typenames: Dict[str, str]
) -> str:
"""
Emit dispatch code for a single native function.
@ -1218,10 +1189,7 @@ def emit_single_dispatch(
# dispatch lambda signature
name = cpp.name(f.func)
lambda_formals = ", ".join(
map(
lambda a: f"{a.type_str} {a.name}",
dispatch_lambda_args(ps, f, symint=symint),
)
map(lambda a: f"{a.type_str} {a.name}", dispatch_lambda_args(ps, f))
)
lambda_return = dispatch_lambda_return_str(f)
@ -1230,8 +1198,8 @@ def emit_single_dispatch(
dispatch_args = ", ".join(cpp_dispatch_exprs(f, python_signature=ps))
# from arg parser outputs to dispatch lambda arguments
parser_outputs = arg_parser_output_exprs(ps, f, symint=symint)
lambda_arg_exprs = dispatch_lambda_exprs(ps, f, symint=symint)
parser_outputs = arg_parser_output_exprs(ps, f)
lambda_arg_exprs = dispatch_lambda_exprs(ps, f)
inits = "\n".join(lambda_arg_exprs.inits)
lambda_args = ", ".join(lambda_arg_exprs.exprs)

View File

@ -383,7 +383,7 @@ def declare_returned_variables(f: NativeFunction) -> str:
return ""
if len(f.func.returns) == 1:
return ""
types = [cpp.return_type(r, symint=True) for r in f.func.returns]
types = map(cpp.return_type, f.func.returns)
names = cpp.return_names(f)
return "\n".join(f"{type.cpp_type()} {name};" for type, name in zip(types, names))
@ -483,13 +483,13 @@ def method_definition(f: NativeFunction) -> str:
# See Note [Plumbing Keys Through The Dispatcher] for details.
["c10::DispatchKeySet ks"]
+ [
f'{cpp.argument_type(a, binds="__placeholder__", symint=True).cpp_type()} {a.name}'
f'{cpp.argument_type(a, binds="__placeholder__").cpp_type()} {a.name}'
for a in f.func.schema_order_arguments()
]
)
return METHOD_DEFINITION.substitute(
return_type=cpp.returns_type(f.func.returns, symint=True).cpp_type(),
return_type=cpp.returns_type(f.func.returns).cpp_type(),
type_wrapper_name=type_wrapper_name(f),
formals=formals,
type_definition_body=emit_trace_body(f),

View File

@ -76,39 +76,31 @@ def process_function(f: NativeFunction) -> Optional[str]:
if Variant.function not in f.variants or not is_factory:
return None
cpp_sigs = CppSignatureGroup.from_native_function(f, method=False)
sigs = [cpp_sigs.signature]
if cpp_sigs.symint_signature is not None:
sigs.append(cpp_sigs.symint_signature)
r = ""
for sig in sigs:
formals: List[str] = []
exprs: List[str] = []
requires_grad = "false"
for arg in sig.arguments():
qualified_type = fully_qualified_type(arg.type)
if arg.default:
formals.append(f"{qualified_type} {arg.name} = {arg.default}")
else:
formals.append(f"{qualified_type} {arg.name}")
sig = CppSignatureGroup.from_native_function(f, method=False).signature
formals: List[str] = []
exprs: List[str] = []
requires_grad = "false"
for arg in sig.arguments():
qualified_type = fully_qualified_type(arg.type)
if arg.default:
formals.append(f"{qualified_type} {arg.name} = {arg.default}")
else:
formals.append(f"{qualified_type} {arg.name}")
if isinstance(arg.argument, TensorOptionsArguments):
# note: we remove the requires_grad setting from the TensorOptions because
# it is ignored anyways (and we actually have an assertion that it isn't set
# which would fail otherwise). We handle requires_grad explicitly here
# instead of passing it through to the kernel.
exprs.append(
f"at::TensorOptions({arg.name}).requires_grad(c10::nullopt)"
)
# Manually set the requires_grad bit on the result tensor.
requires_grad = f"{arg.name}.requires_grad()"
else:
exprs.append(arg.name)
if isinstance(arg.argument, TensorOptionsArguments):
# note: we remove the requires_grad setting from the TensorOptions because
# it is ignored anyways (and we actually have an assertion that it isn't set
# which would fail otherwise). We handle requires_grad explicitly here
# instead of passing it through to the kernel.
exprs.append(f"at::TensorOptions({arg.name}).requires_grad(c10::nullopt)")
# Manually set the requires_grad bit on the result tensor.
requires_grad = f"{arg.name}.requires_grad()"
else:
exprs.append(arg.name)
r += f"""\
inline at::Tensor {sig.name()}({', '.join(formals)}) {{
return f"""\
inline at::Tensor {name}({', '.join(formals)}) {{
at::AutoDispatchBelowADInplaceOrView guard;
return autograd::make_variable(at::{sig.name()}({', '.join(exprs)}), /*requires_grad=*/{requires_grad});
return autograd::make_variable(at::{name}({', '.join(exprs)}), /*requires_grad=*/{requires_grad});
}}
"""
return r

View File

@ -849,9 +849,7 @@ def gen_variable_type_func(
if not fn.info:
key = "Default"
type_definition = METHOD_DEFINITION.substitute(
return_type=cpp.returns_type(
f.func.returns, symint=True
).cpp_type(),
return_type=cpp.returns_type(f.func.returns).cpp_type(),
type_wrapper_name=type_wrapper_name(f, key),
type_definition_body=emit_body(fn, key),
formals=formals,
@ -862,9 +860,7 @@ def gen_variable_type_func(
else:
for key, _ in fn.info.items():
type_definition = METHOD_DEFINITION.substitute(
return_type=cpp.returns_type(
f.func.returns, symint=True
).cpp_type(),
return_type=cpp.returns_type(f.func.returns).cpp_type(),
type_wrapper_name=type_wrapper_name(f, key),
type_definition_body=emit_body(fn, key),
formals=formals,
@ -917,7 +913,7 @@ def emit_body(
# TODO: `cpp_type` is only to keep it byte-for-byte compatible with the old codegen, should remove.
# NB: This is not a clone of cpp.argument() - TensorOptionsArguments / faithful / binds are
# not handled properly as they are irrelevant for this codegen.
cpp_type = cpp.argument_type(a, binds=a.name, symint=True).cpp_type()
cpp_type = cpp.argument_type(a, binds=a.name).cpp_type()
if not is_differentiable(a.name, a.type, info):
return None
@ -1208,7 +1204,6 @@ def emit_body(
api_name=cpp.name(
f.func,
faithful_name_for_out_overloads=True,
symint_overload=f.func.has_symint(),
),
unpacked_args=[dispatch_key_set] + list(unpacked_args),
)
@ -1290,7 +1285,7 @@ def emit_body(
for i, (ret, ret_name) in enumerate(
zip(f.func.returns, cpp.return_names(f))
):
noref_cpp_type = cpp.return_type(ret, symint=True).remove_const_ref()
noref_cpp_type = cpp.return_type(ret).remove_const_ref()
if noref_cpp_type == BaseCType(tensorT):
if aliased_arg_name is not None:
assert (

View File

@ -184,9 +184,7 @@ def create_derivative(
]
return_names = tuple(n if n != "self" else "result" for n in cpp.return_names(f))
return_types = tuple(
cpp.return_type(r, symint=True).remove_const_ref() for r in f.func.returns
)
return_types = tuple(cpp.return_type(r).remove_const_ref() for r in f.func.returns)
named_returns = [
NamedCType(name, type) for name, type in zip(return_names, return_types)
@ -377,7 +375,7 @@ def postprocess_forward_derivatives(
new_args.append(arg_name)
# TODO we are trolling
if f.func.has_symint():
if f.func.is_symint_fn():
defn_name += "_symint"
# Call into the forward again. We need two cases here to handle both Tensor methods and at:: functions.

View File

@ -314,7 +314,6 @@ class TestGenNativeFunctionDeclaration(unittest.TestCase):
dispatch_key=k,
use_out_as_primary=True,
external=False,
symint=False,
device_guard=False,
index=backend_indices[k],
)

View File

@ -612,7 +612,7 @@ class FakeTensorMode(TorchDispatchMode):
torch._C._remove_meta_from_tls_dispatch_include()
if has_symbolic_sizes:
constructors = [aten.empty.memory_format]
constructors = [aten.empty.SymInt]
if func not in constructors:
raise RuntimeError(
f"{func} - couldn't find symbolic meta function/decomposition"

View File

@ -843,7 +843,8 @@ RegisterOperators reg_expand_copy({
"alias ops, should be restored after fusion pass!");
IValue self, size, implicit;
pop(stack, self, size, implicit);
push(stack, self.toTensor().expand(size.toIntVector()));
push(
stack, at::native::expand(self.toTensor(), size.toIntVector()));
};
},
aliasAnalysisFromSchema()),

View File

@ -411,8 +411,7 @@ SchemaTypeParser::parseFakeAndRealType() {
real_value = parseBaseType();
if (real_value->kind() == ScalarTypeType::Kind ||
real_value->kind() == MemoryFormatType::Kind ||
real_value->kind() == LayoutType::Kind ||
real_value->kind() == SymIntType::Kind) {
real_value->kind() == LayoutType::Kind) {
fake_value = c10::TypeFactory::get<IntType>();
} else {
fake_value = real_value;

View File

@ -80,57 +80,31 @@ IValue toIValue(py::handle obj, const TypePtr& type, c10::optional<int32_t> N) {
auto c_obj = py::cast<std::complex<double>>(obj.ptr());
return static_cast<c10::complex<double>>(c_obj);
}
case TypeKind::IntType:
// TODO: Properly fake this type
if (THPQScheme_Check(obj.ptr())) {
auto qscheme = reinterpret_cast<THPQScheme*>(obj.ptr());
return static_cast<uint8_t>(qscheme->qscheme);
}
// For backwards compatibility
if (THPDtype_Check(obj.ptr())) {
auto dtype = reinterpret_cast<THPDtype*>(obj.ptr());
return static_cast<int64_t>(dtype->scalar_type);
}
if (THPQScheme_Check(obj.ptr())) {
auto qscheme = reinterpret_cast<THPQScheme*>(obj.ptr());
return static_cast<uint8_t>(qscheme->qscheme);
}
if (THPLayout_Check(obj.ptr())) {
auto layout = reinterpret_cast<THPLayout*>(obj.ptr());
return static_cast<int8_t>(layout->layout);
}
if (THPMemoryFormat_Check(obj.ptr())) {
auto memory_format = reinterpret_cast<THPMemoryFormat*>(obj.ptr());
return static_cast<int8_t>(memory_format->memory_format);
}
return py::cast<int64_t>(obj);
case TypeKind::LayoutType: {
if (THPLayout_Check(obj.ptr())) {
auto layout = reinterpret_cast<THPLayout*>(obj.ptr());
return static_cast<int8_t>(layout->layout);
}
// For backwards compatibility
return py::cast<int64_t>(obj);
}
case TypeKind::ScalarTypeType: {
if (THPDtype_Check(obj.ptr())) {
auto dtype = reinterpret_cast<THPDtype*>(obj.ptr());
return static_cast<int64_t>(dtype->scalar_type);
}
// For backwards compatibility
return py::cast<int64_t>(obj);
}
case TypeKind::MemoryFormatType: {
if (THPMemoryFormat_Check(obj.ptr())) {
auto memory_format = reinterpret_cast<THPMemoryFormat*>(obj.ptr());
return static_cast<int8_t>(memory_format->memory_format);
}
// For backwards compatibility
return py::cast<int64_t>(obj);
}
case TypeKind::SymIntType:
if (torch::is_symint_node(obj.ptr())) {
return py::cast<c10::SymInt>(obj);
return py::cast<c10::SymInt>(obj);
case TypeKind::IntType:
// NB: Typically, these switches are completely dead, because
// Argument::type() will always report IntType for these types.
// So this is a bit overly permissive: we'll accept a dtype
// passed to an int argument, for example.
case TypeKind::LayoutType:
case TypeKind::ScalarTypeType:
case TypeKind::MemoryFormatType:
if (THPDtype_Check(obj.ptr())) {
auto dtype = reinterpret_cast<THPDtype*>(obj.ptr());
return static_cast<int64_t>(dtype->scalar_type);
}
if (THPQScheme_Check(obj.ptr())) {
auto qscheme = reinterpret_cast<THPQScheme*>(obj.ptr());
return static_cast<uint8_t>(qscheme->qscheme);
}
if (THPLayout_Check(obj.ptr())) {
auto layout = reinterpret_cast<THPLayout*>(obj.ptr());
return static_cast<int8_t>(layout->layout);
}
if (THPMemoryFormat_Check(obj.ptr())) {
auto memory_format = reinterpret_cast<THPMemoryFormat*>(obj.ptr());
return static_cast<int8_t>(memory_format->memory_format);
}
return py::cast<int64_t>(obj);
case TypeKind::NoneType:

View File

@ -646,7 +646,7 @@ inline IValue argumentToIValue(
py::handle object) {
const auto& argument = schema.arguments().at(argumentPosition);
try {
return toIValue(object, argument.real_type(), argument.N());
return toIValue(object, argument.type(), argument.N());
} catch (const py::cast_error& error) {
throw schema_match_error(c10::str(
schema.formatTypeMismatchMsg(

View File

@ -39,8 +39,6 @@
#include <mutex>
#include <unordered_map>
#include <ATen/CompositeExplicitAutogradFunctions.h>
C10_DEFINE_bool(
static_runtime_enable_fast_math,
true,
@ -77,7 +75,7 @@ void repeat_out(at::Tensor& result, const Tensor& self, IntArrayRef repeats) {
return;
}
Tensor xtensor = at::compositeexplicitautograd::expand(self, padded_size);
Tensor xtensor = at::native::expand(self, padded_size);
Tensor urtensor = at::native::alias(result);
for (const auto i : c10::irange(xtensor.dim())) {
// can't unfold with step 0, so make sure step is at least 1
@ -2528,13 +2526,12 @@ REGISTER_OPERATOR_FUNCTOR(aten::zeros, aten_zeros, [](Node* n) -> SROperator {
const auto dtype = p_node->Input(1).toOptional<c10::ScalarType>();
const auto layout = p_node->Input(2).toOptional<c10::Layout>();
if (!hasTensorWithOptions(p_node->Output(0), dtype, layout)) {
p_node->Output(0) = at::compositeexplicitautograd::zeros(
size, dtype, layout, c10::nullopt, c10::nullopt);
p_node->Output(0) = at::native::zeros(size, dtype, layout);
return;
}
auto& out_t = p_node->Output(0).toTensor();
fastResizeToZero(out_t);
at::compositeexplicitautograd::zeros_out(out_t, size);
at::native::zeros_out(size, out_t);
};
});

View File

@ -5,7 +5,6 @@
#include <torch/csrc/lazy/backend/backend_interface.h>
#include <torch/csrc/lazy/core/config.h>
#include <torch/csrc/lazy/core/ir.h>
#include <torch/csrc/lazy/core/tensor.h>
#include <torch/csrc/lazy/core/trie.h>
#include <vector>
@ -267,21 +266,5 @@ static inline NodePtr MakeSizeDiv(const Value& a, const Value& b) {
return getIrBuilder()->MakeSizeDiv(a, b);
}
inline Value GetSymIntValue(c10::SymInt a) {
return Value(
dynamic_cast<torch::lazy::SymIntNodeImpl*>(a.toSymIntNodeImpl().get())
->node_,
0);
}
// TODO: this should return Value
inline std::vector<int64_t> GetSymIntArrayRefValue(c10::SymIntArrayRef arr) {
std::vector<int64_t> r;
for (const auto& a : arr) {
r.emplace_back(a.expect_int());
}
return r;
}
} // namespace lazy
} // namespace torch

View File

@ -269,15 +269,30 @@ at::Tensor LazyNativeFunctions::_to_copy(
}
};
at::Tensor LazyNativeFunctions::empty(
at::SymIntArrayRef sym_size,
at::Tensor LazyNativeFunctions::empty_symint(
c10::SymIntArrayRef size,
c10::optional<at::ScalarType> dtype,
c10::optional<at::Layout> layout,
c10::optional<at::Device> device,
c10::optional<bool> pin_memory,
c10::optional<at::MemoryFormat> memory_format) {
// TODO: support SymIntNodes as well
return empty(
c10::asIntArrayRefSlow(size),
dtype,
layout,
device,
pin_memory,
memory_format);
}
at::Tensor LazyNativeFunctions::empty(
at::IntArrayRef size,
c10::optional<at::ScalarType> dtype,
c10::optional<at::Layout> layout,
c10::optional<at::Device> device,
c10::optional<bool> pin_memory,
c10::optional<at::MemoryFormat> memory_format) {
// TODO: support this directly
auto size = c10::asIntArrayRefSlow(sym_size);
const auto device_type = torch::lazy::getBackend()->EagerFallbackDeviceType();
at::TensorOptions options = at::TensorOptions()
.device(c10::Device(device_type))
@ -307,13 +322,7 @@ at::Tensor LazyNativeFunctions::empty_strided(
c10::optional<at::Device> device,
c10::optional<bool> pin_memory) {
TORCH_LAZY_FN_COUNTER("lazy::");
at::Tensor t = empty(
c10::SymIntArrayRef::fromIntArrayRef(size),
dtype,
layout,
device,
pin_memory,
c10::nullopt);
at::Tensor t = empty(size, dtype, layout, device, pin_memory, c10::nullopt);
return t.as_strided(size, stride, /*storage_offset=*/0);
}
@ -409,8 +418,7 @@ at::Tensor LazyNativeFunctions::_unsafe_view(
const at::Tensor& self,
at::IntArrayRef size) {
TORCH_LAZY_FN_COUNTER("lazy::");
return LazyNativeFunctions::view_copy(
self, c10::SymIntArrayRef::fromIntArrayRef(size));
return LazyNativeFunctions::view_copy(self, size);
}
// This is needed by the torch.tensor constructor.
@ -452,8 +460,8 @@ at::Tensor LazyNativeFunctions::new_empty_strided(
at::Tensor LazyNativeFunctions::narrow_copy(
const at::Tensor& self,
int64_t dim,
c10::SymInt start,
c10::SymInt length) {
int64_t start,
int64_t length) {
return at::functionalization::functionalize_aten_op<ATEN_OP(
narrow_copy)>::call(self, dim, start, length);
}

View File

@ -538,9 +538,7 @@ def gen_differentiable_outputs(
info = fn.info[key] if fn.info else None
outputs: List[DifferentiableOutput] = [
DifferentiableOutput(
name=name,
type=ret.type,
cpp_type=cpp.return_type(ret, symint=True).cpp_type(),
name=name, type=ret.type, cpp_type=cpp.return_type(ret).cpp_type()
)
for name, ret in zip(cpp.return_names(f), f.func.returns)
]

View File

@ -64,14 +64,9 @@ from torchgen.utils import assert_never
# collisions, but functions are fair game to collide
def name(
func: FunctionSchema,
*,
faithful_name_for_out_overloads: bool = False,
symint_overload: bool = False,
) -> str:
def name(func: FunctionSchema, *, faithful_name_for_out_overloads: bool = False) -> str:
name = str(func.name.name)
if symint_overload:
if func.is_symint_fn():
name += "_symint"
if func.is_out_fn():
if faithful_name_for_out_overloads:
@ -86,20 +81,11 @@ def name(
# types look the same no matter if they are argument types or return
# types. Returns None if the type in question is not a value type.
def valuetype_type(
t: Type,
*,
binds: ArgName,
remove_non_owning_ref_types: bool = False,
symint: bool = False,
t: Type, *, binds: ArgName, remove_non_owning_ref_types: bool = False
) -> Optional[NamedCType]:
if isinstance(t, BaseType):
if t.name == BaseTy.Tensor or t.name == BaseTy.Scalar:
return None
elif str(t) == "SymInt":
if symint:
return NamedCType(binds, BaseCType(SymIntT))
else:
return NamedCType(binds, BaseCType(longT))
if remove_non_owning_ref_types:
if t.name == BaseTy.str:
raise AssertionError(
@ -108,7 +94,7 @@ def valuetype_type(
# All other BaseType currently map directly to BaseCppTypes.
return NamedCType(binds, BaseCType(BaseTypeToCppMapping[t.name]))
elif isinstance(t, OptionalType):
elem = valuetype_type(t.elem, binds=binds, symint=symint)
elem = valuetype_type(t.elem, binds=binds)
if elem is None:
return None
return NamedCType(binds, OptionalCType(elem.type))
@ -127,19 +113,11 @@ def valuetype_type(
# For example, we'll return std::vector<int> instead of IntArrayRef.
# See Note [translation from C++ reference to value types]
def argumenttype_type(
t: Type,
*,
mutable: bool,
binds: ArgName,
remove_non_owning_ref_types: bool = False,
symint: bool = False,
t: Type, *, mutable: bool, binds: ArgName, remove_non_owning_ref_types: bool = False
) -> NamedCType:
# If it's a value type, do the value type translation
r = valuetype_type(
t,
binds=binds,
symint=symint,
remove_non_owning_ref_types=remove_non_owning_ref_types,
t, binds=binds, remove_non_owning_ref_types=remove_non_owning_ref_types
)
if r is not None:
return r
@ -168,7 +146,7 @@ def argumenttype_type(
return NamedCType(binds, ConstRefCType(OptionalCType(BaseCType(scalarT))))
elif isinstance(t.elem, ListType) and str(t.elem.elem) == "int":
return NamedCType(binds, BaseCType(optionalIntArrayRefT))
elem = argumenttype_type(t.elem, mutable=mutable, binds=binds, symint=symint)
elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
return NamedCType(binds, OptionalCType(elem.type))
elif isinstance(t, ListType):
# TODO: remove these special cases, ArrayRef fallthrough works fine
@ -179,16 +157,10 @@ def argumenttype_type(
return NamedCType(binds, BaseCType(intArrayRefT))
if str(t.elem) == "SymInt":
if remove_non_owning_ref_types:
if symint:
return NamedCType(binds, VectorCType(BaseCType(SymIntT)))
else:
return NamedCType(binds, VectorCType(BaseCType(longT)))
return NamedCType(binds, VectorCType(BaseCType(SymIntT)))
else:
if symint:
return NamedCType(binds, BaseCType(symIntArrayRefT))
else:
return NamedCType(binds, BaseCType(intArrayRefT))
if str(t.elem) == "Tensor":
return NamedCType(binds, BaseCType(symIntArrayRefT))
elif str(t.elem) == "Tensor":
return NamedCType(binds, BaseCType(tensorListT))
elif str(t.elem) == "Scalar":
return NamedCType(binds, ArrayRefCType(BaseCType(scalarT)))
@ -198,15 +170,15 @@ def argumenttype_type(
return NamedCType(
binds, ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT))))
)
elem = argumenttype_type(t.elem, mutable=mutable, binds=binds, symint=symint)
elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
return NamedCType(binds, ArrayRefCType(elem.type))
else:
raise AssertionError(f"unrecognized type {repr(t)}")
# Translate a JIT argument into its C++ type
def argument_type(a: Argument, *, binds: ArgName, symint: bool = False) -> NamedCType:
return argumenttype_type(a.type, mutable=a.is_write, symint=symint, binds=binds)
def argument_type(a: Argument, *, binds: ArgName) -> NamedCType:
return argumenttype_type(a.type, mutable=a.is_write, binds=binds)
# Translation of a (non-multi) return type from JIT to C++
@ -214,9 +186,9 @@ def argument_type(a: Argument, *, binds: ArgName, symint: bool = False) -> Named
# This is mostly because of the mismatch between return types and return names.
# e.g. a function with a return type of 'void' has 0 return names,
# and a function with a return type of 'std::tuple' has >1 return name.
def returntype_type(t: Type, *, mutable: bool, symint: bool = False) -> CType:
def returntype_type(t: Type, *, mutable: bool) -> CType:
# placeholder is ignored
r = valuetype_type(t, binds="__placeholder__", symint=symint)
r = valuetype_type(t, binds="__placeholder__")
if r is not None:
return r.type
@ -239,7 +211,7 @@ def returntype_type(t: Type, *, mutable: bool, symint: bool = False) -> CType:
assert (
not mutable
), "Native functions should never return a mutable tensor list. They should return void."
elem = returntype_type(t.elem, mutable=False, symint=symint)
elem = returntype_type(t.elem, mutable=False)
assert t.size is None, f"fixed size list returns not supported: {t}"
return VectorCType(elem)
@ -247,18 +219,18 @@ def returntype_type(t: Type, *, mutable: bool, symint: bool = False) -> CType:
# Translation of a single return to its C++ type
def return_type(r: Return, *, symint: bool = False) -> CType:
return returntype_type(r.type, mutable=r.is_write, symint=symint)
def return_type(r: Return) -> CType:
return returntype_type(r.type, mutable=r.is_write)
# Translation of a full (possibly multi) return from JIT to its C++ type
def returns_type(rs: Sequence[Return], *, symint: bool = False) -> CType:
def returns_type(rs: Sequence[Return]) -> CType:
if len(rs) == 0:
return BaseCType(voidT)
elif len(rs) == 1:
return return_type(rs[0], symint=symint)
return return_type(rs[0])
else:
return TupleCType([return_type(r, symint=symint) for r in rs])
return TupleCType([return_type(r) for r in rs])
def return_names(f: NativeFunction, *, fallback_name: str = "result") -> Sequence[str]:
@ -353,7 +325,6 @@ def argument(
cpp_no_default_args: Set[str],
method: bool,
faithful: bool,
symint: bool = False,
has_tensor_options: bool,
) -> List[Binding]:
def sub_argument(
@ -364,7 +335,6 @@ def argument(
cpp_no_default_args=cpp_no_default_args,
method=method,
faithful=faithful,
symint=symint,
has_tensor_options=has_tensor_options,
)
@ -379,7 +349,7 @@ def argument(
default = default_expr(a.default, a.type)
return [
Binding(
nctype=argument_type(a, binds=binds, symint=symint),
nctype=argument_type(a, binds=binds),
name=a.name,
default=default,
argument=a,
@ -420,12 +390,7 @@ def argument(
def arguments(
arguments: Arguments,
*,
faithful: bool,
symint: bool = False,
method: bool,
cpp_no_default_args: Set[str],
arguments: Arguments, *, faithful: bool, method: bool, cpp_no_default_args: Set[str]
) -> List[Binding]:
args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
if faithful:
@ -440,7 +405,6 @@ def arguments(
for r in argument(
a,
faithful=faithful,
symint=symint,
method=method,
has_tensor_options=arguments.tensor_options is not None,
cpp_no_default_args=cpp_no_default_args,

View File

@ -45,7 +45,6 @@ def argumenttype_type(
t,
mutable=mutable,
binds=binds,
symint=True,
remove_non_owning_ref_types=remove_non_owning_ref_types,
)
@ -63,7 +62,7 @@ def argument_type(
def returns_type(rs: Sequence[Return]) -> CType:
# At present, there is no difference. But there could be!
return cpp.returns_type(rs, symint=True)
return cpp.returns_type(rs)
def jit_arguments(func: FunctionSchema) -> List[Argument]:

View File

@ -37,14 +37,6 @@ from torchgen.model import (
_valueT = None
# A ValueT is an IR type which represents the computation of a Tensor. In other
# words, a PyTorch user will do operations on lazy tensors, and each output lazy
# tensor internally tracks a ValueT representing the IR node that would have
# actually produced the value of this tensor for real.
#
# This is configurable because different lazy tensor backends (LTC vs XLA) will
# have different IR representations. (Though, arguably, after unification they
# shouldn't!)
def getValueT() -> BaseCppType:
global _valueT
if not _valueT:
@ -121,27 +113,12 @@ def process_ir_type(
elif str(typ.elem) == "Tensor":
# this is a TensorList which comes in from GetTensorList as a Value
return BaseCType(tensorListValueT)
elif typ.elem == BaseType(BaseTy.SymInt):
# TODO: return a value type. The problem here is analogous to
# the problem with tensorListValueT: if you have SymInt[] you
# cannot conveniently save the list of Value directly, as nodes
# expect to save values as a vector for ALL arguments. So you
# need a separate IR node that represents all of the size nodes
# assembled into a list. I'm not an LTC dev so I don't want to
# figure it out right now. Y'all figure it out...
return VectorCType(BaseCType(longT))
else:
return VectorCType(process_ir_type(typ.elem, properties))
else:
raise AssertionError(f"unrecognized type {repr(typ)}")
# TODO: Determining this based off of CType is bad; this should be computed
# from Type directly; then the same logic as process_ir_type can be used
#
# Invariant: passed typ should be an *owning* CType (e.g., we will report
# that ArrayRef<Value> is NOT a value type)
def isValueType(typ: CType, properties: "Optional[LazyIrProperties]" = None) -> bool:
"""
Given a type, determine if it is a Value-like type. This is equivalent to
@ -156,9 +133,6 @@ def isValueType(typ: CType, properties: "Optional[LazyIrProperties]" = None) ->
or (typ.type == scalarT and not treat_scalars_as_constants)
or typ.type == SymIntT
)
elif typ == VectorCType(BaseCType(SymIntT)):
# TODO: report True for this
return False
elif isinstance(typ, (OptionalCType, ListCType, VectorCType)):
return isValueType(typ.elem, properties)
return False
@ -183,7 +157,6 @@ def isWrappedScalarType(typ: Type) -> bool:
return False
# TODO: dedupe with Type.is_generator_like
def isGeneratorType(typ: Type) -> bool:
if isinstance(typ, BaseType):
return typ.name == BaseTy.Generator
@ -192,15 +165,12 @@ def isGeneratorType(typ: Type) -> bool:
return False
# This class caches a few derived properties computed from an Argument
# and LazyIrProperties
class LazyArgument:
name: str
orig_type: Type
lazy_type_: Optional[CType]
is_wrapped_scalar: bool
is_generator: bool
# TODO: this is lies, it is false for symint list
is_symint_or_list: bool
# true if this argument is or contains a lazy IR value
@ -222,11 +192,7 @@ class LazyArgument:
else:
self.lazy_type_ = process_ir_type(arg.type, properties)
self.is_wrapped_scalar = isWrappedScalarType(arg.type)
self.is_symint_or_list = (
isSymIntType(arg.type)
# TODO: lists of symints are not currently treated as value types
# or (isinstance(arg.type, ListType) and isSymIntType(arg.type.elem))
)
self.is_symint_or_list = isSymIntType(arg.type)
self.is_lazy_value = not self.is_generator and isValueType(
self.lazy_type, properties
@ -302,8 +268,6 @@ class LazyIrProperties:
# Unlike a FunctionSchema, it has no round-trippable string form (relating to the YAML),
# but carries type information from a native FunctionSchema modified for use with IR nodes,
# and preserving original argument names.
#
# TODO: This is not idiomatic with how other torchgen APIs transform on schema.
class LazyIrSchema:
# The name of the operator this function schema describes.
name: "OperatorName"

View File

@ -37,9 +37,6 @@ from torchgen.utils import assert_never
# native:: kernels. The intention is to make native API and dispatcher API
# line up as closely as possible, since this results in the least overhead
# (no translation is needed from dispatcher API to native API).
#
# NB: this is symint aware, you will get the non-SymInt variant for some
# dispatch entries and SymInt for others.
def name(func: FunctionSchema) -> str:
@ -52,9 +49,7 @@ def name(func: FunctionSchema) -> str:
return name
def argumenttype_type(
t: Type, *, mutable: bool, binds: ArgName, symint: bool
) -> NamedCType:
def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> NamedCType:
if str(t) == "Tensor?":
tensor_type: OptionalCType = OptionalCType(BaseCType(tensorT))
if mutable and not local.use_const_ref_for_mutable_tensors():
@ -69,22 +64,19 @@ def argumenttype_type(
return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
elif str(t) == "Scalar?":
return NamedCType(binds, ConstRefCType(OptionalCType(BaseCType(scalarT))))
return cpp.argumenttype_type(t, mutable=mutable, binds=binds, symint=symint)
return cpp.argumenttype_type(t, mutable=mutable, binds=binds)
def returns_type(rs: Sequence[Return], *, symint: bool) -> CType:
return cpp.returns_type(rs, symint=symint)
def returns_type(rs: Sequence[Return]) -> CType:
return cpp.returns_type(rs)
def argument_type(a: Argument, *, binds: ArgName, symint: bool) -> NamedCType:
return argumenttype_type(a.type, mutable=a.is_write, binds=binds, symint=symint)
def argument_type(a: Argument, *, binds: ArgName) -> NamedCType:
return argumenttype_type(a.type, mutable=a.is_write, binds=binds)
def argument(
a: Union[Argument, SelfArgument, TensorOptionsArguments],
*,
is_out: bool,
symint: bool,
a: Union[Argument, SelfArgument, TensorOptionsArguments], *, is_out: bool
) -> List[Binding]:
# Ideally, we NEVER default native functions. However, there are a number
# of functions that call native:: directly and rely on the defaulting
@ -98,7 +90,7 @@ def argument(
default = cpp.default_expr(a.default, a.type)
return [
Binding(
nctype=argument_type(a, binds=a.name, symint=symint),
nctype=argument_type(a, binds=a.name),
name=a.name,
default=default,
argument=a,
@ -106,7 +98,7 @@ def argument(
]
elif isinstance(a, SelfArgument):
# Erase SelfArgument from the distinction
return argument(a.argument, is_out=is_out, symint=symint)
return argument(a.argument, is_out=is_out)
elif isinstance(a, TensorOptionsArguments):
default = None
if should_default:
@ -144,10 +136,8 @@ def argument(
assert_never(a)
def arguments(func: FunctionSchema, *, symint: bool) -> List[Binding]:
def arguments(func: FunctionSchema) -> List[Binding]:
args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
args.extend(func.arguments.non_out)
args.extend(func.arguments.out)
return [
r for arg in args for r in argument(arg, symint=symint, is_out=func.is_out_fn())
]
return [r for arg in args for r in argument(arg, is_out=func.is_out_fn())]

View File

@ -216,12 +216,8 @@ class PythonArgument:
# Compute argument formal for python argument parsing.
# Needs to be consistent with torch/csrc/utils/python_arg_parser.h.
def argument_str(self, *, method: bool = False, symint: bool = True) -> str:
type_str = (
argument_type_str(self.type, symint=symint)
.replace("const ", "")
.replace(" &", "")
)
def argument_str(self, *, method: bool = False) -> str:
type_str = argument_type_str(self.type).replace("const ", "").replace(" &", "")
name = self.name
# s/self/input/ outside method bindings
@ -388,10 +384,10 @@ class PythonSignature:
#
# For a translation to mypy-valid type signatures, see
# signature_str_pyi().
def signature_str(self, *, skip_outputs: bool = False, symint: bool = True) -> str:
def signature_str(self, *, skip_outputs: bool = False) -> str:
args = self.arguments(skip_outputs=skip_outputs)
schema_formals: List[str] = list(
map(lambda a: a.argument_str(method=self.method, symint=symint), args)
map(lambda a: a.argument_str(method=self.method), args)
)
positional_argc = len(self.input_args)
if len(schema_formals) > positional_argc:
@ -430,7 +426,7 @@ class PythonSignature:
vararg_type = args[0].type
if (
isinstance(vararg_type, ListType)
and str(vararg_type.elem) in ["int", "SymInt"]
and str(vararg_type.elem) == "int"
and num_positionalargs == 1
):
have_vararg_version = True
@ -468,11 +464,9 @@ class PythonSignatureDeprecated(PythonSignature):
def deprecated(self) -> bool:
return True
def signature_str(self, *, skip_outputs: bool = False, symint: bool = True) -> str:
def signature_str(self, *, skip_outputs: bool = False) -> str:
return (
PythonSignature.signature_str(
self, skip_outputs=skip_outputs, symint=symint
)
PythonSignature.signature_str(self, skip_outputs=skip_outputs)
+ "|deprecated"
)
@ -639,9 +633,7 @@ def has_tensor_options(f: NativeFunction) -> bool:
# 'simple_type' was introduced by the old codegen, which is slightly
# different from the python schema type, e.g.: doesn't have '?' suffix
# for optional Tensor/TensorList; doesn't have '[size]' suffix for list type.
def argument_type_str(
t: Type, *, simple_type: bool = False, symint: bool = True
) -> str:
def argument_type_str(t: Type, *, simple_type: bool = False) -> str:
if isinstance(t, BaseType):
if t.name == BaseTy.Tensor:
return "Tensor"
@ -673,7 +665,7 @@ def argument_type_str(
if str(t.elem) == "Tensor":
# Is it desired to keep '?' for simple_type with new style dispatcher?
return "Tensor?"
elem = argument_type_str(t.elem, simple_type=simple_type, symint=symint)
elem = argument_type_str(t.elem, simple_type=simple_type)
return f"{elem}?"
elif isinstance(t, ListType):
size = t.size if not simple_type else None
@ -683,12 +675,7 @@ def argument_type_str(
elif str(t.elem) == "int":
return f"IntArrayRef[{size}]" if size is not None else "IntArrayRef"
elif str(t.elem) == "SymInt":
if symint:
return (
f"SymIntArrayRef[{size}]" if size is not None else "SymIntArrayRef"
)
else:
return f"IntArrayRef[{size}]" if size is not None else "IntArrayRef"
return f"SymIntArrayRef[{size}]" if size is not None else "SymIntArrayRef"
elif str(t.elem) == "Tensor":
return f"TensorList[{size}]" if size is not None else "TensorList"
elif str(t.elem) == "Scalar":
@ -700,7 +687,7 @@ def argument_type_str(
return "const c10::List<c10::optional<Tensor>> &"
elif str(t.elem) == "Dimname":
return f"DimnameList[{size}]" if size is not None else "DimnameList"
elem = argument_type_str(t.elem, simple_type=simple_type, symint=symint)
elem = argument_type_str(t.elem, simple_type=simple_type)
return f"ArrayRef<{elem}>"
raise RuntimeError(f"unrecognized type {repr(t)}")
@ -911,7 +898,7 @@ def argument_type_str_pyi(t: Type) -> str:
if t.name == BaseTy.int:
ret = "_int"
if t.name == BaseTy.SymInt:
ret = "Union[_int, SymInt]"
ret = "SymInt"
elif t.name == BaseTy.float:
ret = "_float"
elif t.name == BaseTy.str:
@ -1053,7 +1040,7 @@ def returns_str_pyi(signature: PythonSignature) -> str:
def dispatch_lambda_args(
ps: PythonSignature, f: NativeFunction, symint: bool = True
ps: PythonSignature, f: NativeFunction
) -> Tuple[DispatchLambdaArgument, ...]:
if isinstance(ps, PythonSignatureDeprecated):
schema = ps.deprecated_schema
@ -1064,7 +1051,6 @@ def dispatch_lambda_args(
cpp_args = cpp.arguments(
arguments=schema.arguments,
faithful=False,
symint=symint,
method=False,
cpp_no_default_args=f.cpp_no_default_args,
)
@ -1147,15 +1133,14 @@ def dispatch_lambda_return_str(f: NativeFunction) -> str:
returns_without_annotation = tuple(
map(lambda r: Return(r.name, r.type, None), f.func.returns)
)
return_str = cpp.returns_type(returns_without_annotation, symint=True).cpp_type()
return_str = cpp.returns_type(returns_without_annotation).cpp_type()
if return_str not in SUPPORTED_RETURN_TYPES:
raise RuntimeError(f"{f.func.name} returns unsupported type {return_str}")
return return_str
def cpp_dispatch_target(f: NativeFunction) -> str:
symint = f.func.has_symint()
name = cpp.name(f.func, symint_overload=symint)
name = cpp.name(f.func)
if Variant.method in f.variants:
return f"self.{name}"
if Variant.function in f.variants:
@ -1207,7 +1192,7 @@ def cpp_dispatch_exprs(
# For certain cases it is intentionally more restrictive than necessary,
# e.g.: it doesn't accepts doublelist with definite size.
def arg_parser_unpack_method(
t: Type, default: Optional[str], default_init: Optional[str], *, symint: bool = True
t: Type, default: Optional[str], default_init: Optional[str]
) -> str:
has_default_init = default_init is not None
if has_default_init and str(t) not in (
@ -1239,10 +1224,7 @@ def arg_parser_unpack_method(
elif t.name == BaseTy.int:
return "toInt64"
elif t.name == BaseTy.SymInt:
if symint:
return "toSymInt"
else:
return "toInt64"
return "toSymInt"
elif t.name == BaseTy.bool:
return "toBoolWithDefault" if has_default_init else "toBool"
elif t.name == BaseTy.float:
@ -1263,14 +1245,10 @@ def arg_parser_unpack_method(
return "toDimnameListOptional"
elif not has_default_init and default in (None, "None", "c10::nullopt"):
# If default is None: append 'Optional' to elem's unpacking method
return (
arg_parser_unpack_method(t.elem, None, None, symint=symint) + "Optional"
)
return arg_parser_unpack_method(t.elem, None, None) + "Optional"
else:
# Otherwise, load as underlying type with default
return arg_parser_unpack_method(
t.elem, default, default_init, symint=symint
)
return arg_parser_unpack_method(t.elem, default, default_init)
elif isinstance(t, ListType):
if str(t.elem) == "Tensor":
@ -1291,10 +1269,7 @@ def arg_parser_unpack_method(
return "doublelist"
elif str(t.elem) == "SymInt":
# accept definite size
if symint:
return "symintlist"
else:
return "intlist"
return "symintlist"
elif str(t) == "Scalar[]":
return "scalarlist"
raise RuntimeError(f"type '{t}' is not supported by PythonArgParser")
@ -1303,11 +1278,11 @@ def arg_parser_unpack_method(
# Return RHS expression for python argument using PythonArgParser output.
# e.g. for arg name 'foo', arg type 'bool', arg_index = 2, returns '_r.toBool(2)'
def arg_parser_output_expr(
arg_index: int, a: PythonArgument, *, symint: bool = True
arg_index: int, a: PythonArgument
) -> PythonArgParserOutputExpr:
has_default = a.default_init is not None
unpack_method = arg_parser_unpack_method(
t=a.type, default=a.default, default_init=a.default_init, symint=symint
t=a.type, default=a.default, default_init=a.default_init
)
default = f", {a.default_init}" if has_default else ""
expr = f"_r.{unpack_method}({arg_index}{default})"
@ -1322,12 +1297,12 @@ def arg_parser_output_expr(
# Returns a map with key = arg_name and value = PythonArgParserOutputExpr.
def arg_parser_output_exprs(
ps: PythonSignature, f: NativeFunction, *, symint: bool = True
ps: PythonSignature, f: NativeFunction
) -> Dict[str, PythonArgParserOutputExpr]:
return {
e.name: e
for i, a in enumerate(ps.arguments())
for e in (arg_parser_output_expr(i, a, symint=symint),)
for e in (arg_parser_output_expr(i, a),)
}
@ -1342,13 +1317,13 @@ TENSOR_OPTIONS_FIELDS = {
# bind arg parser outputs (python args) with dispatch lambda arguments (c++ args).
def dispatch_lambda_exprs(
ps: PythonSignature, f: NativeFunction, *, symint: bool = True
ps: PythonSignature, f: NativeFunction
) -> DispatchLambdaArgumentExprs:
# This method is to bind 'arg_parser_outputs' and 'lambda_args' by producing
# 'inits' and 'lambda_args_exprs' for each lambda argument using arg parser
# outputs.
arg_parser_outputs = arg_parser_output_exprs(ps, f, symint=symint)
lambda_args = dispatch_lambda_args(ps, f, symint=symint)
arg_parser_outputs = arg_parser_output_exprs(ps, f)
lambda_args = dispatch_lambda_args(ps, f)
inits: List[str] = []
lambda_args_exprs: Dict[str, str] = {}

View File

@ -42,12 +42,7 @@ from torchgen.utils import assert_never
# some more nominal types
def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> NamedCType:
# If it's a value type, do the value type translation
# NB: structured kernels ALWAYS have symint off, since they involve actual
# kernels that require real ints. The one exception is the
# CompositeExplicitAutograd and the meta function (which could
# hypothetically be SymInt), but for simplicity we plan for these to just
# be handled in Python
r = cpp.valuetype_type(t, symint=False, binds=binds)
r = cpp.valuetype_type(t, binds=binds)
if r is not None:
return r

View File

@ -337,16 +337,10 @@ Check this module for more information.
)
return f"c10::asIntArrayRefSlow({symIntArrayRef_type})"
elif goal.type == BaseCType(symIntArrayRefT):
try:
r = direct_solve(NamedCType(goal.name, BaseCType(intArrayRefT)))
return f"c10::SymIntArrayRef::fromIntArrayRef({r})"
except UnsatError:
return direct_solve(NamedCType(goal.name, longSymVec_ctype))
elif goal.type == BaseCType(SymIntT):
return direct_solve(NamedCType(goal.name, BaseCType(longT)))
return direct_solve(NamedCType(goal.name, longSymVec_ctype))
elif goal.type == BaseCType(longT):
symInt_type = direct_solve(NamedCType(goal.name, BaseCType(SymIntT)))
return f"{symInt_type}.expect_int()"
return f"{symInt_type}.expectInt()"
elif goal.type == BaseCType(optionalIntArrayRefT):
return direct_solve(NamedCType(goal.name, optionalLongVec_ctype))
elif goal.type == BaseCType(optionalScalarRefT):

View File

@ -1,6 +1,6 @@
from dataclasses import dataclass
from enum import Enum
from typing import Dict, Iterator, List, Optional, Sequence, Set, Tuple, TypeVar, Union
from typing import Dict, Iterator, List, Optional, Sequence, Set, TypeVar, Union
from torchgen.model import (
Argument,
@ -417,11 +417,6 @@ class CppSignature:
# (i.e. with a potential TensorOptions argument and out arguments in the front)
faithful: bool
# Is this a symint C++ signature. For BC reasons, functions that take
# SymInts still present as int64_t in C++, and the SymInt variant is
# offered at a different overload name
symint: bool
# The set of C++ arguments which should not have defaults applied to them
cpp_no_default_args: Set[str]
@ -438,17 +433,12 @@ class CppSignature:
return cpp.arguments(
self.func.arguments,
faithful=self.faithful,
symint=self.symint,
method=self.method,
cpp_no_default_args=self.cpp_no_default_args,
)
def name(self) -> str:
n = cpp.name(
self.func,
faithful_name_for_out_overloads=self.faithful,
symint_overload=self.symint,
)
n = cpp.name(self.func, faithful_name_for_out_overloads=self.faithful)
if self.fallback_binding:
n = f"__dispatch_{n}"
return n
@ -461,9 +451,7 @@ class CppSignature:
prefix: str = "",
is_redispatching_fn: bool = False,
) -> str:
returns_type = cpp.returns_type(
self.func.returns, symint=self.symint
).cpp_type()
returns_type = cpp.returns_type(self.func.returns).cpp_type()
cpp_args = [a.decl() for a in self.arguments()]
if is_redispatching_fn:
cpp_args = ["c10::DispatchKeySet dispatchKeySet"] + cpp_args
@ -481,9 +469,7 @@ class CppSignature:
prefix: str = "",
is_redispatching_fn: bool = False,
) -> str:
returns_type = cpp.returns_type(
self.func.returns, symint=self.symint
).cpp_type()
returns_type = cpp.returns_type(self.func.returns).cpp_type()
cpp_args = [a.defn() for a in self.arguments()]
if is_redispatching_fn:
cpp_args = ["c10::DispatchKeySet dispatchKeySet"] + cpp_args
@ -494,12 +480,12 @@ class CppSignature:
def ptr_type(self) -> str:
args_types_str = ", ".join(a.type for a in self.arguments())
return f"{cpp.returns_type(self.func.returns, symint=self.symint).cpp_type()} (*)({args_types_str})"
return f"{cpp.returns_type(self.func.returns).cpp_type()} (*)({args_types_str})"
# Return the C++ function type, e.g., something like int(bool)
def type(self) -> str:
args_types_str = ", ".join(a.type for a in self.arguments())
return f"{cpp.returns_type(self.func.returns, symint=self.symint).cpp_type()} ({args_types_str})"
return f"{cpp.returns_type(self.func.returns).cpp_type()} ({args_types_str})"
# Represents group of all CppSignatures associated with a
@ -511,8 +497,6 @@ class CppSignatureGroup:
func: FunctionSchema
signature: CppSignature
faithful_signature: Optional[CppSignature]
symint_signature: Optional[CppSignature]
symint_faithful_signature: Optional[CppSignature]
def most_faithful_signature(self) -> CppSignature:
if self.faithful_signature:
@ -524,10 +508,6 @@ class CppSignatureGroup:
yield self.signature
if self.faithful_signature:
yield self.faithful_signature
if self.symint_signature:
yield self.symint_signature
if self.symint_faithful_signature:
yield self.symint_faithful_signature
@staticmethod
def from_native_function(
@ -535,35 +515,23 @@ class CppSignatureGroup:
) -> "CppSignatureGroup":
func = f.func
def make_sig(*, faithful: bool, symint: bool) -> CppSignature:
def make_sig(*, faithful: bool) -> CppSignature:
return CppSignature(
func=func,
faithful=faithful,
symint=symint,
method=method,
fallback_binding=fallback_binding,
cpp_no_default_args=f.cpp_no_default_args,
)
def make_sigs(*, symint: bool) -> Tuple[CppSignature, Optional[CppSignature]]:
faithful_signature: Optional[CppSignature] = None
if func.arguments.tensor_options is not None or len(func.arguments.out) > 0:
faithful_signature = make_sig(faithful=True, symint=symint)
signature = make_sig(faithful=False, symint=symint)
return signature, faithful_signature
signature, faithful_signature = make_sigs(symint=False)
symint_signature: Optional[CppSignature] = None
symint_faithful_signature: Optional[CppSignature] = None
if func.has_symint():
symint_signature, symint_faithful_signature = make_sigs(symint=True)
faithful_signature: Optional[CppSignature] = None
if func.arguments.tensor_options is not None or len(func.arguments.out) > 0:
faithful_signature = make_sig(faithful=True)
signature = make_sig(faithful=False)
return CppSignatureGroup(
func=func,
signature=signature,
faithful_signature=faithful_signature,
symint_signature=symint_signature,
symint_faithful_signature=symint_faithful_signature,
)
@ -625,8 +593,6 @@ class NativeSignature:
# The schema this signature is derived from
func: FunctionSchema
symint: bool
prefix: str = ""
def name(self) -> str:
@ -636,24 +602,24 @@ class NativeSignature:
args_str = ", ".join(a.decl() for a in self.arguments())
if name is None:
name = self.name()
return f"{native.returns_type(self.func.returns, symint=self.symint).cpp_type()} {name}({args_str})"
return f"{native.returns_type(self.func.returns).cpp_type()} {name}({args_str})"
def defn(self, name: Optional[str] = None) -> str:
args_str = ", ".join(a.defn() for a in self.arguments())
if name is None:
name = self.name()
return f"{native.returns_type(self.func.returns, symint=self.symint).cpp_type()} {name}({args_str})"
return f"{native.returns_type(self.func.returns).cpp_type()} {name}({args_str})"
def ptr_type(self) -> str:
# don't include defaults in type signature!
args_str = ", ".join(a.defn() for a in self.arguments())
return f"{native.returns_type(self.func.returns, symint=self.symint).cpp_type()} (*)({args_str})"
return f"{native.returns_type(self.func.returns).cpp_type()} (*)({args_str})"
def arguments(self) -> List[Binding]:
return native.arguments(self.func, symint=self.symint)
return native.arguments(self.func)
def returns_type(self) -> CType:
return native.returns_type(self.func.returns, symint=self.symint)
return native.returns_type(self.func.returns)
def dispatcher_exprs(self) -> List[Expr]:
return translate.translate(
@ -779,14 +745,9 @@ def kernel_signature(
# With external backends, we'd like to enforce that they write their kernels with schemas
# that match the Dispatcher API directly, if they can.
if backend_index.external:
# Dispatcher signature faithfully does SymInt, which is good for XLA,
# not so good for more conventional backends but we don't have any of
# those. If we do, that's time to add a new Signature that is a cross
# between DispatcherSignature and NativeSignature
assert backend_index.symint
return DispatcherSignature.from_schema(f.func, prefix=prefix)
else:
return NativeSignature(f.func, prefix=prefix, symint=backend_index.symint)
return NativeSignature(f.func, prefix)
# Functions only, no types

View File

@ -40,8 +40,7 @@ def kernel_name(g: NativeFunctionsGroup, dispatch_key: DispatchKey) -> str:
#
# NB: used for CPU only
def dispatchstub_type(t: Type, *, binds: ArgName) -> Optional[NamedCType]:
# Dispatch stubs are always plain ints
r = cpp.valuetype_type(t, binds=binds, symint=False)
r = cpp.valuetype_type(t, binds=binds)
if r is not None:
return r
@ -65,7 +64,7 @@ def opmath_type(scalar_t: BaseCppType) -> BaseCppType:
#
# NB: CUDA only
def ufunctor_ctor_type(t: Type, *, binds: ArgName, scalar_t: BaseCppType) -> NamedCType:
r = cpp.valuetype_type(t, binds=binds, symint=False)
r = cpp.valuetype_type(t, binds=binds)
if r is not None:
return r
@ -94,7 +93,7 @@ def ufunctor_apply_type(
# is done in the computation type. compute_t is opmath_t in CUDA and scalar_t
# in CPU
def ufunc_type(t: Type, *, binds: ArgName, compute_t: CType) -> NamedCType:
r = cpp.valuetype_type(t, binds=binds, symint=False)
r = cpp.valuetype_type(t, binds=binds)
if r is not None:
return r

View File

@ -136,10 +136,7 @@ def convert_arguments(f: NativeFunction) -> Tuple[List[Binding], List[str]]:
def argumenttype_ivalue_convert(
t: Type, arg_name: str, *, mutable: bool = False
) -> Tuple[str, CType, List[str], List[str]]:
# Unboxing is for mobile, which doesn't care about SymInts
ctype = cpp.argumenttype_type(
t=t, mutable=mutable, binds=arg_name, symint=False
).type
ctype = cpp.argumenttype_type(t=t, mutable=mutable, binds=arg_name).type
if isinstance(t, BaseType):
out_name = f"{arg_name}_base"

View File

@ -27,10 +27,7 @@ from torchgen.dest.lazy_ts_lowering import ts_lowering_body
from torchgen.model import (
Argument,
BackendIndex,
BaseTy,
BaseType,
FunctionSchema,
ListType,
NativeFunction,
NativeFunctionsGroup,
)
@ -43,7 +40,6 @@ def node_ctor_arg_rvalue_string(arg: LazyArgument) -> str:
a lazy Node constructor.
"""
# TODO: Matching on CType seems wrong; should be matching on Type
if isValueType(arg.lazy_type):
if isinstance(arg.lazy_type, BaseCType):
if arg.is_wrapped_scalar:
@ -52,7 +48,7 @@ def node_ctor_arg_rvalue_string(arg: LazyArgument) -> str:
return f"lazy_{arg.name}_tensorlist"
elif arg.is_symint_or_list:
cpp_type = arg.lazy_type.cpp_type()
return f"GetSymIntValue({arg.name})"
return f"{cpp_type}(dynamic_cast<torch::lazy::SymIntNodeImpl*>({arg.name}.toSymIntNodeImpl().get())->node_, 0)"
return f"lazy_{arg.name}->GetIrValue()"
elif isinstance(arg.lazy_type, OptionalCType):
if arg.is_wrapped_scalar:
@ -67,15 +63,7 @@ def node_ctor_arg_rvalue_string(arg: LazyArgument) -> str:
f"TODO not sure if there are other valid types to handle here ({arg.lazy_type})"
)
else:
# NB: this is here because right now we aren't treating SymInt[] as a
# value type; when we do this needs to move above
# NB: we cannot test arg.lazy_type as we've already specified it is an
# int64_t and so we cannot distinguish between SymInt and int64_t
if isinstance(arg.orig_type, ListType) and arg.orig_type.elem == BaseType(
BaseTy.SymInt
):
return f"GetSymIntArrayRefValue({arg.name})"
elif isinstance(arg.lazy_type, VectorCType) and isinstance(
if isinstance(arg.lazy_type, VectorCType) and isinstance(
arg.lazy_type.elem, BaseCType
):
return f"std::vector<{arg.lazy_type.elem.type}>({arg.name}.begin(), {arg.name}.end())"
@ -512,13 +500,9 @@ std::vector<torch::lazy::Shape> shapes{torch::lazy::Shape(out_meta.scalar_type()
dispatch_ns = "compositeexplicitautogradnonfunctional"
else:
dispatch_ns = "meta"
aten_name = schema.aten_name
# TODO: this is trolling
if func.func.has_symint():
aten_name += "_symint"
shape_str = f"""\
{meta_conversion_str}
auto out_meta = at::{dispatch_ns}::{aten_name}({', '.join(meta_call_args)});
auto out_meta = at::{dispatch_ns}::{schema.aten_name}({', '.join(meta_call_args)});
{meta_out}"""
else:
shape_sig = ComputeShapeSignature(metadata.kernel, func)

View File

@ -287,8 +287,8 @@ class RegisterDispatchKey:
self, f: NativeFunction
) -> Union[NativeSignature, DispatcherSignature]:
# The prefix is just to ensure uniqueness. The Dispatcher API doesn't guarantee unique kernel names.
return DispatcherSignature.from_schema(
f.func, prefix=f"wrapper_{f.func.name.overload_name}_"
return kernel_signature(
f, self.backend_index, prefix=f"wrapper_{f.func.name.overload_name}_"
)
def gen_out_inplace_wrapper(
@ -407,11 +407,10 @@ class RegisterDispatchKey:
f, method=False, fallback_binding=False
)
# TODO: dedupe this with the structured codegen
if self.target is Target.NAMESPACED_DECLARATION:
result = ""
for cpp_sig in cpp_sig_group.signatures():
result += f"TORCH_API {cpp_sig.decl()};\n"
result = f"TORCH_API {cpp_sig_group.signature.decl()};\n"
if cpp_sig_group.faithful_signature is not None:
result += f"TORCH_API {cpp_sig_group.faithful_signature.decl()};\n"
return result
elif self.target is Target.NAMESPACED_DEFINITION:
@ -422,11 +421,10 @@ return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), si
}}
"""
result = ""
for cpp_sig in cpp_sig_group.signatures():
result += generate_defn(cpp_sig)
result = generate_defn(cpp_sig_group.signature)
if cpp_sig_group.faithful_signature is not None:
result += generate_defn(cpp_sig_group.faithful_signature)
return result
elif self.target is Target.ANONYMOUS_DEFINITION:
# short circuit for inplace_meta
if inplace_meta:
@ -453,14 +451,7 @@ return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), si
else:
impl_name = f"{metadata.cpp_namespace}::{self.class_method_name}::{metadata.kernel}"
kernel_sig = kernel_signature(f, self.backend_index)
args_exprs_str = ", ".join(
e.expr
for e in translate(
sig.arguments(), kernel_sig.arguments(), method=False
)
)
args_exprs_str = ", ".join(a.name for a in args)
device_check = " // No device check\n"
# Backends that require device guards presumably also require device checks.
@ -750,14 +741,12 @@ resize_out(out, sizes, strides, options);
)
# Signature of the wrapper function we'll register to the dispatcher
sig = NativeSignature(
f.func, prefix="wrapper_", symint=self.backend_index.symint
)
sig = NativeSignature(f.func, prefix="wrapper_")
if self.target is Target.NAMESPACED_DECLARATION:
result = ""
for cpp_sig in cpp_sig_group.signatures():
result += f"TORCH_API {cpp_sig.decl()};\n"
result = f"TORCH_API {cpp_sig_group.signature.decl()};\n"
if cpp_sig_group.faithful_signature is not None:
result += f"TORCH_API {cpp_sig_group.faithful_signature.decl()};\n"
return result
elif self.target is Target.NAMESPACED_DEFINITION:
@ -769,9 +758,9 @@ return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), si
}}
"""
result = ""
for cpp_sig in cpp_sig_group.signatures():
result += generate_defn(cpp_sig)
result = generate_defn(cpp_sig_group.signature)
if cpp_sig_group.faithful_signature is not None:
result += generate_defn(cpp_sig_group.faithful_signature)
return result
elif self.target is Target.ANONYMOUS_DEFINITION:

View File

@ -37,6 +37,7 @@ from torchgen.gen_functionalization_type import (
gen_functionalization_definition,
gen_functionalization_registration,
gen_functionalization_view_inverse_declaration,
gen_symint_view_copy_kernel,
)
from torchgen.gen_vmap_plumbing import gen_all_vmap_plumbing
@ -159,9 +160,6 @@ def parse_native_yaml_struct(
use_out_as_primary=True,
external=False,
device_guard=False,
# I'm actually not sure about this; undefined could be hit on
# empty TensorList, hypothetically that could have sizes in it
symint=False,
index={},
)
)
@ -176,16 +174,6 @@ def parse_native_yaml_struct(
# Only cuda-like devices in tree require device guards
device_guard=is_cuda_dispatch_key(k),
index=v,
# Which dispatch keys natively support symint
# Note: DispatchKey.CompositeExplicitAutograd has to match out
# composites; I think there's some factoring problem here
symint=k
in [
DispatchKey.Meta,
DispatchKey.CompositeImplicitAutograd,
DispatchKey.CompositeExplicitAutograd,
DispatchKey.CompositeExplicitAutogradNonFunctional,
],
)
return ParsedYaml(rs, indices)
@ -874,8 +862,7 @@ class ComputeBackendSelect:
return None
name = native.name(f.func)
# BackendSelect can go to Meta, so it must preserve symints
native_sig = NativeSignature(f.func, symint=True)
native_sig = NativeSignature(f.func)
native_tensor_args = [
a
@ -979,10 +966,7 @@ def dynamic_type(t: Type) -> str:
# also include Tensor[]
if str(t) == "Tensor":
return "at::Tensor"
# This is a legacy concept, so never report SymInt
return cpp.argumenttype_type(
t, mutable=False, binds="__placeholder__", symint=False
).cpp_type()
return cpp.argumenttype_type(t, mutable=False, binds="__placeholder__").cpp_type()
def compute_method_of_yaml(variants: Set[Variant]) -> List[str]:
@ -1047,8 +1031,7 @@ def compute_returns_yaml(
ret = {
"dynamic_type": dynamic_type(r.type),
"name": name,
# legacy, report ints
"type": cpp.return_type(r, symint=False).cpp_type(),
"type": cpp.return_type(r).cpp_type(),
}
if r.name:
@ -1108,8 +1091,7 @@ def compute_argument_yaml(
"dynamic_type": dynamic_type(a.type),
"is_nullable": a.type.is_nullable(),
"name": a.name,
# legacy, report ints
"type": cpp.argument_type(a, binds="__placeholder__", symint=False).cpp_type(),
"type": cpp.argument_type(a, binds="__placeholder__").cpp_type(),
}
if a.default is not None:
arg["default"] = pythonify_default(cpp.default_expr(a.default, a.type))
@ -1175,13 +1157,11 @@ def compute_declaration_yaml(f: NativeFunction) -> object:
method=False,
cpp_no_default_args=set(),
faithful=False,
symint=False,
has_tensor_options=False,
)
]
# legacy, report ints
cpp_returns = cpp.returns_type(f.func.returns, symint=False).cpp_type()
cpp_returns = cpp.returns_type(f.func.returns).cpp_type()
schema_order_cpp_signature = f"{cpp_returns} ({', '.join(cpp_schema_order_types)})"
is_factory_method = (
@ -2410,6 +2390,29 @@ def gen_source_files(
)
},
)
view_copy_with_symint_pairs: List[Tuple[NativeFunction, NativeFunction]] = []
for g1 in view_groups:
for g2 in view_groups:
if g1.view_copy is None or g2.view_copy is None:
continue
# TODO: make this more first class in the data model
g1_base_name = str(g1.view_copy.func.name.name)
g2_base_name = str(g2.view_copy.func.name.name)
same_base_op = (
g1_base_name == g2_base_name
and g1.view_copy.func.arguments.symints_to_ints()
== g2.view_copy.func.arguments.symints_to_ints()
)
op1_not_symint = "SymInt" not in str(g1.view_copy.func.name.overload_name)
op2_symint = "SymInt" in str(g2.view_copy.func.name.overload_name)
if same_base_op and op1_not_symint and op2_symint:
view_copy_with_symint_pairs.append(
(
g1.view_copy,
g2.view_copy,
)
)
# Note [view_copy NativeFunctions]
# Every view operator in native_functions.yaml that is not CompositeImplicitAutograd
@ -2450,6 +2453,12 @@ def gen_source_files(
"CompositeViewCopyKernel_Definitions": list(
mapMaybe(gen_composite_view_copy_kernel, view_groups)
),
"SymIntViewCopyKernel_Definitions": list(
mapMaybe(
lambda pair: gen_symint_view_copy_kernel(pair[0], pair[1]),
view_copy_with_symint_pairs,
)
),
"GeneratedCompositeFunctional_Definitions": list(
mapMaybe(
gen_composite_functional_kernel,

View File

@ -140,7 +140,6 @@ Only the following keys are supported: {", ".join(valid_keys)}'
dispatch_key=dispatch_key,
use_out_as_primary=use_out_as_primary,
external=True,
symint=True, # TODO: make this configurable
device_guard=use_device_guard,
index=metadata,
)

View File

@ -78,18 +78,21 @@ def gen_composite_view_copy_kernel(g: NativeFunctionsViewGroup) -> Optional[str]
if g.view_copy is None:
return None
# For view_copy.SymInt overloads,
# See gen_symint_view_copy_kernel.
if g.view_copy.func.name.overload_name == "SymInt":
return None
# We can make view_copy work in more cases by using reshape()
# when a normal view call would ordinarily fail.
# This also makes LTC more efficient, because they don't need to include
# clone() calls in their graph (which is normally needed by reshape).
if str(g.view_copy.func.name) == "view_copy":
return """\
at::Tensor view_copy(const at::Tensor & self, at::SymIntArrayRef size) {
// TODO: don't cast to int array ref
auto int_size = c10::asIntArrayRefSlow(size);
DimVector shape = infer_size_dv(int_size, self.numel());
at::Tensor view_copy(const at::Tensor & self, at::IntArrayRef size) {
DimVector shape = infer_size_dv(size, self.numel());
if (!at::detail::computeStride(self.sizes(), self.strides(), shape).has_value()) {
return self.reshape(int_size);
return self.reshape(size);
} else {
auto output = at::_ops::view::call(self, size);
return output.clone();
@ -97,8 +100,7 @@ at::Tensor view_copy(const at::Tensor & self, at::SymIntArrayRef size) {
}
"""
# view_copy is a native signature, since we're generating an at::native:: kernel
# Functionalization always operates on symints though
view_copy_sig = NativeSignature(g.view_copy.func, symint=True)
view_copy_sig = NativeSignature(g.view_copy.func)
# view is a dispatcher signature, since we're calling into the at::_ops API
view_sig = DispatcherSignature(g.view.func)
@ -136,6 +138,34 @@ at::Tensor view_copy(const at::Tensor & self, at::SymIntArrayRef size) {
"""
# For symint view copy kernels, we want to generate them to call into
# their concrete view_copy counterparts.
@with_native_function_and
def gen_symint_view_copy_kernel(
view_copy: NativeFunction, view_copy_symint: NativeFunction
) -> str:
# view_copy.symint is a native signature, since we're generating an at::native:: kernel
view_copy_symint_sig = NativeSignature(view_copy_symint.func)
# view_copy is a dispatcher signature, since we're calling into the at::_ops API
view_copy_sig = DispatcherSignature(view_copy.func)
exprs = ", ".join(
[
e.expr
for e in translate(
view_copy_symint_sig.arguments(), view_copy_sig.arguments()
)
]
)
return f"""
{view_copy_symint_sig.defn()} {{
return at::_ops::{view_copy.func.name.unambiguous_name()}::call({exprs});
}}
"""
def return_str(rets: Tuple[Return, ...], names: List[str]) -> str:
assert len(rets) == len(names)
if len(rets) == 0:

View File

@ -816,6 +816,9 @@ class NativeFunction:
backend_metadata,
)
def symints_to_ints(self) -> "NativeFunction":
return dataclasses.replace(self, func=self.func.symints_to_ints())
def validate_unstructured(self) -> None:
# TODO: probably better to accumulate these errors and report them all
# at once
@ -878,6 +881,8 @@ class NativeFunction:
"foreach kernels fall back to slow path when tensor are on different devices, "
"device_check not allowed to be enabled"
)
named_symint = "SymInt" in self.func.name.overload_name
assert named_symint == self.func.has_symint()
# NB: if your function accidentally has rand/dropout/... in its name
# but is not actually random, feel free to amend this to special case
@ -1112,8 +1117,6 @@ class BackendIndex:
external: bool
# Other backend-specific information that is on a per-operator basis
index: Dict["OperatorName", BackendMetadata]
# Whether or not this backend handles symbolic ints or not
symint: bool
@staticmethod
def grow_index(
@ -1232,6 +1235,9 @@ class FunctionSchema:
decl_re = re.compile(r"(?P<name>[^\(]+)\((?P<args>.*)\) -> (?P<returns>.*)")
def symints_to_ints(self) -> "FunctionSchema":
return dataclasses.replace(self, arguments=self.arguments.symints_to_ints())
@staticmethod
def parse(func: str) -> "FunctionSchema":
# We should probably get a proper parser here
@ -1354,6 +1360,10 @@ class FunctionSchema:
def is_functional_fn(self) -> bool:
return "functional" in self.name.overload_name
def is_symint_fn(self) -> bool:
# TODO: make this more robust
return "SymInt" in self.name.overload_name
def is_out_fn(self) -> bool:
# Note [is_out_fn]
#
@ -1694,6 +1704,9 @@ class Type:
def is_list_like(self) -> Optional["ListType"]:
raise NotImplementedError
def symint_to_int(self) -> "Type":
raise NotImplementedError
# Base types are simple, atomic types with no further structure
BaseTy = Enum(
@ -1734,12 +1747,14 @@ class BaseType(Type):
def is_nullable(self) -> bool:
return False
def symint_to_int(self) -> "BaseType":
if self.name == BaseTy.SymInt:
return BaseType(BaseTy.int)
return self
def is_list_like(self) -> Optional["ListType"]:
return None
def is_symint_like(self) -> bool:
return self.name == BaseTy.SymInt
# Optional types may be specified, or may also be validly given None
@dataclass(frozen=True)
@ -1752,12 +1767,12 @@ class OptionalType(Type):
def is_base_ty_like(self, base_ty: BaseTy) -> bool:
return self.elem.is_base_ty_like(base_ty)
def is_symint_like(self) -> bool:
return self.elem.is_symint_like()
def is_nullable(self) -> bool:
return True
def symint_to_int(self) -> "Type":
return dataclasses.replace(self, elem=self.elem.symint_to_int())
def is_list_like(self) -> Optional["ListType"]:
return self.elem.is_list_like()
@ -1776,15 +1791,15 @@ class CustomClassType(Type):
def is_base_ty_like(self, base_ty: BaseTy) -> bool:
return False
def is_symint_like(self) -> bool:
return False
def is_nullable(self) -> bool:
"""
Assume a custom class is not nullable.
"""
return False
def symint_to_int(self) -> "Type":
return self
def is_list_like(self) -> Optional["ListType"]:
return None
@ -1808,12 +1823,12 @@ class ListType(Type):
def is_base_ty_like(self, base_ty: BaseTy) -> bool:
return self.elem.is_base_ty_like(base_ty)
def is_symint_like(self) -> bool:
return self.elem.is_symint_like()
def is_nullable(self) -> bool:
return self.elem.is_nullable()
def symint_to_int(self) -> "ListType":
return ListType(self.elem.symint_to_int(), self.size)
def is_list_like(self) -> Optional["ListType"]:
return self
@ -1887,6 +1902,9 @@ class Argument:
def is_write(self) -> bool:
return self.annotation is not None and self.annotation.is_write
def symint_to_int(self) -> "Argument":
return dataclasses.replace(self, type=self.type.symint_to_int())
def __str__(self) -> str:
type = f"{self.type}"
if self.annotation:
@ -2076,6 +2094,37 @@ class Arguments:
if a.annotation is not None and a.annotation.is_write
]
def symints_to_ints(self) -> "Arguments":
arguments = self
if arguments.self_arg:
arguments = dataclasses.replace(
arguments,
pre_self_positional=tuple(
x.symint_to_int() for x in arguments.pre_self_positional
),
)
if self.tensor_options:
arguments = dataclasses.replace(
arguments,
post_tensor_options_kwarg_only=tuple(
x.symint_to_int() for x in arguments.post_tensor_options_kwarg_only
),
)
arguments = dataclasses.replace(
arguments,
post_self_positional=tuple(
x.symint_to_int() for x in arguments.post_self_positional
),
pre_tensor_options_kwarg_only=tuple(
x.symint_to_int() for x in arguments.pre_tensor_options_kwarg_only
),
)
return arguments
def has_tensor_arg(self) -> bool:
return any(a.type.is_tensor_like() for a in self.flat_non_out)

View File

@ -98,9 +98,7 @@ def is_supported(g: Union[NativeFunctionsGroup, NativeFunctionsViewGroup]) -> bo
return False
if isinstance(g, NativeFunctionsViewGroup):
# TODO: stop doing type tests by converting to C++ and then testing
# the string, just test the dang thing directly
if "at::Tensor" != cpp.returns_type(func.returns, symint=False).cpp_type():
if "at::Tensor" != cpp.returns_type(func.returns).cpp_type():
# Returns a non-Tensor value.
logger.info(f"NON-TENSOR RET TYPE: {str(func)}")
return False
@ -124,8 +122,7 @@ def is_supported(g: Union[NativeFunctionsGroup, NativeFunctionsViewGroup]) -> bo
or not str(func.name).endswith(".out")
):
return False
# TODO: stop type testing by converting to C++
if "at::Tensor &" != cpp.returns_type(func.returns, symint=False).cpp_type():
if "at::Tensor &" != cpp.returns_type(func.returns).cpp_type():
logger.info(f"NON_TENSOR RET TYPE: {str(func)}")
return False
if has_alias(func.arguments.non_out):