mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[MPS] add support for sgn to MPS backend (#110829)
Fixes #86805 Adds support for sgn to MPS backend. Notes: 1. @malfet self-assigned this when he was working on implementing polar, but from what I can tell, he didn't end up needing to implement it. 2. @Berzeg implemented this last year, before view_as_complex was supported. Because of @malfet recent contributions, however, @Berzeg 's implementation works. I've removed the part of his implementation that dealt with non-complex dtypes (since these can just be passed to at::sign), matched the more recent pattern we've been using in UnaryOps.mm, and thrown in a simple implementation of _efficientzerotensor for mps, so that the backward function works. 3. @Berzeg deserves a good bit of credit for this, so let me know if there's a way to assign him some without jamming up the pr (he seems to be AWOL since last working on this) Pull Request resolved: https://github.com/pytorch/pytorch/pull/110829 Approved by: https://github.com/malfet
This commit is contained in:
parent
144cda7f06
commit
4b881b0da3
|
|
@ -92,7 +92,6 @@ TORCH_LIBRARY_IMPL(aten, MPS, m) {
|
|||
m.impl("linalg_svd.U", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>());
|
||||
m.impl("im2col", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>()); // Used in preprocessing by nn.Unfold
|
||||
m.impl("col2im", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>());
|
||||
m.impl("sgn.out", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>());
|
||||
m.impl("_slow_conv2d_forward", slow_conv2d_forward_mps);
|
||||
m.impl("upsample_nearest3d.vec", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,6 +10,14 @@
|
|||
#include <ATen/native/ResizeCommon.h>
|
||||
#include <ATen/native/mps/Copy.h>
|
||||
#include <ATen/native/mps/TensorFactory.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#endif
|
||||
#include <ATen/ops/_efficientzerotensor_native.h>
|
||||
|
||||
namespace at::native {
|
||||
|
||||
static inline void maybe_resize_storage_mps(TensorImpl* self, uint64_t new_size) {
|
||||
|
|
@ -140,4 +148,17 @@ Tensor& set_storage_mps_(Tensor& result, Storage storage, int64_t storage_offset
|
|||
return result;
|
||||
}
|
||||
|
||||
Tensor _efficientzerotensor_mps(IntArrayRef size,
|
||||
c10::optional<ScalarType> dtype,
|
||||
c10::optional<Layout> layout,
|
||||
c10::optional<Device> device,
|
||||
c10::optional<bool> pin_memory) {
|
||||
auto device_ = device_or_default(device);
|
||||
auto allocator = at::native::ZeroTensorAllocator(device_);
|
||||
auto dtype_ = dtype_or_default(dtype);
|
||||
auto zero_ks = at::DispatchKeySet(c10::DispatchKey::MPS) | at::DispatchKeySet(c10::DispatchKey::ZeroTensor);
|
||||
auto out = at::detail::empty_generic(size, &allocator, zero_ks, dtype_, c10::nullopt);
|
||||
return out;
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
|
|
|
|||
|
|
@ -36,9 +36,12 @@
|
|||
#include <ATen/ops/logit_native.h>
|
||||
#include <ATen/ops/neg_native.h>
|
||||
#include <ATen/ops/reciprocal_native.h>
|
||||
#include <ATen/ops/reshape.h>
|
||||
#include <ATen/ops/round_native.h>
|
||||
#include <ATen/ops/rsqrt_native.h>
|
||||
#include <ATen/ops/sgn_native.h>
|
||||
#include <ATen/ops/sigmoid_native.h>
|
||||
#include <ATen/ops/sign.h>
|
||||
#include <ATen/ops/sign_native.h>
|
||||
#include <ATen/ops/signbit_native.h>
|
||||
#include <ATen/ops/sin_native.h>
|
||||
|
|
@ -47,6 +50,7 @@
|
|||
#include <ATen/ops/tan_native.h>
|
||||
#include <ATen/ops/tanh_native.h>
|
||||
#include <ATen/ops/trunc_native.h>
|
||||
#include <ATen/ops/view_as_real.h>
|
||||
#endif
|
||||
|
||||
namespace at::native {
|
||||
|
|
@ -453,4 +457,32 @@ TORCH_IMPL_FUNC(cumprod_out_mps)
|
|||
return cumulative_op_impl(self, dim, dtype, result, MPSCumulativeOpType::CUMPROD, "cumprod_out_mps");
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(sgn_out_mps)(const Tensor& self, const Tensor& output) {
|
||||
if (!self.is_complex()) {
|
||||
Tensor output_copy = output.alias();
|
||||
at::sign_out(output_copy, self);
|
||||
output.copy_(output_copy);
|
||||
return;
|
||||
}
|
||||
|
||||
if (!output.is_same_size(self)) {
|
||||
output.resize_(self.sizes());
|
||||
}
|
||||
|
||||
Tensor realInput = at::view_as_real(self);
|
||||
Tensor realOutput = at::view_as_real(output);
|
||||
|
||||
auto complex_sgn_op = [&](MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) -> MPSGraphTensor* {
|
||||
MPSGraphTensor* squares = [mpsGraph squareWithTensor:inputTensor name:nil];
|
||||
MPSGraphTensor* sumSquares = [mpsGraph reductionSumWithTensor:squares axis:-1 name:nil];
|
||||
MPSGraphTensor* norm = [mpsGraph squareRootWithTensor:sumSquares name:nil];
|
||||
MPSGraphTensor* zero = [mpsGraph constantWithScalar:0.0 dataType:norm.dataType];
|
||||
MPSGraphTensor* isZero = [mpsGraph equalWithPrimaryTensor:norm secondaryTensor:zero name:nil];
|
||||
MPSGraphTensor* sgnTensor = [mpsGraph divisionWithPrimaryTensor:inputTensor secondaryTensor:norm name:nil];
|
||||
return [mpsGraph selectWithPredicateTensor:isZero truePredicateTensor:zero falsePredicateTensor:sgnTensor name:nil];
|
||||
};
|
||||
|
||||
mps::unary_op(realInput, realOutput, "sgn_out_mps", complex_sgn_op);
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
|
|
|
|||
|
|
@ -431,6 +431,7 @@
|
|||
structured_inherits: TensorIteratorBase
|
||||
dispatch:
|
||||
CPU, CUDA: sgn_out
|
||||
MPS: sgn_out_mps
|
||||
SparseCPU, SparseCUDA: sgn_sparse_out
|
||||
SparseCsrCPU, SparseCsrCUDA: sgn_sparse_csr_out
|
||||
tags: pointwise
|
||||
|
|
@ -6294,6 +6295,7 @@
|
|||
dispatch:
|
||||
CPU: _efficientzerotensor
|
||||
CUDA: _efficientzerotensor_cuda
|
||||
MPS: _efficientzerotensor_mps
|
||||
Meta: _efficientzerotensor_meta
|
||||
autogen: _efficientzerotensor.out
|
||||
|
||||
|
|
|
|||
|
|
@ -69,7 +69,6 @@ def mps_ops_grad_modifier(ops):
|
|||
|
||||
# Unimplemented ops
|
||||
'__getitem__': [torch.float16],
|
||||
'sgn': [torch.float16, torch.float32],
|
||||
'_segment_reduce': [torch.float16, torch.float32],
|
||||
'unfold_copy': [torch.float16, torch.float32], # unfold_backward is not implemented
|
||||
'unfold': [torch.float16, torch.float32],
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user