mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "Don't introduce new overload for SymInt (#83628)"
This reverts commit 9790d90e4b.
Reverted https://github.com/pytorch/pytorch/pull/83628 on behalf of https://github.com/malfet due to Breaks internal builds, see D39076487
This commit is contained in:
parent
38e5e4a85f
commit
c7edcd6968
2
.github/ci_commit_pins/xla.txt
vendored
2
.github/ci_commit_pins/xla.txt
vendored
|
|
@ -1 +1 @@
|
|||
a668569f7f9b7ecd946cf2551d30d482799d597d
|
||||
9b2f7929c2dae841888a836449c25b04c8cf4045
|
||||
|
|
|
|||
|
|
@ -186,8 +186,7 @@ Tensor expand_batching_rule(const Tensor& self, IntArrayRef size, bool implicit)
|
|||
}
|
||||
|
||||
Tensor expand_symint_batching_rule(const Tensor& self, SymIntArrayRef psize, bool implicit) {
|
||||
// TODO: properly support this
|
||||
return expand_batching_rule(self, asIntArrayRefSlow(psize), implicit);
|
||||
return self.expand(asIntArrayRefSlow(psize), implicit);
|
||||
}
|
||||
|
||||
std::vector<Tensor> chunk_batching_rule(const Tensor& self, int64_t chunks, int64_t dim) {
|
||||
|
|
@ -470,8 +469,7 @@ Tensor view_batching_rule(const Tensor& self, IntArrayRef size) {
|
|||
}
|
||||
|
||||
Tensor view_symint_batching_rule(const Tensor& self, c10::SymIntArrayRef size) {
|
||||
// TODO: properly support this
|
||||
return view_batching_rule(self, asIntArrayRefSlow(size));
|
||||
return self.view(asIntArrayRefSlow(size));
|
||||
}
|
||||
|
||||
Tensor view_as_complex_batching_rule(const Tensor& self) {
|
||||
|
|
@ -1011,7 +1009,6 @@ Tensor new_empty_symint_batching_rule(
|
|||
c10::optional<Layout> layout,
|
||||
c10::optional<Device> device,
|
||||
c10::optional<bool> pin_memory) {
|
||||
// TODO: properly support this
|
||||
return new_empty_batching_rule(self, asIntArrayRefSlow(size), dtype, layout, device, pin_memory);
|
||||
}
|
||||
|
||||
|
|
@ -1112,7 +1109,8 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
|
|||
m.impl("tensor_split.sections", tensor_split_sections_batching_rule);
|
||||
m.impl("tensor_split.indices", tensor_split_indices_batching_rule);
|
||||
m.impl("diagonal", diagonal_batching_rule);
|
||||
m.impl("expand", expand_symint_batching_rule);
|
||||
m.impl("expand", expand_batching_rule);
|
||||
m.impl("expand.SymInt", expand_symint_batching_rule);
|
||||
m.impl("expand_as", native::expand_as); // composite wrt autograd
|
||||
m.impl("movedim.intlist", movedim_batching_rule);
|
||||
m.impl("movedim.int", static_cast<Tensor(*)(const Tensor&,int64_t,int64_t)>(native::movedim)); // composite wrt autograd
|
||||
|
|
@ -1140,7 +1138,8 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
|
|||
m.impl("unbind.int", unbind_batching_rule);
|
||||
m.impl("unfold", unfold_batching_rule);
|
||||
m.impl("unsqueeze", unsqueeze_batching_rule);
|
||||
m.impl("view", view_symint_batching_rule);
|
||||
m.impl("view", view_batching_rule);
|
||||
m.impl("view.SymInt", view_symint_batching_rule);
|
||||
m.impl("view_as", native::view_as); // composite wrt autograd
|
||||
|
||||
// clamp operations
|
||||
|
|
@ -1278,7 +1277,8 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
|
|||
m.impl("diagonal_backward", diagonal_backward_batching_rule);
|
||||
|
||||
// Tensor.new_* operators
|
||||
m.impl("new_empty", new_empty_symint_batching_rule);
|
||||
m.impl("new_empty", new_empty_batching_rule);
|
||||
m.impl("new_empty.SymInt", new_empty_symint_batching_rule);
|
||||
m.impl("new_empty_strided", new_empty_strided_batching_rule);
|
||||
m.impl("new_zeros", new_zeros_batching_rule);
|
||||
|
||||
|
|
|
|||
|
|
@ -137,8 +137,12 @@ Tensor FunctionalInverses::diagonal_copy_inverse(const Tensor& base, const Tenso
|
|||
return base.diagonal_scatter(mutated_view, offset, dim1, dim2);
|
||||
}
|
||||
|
||||
Tensor FunctionalInverses::expand_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, at::SymIntArrayRef size, bool implicit) {
|
||||
return at::sum_to(mutated_view, base.sym_sizes(),/*always_return_non_view=*/!reapply_views);
|
||||
Tensor FunctionalInverses::expand_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, at::IntArrayRef size, bool implicit) {
|
||||
return at::sum_to(mutated_view, base.sizes(),/*always_return_non_view=*/!reapply_views);
|
||||
}
|
||||
|
||||
Tensor FunctionalInverses::expand_copy_SymInt_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, c10::SymIntArrayRef size, bool implicit) {
|
||||
return at::sum_to(mutated_view, c10::asIntArrayRefSlow(base.sym_sizes()),/*always_return_non_view=*/!reapply_views);
|
||||
}
|
||||
|
||||
Tensor FunctionalInverses::permute_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, at::IntArrayRef dims) {
|
||||
|
|
@ -287,7 +291,15 @@ Tensor FunctionalInverses::unbind_copy_int_inverse(const Tensor& base, const Ten
|
|||
return base.select_scatter(mutated_view, dim, mutated_view_idx);
|
||||
}
|
||||
|
||||
Tensor FunctionalInverses::view_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, at::SymIntArrayRef size) {
|
||||
Tensor FunctionalInverses::view_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, at::IntArrayRef size) {
|
||||
if (reapply_views) {
|
||||
return mutated_view.view(base.sizes());
|
||||
} else {
|
||||
return at::view_copy(mutated_view, base.sizes());
|
||||
}
|
||||
}
|
||||
|
||||
Tensor FunctionalInverses::view_copy_SymInt_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, c10::SymIntArrayRef size) {
|
||||
if (reapply_views) {
|
||||
return mutated_view.view_symint(base.sym_sizes());
|
||||
} else {
|
||||
|
|
@ -295,7 +307,6 @@ Tensor FunctionalInverses::view_copy_inverse(const Tensor& base, const Tensor& m
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
Tensor FunctionalInverses::view_copy_dtype_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, at::ScalarType dtype) {
|
||||
if (reapply_views) {
|
||||
return mutated_view.view(base.scalar_type());
|
||||
|
|
|
|||
|
|
@ -179,6 +179,7 @@ TORCH_LIBRARY_IMPL(aten, Named, m) {
|
|||
m.impl("exp.out", CppFunction::makeFallthrough());
|
||||
m.impl("exp_", CppFunction::makeFallthrough());
|
||||
m.impl("expand", CppFunction::makeFallthrough());
|
||||
m.impl("expand.SymInt", CppFunction::makeFallthrough());
|
||||
m.impl("expm1", CppFunction::makeFallthrough());
|
||||
m.impl("expm1.out", CppFunction::makeFallthrough());
|
||||
m.impl("expm1_", CppFunction::makeFallthrough());
|
||||
|
|
|
|||
|
|
@ -353,14 +353,7 @@ namespace impl {
|
|||
template<bool AllowDeprecatedTypes>
|
||||
struct ivalue_to_arg<c10::SymIntArrayRef, AllowDeprecatedTypes> final {
|
||||
static std::vector<c10::SymInt> call(IValue& v) {
|
||||
if (v.isIntList()) {
|
||||
std::vector<c10::SymInt> r;
|
||||
auto src = v.toIntList();
|
||||
std::transform(src.begin(), src.end(), std::back_inserter(r), [](int64_t i) { return c10::SymInt(i); });
|
||||
return r;
|
||||
} else {
|
||||
return ivalue_to_arg<std::vector<c10::SymInt>, AllowDeprecatedTypes>::call(v);
|
||||
}
|
||||
return ivalue_to_arg<std::vector<c10::SymInt>, AllowDeprecatedTypes>::call(v);
|
||||
}
|
||||
};
|
||||
template<class T, bool AllowDeprecatedTypes>
|
||||
|
|
|
|||
|
|
@ -35,8 +35,7 @@ OperatorEntry::OperatorEntry(OperatorName&& operator_name)
|
|||
|
||||
namespace {
|
||||
void checkSchema(const OperatorName& name, const FunctionSchema& from_def, const std::string& from_def_debug, const FunctionSchema& inferred, const std::string& inferred_debug) {
|
||||
// TODO: figure out if we can just directly save real schema at def time
|
||||
c10::optional<std::string> schema_difference = findSchemaDifferences(from_def.cloneWithRealTypes(), inferred);
|
||||
c10::optional<std::string> schema_difference = findSchemaDifferences(from_def, inferred);
|
||||
if (schema_difference.has_value()) {
|
||||
TORCH_CHECK(false,
|
||||
"Inferred operator schema for a C++ kernel function doesn't match the expected function schema.\n"
|
||||
|
|
|
|||
|
|
@ -231,6 +231,8 @@ TypePtr DynamicType::fallback() const {
|
|||
return BoolType::get();
|
||||
case Tag::Int:
|
||||
return IntType::get();
|
||||
case Tag::SymInt:
|
||||
return SymIntType::get();
|
||||
case Tag::Float:
|
||||
return FloatType::get();
|
||||
case Tag::Complex:
|
||||
|
|
@ -324,6 +326,8 @@ DynamicType::Ptr IValue::TagType<c10::DynamicType>::get(const c10::IValue& v) {
|
|||
return DynamicTypeTrait<ComplexType>::getBaseType();
|
||||
case Tag::Int:
|
||||
return DynamicTypeTrait<IntType>::getBaseType();
|
||||
case Tag::SymInt:
|
||||
return DynamicTypeTrait<SymIntType>::getBaseType();
|
||||
case Tag::Bool:
|
||||
return DynamicTypeTrait<BoolType>::getBaseType();
|
||||
case Tag::String:
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ constexpr DynamicTypeBits kDynamicAnyTypeBit = DYNAMIC_TYPE_BIT(30);
|
|||
|
||||
constexpr DynamicTypeBits kDynamicNoneTypeBit = DYNAMIC_TYPE_BIT(1);
|
||||
constexpr DynamicTypeBits kDynamicIntTypeBit = DYNAMIC_TYPE_BIT(3);
|
||||
constexpr DynamicTypeBits kDynamicSymIntTypeBit = DYNAMIC_TYPE_BIT(23);
|
||||
constexpr DynamicTypeBits kDynamicFloatTypeBit = DYNAMIC_TYPE_BIT(4);
|
||||
constexpr DynamicTypeBits kDynamicComplexTypeBit = DYNAMIC_TYPE_BIT(5);
|
||||
constexpr DynamicTypeBits kDynamicListTypeBit = DYNAMIC_TYPE_BIT(7);
|
||||
|
|
@ -28,6 +29,7 @@ constexpr DynamicTypeBits kDynamicClassTypeBit = DYNAMIC_TYPE_BIT(10);
|
|||
_(Bool, DYNAMIC_TYPE_BIT(2), 1) \
|
||||
_(Int, kDynamicIntTypeBit, 1) \
|
||||
_(Float, kDynamicFloatTypeBit, 1) \
|
||||
_(SymInt, kDynamicSymIntTypeBit, 1) \
|
||||
_(Complex, kDynamicComplexTypeBit, 1) \
|
||||
_(Number, \
|
||||
(kDynamicIntTypeBit | kDynamicFloatTypeBit | kDynamicComplexTypeBit), \
|
||||
|
|
@ -61,7 +63,6 @@ constexpr DynamicTypeBits kDynamicClassTypeBit = DYNAMIC_TYPE_BIT(10);
|
|||
#define FORALL_DYNAMIC_TYPES_FAKE(_) \
|
||||
_(ScalarType, kDynamicIntTypeBit, 1) \
|
||||
_(Layout, kDynamicIntTypeBit, 1) \
|
||||
_(SymInt, kDynamicIntTypeBit, 1) \
|
||||
_(MemoryFormat, kDynamicIntTypeBit, 1)
|
||||
|
||||
#define FORWARD_DECL_TYPE(NAME, _, __) struct NAME ## Type;
|
||||
|
|
|
|||
|
|
@ -17,22 +17,6 @@ const std::vector<Argument>& FunctionSchema::getCorrectList(SchemaArgType type)
|
|||
}
|
||||
}
|
||||
|
||||
FunctionSchema FunctionSchema::cloneWithRealTypes() const {
|
||||
auto cloneWithRealTypes = [](const Argument& a) {
|
||||
return a.cloneWithType(a.real_type());
|
||||
};
|
||||
std::vector<Argument> new_arguments, new_returns;
|
||||
std::transform(arguments().begin(), arguments().end(), std::back_inserter(new_arguments), cloneWithRealTypes);
|
||||
std::transform(returns().begin(), returns().end(), std::back_inserter(new_returns), cloneWithRealTypes);
|
||||
return FunctionSchema(
|
||||
name(),
|
||||
overload_name(),
|
||||
std::move(new_arguments),
|
||||
std::move(new_returns),
|
||||
is_vararg(),
|
||||
is_varret());
|
||||
}
|
||||
|
||||
bool FunctionSchema::canAliasTypeSetsAlias(const c10::optional<AliasTypeSet> &lhs, const c10::optional<AliasTypeSet> &rhs) const {
|
||||
if (!lhs || !rhs) {
|
||||
return false;
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ struct Argument {
|
|||
c10::optional<AliasInfo> alias_info = c10::nullopt)
|
||||
: name_(std::move(name)),
|
||||
type_(fake_type ? std::move(fake_type) : TensorType::get()),
|
||||
real_type_(real_type ? std::move(real_type) : type_),
|
||||
real_type_(real_type ? std::move(real_type) : TensorType::get()),
|
||||
N_(std::move(N)),
|
||||
default_value_(std::move(default_value)),
|
||||
alias_info_(alias_info ? std::make_unique<AliasInfo>(std::move(*alias_info)) : nullptr),
|
||||
|
|
@ -88,8 +88,6 @@ struct Argument {
|
|||
const TypePtr& type() const {
|
||||
return type_;
|
||||
}
|
||||
// if type() is non-null, this is guaranteed to be non-null (if no real
|
||||
// type was provided, this takes on type()'s value)
|
||||
const TypePtr& real_type() const {
|
||||
return real_type_;
|
||||
}
|
||||
|
|
@ -474,8 +472,6 @@ struct TORCH_API FunctionSchema {
|
|||
FunctionSchema cloneWithRemappedTypes(
|
||||
const std::function<TypePtr(TypePtr)> type_map) const;
|
||||
|
||||
FunctionSchema cloneWithRealTypes() const;
|
||||
|
||||
// Check that inputs have the correct types and appends any missing default
|
||||
// values.
|
||||
template <typename T = c10::PlatformType>
|
||||
|
|
|
|||
|
|
@ -1789,12 +1789,30 @@ struct getTypePtr_<SymInt> final {
|
|||
}
|
||||
};
|
||||
template <>
|
||||
struct getTypePtr_<c10::ScalarType> final {
|
||||
static decltype(auto) call() {
|
||||
return IntType::get();
|
||||
}
|
||||
};
|
||||
template <>
|
||||
struct getTypePtr_<c10::Device> final {
|
||||
static decltype(auto) call() {
|
||||
return DeviceObjType::get();
|
||||
}
|
||||
};
|
||||
template <>
|
||||
struct getTypePtr_<c10::Layout> final {
|
||||
static decltype(auto) call() {
|
||||
return IntType::get();
|
||||
}
|
||||
};
|
||||
template <>
|
||||
struct getTypePtr_<c10::MemoryFormat> final {
|
||||
static decltype(auto) call() {
|
||||
return IntType::get();
|
||||
}
|
||||
};
|
||||
template <>
|
||||
struct getTypePtr_<bool> final {
|
||||
static decltype(auto) call() {
|
||||
return BoolType::get();
|
||||
|
|
@ -2115,27 +2133,6 @@ private:
|
|||
LayoutType() : EnumerationType() {}
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
template <>
|
||||
struct getTypePtr_<c10::ScalarType> final {
|
||||
static decltype(auto) call() {
|
||||
return ScalarTypeType::get();
|
||||
}
|
||||
};
|
||||
template <>
|
||||
struct getTypePtr_<c10::Layout> final {
|
||||
static decltype(auto) call() {
|
||||
return LayoutType::get();
|
||||
}
|
||||
};
|
||||
template <>
|
||||
struct getTypePtr_<c10::MemoryFormat> final {
|
||||
static decltype(auto) call() {
|
||||
return MemoryFormatType::get();
|
||||
}
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
// the common supertype of all lists,
|
||||
// List[T] <: AnyList for all T
|
||||
struct AnyListType;
|
||||
|
|
|
|||
|
|
@ -54,6 +54,7 @@ namespace at {
|
|||
#define TENSOR_UTILITIES_AND_CONSTRUCTORS(m) \
|
||||
m.impl("empty_like", torch::CppFunction::makeFallthrough()); \
|
||||
m.impl("empty.memory_format", torch::CppFunction::makeFallthrough()); \
|
||||
m.impl("empty.SymInt", torch::CppFunction::makeFallthrough()); \
|
||||
m.impl("empty.out", torch::CppFunction::makeFallthrough()); \
|
||||
m.impl("empty_strided", torch::CppFunction::makeFallthrough()); \
|
||||
m.impl("full_like", torch::CppFunction::makeFallthrough()); \
|
||||
|
|
|
|||
|
|
@ -13,6 +13,18 @@ namespace at {
|
|||
namespace native {
|
||||
|
||||
Tensor empty_meta(
|
||||
IntArrayRef size,
|
||||
c10::optional<ScalarType> dtype_opt,
|
||||
c10::optional<Layout> layout_opt,
|
||||
c10::optional<Device> device_opt,
|
||||
c10::optional<bool> pin_memory_opt,
|
||||
c10::optional<c10::MemoryFormat> memory_format_opt
|
||||
) {
|
||||
return at::detail::empty_meta(
|
||||
size, dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt);
|
||||
}
|
||||
|
||||
Tensor empty_symint_meta(
|
||||
SymIntArrayRef size,
|
||||
c10::optional<ScalarType> dtype_opt,
|
||||
c10::optional<Layout> layout_opt,
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ Tensor _bincount_cpu_template(
|
|||
AT_ERROR("minlength should be >= 0");
|
||||
}
|
||||
if (self.dim() == 1 && self.numel() == 0) {
|
||||
return at::zeros({minlength}, kLong);
|
||||
return native::zeros({minlength}, kLong);
|
||||
}
|
||||
if (self.dim() != 1 || *self.min().data_ptr<input_t>() < 0) {
|
||||
AT_ERROR("bincount only supports 1-d non-negative integral inputs.");
|
||||
|
|
@ -38,7 +38,7 @@ Tensor _bincount_cpu_template(
|
|||
|
||||
const input_t* self_p = self.data_ptr<input_t>();
|
||||
if (has_weights) {
|
||||
output = at::zeros(
|
||||
output = native::zeros(
|
||||
{nbins},
|
||||
optTypeMetaToScalarType(weights.options().dtype_opt()),
|
||||
weights.options().layout_opt(),
|
||||
|
|
@ -50,7 +50,7 @@ Tensor _bincount_cpu_template(
|
|||
output_p[self_p[i]] += weights_p[i];
|
||||
}
|
||||
} else {
|
||||
output = at::zeros({nbins}, kLong);
|
||||
output = native::zeros({nbins}, kLong);
|
||||
int64_t* output_p = output.data_ptr<int64_t>();
|
||||
for (const auto i : c10::irange(self_size)) {
|
||||
output_p[self_p[i]] += 1L;
|
||||
|
|
|
|||
|
|
@ -186,7 +186,12 @@ Tensor empty_cpu(IntArrayRef size, c10::optional<ScalarType> dtype_opt, c10::opt
|
|||
return at::detail::empty_cpu(size, dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt);
|
||||
}
|
||||
|
||||
Tensor empty_names(
|
||||
Tensor empty_symint_cpu(c10::SymIntArrayRef size, c10::optional<ScalarType> dtype_opt, c10::optional<Layout> layout_opt,
|
||||
c10::optional<Device> device_opt, c10::optional<bool> pin_memory_opt, c10::optional<c10::MemoryFormat> memory_format_opt) {
|
||||
return at::native::empty_cpu(c10::asIntArrayRefSlow(size), dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt);
|
||||
}
|
||||
|
||||
Tensor empty(
|
||||
IntArrayRef size,
|
||||
c10::optional<DimnameList> names,
|
||||
c10::optional<ScalarType> dtype,
|
||||
|
|
@ -214,12 +219,9 @@ Tensor empty_strided_cpu(IntArrayRef size, IntArrayRef stride, c10::optional<Sca
|
|||
return at::detail::empty_strided_cpu(size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt);
|
||||
}
|
||||
|
||||
Tensor& empty_out(SymIntArrayRef sym_size,
|
||||
Tensor& empty_out(IntArrayRef size,
|
||||
c10::optional<c10::MemoryFormat> optional_memory_format,
|
||||
Tensor& result) {
|
||||
// TODO: support empty_out properly (I was forced to change this immediately
|
||||
// with empty so that empty/empty.out had the same type signature)
|
||||
auto size = c10::asIntArrayRefSlow(sym_size);
|
||||
// Preferably, this argument would not be accepted by _out, but the code
|
||||
// generator requires the out and non-out overloads to match exactly
|
||||
TORCH_CHECK(
|
||||
|
|
@ -387,6 +389,17 @@ Tensor empty_like_quantized(
|
|||
}
|
||||
|
||||
Tensor new_empty(
|
||||
const Tensor& self,
|
||||
IntArrayRef size,
|
||||
c10::optional<ScalarType> dtype_opt,
|
||||
c10::optional<Layout> layout_opt,
|
||||
c10::optional<Device> device_opt,
|
||||
c10::optional<bool> pin_memory_opt
|
||||
) {
|
||||
return self.new_empty_symint(c10::SymIntArrayRef::fromIntArrayRef(size), dtype_opt, layout_opt, device_opt, pin_memory_opt);
|
||||
}
|
||||
|
||||
Tensor new_empty_symint(
|
||||
const Tensor& self,
|
||||
SymIntArrayRef size,
|
||||
c10::optional<ScalarType> dtype_opt,
|
||||
|
|
@ -1077,7 +1090,15 @@ Tensor triu_indices_cpu(
|
|||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ zeros ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Tensor zeros(SymIntArrayRef size,
|
||||
Tensor zeros(IntArrayRef size,
|
||||
c10::optional<ScalarType> dtype,
|
||||
c10::optional<Layout> layout,
|
||||
c10::optional<Device> device,
|
||||
c10::optional<bool> pin_memory) {
|
||||
return at::zeros_symint(c10::SymIntArrayRef::fromIntArrayRef(size), dtype, layout, device, pin_memory);
|
||||
}
|
||||
|
||||
Tensor zeros_symint(SymIntArrayRef size,
|
||||
c10::optional<ScalarType> dtype,
|
||||
c10::optional<Layout> layout,
|
||||
c10::optional<Device> device,
|
||||
|
|
@ -1102,17 +1123,8 @@ Tensor _efficientzerotensor(IntArrayRef size,
|
|||
return out;
|
||||
}
|
||||
|
||||
Tensor& zeros_sparse_out(IntArrayRef size, Tensor& result) {
|
||||
result.sparse_resize_and_clear_(size, size.size(), 0.);
|
||||
return result;
|
||||
}
|
||||
|
||||
Tensor& zeros_out(SymIntArrayRef sym_size, Tensor& result) {
|
||||
auto size = c10::asIntArrayRefSlow(sym_size);
|
||||
Tensor& zeros_out(IntArrayRef size, Tensor& result) {
|
||||
if (result.is_sparse()) {
|
||||
// TODO: I think this branch should be dead, but we don't have an easy
|
||||
// way to cover all sparse kernels with zeros_sparse_out, so retain this
|
||||
// for now
|
||||
result.sparse_resize_and_clear_(size, size.size(), 0.);
|
||||
return result;
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -300,7 +300,7 @@ Tensor sparse_broadcast_to(const Tensor& self, IntArrayRef size) {
|
|||
new_values_size[0] = new_indices_size[1];
|
||||
|
||||
Tensor new_values = values.expand(broadcast_dense_sizes).repeat_interleave(nnz_factor, 0);
|
||||
Tensor new_indices = indices.new_empty(new_indices_size);
|
||||
Tensor new_indices = at::native::new_empty(indices, new_indices_size);
|
||||
if (broadcast_sizes.size()>0) {
|
||||
// ones(broadcast_sizes).nonzero() is equivalent to
|
||||
// product(map(arange, broadcast_sizes)) but avoids creating
|
||||
|
|
@ -542,14 +542,14 @@ static Tensor cat_sparse_impl(TensorList tensors, int64_t dim) {
|
|||
zeros_sizes[0] = t._values().size(0);
|
||||
zeros_sizes[values_dim] = cumulative_size;
|
||||
cumulative_size += t._values().size(values_dim);
|
||||
auto z1 = at::zeros(
|
||||
auto z1 = native::zeros(
|
||||
zeros_sizes,
|
||||
optTypeMetaToScalarType(t._values().options().dtype_opt()),
|
||||
t._values().options().layout_opt(),
|
||||
t._values().options().device_opt(),
|
||||
t._values().options().pinned_memory_opt());
|
||||
zeros_sizes[values_dim] = total_size - cumulative_size;
|
||||
auto z2 = at::zeros(
|
||||
auto z2 = native::zeros(
|
||||
zeros_sizes,
|
||||
optTypeMetaToScalarType(t._values().options().dtype_opt()),
|
||||
t._values().options().layout_opt(),
|
||||
|
|
@ -843,9 +843,12 @@ Tensor diag_embed(const Tensor& self, int64_t offset, int64_t dim1_, int64_t dim
|
|||
return result;
|
||||
}
|
||||
|
||||
Tensor expand(const Tensor& self, c10::SymIntArrayRef sym_size, bool /*unused*/) {
|
||||
// TODO: properly support SymInt expand
|
||||
auto size = asIntArrayRefSlow(sym_size);
|
||||
Tensor expand_symint(const Tensor& self, c10::SymIntArrayRef packed_size, bool implicit) {
|
||||
auto size = asIntArrayRefSlow(packed_size);
|
||||
return self.expand(size, implicit);
|
||||
}
|
||||
|
||||
Tensor expand(const Tensor& self, IntArrayRef size, bool /*unused*/) {
|
||||
TORCH_CHECK(size.size() >= (size_t)self.dim(),
|
||||
"expand(", self.toString(), "{", self.sizes(), "}, size=", size,
|
||||
"): the number of sizes provided (", size.size(), ") ",
|
||||
|
|
@ -924,9 +927,12 @@ const Tensor &as_strided_(const Tensor& self, IntArrayRef size, IntArrayRef stri
|
|||
return self;
|
||||
}
|
||||
|
||||
Tensor narrow_copy_dense(const Tensor& self, int64_t dim, SymInt start, SymInt length) {
|
||||
// TODO: properly support SymInt narrow_copy
|
||||
return self.narrow(dim, start.expect_int(), length.expect_int()).clone(at::MemoryFormat::Contiguous);
|
||||
Tensor narrow_copy_symint(const Tensor& self, int64_t dim, int64_t start, SymInt sym_length) {
|
||||
return self.narrow_copy(dim, start, sym_length.expect_int());
|
||||
}
|
||||
|
||||
Tensor narrow_copy_dense(const Tensor& self, int64_t dim, int64_t start, int64_t length) {
|
||||
return self.narrow(dim, start, length).clone(at::MemoryFormat::Contiguous);
|
||||
}
|
||||
|
||||
Tensor narrow_copy_dense_cpu(const Tensor& self, int64_t dim, int64_t start, int64_t length){
|
||||
|
|
@ -2776,7 +2782,7 @@ Tensor unsqueeze_sparse(Tensor const &self, int64_t dim) {
|
|||
if (dim <= sparse_dim) {
|
||||
auto new_indices = at::cat(
|
||||
{indices.narrow(0, 0, dim),
|
||||
at::zeros(
|
||||
native::zeros(
|
||||
{1, indices.size(1)},
|
||||
kLong,
|
||||
indices.options().layout_opt(),
|
||||
|
|
@ -3112,15 +3118,14 @@ Tensor adjoint(const Tensor &self) {
|
|||
return _adjoint(self, /*transpose=*/false, "adjoint()");
|
||||
}
|
||||
|
||||
Tensor view_meta(const Tensor& self,
|
||||
at::SymIntArrayRef size) {
|
||||
// TODO: Properly support SymInt view
|
||||
return view_impl(self, c10::asIntArrayRefSlow(size));
|
||||
Tensor view(const Tensor& self,
|
||||
IntArrayRef size) {
|
||||
return view_impl(self, size);
|
||||
}
|
||||
|
||||
Tensor view(const Tensor& self,
|
||||
at::IntArrayRef size) {
|
||||
return view_impl(self, size);
|
||||
Tensor view_symint(const Tensor& self,
|
||||
c10::SymIntArrayRef size) {
|
||||
return self.view(c10::asIntArrayRefSlow(size));
|
||||
}
|
||||
|
||||
Tensor alias(const Tensor& self) {
|
||||
|
|
@ -3500,8 +3505,8 @@ at::Tensor& expand_copy_SymInt_out(const at::Tensor & self, c10::SymIntArrayRef
|
|||
}
|
||||
|
||||
|
||||
at::Tensor& expand_copy_out(const at::Tensor & self, at::SymIntArrayRef size, bool implicit, at::Tensor & out) {
|
||||
auto tmp = self.expand_symint(size, implicit);
|
||||
at::Tensor& expand_copy_out(const at::Tensor & self, at::IntArrayRef size, bool implicit, at::Tensor & out) {
|
||||
auto tmp = self.expand(size, implicit);
|
||||
out.copy_(tmp);
|
||||
return out;
|
||||
}
|
||||
|
|
@ -3656,8 +3661,8 @@ void unbind_copy_int_out(const at::Tensor & self, int64_t dim, at::TensorList o
|
|||
}
|
||||
|
||||
|
||||
at::Tensor& view_copy_out(const at::Tensor & self, at::SymIntArrayRef size, at::Tensor & out) {
|
||||
auto tmp = self.view_symint(size);
|
||||
at::Tensor& view_copy_out(const at::Tensor & self, at::IntArrayRef size, at::Tensor & out) {
|
||||
auto tmp = self.view(size);
|
||||
out.copy_(tmp);
|
||||
return out;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@
|
|||
#include <ATen/ops/bincount_native.h>
|
||||
#include <ATen/ops/empty.h>
|
||||
#include <ATen/ops/histc_native.h>
|
||||
#include <ATen/ops/zeros.h>
|
||||
#include <ATen/ops/zeros_native.h>
|
||||
#endif
|
||||
|
||||
namespace at {
|
||||
|
|
@ -271,7 +271,7 @@ bool CUDA_tensor_histogram(
|
|||
detail::TensorInfo<output_t, IndexType> pInfo(nullptr, 0, {}, {});
|
||||
Tensor partial_output;
|
||||
if (memType == CUDAHistogramMemoryType::MULTI_BLOCK) {
|
||||
partial_output = at::zeros(
|
||||
partial_output = native::zeros(
|
||||
{grid.x, nbins},
|
||||
optTypeMetaToScalarType(a.options().dtype_opt()),
|
||||
a.options().layout_opt(),
|
||||
|
|
@ -313,7 +313,7 @@ Tensor _bincount_cuda_template(
|
|||
AT_ERROR("minlength should be >= 0");
|
||||
}
|
||||
if (self.dim() == 1 && self.numel() == 0) {
|
||||
return at::zeros(
|
||||
return native::zeros(
|
||||
{minlength},
|
||||
kLong,
|
||||
c10::nullopt /* layout */,
|
||||
|
|
@ -342,7 +342,7 @@ Tensor _bincount_cuda_template(
|
|||
// alloc output counter on GPU
|
||||
Tensor output;
|
||||
if (has_weights) {
|
||||
output = at::zeros(
|
||||
output = native::zeros(
|
||||
{nbins},
|
||||
optTypeMetaToScalarType(weights.options().dtype_opt()),
|
||||
weights.options().layout_opt(),
|
||||
|
|
@ -351,7 +351,7 @@ Tensor _bincount_cuda_template(
|
|||
cuda::CUDA_tensor_histogram<weights_t, input_t, true>(
|
||||
output, self, weights, nbins, minvalue, maxvalue);
|
||||
} else {
|
||||
output = at::zeros(
|
||||
output = native::zeros(
|
||||
{nbins},
|
||||
kLong,
|
||||
c10::nullopt /* layout */,
|
||||
|
|
@ -373,7 +373,7 @@ Tensor _histc_cuda_template(
|
|||
if (nbins <= 0) {
|
||||
AT_ERROR("bins must be > 0");
|
||||
}
|
||||
Tensor output = at::zeros(
|
||||
Tensor output = native::zeros(
|
||||
{nbins},
|
||||
self.scalar_type(),
|
||||
c10::nullopt /* layout */,
|
||||
|
|
|
|||
|
|
@ -55,6 +55,10 @@ Tensor empty_cuda(IntArrayRef size, c10::optional<ScalarType> dtype_opt, c10::op
|
|||
return at::detail::empty_cuda(size, dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt);
|
||||
}
|
||||
|
||||
Tensor empty_symint_cuda(c10::SymIntArrayRef size, c10::optional<ScalarType> dtype_opt, c10::optional<Layout> layout_opt, c10::optional<Device> device_opt, c10::optional<bool> pin_memory_opt, c10::optional<c10::MemoryFormat> memory_format_opt) {
|
||||
return at::native::empty_cuda(asIntArrayRefSlow(size), dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt);
|
||||
}
|
||||
|
||||
Tensor _efficientzerotensor_cuda(IntArrayRef size,
|
||||
c10::optional<ScalarType> dtype,
|
||||
c10::optional<Layout> layout,
|
||||
|
|
|
|||
|
|
@ -436,7 +436,7 @@ Tensor cudnn_convolution_relu(
|
|||
bool allow_tf32 = ctx.allowTF32CuDNN();
|
||||
auto _bias = bias_t.has_value()
|
||||
? bias_t.value()
|
||||
: at::zeros(
|
||||
: at::native::zeros(
|
||||
{output_t.size(1)},
|
||||
optTypeMetaToScalarType(output_t.options().dtype_opt()),
|
||||
output_t.options().layout_opt(),
|
||||
|
|
@ -514,7 +514,7 @@ Tensor cudnn_convolution_add_relu(
|
|||
auto _alpha = alpha.has_value() ? alpha.value().to<float>() : 1.0;
|
||||
auto _bias = bias_t.has_value()
|
||||
? bias_t.value()
|
||||
: at::zeros(
|
||||
: at::native::zeros(
|
||||
{output_t.size(1)},
|
||||
optTypeMetaToScalarType(output_t.options().dtype_opt()),
|
||||
output_t.options().layout_opt(),
|
||||
|
|
|
|||
|
|
@ -1570,7 +1570,7 @@ Tensor miopen_convolution_add_relu(
|
|||
auto _alpha = alpha.has_value() ? alpha.value().to<float>() : 1.0;
|
||||
auto _bias = bias.has_value()
|
||||
? bias.value()
|
||||
: at::zeros(
|
||||
: at::native::zeros(
|
||||
{contig_output.size(1)},
|
||||
optTypeMetaToScalarType(contig_output.options().dtype_opt()),
|
||||
contig_output.options().layout_opt(),
|
||||
|
|
@ -1614,7 +1614,7 @@ Tensor miopen_convolution_relu(
|
|||
|
||||
auto _bias = bias.has_value()
|
||||
? bias.value()
|
||||
: at::zeros(
|
||||
: at::native::zeros(
|
||||
{output_t.size(1)},
|
||||
optTypeMetaToScalarType(output_t.options().dtype_opt()),
|
||||
output_t.options().layout_opt(),
|
||||
|
|
@ -1661,7 +1661,7 @@ Tensor miopen_convolution_relu(
|
|||
|
||||
auto _bias = bias.has_value()
|
||||
? bias.value()
|
||||
: at::zeros(
|
||||
: at::native::zeros(
|
||||
{contig_output.size(1)},
|
||||
optTypeMetaToScalarType(contig_output.options().dtype_opt()),
|
||||
contig_output.options().layout_opt(),
|
||||
|
|
|
|||
|
|
@ -2,6 +2,10 @@
|
|||
|
||||
namespace at { namespace native {
|
||||
|
||||
Tensor empty_symint_mkldnn(c10::SymIntArrayRef sizes, c10::optional<ScalarType> dtype, c10::optional<Layout> layout, c10::optional<Device> device, c10::optional<bool> pin_memory, c10::optional<c10::MemoryFormat> optional_memory_format) {
|
||||
return at::native::empty_mkldnn(c10::asIntArrayRefSlow(sizes), dtype, layout, device, pin_memory, optional_memory_format);
|
||||
}
|
||||
|
||||
#if AT_MKLDNN_ENABLED()
|
||||
|
||||
Tensor empty_mkldnn(IntArrayRef sizes, c10::optional<ScalarType> dtype, c10::optional<Layout> layout, c10::optional<Device> device, c10::optional<bool> pin_memory, c10::optional<c10::MemoryFormat> optional_memory_format) {
|
||||
|
|
|
|||
|
|
@ -71,6 +71,17 @@ Tensor empty_mps(
|
|||
return at::detail::empty_mps(size, dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt);
|
||||
}
|
||||
|
||||
Tensor empty_symint_mps(
|
||||
c10::SymIntArrayRef size,
|
||||
c10::optional<ScalarType> dtype_opt,
|
||||
c10::optional<Layout> layout_opt,
|
||||
c10::optional<Device> device_opt,
|
||||
c10::optional<bool> pin_memory_opt,
|
||||
c10::optional<c10::MemoryFormat> memory_format_opt) {
|
||||
|
||||
return at::native::empty_mps(c10::asIntArrayRefSlow(size), dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt);
|
||||
}
|
||||
|
||||
Tensor empty_strided_mps(
|
||||
IntArrayRef size,
|
||||
IntArrayRef stride,
|
||||
|
|
|
|||
|
|
@ -2046,10 +2046,10 @@
|
|||
device_check: NoCheck
|
||||
device_guard: False
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: empty_names
|
||||
CompositeExplicitAutograd: empty
|
||||
autogen: empty.names_out
|
||||
|
||||
- func: empty.memory_format(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
|
||||
- func: empty.memory_format(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
|
||||
dispatch:
|
||||
CPU: empty_cpu
|
||||
CUDA: empty_cuda
|
||||
|
|
@ -2060,14 +2060,39 @@
|
|||
SparseCsrCPU, SparseCsrCUDA: empty_sparse_compressed
|
||||
QuantizedCPU, QuantizedCUDA, QuantizedMeta: empty_unknown_quantized
|
||||
|
||||
# all calls to empty() in python used to go through the symint overload
|
||||
# even if all arguments were concerete integers.
|
||||
# adding symint overloads of kernels for every dispatch key allowed us
|
||||
# to skip redispatching to `empty.memory_format` and hit backend kernels directly
|
||||
# we recently updated signature parsing to dispath `empty()` calls in python
|
||||
# to `empty.SymInt` iff there's is a symint node argument
|
||||
# hopefully, we could simplify this entry soon
|
||||
- func: empty.SymInt(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
|
||||
dispatch:
|
||||
CPU: empty_symint_cpu
|
||||
CUDA: empty_symint_cuda
|
||||
MPS: empty_symint_mps
|
||||
Meta: empty_symint_meta
|
||||
MkldnnCPU: empty_symint_mkldnn
|
||||
SparseCPU, SparseCUDA, SparseMeta: empty_symint_sparse
|
||||
SparseCsrCPU, SparseCsrCUDA: empty_symint_sparse_compressed
|
||||
QuantizedCPU, QuantizedCUDA: empty_symint_unknown_quantized
|
||||
autogen: empty.SymInt_out
|
||||
|
||||
# We do not make new_empty a composite that calls into new_empty_strided, as the strided version
|
||||
# is significantly more difficult to implement by different backends
|
||||
- func: new_empty(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
|
||||
- func: new_empty(Tensor self, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
|
||||
variants: method
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: new_empty
|
||||
autogen: new_empty.out
|
||||
|
||||
- func: new_empty.SymInt(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
|
||||
variants: method
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: new_empty_symint
|
||||
autogen: new_empty.SymInt_out
|
||||
|
||||
- func: new_empty_strided(Tensor self, int[] size, int[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
|
||||
variants: method
|
||||
dispatch:
|
||||
|
|
@ -2145,7 +2170,7 @@
|
|||
QuantizedCPU, QuantizedCUDA: empty_quantized
|
||||
autogen: empty_quantized.out
|
||||
|
||||
- func: empty.out(SymInt[] size, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)
|
||||
- func: empty.out(int[] size, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)
|
||||
device_check: NoCheck
|
||||
device_guard: False
|
||||
|
||||
|
|
@ -2269,7 +2294,14 @@
|
|||
SparseCPU, SparseCUDA: expm1_sparse_out
|
||||
SparseCsrCPU, SparseCsrCUDA: expm1_sparse_csr_out
|
||||
|
||||
- func: expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a)
|
||||
- func: expand.SymInt(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a)
|
||||
variants: method # This is method-only to match the previous tensor API. In the future we could make this a function too.
|
||||
device_check: NoCheck
|
||||
device_guard: False
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: expand_symint
|
||||
|
||||
- func: expand(Tensor(a) self, int[] size, *, bool implicit=False) -> Tensor(a)
|
||||
variants: method # This is method-only to match the previous tensor API. In the future we could make this a function too.
|
||||
device_check: NoCheck
|
||||
device_guard: False
|
||||
|
|
@ -3665,7 +3697,7 @@
|
|||
dispatch:
|
||||
CompositeExplicitAutograd: mvlgamma_
|
||||
|
||||
- func: narrow_copy(Tensor self, int dim, SymInt start, SymInt length) -> Tensor
|
||||
- func: narrow_copy(Tensor self, int dim, int start, int length) -> Tensor
|
||||
variants: function, method
|
||||
dispatch:
|
||||
CPU: narrow_copy_dense_cpu
|
||||
|
|
@ -3673,7 +3705,13 @@
|
|||
CompositeExplicitAutogradNonFunctional: narrow_copy_dense
|
||||
tags: view_copy
|
||||
|
||||
- func: narrow_copy.out(Tensor self, int dim, SymInt start, SymInt length, *, Tensor(a!) out) -> Tensor(a!)
|
||||
- func: narrow_copy.SymInt(Tensor self, int dim, int start, SymInt length) -> Tensor
|
||||
variants: function, method
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: narrow_copy_symint
|
||||
autogen: narrow_copy.SymInt_out
|
||||
|
||||
- func: narrow_copy.out(Tensor self, int dim, int start, int length, *, Tensor(a!) out) -> Tensor(a!)
|
||||
dispatch:
|
||||
CPU: narrow_copy_dense_cpu_out
|
||||
|
||||
|
|
@ -5542,14 +5580,19 @@
|
|||
CUDA: _efficientzerotensor_cuda
|
||||
autogen: _efficientzerotensor.out
|
||||
|
||||
- func: zeros(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
|
||||
- func: zeros(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: zeros
|
||||
|
||||
- func: zeros.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
|
||||
- func: zeros.SymInt(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: zeros_symint
|
||||
autogen: zeros.SymInt_out
|
||||
|
||||
- func: zeros.out(int[] size, *, Tensor(a!) out) -> Tensor(a!)
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: zeros_out
|
||||
SparseCPU, SparseCUDA, SparseMeta: zeros_sparse_out
|
||||
SparseCPU, SparseCUDA, SparseMeta: zeros_out
|
||||
|
||||
- func: zeros_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
|
||||
dispatch:
|
||||
|
|
@ -6876,13 +6919,20 @@
|
|||
CPU: masked_softmax_backward_cpu
|
||||
autogen: _masked_softmax_backward.out
|
||||
|
||||
- func: view(Tensor(a) self, SymInt[] size) -> Tensor(a)
|
||||
- func: view.SymInt(Tensor(a) self, SymInt[] size) -> Tensor(a)
|
||||
variants: method
|
||||
device_check: NoCheck
|
||||
device_guard: False
|
||||
dispatch:
|
||||
Meta: view_meta
|
||||
ZeroTensor, CPU, CUDA, QuantizedCPU, QuantizedCUDA, MPS: view
|
||||
CompositeExplicitAutograd: view_symint
|
||||
MkldnnCPU: mkldnn_view_symint
|
||||
|
||||
- func: view(Tensor(a) self, int[] size) -> Tensor(a)
|
||||
variants: method
|
||||
device_check: NoCheck
|
||||
device_guard: False
|
||||
dispatch:
|
||||
ZeroTensor, CPU, CUDA, Meta, QuantizedCPU, QuantizedCUDA, MPS: view
|
||||
MkldnnCPU: mkldnn_view
|
||||
|
||||
# Warning: If you want to change the name or overload name of this
|
||||
|
|
@ -12714,12 +12764,18 @@
|
|||
CompositeExplicitAutogradNonFunctional: diagonal_copy
|
||||
tags: view_copy
|
||||
|
||||
- func: expand_copy(Tensor self, SymInt[] size, *, bool implicit=False) -> Tensor
|
||||
- func: expand_copy(Tensor self, int[] size, *, bool implicit=False) -> Tensor
|
||||
variants: function
|
||||
dispatch:
|
||||
CompositeExplicitAutogradNonFunctional: expand_copy
|
||||
tags: view_copy
|
||||
|
||||
- func: expand_copy.SymInt(Tensor self, SymInt[] size, *, bool implicit=False) -> Tensor
|
||||
variants: function
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: expand_copy_SymInt
|
||||
tags: view_copy
|
||||
|
||||
- func: permute_copy(Tensor self, int[] dims) -> Tensor
|
||||
variants: function
|
||||
dispatch:
|
||||
|
|
@ -12848,7 +12904,7 @@
|
|||
CompositeExplicitAutogradNonFunctional: unbind_copy_int
|
||||
tags: view_copy
|
||||
|
||||
- func: view_copy(Tensor self, SymInt[] size) -> Tensor
|
||||
- func: view_copy(Tensor self, int[] size) -> Tensor
|
||||
variants: function
|
||||
dispatch:
|
||||
CompositeExplicitAutogradNonFunctional: view_copy
|
||||
|
|
@ -12908,6 +12964,14 @@
|
|||
CompositeExplicitAutograd: _neg_view_copy_out
|
||||
|
||||
|
||||
- func: view_copy.SymInt(Tensor self, SymInt[] size) -> Tensor
|
||||
variants: function
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: view_copy_SymInt
|
||||
tags: view_copy
|
||||
autogen: view_copy.SymInt_out
|
||||
|
||||
|
||||
- func: as_strided_copy.out(Tensor self, int[] size, int[] stride, int? storage_offset=None, *, Tensor(a!) out) -> Tensor(a!)
|
||||
variants: function
|
||||
dispatch:
|
||||
|
|
@ -12926,7 +12990,13 @@
|
|||
CompositeExplicitAutograd: diagonal_copy_out
|
||||
|
||||
|
||||
- func: expand_copy.out(Tensor self, SymInt[] size, *, bool implicit=False, Tensor(a!) out) -> Tensor(a!)
|
||||
- func: expand_copy.SymInt_out(Tensor self, SymInt[] size, *, bool implicit=False, Tensor(a!) out) -> Tensor(a!)
|
||||
variants: function
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: expand_copy_SymInt_out
|
||||
|
||||
|
||||
- func: expand_copy.out(Tensor self, int[] size, *, bool implicit=False, Tensor(a!) out) -> Tensor(a!)
|
||||
variants: function
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: expand_copy_out
|
||||
|
|
@ -13046,7 +13116,7 @@
|
|||
CompositeExplicitAutograd: unbind_copy_int_out
|
||||
|
||||
|
||||
- func: view_copy.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
|
||||
- func: view_copy.out(Tensor self, int[] size, *, Tensor(a!) out) -> Tensor(a!)
|
||||
variants: function
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: view_copy_out
|
||||
|
|
|
|||
|
|
@ -66,6 +66,16 @@ Tensor empty_per_channel_affine_quantized(
|
|||
quantizer);
|
||||
}
|
||||
|
||||
Tensor empty_symint_unknown_quantized(
|
||||
c10::SymIntArrayRef size,
|
||||
c10::optional<ScalarType> dtype,
|
||||
c10::optional<Layout> layout,
|
||||
c10::optional<Device> device,
|
||||
c10::optional<bool> pin_memory,
|
||||
c10::optional<c10::MemoryFormat> optional_memory_format) {
|
||||
return at::native::empty_unknown_quantized(c10::asIntArrayRefSlow(size), dtype, layout, device, pin_memory, optional_memory_format);
|
||||
}
|
||||
|
||||
Tensor empty_unknown_quantized(
|
||||
IntArrayRef size,
|
||||
c10::optional<ScalarType> dtype,
|
||||
|
|
|
|||
|
|
@ -488,6 +488,16 @@ SPARSE_COMPRESSED_TENSOR(csc, kSparseCsc)
|
|||
SPARSE_COMPRESSED_TENSOR(bsr, kSparseBsr)
|
||||
SPARSE_COMPRESSED_TENSOR(bsc, kSparseBsc)
|
||||
|
||||
Tensor empty_symint_sparse_compressed(
|
||||
c10::SymIntArrayRef size,
|
||||
c10::optional<ScalarType> dtype,
|
||||
c10::optional<Layout> layout,
|
||||
c10::optional<Device> device,
|
||||
c10::optional<bool> pin_memory,
|
||||
c10::optional<MemoryFormat> optional_memory_format) {
|
||||
return at::native::empty_sparse_compressed(c10::asIntArrayRefSlow(size), dtype, layout, device, pin_memory, optional_memory_format);
|
||||
}
|
||||
|
||||
Tensor empty_sparse_compressed(
|
||||
IntArrayRef size,
|
||||
c10::optional<ScalarType> dtype,
|
||||
|
|
|
|||
|
|
@ -207,6 +207,17 @@ Tensor empty_sparse(
|
|||
size.size(), 0, size, dtype, layout, device, pin_memory);
|
||||
}
|
||||
|
||||
/** Empty init **/
|
||||
Tensor empty_symint_sparse(
|
||||
c10::SymIntArrayRef size,
|
||||
c10::optional<ScalarType> dtype,
|
||||
c10::optional<Layout> layout,
|
||||
c10::optional<Device> device,
|
||||
c10::optional<bool> pin_memory,
|
||||
c10::optional<MemoryFormat> optional_memory_format) {
|
||||
return at::native::empty_sparse(c10::asIntArrayRefSlow(size), dtype, layout, device, pin_memory, optional_memory_format);
|
||||
}
|
||||
|
||||
/* Shape init */
|
||||
Tensor sparse_coo_tensor(IntArrayRef size,
|
||||
c10::optional<ScalarType> dtype,
|
||||
|
|
|
|||
|
|
@ -1524,7 +1524,7 @@ SparseTensor& _sspaddmm_out_cpu(
|
|||
int64_t t_nnz = t._nnz();
|
||||
int64_t r_nnz = nnz * dim_k + t_nnz;
|
||||
Tensor newi = at::empty({2, r_nnz}, kLong);
|
||||
Tensor newv = at::zeros(
|
||||
Tensor newv = native::zeros(
|
||||
{r_nnz},
|
||||
optTypeMetaToScalarType(values.options().dtype_opt()),
|
||||
values.options().layout_opt(),
|
||||
|
|
|
|||
|
|
@ -141,6 +141,7 @@ full_codegen:
|
|||
- upsample_nearest2d
|
||||
- upsample_nearest2d_backward
|
||||
- zero
|
||||
- narrow_copy.SymInt
|
||||
- alias_copy
|
||||
- as_strided_copy
|
||||
- diagonal_copy
|
||||
|
|
@ -174,6 +175,7 @@ supported:
|
|||
- _copy_from
|
||||
- _copy_from_and_resize
|
||||
- empty.memory_format
|
||||
- empty.SymInt
|
||||
- empty_strided
|
||||
- fill_.Scalar
|
||||
- normal_
|
||||
|
|
|
|||
|
|
@ -29,13 +29,12 @@ Tensor _empty_affine_quantized(
|
|||
}
|
||||
|
||||
Tensor empty_memory_format(
|
||||
const SymIntArrayRef sym_sizes,
|
||||
const IntArrayRef sizes,
|
||||
const c10::optional<ScalarType> dtype,
|
||||
const c10::optional<c10::Layout> layout,
|
||||
const c10::optional<Device> device,
|
||||
const c10::optional<bool> pin_memory,
|
||||
const optional<MemoryFormat> memory_format) {
|
||||
auto sizes = c10::asIntArrayRefSlow(sym_sizes);
|
||||
return convert(vTensor{
|
||||
api::context(),
|
||||
sizes,
|
||||
|
|
@ -56,12 +55,7 @@ Tensor empty_strided(
|
|||
const optional<Device> device,
|
||||
const optional<bool> pin_memory) {
|
||||
return empty_memory_format(
|
||||
c10::SymIntArrayRef::fromIntArrayRef(sizes),
|
||||
dtype,
|
||||
layout,
|
||||
device,
|
||||
pin_memory,
|
||||
c10::MemoryFormat::Contiguous);
|
||||
sizes, dtype, layout, device, pin_memory, c10::MemoryFormat::Contiguous);
|
||||
}
|
||||
|
||||
#ifdef USE_VULKAN_API
|
||||
|
|
|
|||
|
|
@ -42,8 +42,7 @@ Tensor view_internal(const Tensor& self_arg, const IntArrayRef shape) {
|
|||
return convert(v_output);
|
||||
}
|
||||
|
||||
inline Tensor view(const Tensor& self_arg, const SymIntArrayRef sym_shape) {
|
||||
auto shape = c10::asIntArrayRefSlow(sym_shape);
|
||||
inline Tensor view(const Tensor& self_arg, const IntArrayRef shape) {
|
||||
return view_internal(self_arg, shape);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -55,6 +55,8 @@ void resize_out_helper(const at::TensorList& dst, const at::TensorList& src) {
|
|||
|
||||
${CompositeViewCopyKernel_Definitions}
|
||||
|
||||
${SymIntViewCopyKernel_Definitions}
|
||||
|
||||
${GeneratedCompositeFunctional_Definitions}
|
||||
|
||||
${GeneratedCompositeOut_Definitions}
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ T getSampleValue();
|
|||
|
||||
template <>
|
||||
at::Tensor getSampleValue() {
|
||||
return at::zeros({2, 2}).to(at::kCPU);
|
||||
return at::native::zeros({2, 2}).to(at::kCPU);
|
||||
}
|
||||
|
||||
template <>
|
||||
|
|
|
|||
|
|
@ -105,7 +105,7 @@ void assertOwn(
|
|||
|
||||
template<>
|
||||
Tensor getSampleValue() {
|
||||
return at::zeros({2, 2}).to(at::kCPU);
|
||||
return at::native::zeros({2, 2}).to(at::kCPU);
|
||||
}
|
||||
|
||||
template<>
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ using namespace at;
|
|||
|
||||
static int test_int;
|
||||
|
||||
Tensor empty_override(SymIntArrayRef size, c10::optional<ScalarType> dtype, c10::optional<Layout> layout,
|
||||
Tensor empty_override(IntArrayRef size, c10::optional<ScalarType> dtype, c10::optional<Layout> layout,
|
||||
c10::optional<Device> device, c10::optional<bool> pin_memory, c10::optional<MemoryFormat> optional_memory_format) {
|
||||
test_int = 1;
|
||||
auto tensor_impl = c10::make_intrusive<TensorImpl, UndefinedTensorImpl>(
|
||||
|
|
@ -44,7 +44,7 @@ Tensor empty_strided_override(
|
|||
c10::optional<c10::Device> device,
|
||||
c10::optional<bool> pin_memory) {
|
||||
|
||||
return empty_override(SymIntArrayRef::fromIntArrayRef(size), dtype, layout, device, pin_memory, c10::nullopt);
|
||||
return empty_override(size, dtype, layout, device, pin_memory, c10::nullopt);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(aten, ORT, m) {
|
||||
|
|
|
|||
|
|
@ -119,7 +119,7 @@ TEST(MathKernelTest, NarrowCopy) {
|
|||
for (const auto dim : c10::irange(3)) {
|
||||
const int64_t start = 1, length = 4;
|
||||
auto y_ref = x.narrow(dim, start, length);
|
||||
auto y_test = at::native::narrow_copy_dense(x, dim, c10::SymInt(start), c10::SymInt(length));
|
||||
auto y_test = at::native::narrow_copy_dense(x, dim, start, length);
|
||||
ASSERT_ALLCLOSE_TOLERANCES(y_ref, y_test, 0, 0);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,25 +9,6 @@
|
|||
|
||||
namespace at { namespace functorch {
|
||||
|
||||
template <typename A, A a, typename C>
|
||||
struct NewBlahBatchRuleHelperSymInt;
|
||||
|
||||
template <typename F, F Func, typename A, typename B, typename... T>
|
||||
struct NewBlahBatchRuleHelperSymInt<F, Func, typelist<A, B, T...>> {
|
||||
static std::tuple<Tensor,optional<int64_t>> apply(
|
||||
const Tensor& tensor,
|
||||
optional<int64_t> batch_dim,
|
||||
SymIntArrayRef shape,
|
||||
T... extra_args) {
|
||||
const auto bdim_size = tensor.sym_size(batch_dim.value());
|
||||
c10::SmallVector<c10::SymInt> new_shape;
|
||||
new_shape.reserve(shape.size() + 1);
|
||||
new_shape.emplace_back(bdim_size);
|
||||
new_shape.insert(new_shape.end(), shape.begin(), shape.end());
|
||||
return std::make_tuple(Func(tensor, new_shape, std::forward<T>(extra_args)...), 0);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename A, A a, typename C>
|
||||
struct NewBlahBatchRuleHelper;
|
||||
|
||||
|
|
@ -56,12 +37,6 @@ struct NewBlahBatchRuleHelper<F, Func, typelist<A, B, T...>> {
|
|||
&fn,\
|
||||
c10::guts::function_traits<decltype(fn)>::parameter_types>::apply)
|
||||
|
||||
#define NEW_BLAH_BATCH_RULE_SYMINT(fn) SINGLE_ARG(\
|
||||
NewBlahBatchRuleHelperSymInt<\
|
||||
decltype(&fn),\
|
||||
&fn,\
|
||||
c10::guts::function_traits<decltype(fn)>::parameter_types>::apply)
|
||||
|
||||
std::tuple<Tensor,optional<int64_t>> _new_zeros_with_same_feature_meta_batch_rule(
|
||||
const Tensor& self, optional<int64_t> self_bdim,
|
||||
const Tensor& other, optional<int64_t> other_bdim,
|
||||
|
|
@ -107,6 +82,17 @@ bool _has_same_storage_numel_batch_rule(const Tensor& a, const Tensor& b) {
|
|||
return true;
|
||||
}
|
||||
|
||||
Tensor new_empty_symint_decomp(
|
||||
const Tensor& self,
|
||||
SymIntArrayRef size,
|
||||
c10::optional<ScalarType> dtype_opt,
|
||||
c10::optional<Layout> layout_opt,
|
||||
c10::optional<Device> device_opt,
|
||||
c10::optional<bool> pin_memory_opt
|
||||
) {
|
||||
return self.new_empty(c10::asIntArrayRefSlow(size), dtype_opt, layout_opt, device_opt, pin_memory_opt);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
|
||||
m.impl("_has_same_storage_numel", _has_same_storage_numel_batch_rule);
|
||||
VMAP_SUPPORT(ones_like, BASIC_UNARY_BATCH_RULE(ATEN_FN(ones_like)));
|
||||
|
|
@ -115,7 +101,8 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
|
|||
VMAP_SUPPORT(randn_like, BASIC_UNARY_BATCH_RULE(ATEN_FN(randn_like)));
|
||||
VMAP_SUPPORT(rand_like, BASIC_UNARY_BATCH_RULE(ATEN_FN(rand_like)));
|
||||
VMAP_SUPPORT(full_like, BASIC_UNARY_BATCH_RULE(ATEN_FN(full_like)));
|
||||
VMAP_SUPPORT(new_empty, NEW_BLAH_BATCH_RULE_SYMINT(ATEN_FN(new_empty)));
|
||||
VMAP_SUPPORT(new_empty, NEW_BLAH_BATCH_RULE(ATEN_FN(new_empty)));
|
||||
m.impl("new_empty.SymInt", new_empty_symint_decomp);
|
||||
VMAP_SUPPORT(new_zeros, NEW_BLAH_BATCH_RULE(ATEN_FN(new_zeros)));
|
||||
VMAP_SUPPORT(new_ones, NEW_BLAH_BATCH_RULE(ATEN_FN(new_ones)));
|
||||
VMAP_SUPPORT(new_full, NEW_BLAH_BATCH_RULE(ATEN_FN(new_full)));
|
||||
|
|
|
|||
|
|
@ -427,15 +427,15 @@ std::tuple<Tensor,optional<int64_t>> slice_backward_batch_rule(
|
|||
}
|
||||
|
||||
std::tuple<Tensor, optional<int64_t>> view_batching_rule(
|
||||
const Tensor &self, optional<int64_t> self_bdim, SymIntArrayRef sym_size)
|
||||
const Tensor &self, optional<int64_t> self_bdim, IntArrayRef size)
|
||||
{
|
||||
TORCH_INTERNAL_ASSERT(self_bdim.has_value());
|
||||
auto self_ = moveBatchDimToFront(self, self_bdim);
|
||||
c10::SmallVector<c10::SymInt> size_(sym_size.size() + 1);
|
||||
VmapDimVector size_(size.size() + 1);
|
||||
// copy batch size
|
||||
size_[0] = self_.size(0);
|
||||
std::copy(sym_size.cbegin(), sym_size.cend(), size_.begin() + 1);
|
||||
return std::make_tuple(self_.view_symint(size_), 0);
|
||||
std::copy(size.cbegin(), size.cend(), size_.begin() + 1);
|
||||
return std::make_tuple(self_.view(size_), 0);
|
||||
}
|
||||
|
||||
Tensor view_symint_decomposition(const Tensor& self,
|
||||
|
|
@ -446,7 +446,7 @@ Tensor view_symint_decomposition(const Tensor& self,
|
|||
|
||||
template <typename F, F Func>
|
||||
std::tuple<Tensor, optional<int64_t>> expand_batch_rule(
|
||||
const Tensor &self, optional<int64_t> self_bdim, SymIntArrayRef size, bool implicit)
|
||||
const Tensor &self, optional<int64_t> self_bdim, IntArrayRef size, bool implicit)
|
||||
{
|
||||
auto self_dim = self.dim();
|
||||
TORCH_CHECK(static_cast<uint64_t>(self_dim - 1) <= size.size(),
|
||||
|
|
@ -457,7 +457,7 @@ std::tuple<Tensor, optional<int64_t>> expand_batch_rule(
|
|||
auto self_sizes = self_.sizes();
|
||||
auto batch_size = self_sizes[0];
|
||||
|
||||
c10::SmallVector<c10::SymInt> size_(size.size() + 1);
|
||||
c10::SmallBuffer<int64_t, 5> size_(size.size() + 1);
|
||||
size_[0] = batch_size;
|
||||
std::copy(size.cbegin(), size.cend(), size_.begin() + 1);
|
||||
|
||||
|
|
@ -471,12 +471,12 @@ std::tuple<Tensor, optional<int64_t>> expand_batch_rule(
|
|||
// so the strategy here is to view it first as a tensor of size [B0, 1, 3] and
|
||||
// then expand.
|
||||
auto extra_dims = size.size() - (self_dim - 1);
|
||||
c10::SmallVector<c10::SymInt> view_shape(size_.size(), /*init_value*/1);
|
||||
VmapDimVector view_shape(size_.size(), /*init_value*/1);
|
||||
view_shape[0] = batch_size;
|
||||
std::copy(self_sizes.cbegin() + 1, self_sizes.cend(),
|
||||
view_shape.begin() + 1 + extra_dims);
|
||||
|
||||
return std::make_tuple(Func(self_.view_symint(view_shape), size_, implicit), 0);
|
||||
return std::make_tuple(Func(self_.view(view_shape), size_, implicit), 0);
|
||||
}
|
||||
|
||||
std::tuple<Tensor, optional<int64_t>> unfold_batch_rule(
|
||||
|
|
@ -549,6 +549,8 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
|
|||
VMAP_SUPPORT2(slice, Tensor, slice_batch_rule);
|
||||
VMAP_SUPPORT2(transpose, int, transpose_int_batch_rule);
|
||||
VMAP_SUPPORT(diag_embed, diag_embed_batch_rule);
|
||||
m.impl("expand.SymInt", expand_symint_decomp_hack);
|
||||
m.impl("view.SymInt", view_symint_decomposition);
|
||||
}
|
||||
|
||||
}}
|
||||
|
|
|
|||
|
|
@ -84,6 +84,35 @@ static inline at::DeviceType DefaultDevice() {
|
|||
|
||||
} // namespace
|
||||
|
||||
#ifndef C10_MOBILE
|
||||
TEST(LazyDynamicOpsTest, NarrowCopy) {
|
||||
auto x = torch::rand({5, 10, 10}).to(kLazy);
|
||||
const size_t Y_DIM = 3;
|
||||
const size_t X_DIM_INDEX = 2;
|
||||
auto y = torch::rand({Y_DIM}).to(kLazy);
|
||||
auto ly = torch::lazy::TryGetLtcTensor(y);
|
||||
auto dim_node = MakeNode<SizeNode>(ly->GetIrValue(), 0);
|
||||
auto lmn = c10::make_intrusive<torch::lazy::SymIntNodeImpl>(dim_node);
|
||||
auto z = x.narrow_copy_symint(X_DIM_INDEX, 0, lmn->toSymInt());
|
||||
AllClose(z.cpu(), x.cpu().narrow_copy(X_DIM_INDEX, 0, Y_DIM));
|
||||
}
|
||||
|
||||
TEST(LazyDynamicOpsTest, NarrowCopyViaSymSizes) {
|
||||
FLAGS_ltc_enable_symbolic_shapes = true;
|
||||
auto xc = torch::rand({10});
|
||||
auto x = xc.to(kLazy);
|
||||
const size_t Y_DIM = 3;
|
||||
const size_t X_DIM_INDEX = 0;
|
||||
auto y = torch::rand({Y_DIM}).to(kLazy);
|
||||
auto z = x.narrow_copy_symint(X_DIM_INDEX, 0, y.sym_sizes()[0]);
|
||||
auto zc = xc.narrow_copy(X_DIM_INDEX, 0, Y_DIM);
|
||||
ASSERT_EQ(z.sizes()[0], xc.sizes()[0]); // note, xc not zc
|
||||
// shape inference assumes narrow_copy can copy the whole tensor
|
||||
AllClose(z.cpu(), zc);
|
||||
FLAGS_ltc_enable_symbolic_shapes = false;
|
||||
}
|
||||
#endif
|
||||
|
||||
TEST_F(LazyOpsTest, TestScalarTensor) {
|
||||
torch::Tensor scalar_tensor = torch::scalar_tensor(
|
||||
1., torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
|
||||
|
|
|
|||
|
|
@ -85,7 +85,8 @@ at::Tensor custom__copy_from(const at::Tensor& self, const at::Tensor& dst, bool
|
|||
// More details on the dispatcher can be found at http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/.
|
||||
TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
|
||||
m.impl("add.Tensor", &custom_add_Tensor);
|
||||
m.impl("empty.memory_format", &custom_empty_symint);
|
||||
m.impl("empty.memory_format", &custom_empty_memory_format);
|
||||
m.impl("empty.SymInt", &custom_empty_symint);
|
||||
m.impl("fill_.Scalar", &custom_fill__scalar);
|
||||
m.impl("_copy_from", &custom__copy_from);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -20,10 +20,15 @@ Tensor get_tensor(caffe2::TypeMeta dtype, IntArrayRef size) {
|
|||
return Tensor(std::move(tensor_impl));
|
||||
}
|
||||
|
||||
Tensor empty_override(SymIntArrayRef size, c10::optional<ScalarType> dtype, c10::optional<Layout> layout, c10::optional<Device> device,
|
||||
Tensor empty_override(IntArrayRef size, c10::optional<ScalarType> dtype, c10::optional<Layout> layout, c10::optional<Device> device,
|
||||
c10::optional<bool> pin_memory, c10::optional<c10::MemoryFormat> optional_memory_format) {
|
||||
test_int = 0;
|
||||
return get_tensor(scalarTypeToTypeMeta(dtype_or_default(dtype)), c10::asIntArrayRefSlow(size));
|
||||
return get_tensor(scalarTypeToTypeMeta(dtype_or_default(dtype)), size);
|
||||
}
|
||||
|
||||
Tensor empty_symint_override(c10::SymIntArrayRef size, c10::optional<ScalarType> dtype_opt, c10::optional<Layout> layout_opt,
|
||||
c10::optional<Device> device_opt, c10::optional<bool> pin_memory_opt, c10::optional<c10::MemoryFormat> memory_format_opt) {
|
||||
return empty_override(c10::asIntArrayRefSlow(size), dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt);
|
||||
}
|
||||
|
||||
Tensor& add_out_override(const Tensor & a, const Tensor & b , const Scalar& c, Tensor & out) {
|
||||
|
|
@ -53,6 +58,7 @@ std::tuple<Tensor,Tensor,Tensor> fake_convolution_backward(
|
|||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(aten, ORT, m) {
|
||||
m.impl("empty.SymInt", empty_symint_override);
|
||||
m.impl("empty.memory_format", empty_override);
|
||||
m.impl("add.out", add_out_override);
|
||||
m.impl("convolution_overrideable", fake_convolution);
|
||||
|
|
|
|||
|
|
@ -131,18 +131,6 @@ ALLOW_LIST = [
|
|||
("aten::sum.SymInt", datetime.date(2022, 11, 30)),
|
||||
("aten::mps_linear", datetime.date(9999, 1, 1)),
|
||||
("aten::_mps_linear", datetime.date(9999, 1, 1)),
|
||||
("aten::view_copy.SymInt", datetime.date(2022, 11, 30)),
|
||||
("aten::view_copy.SymInt_out", datetime.date(2022, 11, 30)),
|
||||
("aten::expand_copy.SymInt", datetime.date(2022, 11, 30)),
|
||||
("aten::expand_copy.SymInt_out", datetime.date(2022, 11, 30)),
|
||||
("aten::expand.SymInt", datetime.date(2022, 11, 30)),
|
||||
("aten::narrow_copy.SymInt", datetime.date(2022, 11, 30)),
|
||||
("aten::narrow_copy.SymInt_out", datetime.date(2022, 11, 30)),
|
||||
("aten::view.SymInt", datetime.date(2022, 11, 30)),
|
||||
("aten::new_empty.SymInt", datetime.date(2022, 11, 30)),
|
||||
("aten::new_empty.SymInt_out", datetime.date(2022, 11, 30)),
|
||||
("aten::zeros.SymInt", datetime.date(2022, 11, 30)),
|
||||
("aten::zeros.SymInt_out", datetime.date(2022, 11, 30)),
|
||||
# TODO: FIXME: prims shouldn't be checked
|
||||
("prims::.*", datetime.date(9999, 1, 1)),
|
||||
("aten::_amp_foreach_non_finite_check_and_unscale.out", datetime.date(2022, 9, 1)),
|
||||
|
|
|
|||
|
|
@ -391,7 +391,8 @@ class TestDecomp(TestCase):
|
|||
if func not in decomposition_table or func in [
|
||||
torch.ops.aten.detach.default,
|
||||
# non-deterministic ops
|
||||
torch.ops.aten.new_empty.default
|
||||
torch.ops.aten.new_empty.default,
|
||||
torch.ops.aten.new_empty.SymInt
|
||||
] or any_unsupported(args, kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
|
|
|||
|
|
@ -54,7 +54,7 @@ def cat_meta(tensors, dim=0):
|
|||
return tensors[0].new_empty(new_shape)
|
||||
|
||||
|
||||
@register_meta([aten.narrow_copy.default])
|
||||
@register_meta([aten.narrow_copy.SymInt])
|
||||
def narrow_copy_symint_meta(a, dim, start, length, **kwargs):
|
||||
shape = []
|
||||
for i, x in enumerate(a.shape):
|
||||
|
|
@ -65,7 +65,7 @@ def narrow_copy_symint_meta(a, dim, start, length, **kwargs):
|
|||
return a.new_empty(tuple(shape))
|
||||
|
||||
|
||||
@register_meta([aten.expand.default])
|
||||
@register_meta([aten.expand.SymInt])
|
||||
def expand_symint_meta(a, size, implicit=False):
|
||||
return a.new_empty(size)
|
||||
|
||||
|
|
@ -293,11 +293,11 @@ class TestPySymInt(TestCase):
|
|||
|
||||
shape_env = ShapeEnv()
|
||||
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
|
||||
torch.ops.aten.narrow_copy.default(x, 0, 0, x.shape[0])
|
||||
torch.ops.aten.narrow_copy.SymInt(x, 0, 0, x.shape[0])
|
||||
|
||||
shape_env = ShapeEnv()
|
||||
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
|
||||
torch.ops.aten.expand.default(x, [x.shape[0], x.shape[1], x.shape[2]])
|
||||
torch.ops.aten.expand.SymInt(x, [x.shape[0], x.shape[1], x.shape[2]])
|
||||
|
||||
def test_fx_trace_intlist(self):
|
||||
class CustomModule(torch.nn.Module):
|
||||
|
|
|
|||
|
|
@ -727,6 +727,7 @@ meta_dispatch_skips = {
|
|||
aten.linalg_pinv.atol_rtol_tensor: {f32, f64},
|
||||
aten.linalg_pinv.atol_rtol_tensor_out: {f32, f64},
|
||||
aten.empty.memory_format: {b8, bf16, c128, c64, c32, f16, f32, f64, i16, i32, i64, i8, u8},
|
||||
aten.empty.SymInt: {b8, bf16, c128, c64, c32, f16, f32, f64, i16, i32, i64, i8, u8},
|
||||
}
|
||||
|
||||
meta_dispatch_device_expected_failures = defaultdict(dict)
|
||||
|
|
|
|||
|
|
@ -15949,7 +15949,6 @@ class TestNNDeviceType(NNTestCase):
|
|||
torch.cuda.synchronize()
|
||||
issue_24823_2()
|
||||
|
||||
@skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/945")
|
||||
@dtypes(torch.float, torch.double)
|
||||
@largeTensorTest(lambda self, device, dtype:
|
||||
# Compute sum of the large tensor sizes:
|
||||
|
|
|
|||
|
|
@ -317,21 +317,24 @@ class TestProfilerTree(TestCase):
|
|||
ProfilerTree.format(p.profiler, 12),
|
||||
"""\
|
||||
aten::zeros
|
||||
aten::empty
|
||||
aten::zero_
|
||||
Top level Annotation
|
||||
aten::empty
|
||||
aten::zeros
|
||||
aten::empty
|
||||
aten::zero_
|
||||
Top level Annotation
|
||||
aten::empty
|
||||
aten::zeros
|
||||
aten::zeros
|
||||
aten::empty
|
||||
aten::zero_
|
||||
First Annotation
|
||||
aten::empty
|
||||
aten::ones
|
||||
aten::empty
|
||||
aten::fill_
|
||||
aten::zeros
|
||||
aten::empty
|
||||
aten::zero_
|
||||
aten::zeros
|
||||
aten::empty
|
||||
aten::zero_
|
||||
Second Annotation
|
||||
aten::empty
|
||||
aten::add
|
||||
|
|
@ -340,8 +343,9 @@ class TestProfilerTree(TestCase):
|
|||
aten::empty_strided
|
||||
aten::copy_
|
||||
aten::zeros
|
||||
aten::empty
|
||||
aten::zero_
|
||||
aten::zeros
|
||||
aten::empty
|
||||
aten::zero_
|
||||
Third Annotation
|
||||
aten::empty
|
||||
aten::ones_like
|
||||
|
|
@ -712,7 +716,6 @@ class TestProfilerTree(TestCase):
|
|||
torch/profiler/profiler.py(...): stop
|
||||
...""")
|
||||
|
||||
@unittest.skip("https://github.com/pytorch/pytorch/issues/83606")
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
|
||||
@ProfilerTree.test
|
||||
def test_profiler_experimental_tree_cuda(self):
|
||||
|
|
@ -810,7 +813,6 @@ class TestProfilerTree(TestCase):
|
|||
allow_failure=ALLOW_CUDA_FAILURE,
|
||||
)
|
||||
|
||||
@unittest.skip("https://github.com/pytorch/pytorch/issues/83606")
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
|
||||
@ProfilerTree.test
|
||||
def test_profiler_experimental_tree_cuda_with_stream(self):
|
||||
|
|
|
|||
|
|
@ -762,7 +762,7 @@ class TestSymbolicTracing(TestCase):
|
|||
def forward(self, a_1):
|
||||
sym_size = torch.ops.aten.sym_size(a_1, 0); a_1 = None
|
||||
mul = sym_size * 2; sym_size = None
|
||||
empty = torch.ops.aten.empty.memory_format([mul], device = device(type='cpu'), pin_memory = False); mul = None
|
||||
empty = torch.ops.aten.empty.SymInt([mul], device = device(type='cpu'), pin_memory = False); mul = None
|
||||
sym_size_1 = torch.ops.aten.sym_size(empty, 0)
|
||||
return empty""")
|
||||
|
||||
|
|
|
|||
|
|
@ -622,9 +622,12 @@
|
|||
self: grad * (result + 1)
|
||||
result: auto_element_wise
|
||||
|
||||
# TODO: this derivative is not SymInt safe, need sum_to support
|
||||
- name: expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a)
|
||||
self: at::sum_to(grad, self.sym_sizes())
|
||||
- name: expand(Tensor(a) self, int[] size, *, bool implicit=False) -> Tensor(a)
|
||||
self: at::sum_to(grad, self.sizes())
|
||||
result: auto_linear
|
||||
|
||||
- name: expand.SymInt(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a)
|
||||
self: at::sum_to(grad, c10::asIntArrayRefSlow(self.sym_sizes()))
|
||||
result: auto_linear
|
||||
|
||||
- name: exponential_(Tensor(a!) self, float lambd=1, *, Generator? generator=None) -> Tensor(a!)
|
||||
|
|
@ -1732,11 +1735,16 @@
|
|||
# linear
|
||||
result1: mean(self_t, dim.value_or(IntArrayRef({})), keepdim)
|
||||
|
||||
# TODO: this derivative is not SymInt safe, need reshape_symint
|
||||
- name: view(Tensor(a) self, SymInt[] size) -> Tensor(a)
|
||||
- name: view(Tensor(a) self, int[] size) -> Tensor(a)
|
||||
self: grad.reshape(self.sizes())
|
||||
result: auto_linear
|
||||
|
||||
- name: view.SymInt(Tensor(a) self, SymInt[] size) -> Tensor(a)
|
||||
# TODO: add proper double backward for view.SymInt
|
||||
# by SymIntizing `reshape`
|
||||
self: grad.reshape(c10::asIntArrayRefSlow(self.sym_sizes()))
|
||||
result: auto_linear
|
||||
|
||||
- name: view.dtype(Tensor(a) self, ScalarType dtype) -> Tensor(a)
|
||||
output_differentiability: [False]
|
||||
|
||||
|
|
|
|||
|
|
@ -249,7 +249,6 @@ def unpack_args(f: NativeFunction) -> Tuple[List[str], List[Binding]]:
|
|||
for r in cpp.argument(
|
||||
a,
|
||||
method=False,
|
||||
symint=True,
|
||||
cpp_no_default_args=set(),
|
||||
faithful=False,
|
||||
has_tensor_options=False,
|
||||
|
|
@ -495,7 +494,7 @@ def gen_formals(f: NativeFunction) -> str:
|
|||
# See Note [Plumbing Keys Through The Dispatcher] for details.
|
||||
["c10::DispatchKeySet ks"]
|
||||
+ [
|
||||
f'{cpp.argument_type(a, binds="__placeholder__", symint=True).cpp_type()} {a.name}'
|
||||
f'{cpp.argument_type(a, binds="__placeholder__").cpp_type()} {a.name}'
|
||||
for a in f.func.schema_order_arguments()
|
||||
]
|
||||
)
|
||||
|
|
@ -515,7 +514,7 @@ def inplace_or_view_method_definition(
|
|||
):
|
||||
return None
|
||||
return METHOD_DEFINITION.substitute(
|
||||
return_type=cpp.returns_type(f.func.returns, symint=True).cpp_type(),
|
||||
return_type=cpp.returns_type(f.func.returns).cpp_type(),
|
||||
type_wrapper_name=type_wrapper_name(f),
|
||||
formals=gen_formals(f),
|
||||
type_definition_body=emit_inplace_or_view_body(fn),
|
||||
|
|
|
|||
|
|
@ -238,8 +238,6 @@ def gen(
|
|||
tags_yaml_path: str,
|
||||
deprecated_yaml_path: str,
|
||||
template_path: str,
|
||||
*,
|
||||
symint: bool = True,
|
||||
) -> None:
|
||||
fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
|
||||
native_functions = parse_native_yaml(
|
||||
|
|
@ -255,7 +253,6 @@ def gen(
|
|||
None,
|
||||
"python_variable_methods.cpp",
|
||||
method=True,
|
||||
symint=symint,
|
||||
)
|
||||
|
||||
# NOTE: num_shards here must be synced with gatherTorchFunctions in
|
||||
|
|
@ -269,7 +266,6 @@ def gen(
|
|||
"python_torch_functions.cpp",
|
||||
method=False,
|
||||
num_shards=3,
|
||||
symint=symint,
|
||||
)
|
||||
|
||||
create_python_bindings(
|
||||
|
|
@ -279,7 +275,6 @@ def gen(
|
|||
"torch.nn",
|
||||
"python_nn_functions.cpp",
|
||||
method=False,
|
||||
symint=symint,
|
||||
)
|
||||
|
||||
create_python_bindings(
|
||||
|
|
@ -289,7 +284,6 @@ def gen(
|
|||
"torch.fft",
|
||||
"python_fft_functions.cpp",
|
||||
method=False,
|
||||
symint=symint,
|
||||
)
|
||||
|
||||
create_python_bindings(
|
||||
|
|
@ -299,7 +293,6 @@ def gen(
|
|||
"torch.linalg",
|
||||
"python_linalg_functions.cpp",
|
||||
method=False,
|
||||
symint=symint,
|
||||
)
|
||||
|
||||
create_python_bindings(
|
||||
|
|
@ -309,7 +302,6 @@ def gen(
|
|||
"torch.sparse",
|
||||
"python_sparse_functions.cpp",
|
||||
method=False,
|
||||
symint=symint,
|
||||
)
|
||||
|
||||
create_python_bindings(
|
||||
|
|
@ -319,7 +311,6 @@ def gen(
|
|||
"torch.special",
|
||||
"python_special_functions.cpp",
|
||||
method=False,
|
||||
symint=symint,
|
||||
)
|
||||
|
||||
# Currently, we only use `functions` to generate `return_types` bindings.
|
||||
|
|
@ -363,7 +354,6 @@ def create_python_bindings(
|
|||
filename: str,
|
||||
*,
|
||||
method: bool,
|
||||
symint: bool = True,
|
||||
) -> None:
|
||||
"""Generates Python bindings to ATen functions"""
|
||||
py_methods: List[str] = []
|
||||
|
|
@ -375,9 +365,7 @@ def create_python_bindings(
|
|||
|
||||
for name in sorted(grouped.keys(), key=lambda x: str(x)):
|
||||
overloads = grouped[name]
|
||||
py_methods.append(
|
||||
method_impl(name, module, overloads, method=method, symint=symint)
|
||||
)
|
||||
py_methods.append(method_impl(name, module, overloads, method=method))
|
||||
py_method_defs.append(method_def(name, module, overloads, method=method))
|
||||
py_forwards.extend(forward_decls(name, overloads, method=method))
|
||||
ops_headers.append(f"#include <ATen/ops/{name.base}.h>")
|
||||
|
|
@ -440,7 +428,6 @@ def create_python_bindings_sharded(
|
|||
*,
|
||||
method: bool,
|
||||
num_shards: int,
|
||||
symint: bool = True,
|
||||
) -> None:
|
||||
"""Generates Python bindings to ATen functions"""
|
||||
grouped = group_filter_overloads(pairs, pred)
|
||||
|
|
@ -457,9 +444,7 @@ def create_python_bindings_sharded(
|
|||
return {
|
||||
"ops_headers": [f"#include <ATen/ops/{name.base}.h>"],
|
||||
"py_forwards": list(forward_decls(name, fn_pairs, method=method)),
|
||||
"py_methods": [
|
||||
method_impl(name, module, fn_pairs, method=method, symint=symint)
|
||||
],
|
||||
"py_methods": [method_impl(name, module, fn_pairs, method=method)],
|
||||
"py_method_defs": [method_def(name, module, fn_pairs, method=method)],
|
||||
}
|
||||
|
||||
|
|
@ -788,7 +773,6 @@ def method_impl(
|
|||
overloads: Sequence[PythonSignatureNativeFunctionPair],
|
||||
*,
|
||||
method: bool,
|
||||
symint: bool = True,
|
||||
) -> str:
|
||||
"""
|
||||
Generate a python binding for all overloads of an op.
|
||||
|
|
@ -807,18 +791,14 @@ def method_impl(
|
|||
|
||||
traceable = "true" if all(should_trace(o.function) for o in overloads) else "false"
|
||||
|
||||
grouped_overloads: Sequence[PythonSignatureGroup] = group_overloads(
|
||||
overloads, symint=symint
|
||||
)
|
||||
grouped_overloads: Sequence[PythonSignatureGroup] = group_overloads(overloads)
|
||||
is_singleton = len(grouped_overloads) == 1
|
||||
signatures: List[str] = []
|
||||
dispatch: List[str] = []
|
||||
for overload_index, overload in enumerate(grouped_overloads):
|
||||
signature = overload.signature.signature_str(symint=symint)
|
||||
signature = overload.signature.signature_str()
|
||||
signatures.append(f"{cpp_string(str(signature))},")
|
||||
dispatch_body = emit_dispatch_case(
|
||||
overload, namedtuple_typenames, symint=symint
|
||||
)
|
||||
dispatch_body = emit_dispatch_case(overload, namedtuple_typenames)
|
||||
dispatch.append(
|
||||
PY_VARIABLE_CASE.substitute(
|
||||
overload_index=overload_index, body=dispatch_body
|
||||
|
|
@ -902,8 +882,6 @@ if (_r.isNone(${out_idx})) {
|
|||
def emit_dispatch_case(
|
||||
overload: PythonSignatureGroup,
|
||||
namedtuple_typenames: Dict[str, str],
|
||||
*,
|
||||
symint: bool = True,
|
||||
) -> str:
|
||||
"""
|
||||
Emit dispatch code for a single parsed signature. This corresponds to either
|
||||
|
|
@ -916,19 +894,18 @@ def emit_dispatch_case(
|
|||
return PY_VARIABLE_OUT.substitute(
|
||||
out_idx=overload.signature.output_idx(),
|
||||
call_dispatch=emit_single_dispatch(
|
||||
overload.signature, overload.base, namedtuple_typenames, symint=symint
|
||||
overload.signature, overload.base, namedtuple_typenames
|
||||
),
|
||||
call_dispatch_out=emit_single_dispatch(
|
||||
overload.signature,
|
||||
overload.outplace,
|
||||
namedtuple_typenames,
|
||||
symint=symint,
|
||||
),
|
||||
)
|
||||
else:
|
||||
# no-output version only
|
||||
return emit_single_dispatch(
|
||||
overload.signature, overload.base, namedtuple_typenames, symint=symint
|
||||
overload.signature, overload.base, namedtuple_typenames
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -1010,14 +987,14 @@ def method_def(
|
|||
|
||||
|
||||
def group_overloads(
|
||||
overloads: Sequence[PythonSignatureNativeFunctionPair], *, symint: bool = True
|
||||
overloads: Sequence[PythonSignatureNativeFunctionPair],
|
||||
) -> Sequence[PythonSignatureGroup]:
|
||||
bases: Dict[str, PythonSignatureNativeFunctionPair] = {}
|
||||
outplaces: Dict[str, PythonSignatureNativeFunctionPair] = {}
|
||||
|
||||
# first group by signature ignoring out arguments
|
||||
for overload in overloads:
|
||||
sig = overload.signature.signature_str(skip_outputs=True, symint=symint)
|
||||
sig = overload.signature.signature_str(skip_outputs=True)
|
||||
if overload.function.func.is_out_fn():
|
||||
if sig in outplaces:
|
||||
raise RuntimeError(
|
||||
|
|
@ -1044,11 +1021,9 @@ def group_overloads(
|
|||
and not overload.signature.deprecated
|
||||
):
|
||||
candidates.append(
|
||||
overload.signature.signature_str(
|
||||
skip_outputs=True, symint=symint
|
||||
)
|
||||
overload.signature.signature_str(skip_outputs=True)
|
||||
)
|
||||
out_sig = out.signature.signature_str(symint=symint)
|
||||
out_sig = out.signature.signature_str()
|
||||
raise RuntimeError(
|
||||
f"While identifying overloads, we found an out schema {out_sig} without a corresponding non-out variant. "
|
||||
f"We expected the non-out variant to have schema: \n- {sig}\nPlease check that you spelled the schema "
|
||||
|
|
@ -1063,7 +1038,7 @@ def group_overloads(
|
|||
)
|
||||
for sig, base in bases.items()
|
||||
]
|
||||
return sort_overloads(grouped, symint=symint)
|
||||
return sort_overloads(grouped)
|
||||
|
||||
|
||||
# This function declares a partial order on declarations, and sorts them according
|
||||
|
|
@ -1112,7 +1087,7 @@ def group_overloads(
|
|||
|
||||
|
||||
def sort_overloads(
|
||||
grouped_overloads: Sequence[PythonSignatureGroup], *, symint: bool = True
|
||||
grouped_overloads: Sequence[PythonSignatureGroup],
|
||||
) -> Sequence[PythonSignatureGroup]:
|
||||
# NB: Smaller here means lower priority
|
||||
|
||||
|
|
@ -1157,7 +1132,7 @@ def sort_overloads(
|
|||
|
||||
# First sort by signature
|
||||
grouped_overloads = sorted(
|
||||
grouped_overloads, key=lambda x: x.signature.signature_str(symint=symint)
|
||||
grouped_overloads, key=lambda x: x.signature.signature_str()
|
||||
)
|
||||
|
||||
# Construct the relation graph
|
||||
|
|
@ -1195,11 +1170,7 @@ def sort_overloads(
|
|||
|
||||
|
||||
def emit_single_dispatch(
|
||||
ps: PythonSignature,
|
||||
f: NativeFunction,
|
||||
namedtuple_typenames: Dict[str, str],
|
||||
*,
|
||||
symint: bool = True,
|
||||
ps: PythonSignature, f: NativeFunction, namedtuple_typenames: Dict[str, str]
|
||||
) -> str:
|
||||
"""
|
||||
Emit dispatch code for a single native function.
|
||||
|
|
@ -1218,10 +1189,7 @@ def emit_single_dispatch(
|
|||
# dispatch lambda signature
|
||||
name = cpp.name(f.func)
|
||||
lambda_formals = ", ".join(
|
||||
map(
|
||||
lambda a: f"{a.type_str} {a.name}",
|
||||
dispatch_lambda_args(ps, f, symint=symint),
|
||||
)
|
||||
map(lambda a: f"{a.type_str} {a.name}", dispatch_lambda_args(ps, f))
|
||||
)
|
||||
lambda_return = dispatch_lambda_return_str(f)
|
||||
|
||||
|
|
@ -1230,8 +1198,8 @@ def emit_single_dispatch(
|
|||
dispatch_args = ", ".join(cpp_dispatch_exprs(f, python_signature=ps))
|
||||
|
||||
# from arg parser outputs to dispatch lambda arguments
|
||||
parser_outputs = arg_parser_output_exprs(ps, f, symint=symint)
|
||||
lambda_arg_exprs = dispatch_lambda_exprs(ps, f, symint=symint)
|
||||
parser_outputs = arg_parser_output_exprs(ps, f)
|
||||
lambda_arg_exprs = dispatch_lambda_exprs(ps, f)
|
||||
inits = "\n".join(lambda_arg_exprs.inits)
|
||||
lambda_args = ", ".join(lambda_arg_exprs.exprs)
|
||||
|
||||
|
|
|
|||
|
|
@ -383,7 +383,7 @@ def declare_returned_variables(f: NativeFunction) -> str:
|
|||
return ""
|
||||
if len(f.func.returns) == 1:
|
||||
return ""
|
||||
types = [cpp.return_type(r, symint=True) for r in f.func.returns]
|
||||
types = map(cpp.return_type, f.func.returns)
|
||||
names = cpp.return_names(f)
|
||||
return "\n".join(f"{type.cpp_type()} {name};" for type, name in zip(types, names))
|
||||
|
||||
|
|
@ -483,13 +483,13 @@ def method_definition(f: NativeFunction) -> str:
|
|||
# See Note [Plumbing Keys Through The Dispatcher] for details.
|
||||
["c10::DispatchKeySet ks"]
|
||||
+ [
|
||||
f'{cpp.argument_type(a, binds="__placeholder__", symint=True).cpp_type()} {a.name}'
|
||||
f'{cpp.argument_type(a, binds="__placeholder__").cpp_type()} {a.name}'
|
||||
for a in f.func.schema_order_arguments()
|
||||
]
|
||||
)
|
||||
|
||||
return METHOD_DEFINITION.substitute(
|
||||
return_type=cpp.returns_type(f.func.returns, symint=True).cpp_type(),
|
||||
return_type=cpp.returns_type(f.func.returns).cpp_type(),
|
||||
type_wrapper_name=type_wrapper_name(f),
|
||||
formals=formals,
|
||||
type_definition_body=emit_trace_body(f),
|
||||
|
|
|
|||
|
|
@ -76,39 +76,31 @@ def process_function(f: NativeFunction) -> Optional[str]:
|
|||
if Variant.function not in f.variants or not is_factory:
|
||||
return None
|
||||
|
||||
cpp_sigs = CppSignatureGroup.from_native_function(f, method=False)
|
||||
sigs = [cpp_sigs.signature]
|
||||
if cpp_sigs.symint_signature is not None:
|
||||
sigs.append(cpp_sigs.symint_signature)
|
||||
r = ""
|
||||
for sig in sigs:
|
||||
formals: List[str] = []
|
||||
exprs: List[str] = []
|
||||
requires_grad = "false"
|
||||
for arg in sig.arguments():
|
||||
qualified_type = fully_qualified_type(arg.type)
|
||||
if arg.default:
|
||||
formals.append(f"{qualified_type} {arg.name} = {arg.default}")
|
||||
else:
|
||||
formals.append(f"{qualified_type} {arg.name}")
|
||||
sig = CppSignatureGroup.from_native_function(f, method=False).signature
|
||||
formals: List[str] = []
|
||||
exprs: List[str] = []
|
||||
requires_grad = "false"
|
||||
for arg in sig.arguments():
|
||||
qualified_type = fully_qualified_type(arg.type)
|
||||
if arg.default:
|
||||
formals.append(f"{qualified_type} {arg.name} = {arg.default}")
|
||||
else:
|
||||
formals.append(f"{qualified_type} {arg.name}")
|
||||
|
||||
if isinstance(arg.argument, TensorOptionsArguments):
|
||||
# note: we remove the requires_grad setting from the TensorOptions because
|
||||
# it is ignored anyways (and we actually have an assertion that it isn't set
|
||||
# which would fail otherwise). We handle requires_grad explicitly here
|
||||
# instead of passing it through to the kernel.
|
||||
exprs.append(
|
||||
f"at::TensorOptions({arg.name}).requires_grad(c10::nullopt)"
|
||||
)
|
||||
# Manually set the requires_grad bit on the result tensor.
|
||||
requires_grad = f"{arg.name}.requires_grad()"
|
||||
else:
|
||||
exprs.append(arg.name)
|
||||
if isinstance(arg.argument, TensorOptionsArguments):
|
||||
# note: we remove the requires_grad setting from the TensorOptions because
|
||||
# it is ignored anyways (and we actually have an assertion that it isn't set
|
||||
# which would fail otherwise). We handle requires_grad explicitly here
|
||||
# instead of passing it through to the kernel.
|
||||
exprs.append(f"at::TensorOptions({arg.name}).requires_grad(c10::nullopt)")
|
||||
# Manually set the requires_grad bit on the result tensor.
|
||||
requires_grad = f"{arg.name}.requires_grad()"
|
||||
else:
|
||||
exprs.append(arg.name)
|
||||
|
||||
r += f"""\
|
||||
inline at::Tensor {sig.name()}({', '.join(formals)}) {{
|
||||
return f"""\
|
||||
inline at::Tensor {name}({', '.join(formals)}) {{
|
||||
at::AutoDispatchBelowADInplaceOrView guard;
|
||||
return autograd::make_variable(at::{sig.name()}({', '.join(exprs)}), /*requires_grad=*/{requires_grad});
|
||||
return autograd::make_variable(at::{name}({', '.join(exprs)}), /*requires_grad=*/{requires_grad});
|
||||
}}
|
||||
"""
|
||||
return r
|
||||
|
|
|
|||
|
|
@ -849,9 +849,7 @@ def gen_variable_type_func(
|
|||
if not fn.info:
|
||||
key = "Default"
|
||||
type_definition = METHOD_DEFINITION.substitute(
|
||||
return_type=cpp.returns_type(
|
||||
f.func.returns, symint=True
|
||||
).cpp_type(),
|
||||
return_type=cpp.returns_type(f.func.returns).cpp_type(),
|
||||
type_wrapper_name=type_wrapper_name(f, key),
|
||||
type_definition_body=emit_body(fn, key),
|
||||
formals=formals,
|
||||
|
|
@ -862,9 +860,7 @@ def gen_variable_type_func(
|
|||
else:
|
||||
for key, _ in fn.info.items():
|
||||
type_definition = METHOD_DEFINITION.substitute(
|
||||
return_type=cpp.returns_type(
|
||||
f.func.returns, symint=True
|
||||
).cpp_type(),
|
||||
return_type=cpp.returns_type(f.func.returns).cpp_type(),
|
||||
type_wrapper_name=type_wrapper_name(f, key),
|
||||
type_definition_body=emit_body(fn, key),
|
||||
formals=formals,
|
||||
|
|
@ -917,7 +913,7 @@ def emit_body(
|
|||
# TODO: `cpp_type` is only to keep it byte-for-byte compatible with the old codegen, should remove.
|
||||
# NB: This is not a clone of cpp.argument() - TensorOptionsArguments / faithful / binds are
|
||||
# not handled properly as they are irrelevant for this codegen.
|
||||
cpp_type = cpp.argument_type(a, binds=a.name, symint=True).cpp_type()
|
||||
cpp_type = cpp.argument_type(a, binds=a.name).cpp_type()
|
||||
|
||||
if not is_differentiable(a.name, a.type, info):
|
||||
return None
|
||||
|
|
@ -1208,7 +1204,6 @@ def emit_body(
|
|||
api_name=cpp.name(
|
||||
f.func,
|
||||
faithful_name_for_out_overloads=True,
|
||||
symint_overload=f.func.has_symint(),
|
||||
),
|
||||
unpacked_args=[dispatch_key_set] + list(unpacked_args),
|
||||
)
|
||||
|
|
@ -1290,7 +1285,7 @@ def emit_body(
|
|||
for i, (ret, ret_name) in enumerate(
|
||||
zip(f.func.returns, cpp.return_names(f))
|
||||
):
|
||||
noref_cpp_type = cpp.return_type(ret, symint=True).remove_const_ref()
|
||||
noref_cpp_type = cpp.return_type(ret).remove_const_ref()
|
||||
if noref_cpp_type == BaseCType(tensorT):
|
||||
if aliased_arg_name is not None:
|
||||
assert (
|
||||
|
|
|
|||
|
|
@ -184,9 +184,7 @@ def create_derivative(
|
|||
]
|
||||
|
||||
return_names = tuple(n if n != "self" else "result" for n in cpp.return_names(f))
|
||||
return_types = tuple(
|
||||
cpp.return_type(r, symint=True).remove_const_ref() for r in f.func.returns
|
||||
)
|
||||
return_types = tuple(cpp.return_type(r).remove_const_ref() for r in f.func.returns)
|
||||
|
||||
named_returns = [
|
||||
NamedCType(name, type) for name, type in zip(return_names, return_types)
|
||||
|
|
@ -377,7 +375,7 @@ def postprocess_forward_derivatives(
|
|||
new_args.append(arg_name)
|
||||
|
||||
# TODO we are trolling
|
||||
if f.func.has_symint():
|
||||
if f.func.is_symint_fn():
|
||||
defn_name += "_symint"
|
||||
|
||||
# Call into the forward again. We need two cases here to handle both Tensor methods and at:: functions.
|
||||
|
|
|
|||
|
|
@ -314,7 +314,6 @@ class TestGenNativeFunctionDeclaration(unittest.TestCase):
|
|||
dispatch_key=k,
|
||||
use_out_as_primary=True,
|
||||
external=False,
|
||||
symint=False,
|
||||
device_guard=False,
|
||||
index=backend_indices[k],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -612,7 +612,7 @@ class FakeTensorMode(TorchDispatchMode):
|
|||
torch._C._remove_meta_from_tls_dispatch_include()
|
||||
|
||||
if has_symbolic_sizes:
|
||||
constructors = [aten.empty.memory_format]
|
||||
constructors = [aten.empty.SymInt]
|
||||
if func not in constructors:
|
||||
raise RuntimeError(
|
||||
f"{func} - couldn't find symbolic meta function/decomposition"
|
||||
|
|
|
|||
|
|
@ -843,7 +843,8 @@ RegisterOperators reg_expand_copy({
|
|||
"alias ops, should be restored after fusion pass!");
|
||||
IValue self, size, implicit;
|
||||
pop(stack, self, size, implicit);
|
||||
push(stack, self.toTensor().expand(size.toIntVector()));
|
||||
push(
|
||||
stack, at::native::expand(self.toTensor(), size.toIntVector()));
|
||||
};
|
||||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
|
|
|
|||
|
|
@ -411,8 +411,7 @@ SchemaTypeParser::parseFakeAndRealType() {
|
|||
real_value = parseBaseType();
|
||||
if (real_value->kind() == ScalarTypeType::Kind ||
|
||||
real_value->kind() == MemoryFormatType::Kind ||
|
||||
real_value->kind() == LayoutType::Kind ||
|
||||
real_value->kind() == SymIntType::Kind) {
|
||||
real_value->kind() == LayoutType::Kind) {
|
||||
fake_value = c10::TypeFactory::get<IntType>();
|
||||
} else {
|
||||
fake_value = real_value;
|
||||
|
|
|
|||
|
|
@ -80,57 +80,31 @@ IValue toIValue(py::handle obj, const TypePtr& type, c10::optional<int32_t> N) {
|
|||
auto c_obj = py::cast<std::complex<double>>(obj.ptr());
|
||||
return static_cast<c10::complex<double>>(c_obj);
|
||||
}
|
||||
case TypeKind::IntType:
|
||||
// TODO: Properly fake this type
|
||||
if (THPQScheme_Check(obj.ptr())) {
|
||||
auto qscheme = reinterpret_cast<THPQScheme*>(obj.ptr());
|
||||
return static_cast<uint8_t>(qscheme->qscheme);
|
||||
}
|
||||
// For backwards compatibility
|
||||
if (THPDtype_Check(obj.ptr())) {
|
||||
auto dtype = reinterpret_cast<THPDtype*>(obj.ptr());
|
||||
return static_cast<int64_t>(dtype->scalar_type);
|
||||
}
|
||||
if (THPQScheme_Check(obj.ptr())) {
|
||||
auto qscheme = reinterpret_cast<THPQScheme*>(obj.ptr());
|
||||
return static_cast<uint8_t>(qscheme->qscheme);
|
||||
}
|
||||
if (THPLayout_Check(obj.ptr())) {
|
||||
auto layout = reinterpret_cast<THPLayout*>(obj.ptr());
|
||||
return static_cast<int8_t>(layout->layout);
|
||||
}
|
||||
if (THPMemoryFormat_Check(obj.ptr())) {
|
||||
auto memory_format = reinterpret_cast<THPMemoryFormat*>(obj.ptr());
|
||||
return static_cast<int8_t>(memory_format->memory_format);
|
||||
}
|
||||
return py::cast<int64_t>(obj);
|
||||
case TypeKind::LayoutType: {
|
||||
if (THPLayout_Check(obj.ptr())) {
|
||||
auto layout = reinterpret_cast<THPLayout*>(obj.ptr());
|
||||
return static_cast<int8_t>(layout->layout);
|
||||
}
|
||||
// For backwards compatibility
|
||||
return py::cast<int64_t>(obj);
|
||||
}
|
||||
case TypeKind::ScalarTypeType: {
|
||||
if (THPDtype_Check(obj.ptr())) {
|
||||
auto dtype = reinterpret_cast<THPDtype*>(obj.ptr());
|
||||
return static_cast<int64_t>(dtype->scalar_type);
|
||||
}
|
||||
// For backwards compatibility
|
||||
return py::cast<int64_t>(obj);
|
||||
}
|
||||
case TypeKind::MemoryFormatType: {
|
||||
if (THPMemoryFormat_Check(obj.ptr())) {
|
||||
auto memory_format = reinterpret_cast<THPMemoryFormat*>(obj.ptr());
|
||||
return static_cast<int8_t>(memory_format->memory_format);
|
||||
}
|
||||
// For backwards compatibility
|
||||
return py::cast<int64_t>(obj);
|
||||
}
|
||||
case TypeKind::SymIntType:
|
||||
if (torch::is_symint_node(obj.ptr())) {
|
||||
return py::cast<c10::SymInt>(obj);
|
||||
return py::cast<c10::SymInt>(obj);
|
||||
case TypeKind::IntType:
|
||||
// NB: Typically, these switches are completely dead, because
|
||||
// Argument::type() will always report IntType for these types.
|
||||
// So this is a bit overly permissive: we'll accept a dtype
|
||||
// passed to an int argument, for example.
|
||||
case TypeKind::LayoutType:
|
||||
case TypeKind::ScalarTypeType:
|
||||
case TypeKind::MemoryFormatType:
|
||||
if (THPDtype_Check(obj.ptr())) {
|
||||
auto dtype = reinterpret_cast<THPDtype*>(obj.ptr());
|
||||
return static_cast<int64_t>(dtype->scalar_type);
|
||||
}
|
||||
if (THPQScheme_Check(obj.ptr())) {
|
||||
auto qscheme = reinterpret_cast<THPQScheme*>(obj.ptr());
|
||||
return static_cast<uint8_t>(qscheme->qscheme);
|
||||
}
|
||||
if (THPLayout_Check(obj.ptr())) {
|
||||
auto layout = reinterpret_cast<THPLayout*>(obj.ptr());
|
||||
return static_cast<int8_t>(layout->layout);
|
||||
}
|
||||
if (THPMemoryFormat_Check(obj.ptr())) {
|
||||
auto memory_format = reinterpret_cast<THPMemoryFormat*>(obj.ptr());
|
||||
return static_cast<int8_t>(memory_format->memory_format);
|
||||
}
|
||||
return py::cast<int64_t>(obj);
|
||||
case TypeKind::NoneType:
|
||||
|
|
|
|||
|
|
@ -646,7 +646,7 @@ inline IValue argumentToIValue(
|
|||
py::handle object) {
|
||||
const auto& argument = schema.arguments().at(argumentPosition);
|
||||
try {
|
||||
return toIValue(object, argument.real_type(), argument.N());
|
||||
return toIValue(object, argument.type(), argument.N());
|
||||
} catch (const py::cast_error& error) {
|
||||
throw schema_match_error(c10::str(
|
||||
schema.formatTypeMismatchMsg(
|
||||
|
|
|
|||
|
|
@ -39,8 +39,6 @@
|
|||
#include <mutex>
|
||||
#include <unordered_map>
|
||||
|
||||
#include <ATen/CompositeExplicitAutogradFunctions.h>
|
||||
|
||||
C10_DEFINE_bool(
|
||||
static_runtime_enable_fast_math,
|
||||
true,
|
||||
|
|
@ -77,7 +75,7 @@ void repeat_out(at::Tensor& result, const Tensor& self, IntArrayRef repeats) {
|
|||
return;
|
||||
}
|
||||
|
||||
Tensor xtensor = at::compositeexplicitautograd::expand(self, padded_size);
|
||||
Tensor xtensor = at::native::expand(self, padded_size);
|
||||
Tensor urtensor = at::native::alias(result);
|
||||
for (const auto i : c10::irange(xtensor.dim())) {
|
||||
// can't unfold with step 0, so make sure step is at least 1
|
||||
|
|
@ -2528,13 +2526,12 @@ REGISTER_OPERATOR_FUNCTOR(aten::zeros, aten_zeros, [](Node* n) -> SROperator {
|
|||
const auto dtype = p_node->Input(1).toOptional<c10::ScalarType>();
|
||||
const auto layout = p_node->Input(2).toOptional<c10::Layout>();
|
||||
if (!hasTensorWithOptions(p_node->Output(0), dtype, layout)) {
|
||||
p_node->Output(0) = at::compositeexplicitautograd::zeros(
|
||||
size, dtype, layout, c10::nullopt, c10::nullopt);
|
||||
p_node->Output(0) = at::native::zeros(size, dtype, layout);
|
||||
return;
|
||||
}
|
||||
auto& out_t = p_node->Output(0).toTensor();
|
||||
fastResizeToZero(out_t);
|
||||
at::compositeexplicitautograd::zeros_out(out_t, size);
|
||||
at::native::zeros_out(size, out_t);
|
||||
};
|
||||
});
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@
|
|||
#include <torch/csrc/lazy/backend/backend_interface.h>
|
||||
#include <torch/csrc/lazy/core/config.h>
|
||||
#include <torch/csrc/lazy/core/ir.h>
|
||||
#include <torch/csrc/lazy/core/tensor.h>
|
||||
#include <torch/csrc/lazy/core/trie.h>
|
||||
#include <vector>
|
||||
|
||||
|
|
@ -267,21 +266,5 @@ static inline NodePtr MakeSizeDiv(const Value& a, const Value& b) {
|
|||
return getIrBuilder()->MakeSizeDiv(a, b);
|
||||
}
|
||||
|
||||
inline Value GetSymIntValue(c10::SymInt a) {
|
||||
return Value(
|
||||
dynamic_cast<torch::lazy::SymIntNodeImpl*>(a.toSymIntNodeImpl().get())
|
||||
->node_,
|
||||
0);
|
||||
}
|
||||
|
||||
// TODO: this should return Value
|
||||
inline std::vector<int64_t> GetSymIntArrayRefValue(c10::SymIntArrayRef arr) {
|
||||
std::vector<int64_t> r;
|
||||
for (const auto& a : arr) {
|
||||
r.emplace_back(a.expect_int());
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -269,15 +269,30 @@ at::Tensor LazyNativeFunctions::_to_copy(
|
|||
}
|
||||
};
|
||||
|
||||
at::Tensor LazyNativeFunctions::empty(
|
||||
at::SymIntArrayRef sym_size,
|
||||
at::Tensor LazyNativeFunctions::empty_symint(
|
||||
c10::SymIntArrayRef size,
|
||||
c10::optional<at::ScalarType> dtype,
|
||||
c10::optional<at::Layout> layout,
|
||||
c10::optional<at::Device> device,
|
||||
c10::optional<bool> pin_memory,
|
||||
c10::optional<at::MemoryFormat> memory_format) {
|
||||
// TODO: support SymIntNodes as well
|
||||
return empty(
|
||||
c10::asIntArrayRefSlow(size),
|
||||
dtype,
|
||||
layout,
|
||||
device,
|
||||
pin_memory,
|
||||
memory_format);
|
||||
}
|
||||
|
||||
at::Tensor LazyNativeFunctions::empty(
|
||||
at::IntArrayRef size,
|
||||
c10::optional<at::ScalarType> dtype,
|
||||
c10::optional<at::Layout> layout,
|
||||
c10::optional<at::Device> device,
|
||||
c10::optional<bool> pin_memory,
|
||||
c10::optional<at::MemoryFormat> memory_format) {
|
||||
// TODO: support this directly
|
||||
auto size = c10::asIntArrayRefSlow(sym_size);
|
||||
const auto device_type = torch::lazy::getBackend()->EagerFallbackDeviceType();
|
||||
at::TensorOptions options = at::TensorOptions()
|
||||
.device(c10::Device(device_type))
|
||||
|
|
@ -307,13 +322,7 @@ at::Tensor LazyNativeFunctions::empty_strided(
|
|||
c10::optional<at::Device> device,
|
||||
c10::optional<bool> pin_memory) {
|
||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||
at::Tensor t = empty(
|
||||
c10::SymIntArrayRef::fromIntArrayRef(size),
|
||||
dtype,
|
||||
layout,
|
||||
device,
|
||||
pin_memory,
|
||||
c10::nullopt);
|
||||
at::Tensor t = empty(size, dtype, layout, device, pin_memory, c10::nullopt);
|
||||
return t.as_strided(size, stride, /*storage_offset=*/0);
|
||||
}
|
||||
|
||||
|
|
@ -409,8 +418,7 @@ at::Tensor LazyNativeFunctions::_unsafe_view(
|
|||
const at::Tensor& self,
|
||||
at::IntArrayRef size) {
|
||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||
return LazyNativeFunctions::view_copy(
|
||||
self, c10::SymIntArrayRef::fromIntArrayRef(size));
|
||||
return LazyNativeFunctions::view_copy(self, size);
|
||||
}
|
||||
|
||||
// This is needed by the torch.tensor constructor.
|
||||
|
|
@ -452,8 +460,8 @@ at::Tensor LazyNativeFunctions::new_empty_strided(
|
|||
at::Tensor LazyNativeFunctions::narrow_copy(
|
||||
const at::Tensor& self,
|
||||
int64_t dim,
|
||||
c10::SymInt start,
|
||||
c10::SymInt length) {
|
||||
int64_t start,
|
||||
int64_t length) {
|
||||
return at::functionalization::functionalize_aten_op<ATEN_OP(
|
||||
narrow_copy)>::call(self, dim, start, length);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -538,9 +538,7 @@ def gen_differentiable_outputs(
|
|||
info = fn.info[key] if fn.info else None
|
||||
outputs: List[DifferentiableOutput] = [
|
||||
DifferentiableOutput(
|
||||
name=name,
|
||||
type=ret.type,
|
||||
cpp_type=cpp.return_type(ret, symint=True).cpp_type(),
|
||||
name=name, type=ret.type, cpp_type=cpp.return_type(ret).cpp_type()
|
||||
)
|
||||
for name, ret in zip(cpp.return_names(f), f.func.returns)
|
||||
]
|
||||
|
|
|
|||
|
|
@ -64,14 +64,9 @@ from torchgen.utils import assert_never
|
|||
# collisions, but functions are fair game to collide
|
||||
|
||||
|
||||
def name(
|
||||
func: FunctionSchema,
|
||||
*,
|
||||
faithful_name_for_out_overloads: bool = False,
|
||||
symint_overload: bool = False,
|
||||
) -> str:
|
||||
def name(func: FunctionSchema, *, faithful_name_for_out_overloads: bool = False) -> str:
|
||||
name = str(func.name.name)
|
||||
if symint_overload:
|
||||
if func.is_symint_fn():
|
||||
name += "_symint"
|
||||
if func.is_out_fn():
|
||||
if faithful_name_for_out_overloads:
|
||||
|
|
@ -86,20 +81,11 @@ def name(
|
|||
# types look the same no matter if they are argument types or return
|
||||
# types. Returns None if the type in question is not a value type.
|
||||
def valuetype_type(
|
||||
t: Type,
|
||||
*,
|
||||
binds: ArgName,
|
||||
remove_non_owning_ref_types: bool = False,
|
||||
symint: bool = False,
|
||||
t: Type, *, binds: ArgName, remove_non_owning_ref_types: bool = False
|
||||
) -> Optional[NamedCType]:
|
||||
if isinstance(t, BaseType):
|
||||
if t.name == BaseTy.Tensor or t.name == BaseTy.Scalar:
|
||||
return None
|
||||
elif str(t) == "SymInt":
|
||||
if symint:
|
||||
return NamedCType(binds, BaseCType(SymIntT))
|
||||
else:
|
||||
return NamedCType(binds, BaseCType(longT))
|
||||
if remove_non_owning_ref_types:
|
||||
if t.name == BaseTy.str:
|
||||
raise AssertionError(
|
||||
|
|
@ -108,7 +94,7 @@ def valuetype_type(
|
|||
# All other BaseType currently map directly to BaseCppTypes.
|
||||
return NamedCType(binds, BaseCType(BaseTypeToCppMapping[t.name]))
|
||||
elif isinstance(t, OptionalType):
|
||||
elem = valuetype_type(t.elem, binds=binds, symint=symint)
|
||||
elem = valuetype_type(t.elem, binds=binds)
|
||||
if elem is None:
|
||||
return None
|
||||
return NamedCType(binds, OptionalCType(elem.type))
|
||||
|
|
@ -127,19 +113,11 @@ def valuetype_type(
|
|||
# For example, we'll return std::vector<int> instead of IntArrayRef.
|
||||
# See Note [translation from C++ reference to value types]
|
||||
def argumenttype_type(
|
||||
t: Type,
|
||||
*,
|
||||
mutable: bool,
|
||||
binds: ArgName,
|
||||
remove_non_owning_ref_types: bool = False,
|
||||
symint: bool = False,
|
||||
t: Type, *, mutable: bool, binds: ArgName, remove_non_owning_ref_types: bool = False
|
||||
) -> NamedCType:
|
||||
# If it's a value type, do the value type translation
|
||||
r = valuetype_type(
|
||||
t,
|
||||
binds=binds,
|
||||
symint=symint,
|
||||
remove_non_owning_ref_types=remove_non_owning_ref_types,
|
||||
t, binds=binds, remove_non_owning_ref_types=remove_non_owning_ref_types
|
||||
)
|
||||
if r is not None:
|
||||
return r
|
||||
|
|
@ -168,7 +146,7 @@ def argumenttype_type(
|
|||
return NamedCType(binds, ConstRefCType(OptionalCType(BaseCType(scalarT))))
|
||||
elif isinstance(t.elem, ListType) and str(t.elem.elem) == "int":
|
||||
return NamedCType(binds, BaseCType(optionalIntArrayRefT))
|
||||
elem = argumenttype_type(t.elem, mutable=mutable, binds=binds, symint=symint)
|
||||
elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
|
||||
return NamedCType(binds, OptionalCType(elem.type))
|
||||
elif isinstance(t, ListType):
|
||||
# TODO: remove these special cases, ArrayRef fallthrough works fine
|
||||
|
|
@ -179,16 +157,10 @@ def argumenttype_type(
|
|||
return NamedCType(binds, BaseCType(intArrayRefT))
|
||||
if str(t.elem) == "SymInt":
|
||||
if remove_non_owning_ref_types:
|
||||
if symint:
|
||||
return NamedCType(binds, VectorCType(BaseCType(SymIntT)))
|
||||
else:
|
||||
return NamedCType(binds, VectorCType(BaseCType(longT)))
|
||||
return NamedCType(binds, VectorCType(BaseCType(SymIntT)))
|
||||
else:
|
||||
if symint:
|
||||
return NamedCType(binds, BaseCType(symIntArrayRefT))
|
||||
else:
|
||||
return NamedCType(binds, BaseCType(intArrayRefT))
|
||||
if str(t.elem) == "Tensor":
|
||||
return NamedCType(binds, BaseCType(symIntArrayRefT))
|
||||
elif str(t.elem) == "Tensor":
|
||||
return NamedCType(binds, BaseCType(tensorListT))
|
||||
elif str(t.elem) == "Scalar":
|
||||
return NamedCType(binds, ArrayRefCType(BaseCType(scalarT)))
|
||||
|
|
@ -198,15 +170,15 @@ def argumenttype_type(
|
|||
return NamedCType(
|
||||
binds, ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT))))
|
||||
)
|
||||
elem = argumenttype_type(t.elem, mutable=mutable, binds=binds, symint=symint)
|
||||
elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
|
||||
return NamedCType(binds, ArrayRefCType(elem.type))
|
||||
else:
|
||||
raise AssertionError(f"unrecognized type {repr(t)}")
|
||||
|
||||
|
||||
# Translate a JIT argument into its C++ type
|
||||
def argument_type(a: Argument, *, binds: ArgName, symint: bool = False) -> NamedCType:
|
||||
return argumenttype_type(a.type, mutable=a.is_write, symint=symint, binds=binds)
|
||||
def argument_type(a: Argument, *, binds: ArgName) -> NamedCType:
|
||||
return argumenttype_type(a.type, mutable=a.is_write, binds=binds)
|
||||
|
||||
|
||||
# Translation of a (non-multi) return type from JIT to C++
|
||||
|
|
@ -214,9 +186,9 @@ def argument_type(a: Argument, *, binds: ArgName, symint: bool = False) -> Named
|
|||
# This is mostly because of the mismatch between return types and return names.
|
||||
# e.g. a function with a return type of 'void' has 0 return names,
|
||||
# and a function with a return type of 'std::tuple' has >1 return name.
|
||||
def returntype_type(t: Type, *, mutable: bool, symint: bool = False) -> CType:
|
||||
def returntype_type(t: Type, *, mutable: bool) -> CType:
|
||||
# placeholder is ignored
|
||||
r = valuetype_type(t, binds="__placeholder__", symint=symint)
|
||||
r = valuetype_type(t, binds="__placeholder__")
|
||||
if r is not None:
|
||||
return r.type
|
||||
|
||||
|
|
@ -239,7 +211,7 @@ def returntype_type(t: Type, *, mutable: bool, symint: bool = False) -> CType:
|
|||
assert (
|
||||
not mutable
|
||||
), "Native functions should never return a mutable tensor list. They should return void."
|
||||
elem = returntype_type(t.elem, mutable=False, symint=symint)
|
||||
elem = returntype_type(t.elem, mutable=False)
|
||||
assert t.size is None, f"fixed size list returns not supported: {t}"
|
||||
return VectorCType(elem)
|
||||
|
||||
|
|
@ -247,18 +219,18 @@ def returntype_type(t: Type, *, mutable: bool, symint: bool = False) -> CType:
|
|||
|
||||
|
||||
# Translation of a single return to its C++ type
|
||||
def return_type(r: Return, *, symint: bool = False) -> CType:
|
||||
return returntype_type(r.type, mutable=r.is_write, symint=symint)
|
||||
def return_type(r: Return) -> CType:
|
||||
return returntype_type(r.type, mutable=r.is_write)
|
||||
|
||||
|
||||
# Translation of a full (possibly multi) return from JIT to its C++ type
|
||||
def returns_type(rs: Sequence[Return], *, symint: bool = False) -> CType:
|
||||
def returns_type(rs: Sequence[Return]) -> CType:
|
||||
if len(rs) == 0:
|
||||
return BaseCType(voidT)
|
||||
elif len(rs) == 1:
|
||||
return return_type(rs[0], symint=symint)
|
||||
return return_type(rs[0])
|
||||
else:
|
||||
return TupleCType([return_type(r, symint=symint) for r in rs])
|
||||
return TupleCType([return_type(r) for r in rs])
|
||||
|
||||
|
||||
def return_names(f: NativeFunction, *, fallback_name: str = "result") -> Sequence[str]:
|
||||
|
|
@ -353,7 +325,6 @@ def argument(
|
|||
cpp_no_default_args: Set[str],
|
||||
method: bool,
|
||||
faithful: bool,
|
||||
symint: bool = False,
|
||||
has_tensor_options: bool,
|
||||
) -> List[Binding]:
|
||||
def sub_argument(
|
||||
|
|
@ -364,7 +335,6 @@ def argument(
|
|||
cpp_no_default_args=cpp_no_default_args,
|
||||
method=method,
|
||||
faithful=faithful,
|
||||
symint=symint,
|
||||
has_tensor_options=has_tensor_options,
|
||||
)
|
||||
|
||||
|
|
@ -379,7 +349,7 @@ def argument(
|
|||
default = default_expr(a.default, a.type)
|
||||
return [
|
||||
Binding(
|
||||
nctype=argument_type(a, binds=binds, symint=symint),
|
||||
nctype=argument_type(a, binds=binds),
|
||||
name=a.name,
|
||||
default=default,
|
||||
argument=a,
|
||||
|
|
@ -420,12 +390,7 @@ def argument(
|
|||
|
||||
|
||||
def arguments(
|
||||
arguments: Arguments,
|
||||
*,
|
||||
faithful: bool,
|
||||
symint: bool = False,
|
||||
method: bool,
|
||||
cpp_no_default_args: Set[str],
|
||||
arguments: Arguments, *, faithful: bool, method: bool, cpp_no_default_args: Set[str]
|
||||
) -> List[Binding]:
|
||||
args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
|
||||
if faithful:
|
||||
|
|
@ -440,7 +405,6 @@ def arguments(
|
|||
for r in argument(
|
||||
a,
|
||||
faithful=faithful,
|
||||
symint=symint,
|
||||
method=method,
|
||||
has_tensor_options=arguments.tensor_options is not None,
|
||||
cpp_no_default_args=cpp_no_default_args,
|
||||
|
|
|
|||
|
|
@ -45,7 +45,6 @@ def argumenttype_type(
|
|||
t,
|
||||
mutable=mutable,
|
||||
binds=binds,
|
||||
symint=True,
|
||||
remove_non_owning_ref_types=remove_non_owning_ref_types,
|
||||
)
|
||||
|
||||
|
|
@ -63,7 +62,7 @@ def argument_type(
|
|||
|
||||
def returns_type(rs: Sequence[Return]) -> CType:
|
||||
# At present, there is no difference. But there could be!
|
||||
return cpp.returns_type(rs, symint=True)
|
||||
return cpp.returns_type(rs)
|
||||
|
||||
|
||||
def jit_arguments(func: FunctionSchema) -> List[Argument]:
|
||||
|
|
|
|||
|
|
@ -37,14 +37,6 @@ from torchgen.model import (
|
|||
_valueT = None
|
||||
|
||||
|
||||
# A ValueT is an IR type which represents the computation of a Tensor. In other
|
||||
# words, a PyTorch user will do operations on lazy tensors, and each output lazy
|
||||
# tensor internally tracks a ValueT representing the IR node that would have
|
||||
# actually produced the value of this tensor for real.
|
||||
#
|
||||
# This is configurable because different lazy tensor backends (LTC vs XLA) will
|
||||
# have different IR representations. (Though, arguably, after unification they
|
||||
# shouldn't!)
|
||||
def getValueT() -> BaseCppType:
|
||||
global _valueT
|
||||
if not _valueT:
|
||||
|
|
@ -121,27 +113,12 @@ def process_ir_type(
|
|||
elif str(typ.elem) == "Tensor":
|
||||
# this is a TensorList which comes in from GetTensorList as a Value
|
||||
return BaseCType(tensorListValueT)
|
||||
elif typ.elem == BaseType(BaseTy.SymInt):
|
||||
# TODO: return a value type. The problem here is analogous to
|
||||
# the problem with tensorListValueT: if you have SymInt[] you
|
||||
# cannot conveniently save the list of Value directly, as nodes
|
||||
# expect to save values as a vector for ALL arguments. So you
|
||||
# need a separate IR node that represents all of the size nodes
|
||||
# assembled into a list. I'm not an LTC dev so I don't want to
|
||||
# figure it out right now. Y'all figure it out...
|
||||
return VectorCType(BaseCType(longT))
|
||||
|
||||
else:
|
||||
return VectorCType(process_ir_type(typ.elem, properties))
|
||||
else:
|
||||
raise AssertionError(f"unrecognized type {repr(typ)}")
|
||||
|
||||
|
||||
# TODO: Determining this based off of CType is bad; this should be computed
|
||||
# from Type directly; then the same logic as process_ir_type can be used
|
||||
#
|
||||
# Invariant: passed typ should be an *owning* CType (e.g., we will report
|
||||
# that ArrayRef<Value> is NOT a value type)
|
||||
def isValueType(typ: CType, properties: "Optional[LazyIrProperties]" = None) -> bool:
|
||||
"""
|
||||
Given a type, determine if it is a Value-like type. This is equivalent to
|
||||
|
|
@ -156,9 +133,6 @@ def isValueType(typ: CType, properties: "Optional[LazyIrProperties]" = None) ->
|
|||
or (typ.type == scalarT and not treat_scalars_as_constants)
|
||||
or typ.type == SymIntT
|
||||
)
|
||||
elif typ == VectorCType(BaseCType(SymIntT)):
|
||||
# TODO: report True for this
|
||||
return False
|
||||
elif isinstance(typ, (OptionalCType, ListCType, VectorCType)):
|
||||
return isValueType(typ.elem, properties)
|
||||
return False
|
||||
|
|
@ -183,7 +157,6 @@ def isWrappedScalarType(typ: Type) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
# TODO: dedupe with Type.is_generator_like
|
||||
def isGeneratorType(typ: Type) -> bool:
|
||||
if isinstance(typ, BaseType):
|
||||
return typ.name == BaseTy.Generator
|
||||
|
|
@ -192,15 +165,12 @@ def isGeneratorType(typ: Type) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
# This class caches a few derived properties computed from an Argument
|
||||
# and LazyIrProperties
|
||||
class LazyArgument:
|
||||
name: str
|
||||
orig_type: Type
|
||||
lazy_type_: Optional[CType]
|
||||
is_wrapped_scalar: bool
|
||||
is_generator: bool
|
||||
# TODO: this is lies, it is false for symint list
|
||||
is_symint_or_list: bool
|
||||
|
||||
# true if this argument is or contains a lazy IR value
|
||||
|
|
@ -222,11 +192,7 @@ class LazyArgument:
|
|||
else:
|
||||
self.lazy_type_ = process_ir_type(arg.type, properties)
|
||||
self.is_wrapped_scalar = isWrappedScalarType(arg.type)
|
||||
self.is_symint_or_list = (
|
||||
isSymIntType(arg.type)
|
||||
# TODO: lists of symints are not currently treated as value types
|
||||
# or (isinstance(arg.type, ListType) and isSymIntType(arg.type.elem))
|
||||
)
|
||||
self.is_symint_or_list = isSymIntType(arg.type)
|
||||
|
||||
self.is_lazy_value = not self.is_generator and isValueType(
|
||||
self.lazy_type, properties
|
||||
|
|
@ -302,8 +268,6 @@ class LazyIrProperties:
|
|||
# Unlike a FunctionSchema, it has no round-trippable string form (relating to the YAML),
|
||||
# but carries type information from a native FunctionSchema modified for use with IR nodes,
|
||||
# and preserving original argument names.
|
||||
#
|
||||
# TODO: This is not idiomatic with how other torchgen APIs transform on schema.
|
||||
class LazyIrSchema:
|
||||
# The name of the operator this function schema describes.
|
||||
name: "OperatorName"
|
||||
|
|
|
|||
|
|
@ -37,9 +37,6 @@ from torchgen.utils import assert_never
|
|||
# native:: kernels. The intention is to make native API and dispatcher API
|
||||
# line up as closely as possible, since this results in the least overhead
|
||||
# (no translation is needed from dispatcher API to native API).
|
||||
#
|
||||
# NB: this is symint aware, you will get the non-SymInt variant for some
|
||||
# dispatch entries and SymInt for others.
|
||||
|
||||
|
||||
def name(func: FunctionSchema) -> str:
|
||||
|
|
@ -52,9 +49,7 @@ def name(func: FunctionSchema) -> str:
|
|||
return name
|
||||
|
||||
|
||||
def argumenttype_type(
|
||||
t: Type, *, mutable: bool, binds: ArgName, symint: bool
|
||||
) -> NamedCType:
|
||||
def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> NamedCType:
|
||||
if str(t) == "Tensor?":
|
||||
tensor_type: OptionalCType = OptionalCType(BaseCType(tensorT))
|
||||
if mutable and not local.use_const_ref_for_mutable_tensors():
|
||||
|
|
@ -69,22 +64,19 @@ def argumenttype_type(
|
|||
return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
|
||||
elif str(t) == "Scalar?":
|
||||
return NamedCType(binds, ConstRefCType(OptionalCType(BaseCType(scalarT))))
|
||||
return cpp.argumenttype_type(t, mutable=mutable, binds=binds, symint=symint)
|
||||
return cpp.argumenttype_type(t, mutable=mutable, binds=binds)
|
||||
|
||||
|
||||
def returns_type(rs: Sequence[Return], *, symint: bool) -> CType:
|
||||
return cpp.returns_type(rs, symint=symint)
|
||||
def returns_type(rs: Sequence[Return]) -> CType:
|
||||
return cpp.returns_type(rs)
|
||||
|
||||
|
||||
def argument_type(a: Argument, *, binds: ArgName, symint: bool) -> NamedCType:
|
||||
return argumenttype_type(a.type, mutable=a.is_write, binds=binds, symint=symint)
|
||||
def argument_type(a: Argument, *, binds: ArgName) -> NamedCType:
|
||||
return argumenttype_type(a.type, mutable=a.is_write, binds=binds)
|
||||
|
||||
|
||||
def argument(
|
||||
a: Union[Argument, SelfArgument, TensorOptionsArguments],
|
||||
*,
|
||||
is_out: bool,
|
||||
symint: bool,
|
||||
a: Union[Argument, SelfArgument, TensorOptionsArguments], *, is_out: bool
|
||||
) -> List[Binding]:
|
||||
# Ideally, we NEVER default native functions. However, there are a number
|
||||
# of functions that call native:: directly and rely on the defaulting
|
||||
|
|
@ -98,7 +90,7 @@ def argument(
|
|||
default = cpp.default_expr(a.default, a.type)
|
||||
return [
|
||||
Binding(
|
||||
nctype=argument_type(a, binds=a.name, symint=symint),
|
||||
nctype=argument_type(a, binds=a.name),
|
||||
name=a.name,
|
||||
default=default,
|
||||
argument=a,
|
||||
|
|
@ -106,7 +98,7 @@ def argument(
|
|||
]
|
||||
elif isinstance(a, SelfArgument):
|
||||
# Erase SelfArgument from the distinction
|
||||
return argument(a.argument, is_out=is_out, symint=symint)
|
||||
return argument(a.argument, is_out=is_out)
|
||||
elif isinstance(a, TensorOptionsArguments):
|
||||
default = None
|
||||
if should_default:
|
||||
|
|
@ -144,10 +136,8 @@ def argument(
|
|||
assert_never(a)
|
||||
|
||||
|
||||
def arguments(func: FunctionSchema, *, symint: bool) -> List[Binding]:
|
||||
def arguments(func: FunctionSchema) -> List[Binding]:
|
||||
args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
|
||||
args.extend(func.arguments.non_out)
|
||||
args.extend(func.arguments.out)
|
||||
return [
|
||||
r for arg in args for r in argument(arg, symint=symint, is_out=func.is_out_fn())
|
||||
]
|
||||
return [r for arg in args for r in argument(arg, is_out=func.is_out_fn())]
|
||||
|
|
|
|||
|
|
@ -216,12 +216,8 @@ class PythonArgument:
|
|||
|
||||
# Compute argument formal for python argument parsing.
|
||||
# Needs to be consistent with torch/csrc/utils/python_arg_parser.h.
|
||||
def argument_str(self, *, method: bool = False, symint: bool = True) -> str:
|
||||
type_str = (
|
||||
argument_type_str(self.type, symint=symint)
|
||||
.replace("const ", "")
|
||||
.replace(" &", "")
|
||||
)
|
||||
def argument_str(self, *, method: bool = False) -> str:
|
||||
type_str = argument_type_str(self.type).replace("const ", "").replace(" &", "")
|
||||
|
||||
name = self.name
|
||||
# s/self/input/ outside method bindings
|
||||
|
|
@ -388,10 +384,10 @@ class PythonSignature:
|
|||
#
|
||||
# For a translation to mypy-valid type signatures, see
|
||||
# signature_str_pyi().
|
||||
def signature_str(self, *, skip_outputs: bool = False, symint: bool = True) -> str:
|
||||
def signature_str(self, *, skip_outputs: bool = False) -> str:
|
||||
args = self.arguments(skip_outputs=skip_outputs)
|
||||
schema_formals: List[str] = list(
|
||||
map(lambda a: a.argument_str(method=self.method, symint=symint), args)
|
||||
map(lambda a: a.argument_str(method=self.method), args)
|
||||
)
|
||||
positional_argc = len(self.input_args)
|
||||
if len(schema_formals) > positional_argc:
|
||||
|
|
@ -430,7 +426,7 @@ class PythonSignature:
|
|||
vararg_type = args[0].type
|
||||
if (
|
||||
isinstance(vararg_type, ListType)
|
||||
and str(vararg_type.elem) in ["int", "SymInt"]
|
||||
and str(vararg_type.elem) == "int"
|
||||
and num_positionalargs == 1
|
||||
):
|
||||
have_vararg_version = True
|
||||
|
|
@ -468,11 +464,9 @@ class PythonSignatureDeprecated(PythonSignature):
|
|||
def deprecated(self) -> bool:
|
||||
return True
|
||||
|
||||
def signature_str(self, *, skip_outputs: bool = False, symint: bool = True) -> str:
|
||||
def signature_str(self, *, skip_outputs: bool = False) -> str:
|
||||
return (
|
||||
PythonSignature.signature_str(
|
||||
self, skip_outputs=skip_outputs, symint=symint
|
||||
)
|
||||
PythonSignature.signature_str(self, skip_outputs=skip_outputs)
|
||||
+ "|deprecated"
|
||||
)
|
||||
|
||||
|
|
@ -639,9 +633,7 @@ def has_tensor_options(f: NativeFunction) -> bool:
|
|||
# 'simple_type' was introduced by the old codegen, which is slightly
|
||||
# different from the python schema type, e.g.: doesn't have '?' suffix
|
||||
# for optional Tensor/TensorList; doesn't have '[size]' suffix for list type.
|
||||
def argument_type_str(
|
||||
t: Type, *, simple_type: bool = False, symint: bool = True
|
||||
) -> str:
|
||||
def argument_type_str(t: Type, *, simple_type: bool = False) -> str:
|
||||
if isinstance(t, BaseType):
|
||||
if t.name == BaseTy.Tensor:
|
||||
return "Tensor"
|
||||
|
|
@ -673,7 +665,7 @@ def argument_type_str(
|
|||
if str(t.elem) == "Tensor":
|
||||
# Is it desired to keep '?' for simple_type with new style dispatcher?
|
||||
return "Tensor?"
|
||||
elem = argument_type_str(t.elem, simple_type=simple_type, symint=symint)
|
||||
elem = argument_type_str(t.elem, simple_type=simple_type)
|
||||
return f"{elem}?"
|
||||
elif isinstance(t, ListType):
|
||||
size = t.size if not simple_type else None
|
||||
|
|
@ -683,12 +675,7 @@ def argument_type_str(
|
|||
elif str(t.elem) == "int":
|
||||
return f"IntArrayRef[{size}]" if size is not None else "IntArrayRef"
|
||||
elif str(t.elem) == "SymInt":
|
||||
if symint:
|
||||
return (
|
||||
f"SymIntArrayRef[{size}]" if size is not None else "SymIntArrayRef"
|
||||
)
|
||||
else:
|
||||
return f"IntArrayRef[{size}]" if size is not None else "IntArrayRef"
|
||||
return f"SymIntArrayRef[{size}]" if size is not None else "SymIntArrayRef"
|
||||
elif str(t.elem) == "Tensor":
|
||||
return f"TensorList[{size}]" if size is not None else "TensorList"
|
||||
elif str(t.elem) == "Scalar":
|
||||
|
|
@ -700,7 +687,7 @@ def argument_type_str(
|
|||
return "const c10::List<c10::optional<Tensor>> &"
|
||||
elif str(t.elem) == "Dimname":
|
||||
return f"DimnameList[{size}]" if size is not None else "DimnameList"
|
||||
elem = argument_type_str(t.elem, simple_type=simple_type, symint=symint)
|
||||
elem = argument_type_str(t.elem, simple_type=simple_type)
|
||||
return f"ArrayRef<{elem}>"
|
||||
|
||||
raise RuntimeError(f"unrecognized type {repr(t)}")
|
||||
|
|
@ -911,7 +898,7 @@ def argument_type_str_pyi(t: Type) -> str:
|
|||
if t.name == BaseTy.int:
|
||||
ret = "_int"
|
||||
if t.name == BaseTy.SymInt:
|
||||
ret = "Union[_int, SymInt]"
|
||||
ret = "SymInt"
|
||||
elif t.name == BaseTy.float:
|
||||
ret = "_float"
|
||||
elif t.name == BaseTy.str:
|
||||
|
|
@ -1053,7 +1040,7 @@ def returns_str_pyi(signature: PythonSignature) -> str:
|
|||
|
||||
|
||||
def dispatch_lambda_args(
|
||||
ps: PythonSignature, f: NativeFunction, symint: bool = True
|
||||
ps: PythonSignature, f: NativeFunction
|
||||
) -> Tuple[DispatchLambdaArgument, ...]:
|
||||
if isinstance(ps, PythonSignatureDeprecated):
|
||||
schema = ps.deprecated_schema
|
||||
|
|
@ -1064,7 +1051,6 @@ def dispatch_lambda_args(
|
|||
cpp_args = cpp.arguments(
|
||||
arguments=schema.arguments,
|
||||
faithful=False,
|
||||
symint=symint,
|
||||
method=False,
|
||||
cpp_no_default_args=f.cpp_no_default_args,
|
||||
)
|
||||
|
|
@ -1147,15 +1133,14 @@ def dispatch_lambda_return_str(f: NativeFunction) -> str:
|
|||
returns_without_annotation = tuple(
|
||||
map(lambda r: Return(r.name, r.type, None), f.func.returns)
|
||||
)
|
||||
return_str = cpp.returns_type(returns_without_annotation, symint=True).cpp_type()
|
||||
return_str = cpp.returns_type(returns_without_annotation).cpp_type()
|
||||
if return_str not in SUPPORTED_RETURN_TYPES:
|
||||
raise RuntimeError(f"{f.func.name} returns unsupported type {return_str}")
|
||||
return return_str
|
||||
|
||||
|
||||
def cpp_dispatch_target(f: NativeFunction) -> str:
|
||||
symint = f.func.has_symint()
|
||||
name = cpp.name(f.func, symint_overload=symint)
|
||||
name = cpp.name(f.func)
|
||||
if Variant.method in f.variants:
|
||||
return f"self.{name}"
|
||||
if Variant.function in f.variants:
|
||||
|
|
@ -1207,7 +1192,7 @@ def cpp_dispatch_exprs(
|
|||
# For certain cases it is intentionally more restrictive than necessary,
|
||||
# e.g.: it doesn't accepts doublelist with definite size.
|
||||
def arg_parser_unpack_method(
|
||||
t: Type, default: Optional[str], default_init: Optional[str], *, symint: bool = True
|
||||
t: Type, default: Optional[str], default_init: Optional[str]
|
||||
) -> str:
|
||||
has_default_init = default_init is not None
|
||||
if has_default_init and str(t) not in (
|
||||
|
|
@ -1239,10 +1224,7 @@ def arg_parser_unpack_method(
|
|||
elif t.name == BaseTy.int:
|
||||
return "toInt64"
|
||||
elif t.name == BaseTy.SymInt:
|
||||
if symint:
|
||||
return "toSymInt"
|
||||
else:
|
||||
return "toInt64"
|
||||
return "toSymInt"
|
||||
elif t.name == BaseTy.bool:
|
||||
return "toBoolWithDefault" if has_default_init else "toBool"
|
||||
elif t.name == BaseTy.float:
|
||||
|
|
@ -1263,14 +1245,10 @@ def arg_parser_unpack_method(
|
|||
return "toDimnameListOptional"
|
||||
elif not has_default_init and default in (None, "None", "c10::nullopt"):
|
||||
# If default is None: append 'Optional' to elem's unpacking method
|
||||
return (
|
||||
arg_parser_unpack_method(t.elem, None, None, symint=symint) + "Optional"
|
||||
)
|
||||
return arg_parser_unpack_method(t.elem, None, None) + "Optional"
|
||||
else:
|
||||
# Otherwise, load as underlying type with default
|
||||
return arg_parser_unpack_method(
|
||||
t.elem, default, default_init, symint=symint
|
||||
)
|
||||
return arg_parser_unpack_method(t.elem, default, default_init)
|
||||
|
||||
elif isinstance(t, ListType):
|
||||
if str(t.elem) == "Tensor":
|
||||
|
|
@ -1291,10 +1269,7 @@ def arg_parser_unpack_method(
|
|||
return "doublelist"
|
||||
elif str(t.elem) == "SymInt":
|
||||
# accept definite size
|
||||
if symint:
|
||||
return "symintlist"
|
||||
else:
|
||||
return "intlist"
|
||||
return "symintlist"
|
||||
elif str(t) == "Scalar[]":
|
||||
return "scalarlist"
|
||||
raise RuntimeError(f"type '{t}' is not supported by PythonArgParser")
|
||||
|
|
@ -1303,11 +1278,11 @@ def arg_parser_unpack_method(
|
|||
# Return RHS expression for python argument using PythonArgParser output.
|
||||
# e.g. for arg name 'foo', arg type 'bool', arg_index = 2, returns '_r.toBool(2)'
|
||||
def arg_parser_output_expr(
|
||||
arg_index: int, a: PythonArgument, *, symint: bool = True
|
||||
arg_index: int, a: PythonArgument
|
||||
) -> PythonArgParserOutputExpr:
|
||||
has_default = a.default_init is not None
|
||||
unpack_method = arg_parser_unpack_method(
|
||||
t=a.type, default=a.default, default_init=a.default_init, symint=symint
|
||||
t=a.type, default=a.default, default_init=a.default_init
|
||||
)
|
||||
default = f", {a.default_init}" if has_default else ""
|
||||
expr = f"_r.{unpack_method}({arg_index}{default})"
|
||||
|
|
@ -1322,12 +1297,12 @@ def arg_parser_output_expr(
|
|||
|
||||
# Returns a map with key = arg_name and value = PythonArgParserOutputExpr.
|
||||
def arg_parser_output_exprs(
|
||||
ps: PythonSignature, f: NativeFunction, *, symint: bool = True
|
||||
ps: PythonSignature, f: NativeFunction
|
||||
) -> Dict[str, PythonArgParserOutputExpr]:
|
||||
return {
|
||||
e.name: e
|
||||
for i, a in enumerate(ps.arguments())
|
||||
for e in (arg_parser_output_expr(i, a, symint=symint),)
|
||||
for e in (arg_parser_output_expr(i, a),)
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -1342,13 +1317,13 @@ TENSOR_OPTIONS_FIELDS = {
|
|||
|
||||
# bind arg parser outputs (python args) with dispatch lambda arguments (c++ args).
|
||||
def dispatch_lambda_exprs(
|
||||
ps: PythonSignature, f: NativeFunction, *, symint: bool = True
|
||||
ps: PythonSignature, f: NativeFunction
|
||||
) -> DispatchLambdaArgumentExprs:
|
||||
# This method is to bind 'arg_parser_outputs' and 'lambda_args' by producing
|
||||
# 'inits' and 'lambda_args_exprs' for each lambda argument using arg parser
|
||||
# outputs.
|
||||
arg_parser_outputs = arg_parser_output_exprs(ps, f, symint=symint)
|
||||
lambda_args = dispatch_lambda_args(ps, f, symint=symint)
|
||||
arg_parser_outputs = arg_parser_output_exprs(ps, f)
|
||||
lambda_args = dispatch_lambda_args(ps, f)
|
||||
inits: List[str] = []
|
||||
lambda_args_exprs: Dict[str, str] = {}
|
||||
|
||||
|
|
|
|||
|
|
@ -42,12 +42,7 @@ from torchgen.utils import assert_never
|
|||
# some more nominal types
|
||||
def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> NamedCType:
|
||||
# If it's a value type, do the value type translation
|
||||
# NB: structured kernels ALWAYS have symint off, since they involve actual
|
||||
# kernels that require real ints. The one exception is the
|
||||
# CompositeExplicitAutograd and the meta function (which could
|
||||
# hypothetically be SymInt), but for simplicity we plan for these to just
|
||||
# be handled in Python
|
||||
r = cpp.valuetype_type(t, symint=False, binds=binds)
|
||||
r = cpp.valuetype_type(t, binds=binds)
|
||||
if r is not None:
|
||||
return r
|
||||
|
||||
|
|
|
|||
|
|
@ -337,16 +337,10 @@ Check this module for more information.
|
|||
)
|
||||
return f"c10::asIntArrayRefSlow({symIntArrayRef_type})"
|
||||
elif goal.type == BaseCType(symIntArrayRefT):
|
||||
try:
|
||||
r = direct_solve(NamedCType(goal.name, BaseCType(intArrayRefT)))
|
||||
return f"c10::SymIntArrayRef::fromIntArrayRef({r})"
|
||||
except UnsatError:
|
||||
return direct_solve(NamedCType(goal.name, longSymVec_ctype))
|
||||
elif goal.type == BaseCType(SymIntT):
|
||||
return direct_solve(NamedCType(goal.name, BaseCType(longT)))
|
||||
return direct_solve(NamedCType(goal.name, longSymVec_ctype))
|
||||
elif goal.type == BaseCType(longT):
|
||||
symInt_type = direct_solve(NamedCType(goal.name, BaseCType(SymIntT)))
|
||||
return f"{symInt_type}.expect_int()"
|
||||
return f"{symInt_type}.expectInt()"
|
||||
elif goal.type == BaseCType(optionalIntArrayRefT):
|
||||
return direct_solve(NamedCType(goal.name, optionalLongVec_ctype))
|
||||
elif goal.type == BaseCType(optionalScalarRefT):
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Dict, Iterator, List, Optional, Sequence, Set, Tuple, TypeVar, Union
|
||||
from typing import Dict, Iterator, List, Optional, Sequence, Set, TypeVar, Union
|
||||
|
||||
from torchgen.model import (
|
||||
Argument,
|
||||
|
|
@ -417,11 +417,6 @@ class CppSignature:
|
|||
# (i.e. with a potential TensorOptions argument and out arguments in the front)
|
||||
faithful: bool
|
||||
|
||||
# Is this a symint C++ signature. For BC reasons, functions that take
|
||||
# SymInts still present as int64_t in C++, and the SymInt variant is
|
||||
# offered at a different overload name
|
||||
symint: bool
|
||||
|
||||
# The set of C++ arguments which should not have defaults applied to them
|
||||
cpp_no_default_args: Set[str]
|
||||
|
||||
|
|
@ -438,17 +433,12 @@ class CppSignature:
|
|||
return cpp.arguments(
|
||||
self.func.arguments,
|
||||
faithful=self.faithful,
|
||||
symint=self.symint,
|
||||
method=self.method,
|
||||
cpp_no_default_args=self.cpp_no_default_args,
|
||||
)
|
||||
|
||||
def name(self) -> str:
|
||||
n = cpp.name(
|
||||
self.func,
|
||||
faithful_name_for_out_overloads=self.faithful,
|
||||
symint_overload=self.symint,
|
||||
)
|
||||
n = cpp.name(self.func, faithful_name_for_out_overloads=self.faithful)
|
||||
if self.fallback_binding:
|
||||
n = f"__dispatch_{n}"
|
||||
return n
|
||||
|
|
@ -461,9 +451,7 @@ class CppSignature:
|
|||
prefix: str = "",
|
||||
is_redispatching_fn: bool = False,
|
||||
) -> str:
|
||||
returns_type = cpp.returns_type(
|
||||
self.func.returns, symint=self.symint
|
||||
).cpp_type()
|
||||
returns_type = cpp.returns_type(self.func.returns).cpp_type()
|
||||
cpp_args = [a.decl() for a in self.arguments()]
|
||||
if is_redispatching_fn:
|
||||
cpp_args = ["c10::DispatchKeySet dispatchKeySet"] + cpp_args
|
||||
|
|
@ -481,9 +469,7 @@ class CppSignature:
|
|||
prefix: str = "",
|
||||
is_redispatching_fn: bool = False,
|
||||
) -> str:
|
||||
returns_type = cpp.returns_type(
|
||||
self.func.returns, symint=self.symint
|
||||
).cpp_type()
|
||||
returns_type = cpp.returns_type(self.func.returns).cpp_type()
|
||||
cpp_args = [a.defn() for a in self.arguments()]
|
||||
if is_redispatching_fn:
|
||||
cpp_args = ["c10::DispatchKeySet dispatchKeySet"] + cpp_args
|
||||
|
|
@ -494,12 +480,12 @@ class CppSignature:
|
|||
|
||||
def ptr_type(self) -> str:
|
||||
args_types_str = ", ".join(a.type for a in self.arguments())
|
||||
return f"{cpp.returns_type(self.func.returns, symint=self.symint).cpp_type()} (*)({args_types_str})"
|
||||
return f"{cpp.returns_type(self.func.returns).cpp_type()} (*)({args_types_str})"
|
||||
|
||||
# Return the C++ function type, e.g., something like int(bool)
|
||||
def type(self) -> str:
|
||||
args_types_str = ", ".join(a.type for a in self.arguments())
|
||||
return f"{cpp.returns_type(self.func.returns, symint=self.symint).cpp_type()} ({args_types_str})"
|
||||
return f"{cpp.returns_type(self.func.returns).cpp_type()} ({args_types_str})"
|
||||
|
||||
|
||||
# Represents group of all CppSignatures associated with a
|
||||
|
|
@ -511,8 +497,6 @@ class CppSignatureGroup:
|
|||
func: FunctionSchema
|
||||
signature: CppSignature
|
||||
faithful_signature: Optional[CppSignature]
|
||||
symint_signature: Optional[CppSignature]
|
||||
symint_faithful_signature: Optional[CppSignature]
|
||||
|
||||
def most_faithful_signature(self) -> CppSignature:
|
||||
if self.faithful_signature:
|
||||
|
|
@ -524,10 +508,6 @@ class CppSignatureGroup:
|
|||
yield self.signature
|
||||
if self.faithful_signature:
|
||||
yield self.faithful_signature
|
||||
if self.symint_signature:
|
||||
yield self.symint_signature
|
||||
if self.symint_faithful_signature:
|
||||
yield self.symint_faithful_signature
|
||||
|
||||
@staticmethod
|
||||
def from_native_function(
|
||||
|
|
@ -535,35 +515,23 @@ class CppSignatureGroup:
|
|||
) -> "CppSignatureGroup":
|
||||
func = f.func
|
||||
|
||||
def make_sig(*, faithful: bool, symint: bool) -> CppSignature:
|
||||
def make_sig(*, faithful: bool) -> CppSignature:
|
||||
return CppSignature(
|
||||
func=func,
|
||||
faithful=faithful,
|
||||
symint=symint,
|
||||
method=method,
|
||||
fallback_binding=fallback_binding,
|
||||
cpp_no_default_args=f.cpp_no_default_args,
|
||||
)
|
||||
|
||||
def make_sigs(*, symint: bool) -> Tuple[CppSignature, Optional[CppSignature]]:
|
||||
faithful_signature: Optional[CppSignature] = None
|
||||
if func.arguments.tensor_options is not None or len(func.arguments.out) > 0:
|
||||
faithful_signature = make_sig(faithful=True, symint=symint)
|
||||
signature = make_sig(faithful=False, symint=symint)
|
||||
return signature, faithful_signature
|
||||
|
||||
signature, faithful_signature = make_sigs(symint=False)
|
||||
symint_signature: Optional[CppSignature] = None
|
||||
symint_faithful_signature: Optional[CppSignature] = None
|
||||
if func.has_symint():
|
||||
symint_signature, symint_faithful_signature = make_sigs(symint=True)
|
||||
|
||||
faithful_signature: Optional[CppSignature] = None
|
||||
if func.arguments.tensor_options is not None or len(func.arguments.out) > 0:
|
||||
faithful_signature = make_sig(faithful=True)
|
||||
signature = make_sig(faithful=False)
|
||||
return CppSignatureGroup(
|
||||
func=func,
|
||||
signature=signature,
|
||||
faithful_signature=faithful_signature,
|
||||
symint_signature=symint_signature,
|
||||
symint_faithful_signature=symint_faithful_signature,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -625,8 +593,6 @@ class NativeSignature:
|
|||
# The schema this signature is derived from
|
||||
func: FunctionSchema
|
||||
|
||||
symint: bool
|
||||
|
||||
prefix: str = ""
|
||||
|
||||
def name(self) -> str:
|
||||
|
|
@ -636,24 +602,24 @@ class NativeSignature:
|
|||
args_str = ", ".join(a.decl() for a in self.arguments())
|
||||
if name is None:
|
||||
name = self.name()
|
||||
return f"{native.returns_type(self.func.returns, symint=self.symint).cpp_type()} {name}({args_str})"
|
||||
return f"{native.returns_type(self.func.returns).cpp_type()} {name}({args_str})"
|
||||
|
||||
def defn(self, name: Optional[str] = None) -> str:
|
||||
args_str = ", ".join(a.defn() for a in self.arguments())
|
||||
if name is None:
|
||||
name = self.name()
|
||||
return f"{native.returns_type(self.func.returns, symint=self.symint).cpp_type()} {name}({args_str})"
|
||||
return f"{native.returns_type(self.func.returns).cpp_type()} {name}({args_str})"
|
||||
|
||||
def ptr_type(self) -> str:
|
||||
# don't include defaults in type signature!
|
||||
args_str = ", ".join(a.defn() for a in self.arguments())
|
||||
return f"{native.returns_type(self.func.returns, symint=self.symint).cpp_type()} (*)({args_str})"
|
||||
return f"{native.returns_type(self.func.returns).cpp_type()} (*)({args_str})"
|
||||
|
||||
def arguments(self) -> List[Binding]:
|
||||
return native.arguments(self.func, symint=self.symint)
|
||||
return native.arguments(self.func)
|
||||
|
||||
def returns_type(self) -> CType:
|
||||
return native.returns_type(self.func.returns, symint=self.symint)
|
||||
return native.returns_type(self.func.returns)
|
||||
|
||||
def dispatcher_exprs(self) -> List[Expr]:
|
||||
return translate.translate(
|
||||
|
|
@ -779,14 +745,9 @@ def kernel_signature(
|
|||
# With external backends, we'd like to enforce that they write their kernels with schemas
|
||||
# that match the Dispatcher API directly, if they can.
|
||||
if backend_index.external:
|
||||
# Dispatcher signature faithfully does SymInt, which is good for XLA,
|
||||
# not so good for more conventional backends but we don't have any of
|
||||
# those. If we do, that's time to add a new Signature that is a cross
|
||||
# between DispatcherSignature and NativeSignature
|
||||
assert backend_index.symint
|
||||
return DispatcherSignature.from_schema(f.func, prefix=prefix)
|
||||
else:
|
||||
return NativeSignature(f.func, prefix=prefix, symint=backend_index.symint)
|
||||
return NativeSignature(f.func, prefix)
|
||||
|
||||
|
||||
# Functions only, no types
|
||||
|
|
|
|||
|
|
@ -40,8 +40,7 @@ def kernel_name(g: NativeFunctionsGroup, dispatch_key: DispatchKey) -> str:
|
|||
#
|
||||
# NB: used for CPU only
|
||||
def dispatchstub_type(t: Type, *, binds: ArgName) -> Optional[NamedCType]:
|
||||
# Dispatch stubs are always plain ints
|
||||
r = cpp.valuetype_type(t, binds=binds, symint=False)
|
||||
r = cpp.valuetype_type(t, binds=binds)
|
||||
if r is not None:
|
||||
return r
|
||||
|
||||
|
|
@ -65,7 +64,7 @@ def opmath_type(scalar_t: BaseCppType) -> BaseCppType:
|
|||
#
|
||||
# NB: CUDA only
|
||||
def ufunctor_ctor_type(t: Type, *, binds: ArgName, scalar_t: BaseCppType) -> NamedCType:
|
||||
r = cpp.valuetype_type(t, binds=binds, symint=False)
|
||||
r = cpp.valuetype_type(t, binds=binds)
|
||||
if r is not None:
|
||||
return r
|
||||
|
||||
|
|
@ -94,7 +93,7 @@ def ufunctor_apply_type(
|
|||
# is done in the computation type. compute_t is opmath_t in CUDA and scalar_t
|
||||
# in CPU
|
||||
def ufunc_type(t: Type, *, binds: ArgName, compute_t: CType) -> NamedCType:
|
||||
r = cpp.valuetype_type(t, binds=binds, symint=False)
|
||||
r = cpp.valuetype_type(t, binds=binds)
|
||||
if r is not None:
|
||||
return r
|
||||
|
||||
|
|
|
|||
|
|
@ -136,10 +136,7 @@ def convert_arguments(f: NativeFunction) -> Tuple[List[Binding], List[str]]:
|
|||
def argumenttype_ivalue_convert(
|
||||
t: Type, arg_name: str, *, mutable: bool = False
|
||||
) -> Tuple[str, CType, List[str], List[str]]:
|
||||
# Unboxing is for mobile, which doesn't care about SymInts
|
||||
ctype = cpp.argumenttype_type(
|
||||
t=t, mutable=mutable, binds=arg_name, symint=False
|
||||
).type
|
||||
ctype = cpp.argumenttype_type(t=t, mutable=mutable, binds=arg_name).type
|
||||
|
||||
if isinstance(t, BaseType):
|
||||
out_name = f"{arg_name}_base"
|
||||
|
|
|
|||
|
|
@ -27,10 +27,7 @@ from torchgen.dest.lazy_ts_lowering import ts_lowering_body
|
|||
from torchgen.model import (
|
||||
Argument,
|
||||
BackendIndex,
|
||||
BaseTy,
|
||||
BaseType,
|
||||
FunctionSchema,
|
||||
ListType,
|
||||
NativeFunction,
|
||||
NativeFunctionsGroup,
|
||||
)
|
||||
|
|
@ -43,7 +40,6 @@ def node_ctor_arg_rvalue_string(arg: LazyArgument) -> str:
|
|||
a lazy Node constructor.
|
||||
"""
|
||||
|
||||
# TODO: Matching on CType seems wrong; should be matching on Type
|
||||
if isValueType(arg.lazy_type):
|
||||
if isinstance(arg.lazy_type, BaseCType):
|
||||
if arg.is_wrapped_scalar:
|
||||
|
|
@ -52,7 +48,7 @@ def node_ctor_arg_rvalue_string(arg: LazyArgument) -> str:
|
|||
return f"lazy_{arg.name}_tensorlist"
|
||||
elif arg.is_symint_or_list:
|
||||
cpp_type = arg.lazy_type.cpp_type()
|
||||
return f"GetSymIntValue({arg.name})"
|
||||
return f"{cpp_type}(dynamic_cast<torch::lazy::SymIntNodeImpl*>({arg.name}.toSymIntNodeImpl().get())->node_, 0)"
|
||||
return f"lazy_{arg.name}->GetIrValue()"
|
||||
elif isinstance(arg.lazy_type, OptionalCType):
|
||||
if arg.is_wrapped_scalar:
|
||||
|
|
@ -67,15 +63,7 @@ def node_ctor_arg_rvalue_string(arg: LazyArgument) -> str:
|
|||
f"TODO not sure if there are other valid types to handle here ({arg.lazy_type})"
|
||||
)
|
||||
else:
|
||||
# NB: this is here because right now we aren't treating SymInt[] as a
|
||||
# value type; when we do this needs to move above
|
||||
# NB: we cannot test arg.lazy_type as we've already specified it is an
|
||||
# int64_t and so we cannot distinguish between SymInt and int64_t
|
||||
if isinstance(arg.orig_type, ListType) and arg.orig_type.elem == BaseType(
|
||||
BaseTy.SymInt
|
||||
):
|
||||
return f"GetSymIntArrayRefValue({arg.name})"
|
||||
elif isinstance(arg.lazy_type, VectorCType) and isinstance(
|
||||
if isinstance(arg.lazy_type, VectorCType) and isinstance(
|
||||
arg.lazy_type.elem, BaseCType
|
||||
):
|
||||
return f"std::vector<{arg.lazy_type.elem.type}>({arg.name}.begin(), {arg.name}.end())"
|
||||
|
|
@ -512,13 +500,9 @@ std::vector<torch::lazy::Shape> shapes{torch::lazy::Shape(out_meta.scalar_type()
|
|||
dispatch_ns = "compositeexplicitautogradnonfunctional"
|
||||
else:
|
||||
dispatch_ns = "meta"
|
||||
aten_name = schema.aten_name
|
||||
# TODO: this is trolling
|
||||
if func.func.has_symint():
|
||||
aten_name += "_symint"
|
||||
shape_str = f"""\
|
||||
{meta_conversion_str}
|
||||
auto out_meta = at::{dispatch_ns}::{aten_name}({', '.join(meta_call_args)});
|
||||
auto out_meta = at::{dispatch_ns}::{schema.aten_name}({', '.join(meta_call_args)});
|
||||
{meta_out}"""
|
||||
else:
|
||||
shape_sig = ComputeShapeSignature(metadata.kernel, func)
|
||||
|
|
|
|||
|
|
@ -287,8 +287,8 @@ class RegisterDispatchKey:
|
|||
self, f: NativeFunction
|
||||
) -> Union[NativeSignature, DispatcherSignature]:
|
||||
# The prefix is just to ensure uniqueness. The Dispatcher API doesn't guarantee unique kernel names.
|
||||
return DispatcherSignature.from_schema(
|
||||
f.func, prefix=f"wrapper_{f.func.name.overload_name}_"
|
||||
return kernel_signature(
|
||||
f, self.backend_index, prefix=f"wrapper_{f.func.name.overload_name}_"
|
||||
)
|
||||
|
||||
def gen_out_inplace_wrapper(
|
||||
|
|
@ -407,11 +407,10 @@ class RegisterDispatchKey:
|
|||
f, method=False, fallback_binding=False
|
||||
)
|
||||
|
||||
# TODO: dedupe this with the structured codegen
|
||||
if self.target is Target.NAMESPACED_DECLARATION:
|
||||
result = ""
|
||||
for cpp_sig in cpp_sig_group.signatures():
|
||||
result += f"TORCH_API {cpp_sig.decl()};\n"
|
||||
result = f"TORCH_API {cpp_sig_group.signature.decl()};\n"
|
||||
if cpp_sig_group.faithful_signature is not None:
|
||||
result += f"TORCH_API {cpp_sig_group.faithful_signature.decl()};\n"
|
||||
return result
|
||||
elif self.target is Target.NAMESPACED_DEFINITION:
|
||||
|
||||
|
|
@ -422,11 +421,10 @@ return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), si
|
|||
}}
|
||||
"""
|
||||
|
||||
result = ""
|
||||
for cpp_sig in cpp_sig_group.signatures():
|
||||
result += generate_defn(cpp_sig)
|
||||
result = generate_defn(cpp_sig_group.signature)
|
||||
if cpp_sig_group.faithful_signature is not None:
|
||||
result += generate_defn(cpp_sig_group.faithful_signature)
|
||||
return result
|
||||
|
||||
elif self.target is Target.ANONYMOUS_DEFINITION:
|
||||
# short circuit for inplace_meta
|
||||
if inplace_meta:
|
||||
|
|
@ -453,14 +451,7 @@ return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), si
|
|||
else:
|
||||
impl_name = f"{metadata.cpp_namespace}::{self.class_method_name}::{metadata.kernel}"
|
||||
|
||||
kernel_sig = kernel_signature(f, self.backend_index)
|
||||
|
||||
args_exprs_str = ", ".join(
|
||||
e.expr
|
||||
for e in translate(
|
||||
sig.arguments(), kernel_sig.arguments(), method=False
|
||||
)
|
||||
)
|
||||
args_exprs_str = ", ".join(a.name for a in args)
|
||||
|
||||
device_check = " // No device check\n"
|
||||
# Backends that require device guards presumably also require device checks.
|
||||
|
|
@ -750,14 +741,12 @@ resize_out(out, sizes, strides, options);
|
|||
)
|
||||
|
||||
# Signature of the wrapper function we'll register to the dispatcher
|
||||
sig = NativeSignature(
|
||||
f.func, prefix="wrapper_", symint=self.backend_index.symint
|
||||
)
|
||||
sig = NativeSignature(f.func, prefix="wrapper_")
|
||||
|
||||
if self.target is Target.NAMESPACED_DECLARATION:
|
||||
result = ""
|
||||
for cpp_sig in cpp_sig_group.signatures():
|
||||
result += f"TORCH_API {cpp_sig.decl()};\n"
|
||||
result = f"TORCH_API {cpp_sig_group.signature.decl()};\n"
|
||||
if cpp_sig_group.faithful_signature is not None:
|
||||
result += f"TORCH_API {cpp_sig_group.faithful_signature.decl()};\n"
|
||||
return result
|
||||
|
||||
elif self.target is Target.NAMESPACED_DEFINITION:
|
||||
|
|
@ -769,9 +758,9 @@ return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), si
|
|||
}}
|
||||
"""
|
||||
|
||||
result = ""
|
||||
for cpp_sig in cpp_sig_group.signatures():
|
||||
result += generate_defn(cpp_sig)
|
||||
result = generate_defn(cpp_sig_group.signature)
|
||||
if cpp_sig_group.faithful_signature is not None:
|
||||
result += generate_defn(cpp_sig_group.faithful_signature)
|
||||
return result
|
||||
|
||||
elif self.target is Target.ANONYMOUS_DEFINITION:
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ from torchgen.gen_functionalization_type import (
|
|||
gen_functionalization_definition,
|
||||
gen_functionalization_registration,
|
||||
gen_functionalization_view_inverse_declaration,
|
||||
gen_symint_view_copy_kernel,
|
||||
)
|
||||
from torchgen.gen_vmap_plumbing import gen_all_vmap_plumbing
|
||||
|
||||
|
|
@ -159,9 +160,6 @@ def parse_native_yaml_struct(
|
|||
use_out_as_primary=True,
|
||||
external=False,
|
||||
device_guard=False,
|
||||
# I'm actually not sure about this; undefined could be hit on
|
||||
# empty TensorList, hypothetically that could have sizes in it
|
||||
symint=False,
|
||||
index={},
|
||||
)
|
||||
)
|
||||
|
|
@ -176,16 +174,6 @@ def parse_native_yaml_struct(
|
|||
# Only cuda-like devices in tree require device guards
|
||||
device_guard=is_cuda_dispatch_key(k),
|
||||
index=v,
|
||||
# Which dispatch keys natively support symint
|
||||
# Note: DispatchKey.CompositeExplicitAutograd has to match out
|
||||
# composites; I think there's some factoring problem here
|
||||
symint=k
|
||||
in [
|
||||
DispatchKey.Meta,
|
||||
DispatchKey.CompositeImplicitAutograd,
|
||||
DispatchKey.CompositeExplicitAutograd,
|
||||
DispatchKey.CompositeExplicitAutogradNonFunctional,
|
||||
],
|
||||
)
|
||||
return ParsedYaml(rs, indices)
|
||||
|
||||
|
|
@ -874,8 +862,7 @@ class ComputeBackendSelect:
|
|||
return None
|
||||
|
||||
name = native.name(f.func)
|
||||
# BackendSelect can go to Meta, so it must preserve symints
|
||||
native_sig = NativeSignature(f.func, symint=True)
|
||||
native_sig = NativeSignature(f.func)
|
||||
|
||||
native_tensor_args = [
|
||||
a
|
||||
|
|
@ -979,10 +966,7 @@ def dynamic_type(t: Type) -> str:
|
|||
# also include Tensor[]
|
||||
if str(t) == "Tensor":
|
||||
return "at::Tensor"
|
||||
# This is a legacy concept, so never report SymInt
|
||||
return cpp.argumenttype_type(
|
||||
t, mutable=False, binds="__placeholder__", symint=False
|
||||
).cpp_type()
|
||||
return cpp.argumenttype_type(t, mutable=False, binds="__placeholder__").cpp_type()
|
||||
|
||||
|
||||
def compute_method_of_yaml(variants: Set[Variant]) -> List[str]:
|
||||
|
|
@ -1047,8 +1031,7 @@ def compute_returns_yaml(
|
|||
ret = {
|
||||
"dynamic_type": dynamic_type(r.type),
|
||||
"name": name,
|
||||
# legacy, report ints
|
||||
"type": cpp.return_type(r, symint=False).cpp_type(),
|
||||
"type": cpp.return_type(r).cpp_type(),
|
||||
}
|
||||
|
||||
if r.name:
|
||||
|
|
@ -1108,8 +1091,7 @@ def compute_argument_yaml(
|
|||
"dynamic_type": dynamic_type(a.type),
|
||||
"is_nullable": a.type.is_nullable(),
|
||||
"name": a.name,
|
||||
# legacy, report ints
|
||||
"type": cpp.argument_type(a, binds="__placeholder__", symint=False).cpp_type(),
|
||||
"type": cpp.argument_type(a, binds="__placeholder__").cpp_type(),
|
||||
}
|
||||
if a.default is not None:
|
||||
arg["default"] = pythonify_default(cpp.default_expr(a.default, a.type))
|
||||
|
|
@ -1175,13 +1157,11 @@ def compute_declaration_yaml(f: NativeFunction) -> object:
|
|||
method=False,
|
||||
cpp_no_default_args=set(),
|
||||
faithful=False,
|
||||
symint=False,
|
||||
has_tensor_options=False,
|
||||
)
|
||||
]
|
||||
|
||||
# legacy, report ints
|
||||
cpp_returns = cpp.returns_type(f.func.returns, symint=False).cpp_type()
|
||||
cpp_returns = cpp.returns_type(f.func.returns).cpp_type()
|
||||
schema_order_cpp_signature = f"{cpp_returns} ({', '.join(cpp_schema_order_types)})"
|
||||
|
||||
is_factory_method = (
|
||||
|
|
@ -2410,6 +2390,29 @@ def gen_source_files(
|
|||
)
|
||||
},
|
||||
)
|
||||
view_copy_with_symint_pairs: List[Tuple[NativeFunction, NativeFunction]] = []
|
||||
for g1 in view_groups:
|
||||
for g2 in view_groups:
|
||||
if g1.view_copy is None or g2.view_copy is None:
|
||||
continue
|
||||
# TODO: make this more first class in the data model
|
||||
g1_base_name = str(g1.view_copy.func.name.name)
|
||||
g2_base_name = str(g2.view_copy.func.name.name)
|
||||
|
||||
same_base_op = (
|
||||
g1_base_name == g2_base_name
|
||||
and g1.view_copy.func.arguments.symints_to_ints()
|
||||
== g2.view_copy.func.arguments.symints_to_ints()
|
||||
)
|
||||
op1_not_symint = "SymInt" not in str(g1.view_copy.func.name.overload_name)
|
||||
op2_symint = "SymInt" in str(g2.view_copy.func.name.overload_name)
|
||||
if same_base_op and op1_not_symint and op2_symint:
|
||||
view_copy_with_symint_pairs.append(
|
||||
(
|
||||
g1.view_copy,
|
||||
g2.view_copy,
|
||||
)
|
||||
)
|
||||
|
||||
# Note [view_copy NativeFunctions]
|
||||
# Every view operator in native_functions.yaml that is not CompositeImplicitAutograd
|
||||
|
|
@ -2450,6 +2453,12 @@ def gen_source_files(
|
|||
"CompositeViewCopyKernel_Definitions": list(
|
||||
mapMaybe(gen_composite_view_copy_kernel, view_groups)
|
||||
),
|
||||
"SymIntViewCopyKernel_Definitions": list(
|
||||
mapMaybe(
|
||||
lambda pair: gen_symint_view_copy_kernel(pair[0], pair[1]),
|
||||
view_copy_with_symint_pairs,
|
||||
)
|
||||
),
|
||||
"GeneratedCompositeFunctional_Definitions": list(
|
||||
mapMaybe(
|
||||
gen_composite_functional_kernel,
|
||||
|
|
|
|||
|
|
@ -140,7 +140,6 @@ Only the following keys are supported: {", ".join(valid_keys)}'
|
|||
dispatch_key=dispatch_key,
|
||||
use_out_as_primary=use_out_as_primary,
|
||||
external=True,
|
||||
symint=True, # TODO: make this configurable
|
||||
device_guard=use_device_guard,
|
||||
index=metadata,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -78,18 +78,21 @@ def gen_composite_view_copy_kernel(g: NativeFunctionsViewGroup) -> Optional[str]
|
|||
if g.view_copy is None:
|
||||
return None
|
||||
|
||||
# For view_copy.SymInt overloads,
|
||||
# See gen_symint_view_copy_kernel.
|
||||
if g.view_copy.func.name.overload_name == "SymInt":
|
||||
return None
|
||||
|
||||
# We can make view_copy work in more cases by using reshape()
|
||||
# when a normal view call would ordinarily fail.
|
||||
# This also makes LTC more efficient, because they don't need to include
|
||||
# clone() calls in their graph (which is normally needed by reshape).
|
||||
if str(g.view_copy.func.name) == "view_copy":
|
||||
return """\
|
||||
at::Tensor view_copy(const at::Tensor & self, at::SymIntArrayRef size) {
|
||||
// TODO: don't cast to int array ref
|
||||
auto int_size = c10::asIntArrayRefSlow(size);
|
||||
DimVector shape = infer_size_dv(int_size, self.numel());
|
||||
at::Tensor view_copy(const at::Tensor & self, at::IntArrayRef size) {
|
||||
DimVector shape = infer_size_dv(size, self.numel());
|
||||
if (!at::detail::computeStride(self.sizes(), self.strides(), shape).has_value()) {
|
||||
return self.reshape(int_size);
|
||||
return self.reshape(size);
|
||||
} else {
|
||||
auto output = at::_ops::view::call(self, size);
|
||||
return output.clone();
|
||||
|
|
@ -97,8 +100,7 @@ at::Tensor view_copy(const at::Tensor & self, at::SymIntArrayRef size) {
|
|||
}
|
||||
"""
|
||||
# view_copy is a native signature, since we're generating an at::native:: kernel
|
||||
# Functionalization always operates on symints though
|
||||
view_copy_sig = NativeSignature(g.view_copy.func, symint=True)
|
||||
view_copy_sig = NativeSignature(g.view_copy.func)
|
||||
|
||||
# view is a dispatcher signature, since we're calling into the at::_ops API
|
||||
view_sig = DispatcherSignature(g.view.func)
|
||||
|
|
@ -136,6 +138,34 @@ at::Tensor view_copy(const at::Tensor & self, at::SymIntArrayRef size) {
|
|||
"""
|
||||
|
||||
|
||||
# For symint view copy kernels, we want to generate them to call into
|
||||
# their concrete view_copy counterparts.
|
||||
@with_native_function_and
|
||||
def gen_symint_view_copy_kernel(
|
||||
view_copy: NativeFunction, view_copy_symint: NativeFunction
|
||||
) -> str:
|
||||
# view_copy.symint is a native signature, since we're generating an at::native:: kernel
|
||||
view_copy_symint_sig = NativeSignature(view_copy_symint.func)
|
||||
|
||||
# view_copy is a dispatcher signature, since we're calling into the at::_ops API
|
||||
view_copy_sig = DispatcherSignature(view_copy.func)
|
||||
|
||||
exprs = ", ".join(
|
||||
[
|
||||
e.expr
|
||||
for e in translate(
|
||||
view_copy_symint_sig.arguments(), view_copy_sig.arguments()
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
return f"""
|
||||
{view_copy_symint_sig.defn()} {{
|
||||
return at::_ops::{view_copy.func.name.unambiguous_name()}::call({exprs});
|
||||
}}
|
||||
"""
|
||||
|
||||
|
||||
def return_str(rets: Tuple[Return, ...], names: List[str]) -> str:
|
||||
assert len(rets) == len(names)
|
||||
if len(rets) == 0:
|
||||
|
|
|
|||
|
|
@ -816,6 +816,9 @@ class NativeFunction:
|
|||
backend_metadata,
|
||||
)
|
||||
|
||||
def symints_to_ints(self) -> "NativeFunction":
|
||||
return dataclasses.replace(self, func=self.func.symints_to_ints())
|
||||
|
||||
def validate_unstructured(self) -> None:
|
||||
# TODO: probably better to accumulate these errors and report them all
|
||||
# at once
|
||||
|
|
@ -878,6 +881,8 @@ class NativeFunction:
|
|||
"foreach kernels fall back to slow path when tensor are on different devices, "
|
||||
"device_check not allowed to be enabled"
|
||||
)
|
||||
named_symint = "SymInt" in self.func.name.overload_name
|
||||
assert named_symint == self.func.has_symint()
|
||||
|
||||
# NB: if your function accidentally has rand/dropout/... in its name
|
||||
# but is not actually random, feel free to amend this to special case
|
||||
|
|
@ -1112,8 +1117,6 @@ class BackendIndex:
|
|||
external: bool
|
||||
# Other backend-specific information that is on a per-operator basis
|
||||
index: Dict["OperatorName", BackendMetadata]
|
||||
# Whether or not this backend handles symbolic ints or not
|
||||
symint: bool
|
||||
|
||||
@staticmethod
|
||||
def grow_index(
|
||||
|
|
@ -1232,6 +1235,9 @@ class FunctionSchema:
|
|||
|
||||
decl_re = re.compile(r"(?P<name>[^\(]+)\((?P<args>.*)\) -> (?P<returns>.*)")
|
||||
|
||||
def symints_to_ints(self) -> "FunctionSchema":
|
||||
return dataclasses.replace(self, arguments=self.arguments.symints_to_ints())
|
||||
|
||||
@staticmethod
|
||||
def parse(func: str) -> "FunctionSchema":
|
||||
# We should probably get a proper parser here
|
||||
|
|
@ -1354,6 +1360,10 @@ class FunctionSchema:
|
|||
def is_functional_fn(self) -> bool:
|
||||
return "functional" in self.name.overload_name
|
||||
|
||||
def is_symint_fn(self) -> bool:
|
||||
# TODO: make this more robust
|
||||
return "SymInt" in self.name.overload_name
|
||||
|
||||
def is_out_fn(self) -> bool:
|
||||
# Note [is_out_fn]
|
||||
#
|
||||
|
|
@ -1694,6 +1704,9 @@ class Type:
|
|||
def is_list_like(self) -> Optional["ListType"]:
|
||||
raise NotImplementedError
|
||||
|
||||
def symint_to_int(self) -> "Type":
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# Base types are simple, atomic types with no further structure
|
||||
BaseTy = Enum(
|
||||
|
|
@ -1734,12 +1747,14 @@ class BaseType(Type):
|
|||
def is_nullable(self) -> bool:
|
||||
return False
|
||||
|
||||
def symint_to_int(self) -> "BaseType":
|
||||
if self.name == BaseTy.SymInt:
|
||||
return BaseType(BaseTy.int)
|
||||
return self
|
||||
|
||||
def is_list_like(self) -> Optional["ListType"]:
|
||||
return None
|
||||
|
||||
def is_symint_like(self) -> bool:
|
||||
return self.name == BaseTy.SymInt
|
||||
|
||||
|
||||
# Optional types may be specified, or may also be validly given None
|
||||
@dataclass(frozen=True)
|
||||
|
|
@ -1752,12 +1767,12 @@ class OptionalType(Type):
|
|||
def is_base_ty_like(self, base_ty: BaseTy) -> bool:
|
||||
return self.elem.is_base_ty_like(base_ty)
|
||||
|
||||
def is_symint_like(self) -> bool:
|
||||
return self.elem.is_symint_like()
|
||||
|
||||
def is_nullable(self) -> bool:
|
||||
return True
|
||||
|
||||
def symint_to_int(self) -> "Type":
|
||||
return dataclasses.replace(self, elem=self.elem.symint_to_int())
|
||||
|
||||
def is_list_like(self) -> Optional["ListType"]:
|
||||
return self.elem.is_list_like()
|
||||
|
||||
|
|
@ -1776,15 +1791,15 @@ class CustomClassType(Type):
|
|||
def is_base_ty_like(self, base_ty: BaseTy) -> bool:
|
||||
return False
|
||||
|
||||
def is_symint_like(self) -> bool:
|
||||
return False
|
||||
|
||||
def is_nullable(self) -> bool:
|
||||
"""
|
||||
Assume a custom class is not nullable.
|
||||
"""
|
||||
return False
|
||||
|
||||
def symint_to_int(self) -> "Type":
|
||||
return self
|
||||
|
||||
def is_list_like(self) -> Optional["ListType"]:
|
||||
return None
|
||||
|
||||
|
|
@ -1808,12 +1823,12 @@ class ListType(Type):
|
|||
def is_base_ty_like(self, base_ty: BaseTy) -> bool:
|
||||
return self.elem.is_base_ty_like(base_ty)
|
||||
|
||||
def is_symint_like(self) -> bool:
|
||||
return self.elem.is_symint_like()
|
||||
|
||||
def is_nullable(self) -> bool:
|
||||
return self.elem.is_nullable()
|
||||
|
||||
def symint_to_int(self) -> "ListType":
|
||||
return ListType(self.elem.symint_to_int(), self.size)
|
||||
|
||||
def is_list_like(self) -> Optional["ListType"]:
|
||||
return self
|
||||
|
||||
|
|
@ -1887,6 +1902,9 @@ class Argument:
|
|||
def is_write(self) -> bool:
|
||||
return self.annotation is not None and self.annotation.is_write
|
||||
|
||||
def symint_to_int(self) -> "Argument":
|
||||
return dataclasses.replace(self, type=self.type.symint_to_int())
|
||||
|
||||
def __str__(self) -> str:
|
||||
type = f"{self.type}"
|
||||
if self.annotation:
|
||||
|
|
@ -2076,6 +2094,37 @@ class Arguments:
|
|||
if a.annotation is not None and a.annotation.is_write
|
||||
]
|
||||
|
||||
def symints_to_ints(self) -> "Arguments":
|
||||
arguments = self
|
||||
|
||||
if arguments.self_arg:
|
||||
arguments = dataclasses.replace(
|
||||
arguments,
|
||||
pre_self_positional=tuple(
|
||||
x.symint_to_int() for x in arguments.pre_self_positional
|
||||
),
|
||||
)
|
||||
|
||||
if self.tensor_options:
|
||||
arguments = dataclasses.replace(
|
||||
arguments,
|
||||
post_tensor_options_kwarg_only=tuple(
|
||||
x.symint_to_int() for x in arguments.post_tensor_options_kwarg_only
|
||||
),
|
||||
)
|
||||
|
||||
arguments = dataclasses.replace(
|
||||
arguments,
|
||||
post_self_positional=tuple(
|
||||
x.symint_to_int() for x in arguments.post_self_positional
|
||||
),
|
||||
pre_tensor_options_kwarg_only=tuple(
|
||||
x.symint_to_int() for x in arguments.pre_tensor_options_kwarg_only
|
||||
),
|
||||
)
|
||||
|
||||
return arguments
|
||||
|
||||
def has_tensor_arg(self) -> bool:
|
||||
return any(a.type.is_tensor_like() for a in self.flat_non_out)
|
||||
|
||||
|
|
|
|||
|
|
@ -98,9 +98,7 @@ def is_supported(g: Union[NativeFunctionsGroup, NativeFunctionsViewGroup]) -> bo
|
|||
return False
|
||||
|
||||
if isinstance(g, NativeFunctionsViewGroup):
|
||||
# TODO: stop doing type tests by converting to C++ and then testing
|
||||
# the string, just test the dang thing directly
|
||||
if "at::Tensor" != cpp.returns_type(func.returns, symint=False).cpp_type():
|
||||
if "at::Tensor" != cpp.returns_type(func.returns).cpp_type():
|
||||
# Returns a non-Tensor value.
|
||||
logger.info(f"NON-TENSOR RET TYPE: {str(func)}")
|
||||
return False
|
||||
|
|
@ -124,8 +122,7 @@ def is_supported(g: Union[NativeFunctionsGroup, NativeFunctionsViewGroup]) -> bo
|
|||
or not str(func.name).endswith(".out")
|
||||
):
|
||||
return False
|
||||
# TODO: stop type testing by converting to C++
|
||||
if "at::Tensor &" != cpp.returns_type(func.returns, symint=False).cpp_type():
|
||||
if "at::Tensor &" != cpp.returns_type(func.returns).cpp_type():
|
||||
logger.info(f"NON_TENSOR RET TYPE: {str(func)}")
|
||||
return False
|
||||
if has_alias(func.arguments.non_out):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user