[MPS] add support for sgn to MPS backend (#110829)

Fixes #86805

Adds support for sgn to MPS backend.

Notes:

1. @malfet self-assigned this when he was working on implementing polar, but from what I can tell, he didn't end up needing to implement it.

2. @Berzeg implemented this last year, before view_as_complex was supported. Because of @malfet recent contributions, however, @Berzeg 's implementation works. I've removed the part of his implementation that dealt with non-complex dtypes (since these can just be passed to at::sign), matched the more recent pattern we've been using in UnaryOps.mm, and thrown in a simple implementation of _efficientzerotensor for mps, so that the backward function works.
3. @Berzeg deserves a good bit of credit for this, so let me know if there's a way to assign him some without jamming up the pr (he seems to be AWOL since last working on this)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110829
Approved by: https://github.com/malfet
This commit is contained in:
igm503 2023-10-09 16:53:25 +00:00 committed by PyTorch MergeBot
parent 144cda7f06
commit 4b881b0da3
5 changed files with 55 additions and 2 deletions

View File

@ -92,7 +92,6 @@ TORCH_LIBRARY_IMPL(aten, MPS, m) {
m.impl("linalg_svd.U", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>());
m.impl("im2col", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>()); // Used in preprocessing by nn.Unfold
m.impl("col2im", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>());
m.impl("sgn.out", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>());
m.impl("_slow_conv2d_forward", slow_conv2d_forward_mps);
m.impl("upsample_nearest3d.vec", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>());
}

View File

@ -10,6 +10,14 @@
#include <ATen/native/ResizeCommon.h>
#include <ATen/native/mps/Copy.h>
#include <ATen/native/mps/TensorFactory.h>
#include <ATen/Dispatch.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#endif
#include <ATen/ops/_efficientzerotensor_native.h>
namespace at::native {
static inline void maybe_resize_storage_mps(TensorImpl* self, uint64_t new_size) {
@ -140,4 +148,17 @@ Tensor& set_storage_mps_(Tensor& result, Storage storage, int64_t storage_offset
return result;
}
Tensor _efficientzerotensor_mps(IntArrayRef size,
c10::optional<ScalarType> dtype,
c10::optional<Layout> layout,
c10::optional<Device> device,
c10::optional<bool> pin_memory) {
auto device_ = device_or_default(device);
auto allocator = at::native::ZeroTensorAllocator(device_);
auto dtype_ = dtype_or_default(dtype);
auto zero_ks = at::DispatchKeySet(c10::DispatchKey::MPS) | at::DispatchKeySet(c10::DispatchKey::ZeroTensor);
auto out = at::detail::empty_generic(size, &allocator, zero_ks, dtype_, c10::nullopt);
return out;
}
} // namespace at::native

View File

@ -36,9 +36,12 @@
#include <ATen/ops/logit_native.h>
#include <ATen/ops/neg_native.h>
#include <ATen/ops/reciprocal_native.h>
#include <ATen/ops/reshape.h>
#include <ATen/ops/round_native.h>
#include <ATen/ops/rsqrt_native.h>
#include <ATen/ops/sgn_native.h>
#include <ATen/ops/sigmoid_native.h>
#include <ATen/ops/sign.h>
#include <ATen/ops/sign_native.h>
#include <ATen/ops/signbit_native.h>
#include <ATen/ops/sin_native.h>
@ -47,6 +50,7 @@
#include <ATen/ops/tan_native.h>
#include <ATen/ops/tanh_native.h>
#include <ATen/ops/trunc_native.h>
#include <ATen/ops/view_as_real.h>
#endif
namespace at::native {
@ -453,4 +457,32 @@ TORCH_IMPL_FUNC(cumprod_out_mps)
return cumulative_op_impl(self, dim, dtype, result, MPSCumulativeOpType::CUMPROD, "cumprod_out_mps");
}
TORCH_IMPL_FUNC(sgn_out_mps)(const Tensor& self, const Tensor& output) {
if (!self.is_complex()) {
Tensor output_copy = output.alias();
at::sign_out(output_copy, self);
output.copy_(output_copy);
return;
}
if (!output.is_same_size(self)) {
output.resize_(self.sizes());
}
Tensor realInput = at::view_as_real(self);
Tensor realOutput = at::view_as_real(output);
auto complex_sgn_op = [&](MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) -> MPSGraphTensor* {
MPSGraphTensor* squares = [mpsGraph squareWithTensor:inputTensor name:nil];
MPSGraphTensor* sumSquares = [mpsGraph reductionSumWithTensor:squares axis:-1 name:nil];
MPSGraphTensor* norm = [mpsGraph squareRootWithTensor:sumSquares name:nil];
MPSGraphTensor* zero = [mpsGraph constantWithScalar:0.0 dataType:norm.dataType];
MPSGraphTensor* isZero = [mpsGraph equalWithPrimaryTensor:norm secondaryTensor:zero name:nil];
MPSGraphTensor* sgnTensor = [mpsGraph divisionWithPrimaryTensor:inputTensor secondaryTensor:norm name:nil];
return [mpsGraph selectWithPredicateTensor:isZero truePredicateTensor:zero falsePredicateTensor:sgnTensor name:nil];
};
mps::unary_op(realInput, realOutput, "sgn_out_mps", complex_sgn_op);
}
} // namespace at::native

View File

@ -431,6 +431,7 @@
structured_inherits: TensorIteratorBase
dispatch:
CPU, CUDA: sgn_out
MPS: sgn_out_mps
SparseCPU, SparseCUDA: sgn_sparse_out
SparseCsrCPU, SparseCsrCUDA: sgn_sparse_csr_out
tags: pointwise
@ -6294,6 +6295,7 @@
dispatch:
CPU: _efficientzerotensor
CUDA: _efficientzerotensor_cuda
MPS: _efficientzerotensor_mps
Meta: _efficientzerotensor_meta
autogen: _efficientzerotensor.out

View File

@ -69,7 +69,6 @@ def mps_ops_grad_modifier(ops):
# Unimplemented ops
'__getitem__': [torch.float16],
'sgn': [torch.float16, torch.float32],
'_segment_reduce': [torch.float16, torch.float32],
'unfold_copy': [torch.float16, torch.float32], # unfold_backward is not implemented
'unfold': [torch.float16, torch.float32],