mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "Enable dim=None for torch.sum (#75845)"
This reverts commite79a51f7db. Reverted https://github.com/pytorch/pytorch/pull/75845 on behalf of https://github.com/malfet due to Breaks MacOS builds, seee79a51f7db
This commit is contained in:
parent
f9656817df
commit
ee6ebfc06b
2
.github/ci_commit_pins/xla.txt
vendored
2
.github/ci_commit_pins/xla.txt
vendored
|
|
@ -1 +1 @@
|
|||
35f759fdd7eb585679df7c1e6db4569b1aba5475
|
||||
de45c7c503f403be2c85066013b6a860f04f1152
|
||||
|
|
|
|||
|
|
@ -56,21 +56,18 @@ static bool is_allowed_dim_on_scalar_tensor(int64_t dim) {
|
|||
return dim == 0 || dim == -1;
|
||||
}
|
||||
|
||||
Tensor sum_batching_rule(const Tensor& self, OptionalIntArrayRef opt_dims, bool keepdim, optional<ScalarType> dtype) {
|
||||
if (opt_dims.has_value()) {
|
||||
auto dims = opt_dims.value();
|
||||
// PyTorch has a special case where sum(scalar_tensor, dim=0) does not fail
|
||||
// and instead returns a new scalar tensor (this also happens for dim=-1)
|
||||
// If the following happens:
|
||||
// >>> x = torch.randn(B0) # the per-examples are all scalars
|
||||
// >>> vmap(partial(torch.sum, dim=0), x)
|
||||
// then we replicate the behavior of sum(scalar_tensor, dim=0).
|
||||
if (/*logical*/self.dim() == 0 && (dims.size() == 0 || (dims.size() == 1 && is_allowed_dim_on_scalar_tensor(dims[0])))) {
|
||||
return self.clone();
|
||||
}
|
||||
Tensor sum_batching_rule(const Tensor& self, IntArrayRef dims, bool keepdim, optional<ScalarType> dtype) {
|
||||
// PyTorch has a special case where sum(scalar_tensor, dim=0) does not fail
|
||||
// and instead returns a new scalar tensor (this also happens for dim=-1)
|
||||
// If the following happens:
|
||||
// >>> x = torch.randn(B0) # the per-examples are all scalars
|
||||
// >>> vmap(partial(torch.sum, dim=0), x)
|
||||
// then we replicate the behavior of sum(scalar_tensor, dim=0).
|
||||
if (/*logical*/self.dim() == 0 && (dims.size() == 0 || (dims.size() == 1 && is_allowed_dim_on_scalar_tensor(dims[0])))) {
|
||||
return self.clone();
|
||||
}
|
||||
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
|
||||
auto dims_physical = self_physical.getPhysicalDims(opt_dims);
|
||||
auto dims_physical = self_physical.getPhysicalDims(dims);
|
||||
auto result = at::sum(self_physical.tensor(), dims_physical, keepdim, dtype);
|
||||
return self_physical.getPhysicalToLogicalMap().apply(result);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -55,20 +55,13 @@ int64_t VmapPhysicalView::numLogicalDims() const {
|
|||
return /*physical*/tensor_.dim() - numBatchDims();
|
||||
}
|
||||
|
||||
VmapDimVector VmapPhysicalView::getPhysicalDims(OptionalIntArrayRef opt_logical_dims) const {
|
||||
VmapDimVector VmapPhysicalView::getPhysicalDims(IntArrayRef logical_dims) const {
|
||||
auto logical_ndim = numLogicalDims();
|
||||
// NB: fmap doesn't have a SmallVector variant, so we don't use it here.
|
||||
VmapDimVector result;
|
||||
result.reserve(logical_ndim);
|
||||
if (opt_logical_dims.has_value()) {
|
||||
auto logical_dims = opt_logical_dims.value();
|
||||
for (auto dim : logical_dims) {
|
||||
result.push_back(maybe_wrap_dim(dim, logical_ndim) + numBatchDims());
|
||||
}
|
||||
} else {
|
||||
for (int64_t dim = 0; dim < logical_ndim; dim++) {
|
||||
result.push_back(dim + numBatchDims());
|
||||
}
|
||||
for (auto dim : logical_dims) {
|
||||
result.push_back(maybe_wrap_dim(dim, logical_ndim) + numBatchDims());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -131,7 +131,7 @@ struct TORCH_API VmapPhysicalView {
|
|||
// This is because the size of levels tell us that the first two dimensions
|
||||
// of `tensor_` are batch dimensions, so a logical dim of `n` is actually
|
||||
// a physical dim of `n + 2`.
|
||||
VmapDimVector getPhysicalDims(OptionalIntArrayRef logical_dims) const;
|
||||
VmapDimVector getPhysicalDims(IntArrayRef logical_dims) const;
|
||||
int64_t getPhysicalDim(int64_t logical_dim) const;
|
||||
|
||||
// Returns a VmapPhysicalToLogicalMap object. This can be used for
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ namespace at {
|
|||
constexpr size_t dim_bitset_size = 64;
|
||||
|
||||
static inline std::bitset<dim_bitset_size> dim_list_to_bitset(
|
||||
OptionalIntArrayRef opt_dims,
|
||||
IntArrayRef dims,
|
||||
int64_t ndims) {
|
||||
TORCH_CHECK(
|
||||
ndims <= (int64_t)dim_bitset_size,
|
||||
|
|
@ -22,21 +22,11 @@ static inline std::bitset<dim_bitset_size> dim_list_to_bitset(
|
|||
dim_bitset_size,
|
||||
" dims are supported");
|
||||
std::bitset<dim_bitset_size> seen;
|
||||
if (opt_dims.has_value()) {
|
||||
auto dims = opt_dims.value();
|
||||
for (const auto i : c10::irange(dims.size())) {
|
||||
size_t dim = maybe_wrap_dim(dims[i], ndims);
|
||||
TORCH_CHECK(
|
||||
!seen[dim],
|
||||
"dim ",
|
||||
dim,
|
||||
" appears multiple times in the list of dims");
|
||||
seen[dim] = true;
|
||||
}
|
||||
} else {
|
||||
for (int64_t dim = 0; dim < ndims; dim++) {
|
||||
seen[dim] = true;
|
||||
}
|
||||
for (const auto i : c10::irange(dims.size())) {
|
||||
size_t dim = maybe_wrap_dim(dims[i], ndims);
|
||||
TORCH_CHECK(
|
||||
!seen[dim], "dim ", dim, " appears multiple times in the list of dims");
|
||||
seen[dim] = true;
|
||||
}
|
||||
return seen;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -455,7 +455,7 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) {
|
|||
// KERNEL(ADD_NS(norm), "norm.ScalarOpt_dim_dtype", Tensor (const Tensor &, c10::optional<Scalar>, IntArrayRef, bool, ScalarType), fp32_set_opt_dtype)
|
||||
// KERNEL(ADD_NS(norm), "norm.names_ScalarOpt_dim_dtype", Tensor (const Tensor &, c10::optional<Scalar>, DimnameList, bool, ScalarType), fp32_set_opt_dtype)
|
||||
KERNEL(ADD_NS(sum), "sum", Tensor (const Tensor &, c10::optional<ScalarType>), fp32_set_opt_dtype)
|
||||
KERNEL(ADD_NS(sum), "sum.dim_IntList", Tensor (const Tensor &, OptionalIntArrayRef, bool, c10::optional<ScalarType>), fp32_set_opt_dtype)
|
||||
KERNEL(ADD_NS(sum), "sum.dim_IntList", Tensor (const Tensor &, IntArrayRef, bool, c10::optional<ScalarType>), fp32_set_opt_dtype)
|
||||
KERNEL(ADD_NS(sum), "sum.dim_DimnameList", Tensor (const Tensor &, DimnameList, bool, c10::optional<ScalarType>), fp32_set_opt_dtype)
|
||||
// fp32_append_dtype
|
||||
// The fp32_append_dtype wrapper overrides implicit promotion behavior.
|
||||
|
|
|
|||
|
|
@ -52,6 +52,8 @@ namespace meta {
|
|||
|
||||
static ScalarType infer_dtype_from_optional(
|
||||
const Tensor& self,
|
||||
IntArrayRef dim,
|
||||
bool keepdim,
|
||||
const optional<ScalarType>& opt_dtype,
|
||||
const Tensor& result) {
|
||||
// 'opt_dtype' has the priority for both cases.
|
||||
|
|
@ -185,9 +187,9 @@ TORCH_META_FUNC(cumprod)
|
|||
}
|
||||
|
||||
TORCH_META_FUNC2(sum, dim_IntList)
|
||||
(const Tensor& self, OptionalIntArrayRef opt_dim, bool keepdim, optional<ScalarType> opt_dtype) {
|
||||
auto out_dtype = infer_dtype_from_optional(self, opt_dtype, maybe_get_output());
|
||||
resize_reduction(*this, self, opt_dim, keepdim, out_dtype);
|
||||
(const Tensor& self, IntArrayRef dim, bool keepdim, optional<ScalarType> opt_dtype) {
|
||||
auto out_dtype = infer_dtype_from_optional(self, dim, keepdim, opt_dtype, maybe_get_output());
|
||||
resize_reduction(*this, self, dim, keepdim, out_dtype);
|
||||
}
|
||||
|
||||
TORCH_META_FUNC2(prod, dim_int)
|
||||
|
|
@ -195,7 +197,7 @@ TORCH_META_FUNC2(prod, dim_int)
|
|||
int64_t dim,
|
||||
bool keepdim,
|
||||
c10::optional<ScalarType> dtype) {
|
||||
auto out_dtype = infer_dtype_from_optional(self, dtype, maybe_get_output());
|
||||
auto out_dtype = infer_dtype_from_optional(self, dim, keepdim, dtype, maybe_get_output());
|
||||
resize_reduction(*this, self, dim, keepdim, out_dtype);
|
||||
}
|
||||
|
||||
|
|
@ -219,7 +221,7 @@ TORCH_META_FUNC2(mean, dim)
|
|||
"Got: ", dtype);
|
||||
}
|
||||
|
||||
auto out_dtype = infer_dtype_from_optional(self, opt_dtype, maybe_get_output());
|
||||
auto out_dtype = infer_dtype_from_optional(self, dim, keepdim, opt_dtype, maybe_get_output());
|
||||
resize_reduction(*this, self, dim, keepdim, out_dtype);
|
||||
}
|
||||
|
||||
|
|
@ -1059,11 +1061,11 @@ inline ScalarType get_dtype_from_result(Tensor& result, optional<ScalarType> dty
|
|||
|
||||
TORCH_IMPL_FUNC(sum_out)
|
||||
(const Tensor& self,
|
||||
OptionalIntArrayRef opt_dim,
|
||||
IntArrayRef dim,
|
||||
bool keepdim,
|
||||
optional<ScalarType> opt_dtype,
|
||||
const Tensor& result) {
|
||||
auto iter = meta::make_reduction_from_out_ty(self, result, opt_dim, keepdim, result.scalar_type());
|
||||
auto iter = meta::make_reduction_from_out_ty(self, result, dim, keepdim, result.scalar_type());
|
||||
if (iter.numel() == 0) {
|
||||
result.zero_();
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -110,27 +110,12 @@ static inline Tensor integer_upcast(const Tensor& self, optional<ScalarType> dty
|
|||
|
||||
using DimMask = TensorIterator::DimMask;
|
||||
|
||||
static DimVector make_dim_vector(OptionalIntArrayRef opt_dims, int64_t ndim) {
|
||||
if (opt_dims.has_value()) {
|
||||
return DimVector(opt_dims.value());
|
||||
} else {
|
||||
std::vector<int64_t> all_dims(ndim);
|
||||
std::iota(all_dims.begin(), all_dims.end(), 0);
|
||||
return DimVector(all_dims);
|
||||
}
|
||||
}
|
||||
|
||||
static DimMask make_dim_mask(OptionalIntArrayRef opt_dims, int64_t ndim) {
|
||||
static DimMask make_dim_mask(IntArrayRef dims, int64_t ndim) {
|
||||
DimMask mask;
|
||||
if (opt_dims.has_value()) {
|
||||
auto dims = opt_dims.value();
|
||||
if (dims.empty()) {
|
||||
mask = DimMask().flip();
|
||||
} else {
|
||||
mask = at::dim_list_to_bitset(dims, ndim);
|
||||
}
|
||||
} else {
|
||||
if (dims.empty()) {
|
||||
mask = DimMask().flip();
|
||||
} else {
|
||||
mask = at::dim_list_to_bitset(dims, ndim);
|
||||
}
|
||||
return mask;
|
||||
}
|
||||
|
|
@ -335,10 +320,10 @@ static C10_UNUSED DimVector get_reduction_shape(
|
|||
static void resize_reduction(
|
||||
impl::MetaBase& meta,
|
||||
const Tensor& self,
|
||||
OptionalIntArrayRef opt_dims,
|
||||
IntArrayRef dims,
|
||||
bool keepdim,
|
||||
ScalarType out_dtype) {
|
||||
DimVector dims_ = at::native::make_dim_vector(opt_dims, self.dim());
|
||||
DimVector dims_(dims);
|
||||
maybe_wrap_dims(dims_, self.dim());
|
||||
auto shape = get_reduction_shape(self, dims_, keepdim);
|
||||
meta.set_output_raw_strided(0, shape, {}, self.options().dtype(out_dtype));
|
||||
|
|
@ -366,11 +351,11 @@ static void resize_reduction_with_indices(
|
|||
static TensorIterator make_reduction(
|
||||
const Tensor& self,
|
||||
const Tensor& result,
|
||||
OptionalIntArrayRef opt_dims,
|
||||
IntArrayRef dims,
|
||||
bool keepdim,
|
||||
ScalarType in_dtype) {
|
||||
int64_t ndim = self.dim();
|
||||
auto mask = at::native::make_dim_mask(opt_dims, ndim);
|
||||
auto mask = at::native::make_dim_mask(dims, ndim);
|
||||
auto viewed_result =
|
||||
at::native::review_reduce_result(result, ndim, mask, keepdim);
|
||||
if (self.scalar_type() == in_dtype) {
|
||||
|
|
@ -404,7 +389,7 @@ static TensorIterator make_reduction(
|
|||
static C10_UNUSED TensorIterator make_reduction_from_out_ty(
|
||||
const Tensor& self,
|
||||
const Tensor& result,
|
||||
OptionalIntArrayRef opt_dims,
|
||||
IntArrayRef dims,
|
||||
bool keepdim,
|
||||
ScalarType out_dtype) {
|
||||
// special case for type promotion in mixed precision, improves computational
|
||||
|
|
@ -416,7 +401,7 @@ static C10_UNUSED TensorIterator make_reduction_from_out_ty(
|
|||
(self.scalar_type() == kHalf || self.scalar_type() == kBFloat16) &&
|
||||
out_dtype == kFloat);
|
||||
auto in_dtype = gpu_lowp_to_f32 ? self.scalar_type() : out_dtype;
|
||||
return make_reduction(self, result, opt_dims, keepdim, in_dtype);
|
||||
return make_reduction(self, result, dims, keepdim, in_dtype);
|
||||
}
|
||||
|
||||
} // namespace meta
|
||||
|
|
|
|||
|
|
@ -4534,7 +4534,7 @@
|
|||
CompositeExplicitAutograd: sum
|
||||
SparseCsrCPU, SparseCsrCUDA: sum_csr
|
||||
|
||||
- func: sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
|
||||
- func: sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
|
||||
structured_delegate: sum.IntList_out
|
||||
device_check: NoCheck # TensorIterator
|
||||
variants: function, method
|
||||
|
|
@ -4543,7 +4543,7 @@
|
|||
device_check: NoCheck # TensorIterator
|
||||
variants: function, method
|
||||
|
||||
- func: sum.IntList_out(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
|
||||
- func: sum.IntList_out(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
|
||||
structured: True
|
||||
device_check: NoCheck # TensorIterator
|
||||
dispatch:
|
||||
|
|
|
|||
|
|
@ -74,9 +74,6 @@ class OptionalArrayRef final {
|
|||
Args&&... args)
|
||||
: wrapped_opt_array_ref(ip, il, args...) {}
|
||||
|
||||
constexpr OptionalArrayRef(const std::initializer_list<T>& Vec)
|
||||
: wrapped_opt_array_ref(ArrayRef<T>(Vec)) {}
|
||||
|
||||
// Destructor
|
||||
|
||||
~OptionalArrayRef() = default;
|
||||
|
|
|
|||
|
|
@ -5365,7 +5365,7 @@ a")
|
|||
def func2(x):
|
||||
return x.sum(dim=4)
|
||||
|
||||
# test that shape analysis is written correctly for sum with OptionalIntArrayRef[1] dim argument
|
||||
# test that shape analysis is written correctly for sum with IntArrayRef[1] dim argument
|
||||
self.run_pass('constant_propagation', func.graph)
|
||||
self.run_pass('constant_propagation', func2.graph)
|
||||
g = _propagate_shapes(func.graph, (torch.zeros(1, 1, 1, 1, 4),), False)
|
||||
|
|
|
|||
|
|
@ -1195,11 +1195,8 @@ class TestNamedTensor(TestCase):
|
|||
check_output(op(t, 1), ['N', 'L'])
|
||||
check_output(op(t, -1), ['N', 'C'])
|
||||
check_output(op(t, 'C'), ['N', 'L'])
|
||||
if op.__name__ in ['sum']:
|
||||
check_output(op(t, None), [])
|
||||
else:
|
||||
with self.assertRaisesRegex(RuntimeError, 'Please look up dimensions by name'):
|
||||
op(t, None)
|
||||
with self.assertRaisesRegex(RuntimeError, 'Please look up dimensions by name'):
|
||||
op(t, None)
|
||||
with self.assertRaisesRegex(RuntimeError, 'Name \'H\' not found'):
|
||||
op(t, 'H')
|
||||
|
||||
|
|
|
|||
|
|
@ -1517,7 +1517,7 @@
|
|||
self: grad.expand(self.sizes())
|
||||
result: auto_linear
|
||||
|
||||
- name: sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
|
||||
- name: sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
|
||||
self: sum_backward(grad, self.sizes(), dim, keepdim)
|
||||
result: auto_linear
|
||||
|
||||
|
|
|
|||
|
|
@ -20,7 +20,6 @@
|
|||
#include <ATen/native/IndexingUtils.h>
|
||||
#include <ATen/native/LinearAlgebraUtils.h>
|
||||
#include <c10/core/TensorOptions.h>
|
||||
#include <c10/util/OptionalArrayRef.h>
|
||||
#include <c10/util/SmallBuffer.h>
|
||||
#include <c10/util/accumulate.h>
|
||||
#include <c10/util/irange.h>
|
||||
|
|
@ -39,7 +38,6 @@ namespace details {
|
|||
|
||||
using at::areAnyTensorSubclassLike;
|
||||
using at::IntArrayRef;
|
||||
using at::OptionalIntArrayRef;
|
||||
using at::Scalar;
|
||||
using at::Tensor;
|
||||
using at::TensorList;
|
||||
|
|
@ -537,11 +535,8 @@ Tensor deg2rad_backward(const Tensor& grad) {
|
|||
return at::mul(grad, at::native::wrapped_scalar_tensor(Scalar(M_PI_180)));
|
||||
}
|
||||
|
||||
Tensor unsqueeze_multiple(
|
||||
const Tensor& t,
|
||||
OptionalIntArrayRef opt_dim,
|
||||
size_t n_dims) {
|
||||
auto dims_to_unsqueeze = at::dim_list_to_bitset(opt_dim, n_dims);
|
||||
Tensor unsqueeze_multiple(const Tensor& t, IntArrayRef dim, size_t n_dims) {
|
||||
auto dims_to_unsqueeze = at::dim_list_to_bitset(dim, n_dims);
|
||||
Tensor res = t;
|
||||
for (const auto i : c10::irange(n_dims)) {
|
||||
if (dims_to_unsqueeze[i]) {
|
||||
|
|
@ -554,13 +549,13 @@ Tensor unsqueeze_multiple(
|
|||
Tensor sum_backward(
|
||||
const Tensor& grad,
|
||||
IntArrayRef sizes,
|
||||
OptionalIntArrayRef opt_dims,
|
||||
IntArrayRef dims,
|
||||
bool keepdim) {
|
||||
if (!keepdim && sizes.size() > 0) {
|
||||
if (opt_dims.has_value() && opt_dims.value().size() == 1) {
|
||||
return grad.unsqueeze(opt_dims.value()[0]).expand(sizes);
|
||||
if (dims.size() == 1) {
|
||||
return grad.unsqueeze(dims[0]).expand(sizes);
|
||||
} else {
|
||||
Tensor res = unsqueeze_multiple(grad, opt_dims, sizes.size());
|
||||
Tensor res = unsqueeze_multiple(grad, dims, sizes.size());
|
||||
return res.expand(sizes);
|
||||
}
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -146,12 +146,12 @@ at::Tensor rad2deg_backward(const at::Tensor& grad);
|
|||
at::Tensor deg2rad_backward(const at::Tensor& grad);
|
||||
at::Tensor unsqueeze_multiple(
|
||||
const at::Tensor& t,
|
||||
at::OptionalIntArrayRef opt_dim,
|
||||
at::IntArrayRef dim,
|
||||
size_t n_dims);
|
||||
at::Tensor sum_backward(
|
||||
const at::Tensor& grad,
|
||||
at::IntArrayRef sizes,
|
||||
at::OptionalIntArrayRef opt_dims,
|
||||
at::IntArrayRef dims,
|
||||
bool keepdim);
|
||||
at::Tensor nansum_backward(
|
||||
const at::Tensor& grad,
|
||||
|
|
|
|||
|
|
@ -2478,7 +2478,7 @@ class IrParser {
|
|||
|
||||
{
|
||||
auto ptr_op = getOperatorForLiteral(
|
||||
"aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)");
|
||||
"aten::sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)");
|
||||
REGISTER_PARSE_RULE(
|
||||
ptr_op,
|
||||
{
|
||||
|
|
@ -3855,7 +3855,7 @@ bool insertProfileIValue(ProfilingRecord* pr, Node* node, size_t offset) {
|
|||
|
||||
static auto reduction_operator_schema =
|
||||
getOperatorForLiteral(
|
||||
"aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)")
|
||||
"aten::sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)")
|
||||
->schema();
|
||||
if (node->matches(reduction_operator_schema)) {
|
||||
switch (offset) {
|
||||
|
|
|
|||
|
|
@ -1980,7 +1980,7 @@ class ShapePropagator : public PropertyPropBase {
|
|||
return true;
|
||||
} else if (
|
||||
node->matches(
|
||||
"aten::sum(Tensor self, int[]? dim, bool keepdim, *, int? dtype) -> Tensor",
|
||||
"aten::sum(Tensor self, int[] dim, bool keepdim, *, int? dtype) -> Tensor",
|
||||
/*const_inputs=*/{attr::dim, attr::keepdim})) {
|
||||
auto& tp = tensor_types.at(0);
|
||||
auto sizes = tp->sizes().concrete_sizes().value();
|
||||
|
|
|
|||
|
|
@ -94,7 +94,7 @@ bool isSupported(Node* node) {
|
|||
|
||||
static const OperatorSet supported_reduction_set{
|
||||
"aten::sum(Tensor self, *, ScalarType? dtype=None) -> Tensor",
|
||||
"aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor",
|
||||
"aten::sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor",
|
||||
"aten::softmax.int(Tensor self, int dim , ScalarType? dtype=None) -> Tensor",
|
||||
"aten::log_softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor",
|
||||
};
|
||||
|
|
|
|||
|
|
@ -2158,54 +2158,6 @@ def transpose(self: List[int],
|
|||
_4 = torch.append(out, self[idx])
|
||||
return out
|
||||
|
||||
)=====")
|
||||
+ std::string(R"=====(def sum_dim(self: List[int],
|
||||
opt_dims: Optional[List[int]],
|
||||
keep_dim: bool,
|
||||
dt: Any) -> List[int]:
|
||||
out = annotate(List[int], [])
|
||||
if opt_dims is None:
|
||||
dims:List[int] = []
|
||||
else:
|
||||
dims = opt_dims
|
||||
for idx in range(torch.len(self)):
|
||||
is_mean_dim = False
|
||||
for _0 in range(torch.len(dims)):
|
||||
reduce_dim = dims[_0]
|
||||
_1 = torch.len(self)
|
||||
if torch.le(_1, 0):
|
||||
dim_post_expr = 1
|
||||
else:
|
||||
dim_post_expr = _1
|
||||
min = torch.neg(dim_post_expr)
|
||||
max = torch.sub(dim_post_expr, 1)
|
||||
if torch.lt(reduce_dim, min):
|
||||
_2 = True
|
||||
else:
|
||||
_2 = torch.gt(reduce_dim, max)
|
||||
if torch.__not__(_2):
|
||||
pass
|
||||
else:
|
||||
ops.prim.RaiseException("AssertionError: ")
|
||||
if torch.lt(reduce_dim, 0):
|
||||
dim0 = torch.add(reduce_dim, dim_post_expr)
|
||||
dim = dim0
|
||||
else:
|
||||
dim = reduce_dim
|
||||
if torch.eq(idx, dim):
|
||||
is_mean_dim0 = True
|
||||
else:
|
||||
is_mean_dim0 = is_mean_dim
|
||||
is_mean_dim = is_mean_dim0
|
||||
if is_mean_dim:
|
||||
if keep_dim:
|
||||
_3 = torch.append(out, 1)
|
||||
else:
|
||||
pass
|
||||
else:
|
||||
_4 = torch.append(out, self[idx])
|
||||
return out
|
||||
|
||||
)=====")
|
||||
+ std::string(R"=====(def max_dim(self: List[int],
|
||||
dim: int,
|
||||
|
|
@ -2797,7 +2749,7 @@ const OperatorMap<std::string>& GetShapeFunctionMappings() {
|
|||
{"aten::expand_as(Tensor(a) self, Tensor other) -> Tensor(a)", "expand"},
|
||||
{"aten::expand(Tensor(a) self, int[] size, *, bool implicit=False) -> Tensor(a)", "expand_one_unused"},
|
||||
{"aten::mean.dim(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "mean_dim"},
|
||||
{"aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "sum_dim"},
|
||||
{"aten::sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "mean_dim"},
|
||||
{"aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)", "max_dim"},
|
||||
{"aten::mean(Tensor self, *, ScalarType? dtype=None) -> Tensor", "zero_dim_tensor"},
|
||||
{"aten::sum(Tensor self, *, ScalarType? dtype=None) -> Tensor", "zero_dim_tensor"},
|
||||
|
|
|
|||
|
|
@ -1691,10 +1691,10 @@ REGISTER_OPERATOR_FUNCTOR(aten::sum, aten_sum, [](Node* n) -> SROperator {
|
|||
};
|
||||
}
|
||||
if (n->matches(torch::schema(
|
||||
"aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"))) {
|
||||
"aten::sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"))) {
|
||||
return [](ProcessedNode* p_node) {
|
||||
const at::Tensor& self = p_node->Input(0).toTensor();
|
||||
auto dim = p_node->Input(1).toDimVector();
|
||||
auto dim = p_node->Input(1).toIntList().vec();
|
||||
auto keepdim = p_node->Input(2).toBool();
|
||||
auto dtype = p_node->Input(3).toOptional<at::ScalarType>();
|
||||
if (p_node->Output(0).isNone()) {
|
||||
|
|
|
|||
|
|
@ -1767,7 +1767,7 @@ int nnc_lowerings_lazy_registration() {
|
|||
|
||||
RegisterNNCLoweringsFunction aten_sum(
|
||||
{"aten::sum(Tensor self, *, int? dtype=None) -> (Tensor)",
|
||||
"aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)"},
|
||||
"aten::sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)"},
|
||||
computeSum);
|
||||
|
||||
RegisterNNCLoweringsFunction aten_softmax(
|
||||
|
|
|
|||
|
|
@ -1004,7 +1004,7 @@ add_shape_compute_mapping("aten::view(Tensor(a) self, int[] size) -> Tensor(a)",
|
|||
add_shape_compute_mapping("aten::expand_as(Tensor(a) self, Tensor other) -> Tensor(a)", expand)
|
||||
add_shape_compute_mapping("aten::expand(Tensor(a) self, int[] size, *, bool implicit=False) -> Tensor(a)", expand_one_unused)
|
||||
add_shape_compute_mapping("aten::mean.dim(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", mean_dim)
|
||||
add_shape_compute_mapping("aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", mean_dim)
|
||||
add_shape_compute_mapping("aten::sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", mean_dim)
|
||||
add_shape_compute_mapping("aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)", max_dim)
|
||||
add_shape_compute_mapping("aten::mean(Tensor self, *, ScalarType? dtype=None) -> Tensor", zero_dim_tensor)
|
||||
add_shape_compute_mapping("aten::sum(Tensor self, *, ScalarType? dtype=None) -> Tensor", zero_dim_tensor)
|
||||
|
|
|
|||
|
|
@ -18703,6 +18703,9 @@ op_db: List[OpInfo] = [
|
|||
# FIXME: sum reduces all dimensions when dim=[]
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'),
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'),
|
||||
# FIXME: sum does not support passing None to dim
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_none'),
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_none_keepdim'),
|
||||
# FIXME: improve precision
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input',
|
||||
dtypes=[torch.float16]),
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user