torch.sgn for complex tensors (#39955)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/39955

resolves https://github.com/pytorch/pytorch/issues/36323 by adding `torch.sgn` for complex tensors.
`torch.sgn` returns `x/abs(x)` for `x != 0` and returns `0 + 0j` for `x==0`

This PR doesn't test the correctness of the gradients. It will be done as a part of auditing all the ops in future once we decide the autograd behavior (JAX vs TF) and add gradchek.

Test Plan: Imported from OSS

Reviewed By: mruberry

Differential Revision: D23460526

Pulled By: anjali411

fbshipit-source-id: 70fc4e14e4d66196e27cf188e0422a335fc42f92
This commit is contained in:
anjali411 2020-09-22 08:01:16 -07:00 committed by Facebook GitHub Bot
parent 1b059f2c6d
commit 58b6ab69e5
22 changed files with 196 additions and 13 deletions

View File

@ -611,6 +611,7 @@ _(aten, sigmoid) \
_(aten, sign) \
_(aten, signbit) \
_(aten, silu) \
_(aten, sgn) \
_(aten, sin) \
_(aten, sinh) \
_(aten, size) \

View File

@ -239,6 +239,13 @@ public:
// Specifically map() does not perform the type conversion needed by abs.
return map([](T x) { return static_cast<T>(std::abs(x)); });
}
template <typename other_t_sgn = T,
typename std::enable_if<c10::is_complex<other_t_sgn>::value, int>::type = 0>
Vec256<T> sgn() const {
return map(at::native::sgn_impl);
}
template <typename other_t_angle = T,
typename std::enable_if<!c10::is_complex<other_t_angle>::value, int>::type = 0>
Vec256<T> angle() const {

View File

@ -134,6 +134,16 @@ public:
auto angle = _mm256_permute_pd(angle_(), 0x05); // angle 90-angle
return _mm256_and_pd(angle, real_mask); // angle 0
}
Vec256<c10::complex<double>> sgn() const {
auto abs = abs_();
auto zero = _mm256_setzero_pd();
auto mask = _mm256_cmp_pd(abs, zero, _CMP_EQ_OQ);
auto abs_val = Vec256(abs);
auto div = values / abs_val.values; // x / abs(x)
return blendv(div, zero, mask);
}
__m256d real_() const {
const __m256d real_mask = _mm256_castsi256_pd(_mm256_setr_epi64x(0xFFFFFFFFFFFFFFFF, 0x0000000000000000,
0xFFFFFFFFFFFFFFFF, 0x0000000000000000));

View File

@ -171,6 +171,16 @@ public:
auto angle = _mm256_permute_ps(angle_(), 0xB1); // angle 90-angle
return _mm256_and_ps(angle, real_mask); // angle 0
}
Vec256<c10::complex<float>> sgn() const {
auto abs = abs_();
auto zero = _mm256_setzero_ps();
auto mask = _mm256_cmp_ps(abs, zero, _CMP_EQ_OQ);
auto abs_val = Vec256(abs);
auto div = values / abs_val.values; // x / abs(x)
return _mm256_blendv_ps(div, zero, mask);
}
__m256 real_() const {
const __m256 real_mask = _mm256_castsi256_ps(_mm256_setr_epi32(0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000,
0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000));

View File

@ -301,6 +301,17 @@ Tensor& sign_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(
Tensor sign(const Tensor& self) { return unary_op_impl(self, at::sign_out); }
Tensor& sign_(Tensor& self) { return unary_op_impl_(self, at::sign_out); }
Tensor& sgn_out(Tensor& result, const Tensor& self) {
if (self.is_complex()) {
return unary_op_impl_out(result, self, sgn_stub);
} else {
return unary_op_impl_out(result, self, sign_stub);
}
}
Tensor sgn(const Tensor& self) { return unary_op_impl(self, at::sgn_out); }
Tensor& sgn_(Tensor& self) { return unary_op_impl_(self, at::sgn_out); }
Tensor& sin_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, sin_stub); }
Tensor sin(const Tensor& self) { return unary_op_impl(self, at::sin_out); }
Tensor& sin_(Tensor& self) { return unary_op_impl_(self, at::sin_out); }
@ -639,6 +650,7 @@ DEFINE_DISPATCH(sigmoid_stub);
DEFINE_DISPATCH(logit_stub);
DEFINE_DISPATCH(sign_stub);
DEFINE_DISPATCH(signbit_stub);
DEFINE_DISPATCH(sgn_stub);
DEFINE_DISPATCH(sin_stub);
DEFINE_DISPATCH(sinh_stub);
DEFINE_DISPATCH(sqrt_stub);

View File

@ -53,6 +53,7 @@ DECLARE_DISPATCH(unary_fn, sigmoid_stub);
DECLARE_DISPATCH(unary_fn_with_scalar, logit_stub);
DECLARE_DISPATCH(unary_fn, sign_stub);
DECLARE_DISPATCH(unary_fn, signbit_stub);
DECLARE_DISPATCH(unary_fn, sgn_stub);
DECLARE_DISPATCH(unary_fn, sin_stub);
DECLARE_DISPATCH(unary_fn, sinh_stub);
DECLARE_DISPATCH(unary_fn, sqrt_stub);

View File

@ -270,16 +270,16 @@ static void sign_kernel(TensorIterator& iter){
auto one_vec = Vec256<scalar_t>(static_cast<scalar_t>(1));
cpu_kernel_vec(
iter,
[=](scalar_t a) -> scalar_t { return (0 < a) - (a < 0); },
[=](Vec256<scalar_t> self_vec){
iter,
[=](scalar_t a) -> scalar_t { return (0 < a) - (a < 0); },
[=](Vec256<scalar_t> self_vec){
// Comparision operators returns bitmask.
auto left = Vec256<scalar_t>::blendv(zero_vec, one_vec, zero_vec < self_vec);
auto right = Vec256<scalar_t>::blendv(zero_vec, one_vec, self_vec < zero_vec);
// Comparision operators returns bitmask.
auto left = Vec256<scalar_t>::blendv(zero_vec, one_vec, zero_vec < self_vec);
auto right = Vec256<scalar_t>::blendv(zero_vec, one_vec, self_vec < zero_vec);
return left - right;
});
return left - right;
});
});
}
}
@ -290,6 +290,15 @@ static void signbit_kernel(TensorIterator& iter){
});
}
static void sgn_kernel(TensorIterator& iter){
AT_DISPATCH_COMPLEX_TYPES(iter.dtype(), 'sgn_cpu', [&]() {
cpu_kernel_vec(
iter,
[=](scalar_t a) -> scalar_t { return sgn_impl(a); },
[=](Vec256<scalar_t> a) { return a.sgn(); });
});
}
static void sinh_kernel(TensorIterator& iter) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.dtype(), "sinh_cpu", [&]() {
cpu_kernel_vec(
@ -639,6 +648,7 @@ REGISTER_DISPATCH(reciprocal_stub, &reciprocal_kernel);
REGISTER_DISPATCH(neg_stub, &neg_kernel);
REGISTER_DISPATCH(sign_stub, &sign_kernel);
REGISTER_DISPATCH(signbit_stub, &signbit_kernel);
REGISTER_DISPATCH(sgn_stub, &sgn_kernel);
REGISTER_DISPATCH(sinh_stub, &sinh_kernel);
REGISTER_DISPATCH(cosh_stub, &cosh_kernel);
REGISTER_DISPATCH(acosh_stub, &acosh_kernel);

View File

@ -138,6 +138,15 @@ inline c10::complex<double> ceil_impl (c10::complex<double> z) {
return c10::complex<double>(std::ceil(z.real()), std::ceil(z.imag()));
}
template<typename T>
inline c10::complex<T> sgn_impl (c10::complex<T> z) {
if (z == c10::complex<T>(0, 0)) {
return c10::complex<T>(0, 0);
} else {
return z / zabs(z);
}
}
template <typename TYPE>
inline TYPE floor_impl (TYPE z) {
return std::floor(z);

View File

@ -51,9 +51,26 @@ void signbit_kernel_cuda(TensorIterator& iter){
});
}
template<typename T>
__host__ __device__ static inline c10::complex<T> sgn_wrapper(c10::complex<T> z) {
if (z == c10::complex<T>(0, 0)) {
return c10::complex<T>(0, 0);
} else {
return z / std::abs(z);
}
}
void sgn_kernel_cuda(TensorIterator& iter){
AT_DISPATCH_COMPLEX_TYPES(iter.dtype(), "sgn_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return sgn_wrapper(a);
});
});
}
REGISTER_DISPATCH(logical_not_stub, &logical_not_kernel_cuda);
REGISTER_DISPATCH(neg_stub, &neg_kernel_cuda);
REGISTER_DISPATCH(sign_stub, &sign_kernel_cuda);
REGISTER_DISPATCH(signbit_stub, &signbit_kernel_cuda);
REGISTER_DISPATCH(sgn_stub, &sgn_kernel_cuda);
}} // namespace at::native

View File

@ -277,6 +277,15 @@
use_c10_dispatcher: full
variants: function
- func: sgn(Tensor self) -> Tensor
use_c10_dispatcher: full
variants: function, method
- func: sgn_(Tensor(a!) self) -> Tensor(a!)
variants: method
- func: sgn.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
- func: real(Tensor(a) self) -> Tensor(a)
use_c10_dispatcher: full
variants: function

View File

@ -91,6 +91,7 @@ unary_ops_list = op_bench.op_list(
['sigmoid', torch.sigmoid],
['sigmoid_', torch.sigmoid_],
['sign', torch.sign],
['sgn', torch.sgn],
['sin', torch.sin],
['sin_', torch.sin_],
['sinh', torch.sinh],

View File

@ -197,6 +197,8 @@ If you don't see an operation listed here, but it would help your use case, plea
:meth:`Tensor.sigmoid_`,None
":meth:`Tensor.sign`, :func:`torch.sign`",:ref:`keeps_input_names-doc`
:meth:`Tensor.sign_`,None
":meth:`Tensor.sgn`, :func:`torch.sgn`",:ref:`keeps_input_names-doc`
:meth:`Tensor.sgn_`,None
":meth:`Tensor.sin`, :func:`torch.sin`",:ref:`keeps_input_names-doc`
:meth:`Tensor.sin_`,None
":meth:`Tensor.sinh`, :func:`torch.sinh`",:ref:`keeps_input_names-doc`

View File

@ -532,6 +532,8 @@ view of a storage and defines numeric operations on it.
.. automethod:: sign
.. automethod:: sign_
.. automethod:: signbit
.. automethod:: sgn
.. automethod:: sgn_
.. automethod:: sin
.. automethod:: sin_
.. automethod:: sinh

View File

@ -4692,7 +4692,7 @@ complex_list = ['t', 'view', 'reshape', 'reshape_as', 'view_as', 'roll', 'clone'
'permute', 'squeeze', 'unsqueeze', 'resize', 'resize_as', 'tril', 'triu',
'chunk', 'split', 'split_with_sizes', 'repeat', 'expand', 'zero_', 'round',
'eq_', 'ne_', 'add', '__radd__', 'sum', 'conj', 'sin', 'cos', 'mul', 'sinh',
'cosh', '__rmul__'] + separate_complex_tests
'cosh', '__rmul__', 'sgn'] + separate_complex_tests
# TODO(@anjali411): add the commented tests back after updating the formula based on tensorflow definition - @anjali411
# complex_list += ['fill_', 't', '__rdiv__', 'tanh']

View File

@ -11288,6 +11288,19 @@ class TestTorchDeviceType(TestCase):
with self.assertRaisesRegex(RuntimeError, 'signbit is not implemented for complex tensors.'):
torch.signbit(t, out=out)
@dtypes(torch.cfloat, torch.cdouble)
def test_sgn(self, device, dtype):
x = torch.randn(100, dtype=dtype)
angle = x.angle()
out = x.sgn()
self.assertEqual(out.angle(), angle)
self.assertEqual(out.abs(), torch.ones_like(x).real)
x_out = torch.empty_like(x)
torch.sgn(x, out=x_out)
self.assertEqual(x_out.angle(), angle)
self.assertEqual(x_out.abs(), torch.ones_like(x).real)
@dtypes(*(torch.testing.get_all_dtypes(include_bool=False)))
def test_signbit_non_boolean_output(self, device, dtype):
# test non-boolean tensors as the `out=` parameters
@ -14709,6 +14722,8 @@ class TestTorchDeviceType(TestCase):
lambda x, y: x.logit_(1e-6),
lambda x, y: x.sign(),
lambda x, y: x.sign_(),
lambda x, y: x.sgn(),
lambda x, y: x.sgn_(),
lambda x, y: x.sin(),
lambda x, y: x.sin_(),
lambda x, y: x.sinh(),

View File

@ -159,7 +159,7 @@
# NB: The parameter names here MUST be consistent with the parameter names
# in Decalarations.yaml
- name: abs(Tensor self) -> Tensor
self: grad * self.sign()
self: grad * self.sgn()
- name: acos(Tensor self) -> Tensor
self: grad * -((-self * self + 1).rsqrt())
@ -397,11 +397,11 @@
# of the higher order derivatives, see https://github.com/pytorch/pytorch/issues/43414
# Note that we don't use "result" because saving it would be BC-breaking when it is used in an inplace operation later
- name: div.Tensor(Tensor self, Tensor other) -> Tensor
self: grad / other
other: -grad * (self / other) / other
self: div_tensor_self_backward(grad, other, self.scalar_type())
other: div_tensor_other_backward(grad, self, other)
- name: div.Scalar(Tensor self, Scalar other) -> Tensor
self: grad / other
self: div_tensor_self_backward(grad, at::scalar_to_tensor(other), self.scalar_type())
- name: dot(Tensor self, Tensor tensor) -> Tensor
self: grad * tensor
@ -928,6 +928,9 @@
- name: sign(Tensor self) -> Tensor
self: zeros_like(grad)
- name: sgn(Tensor self) -> Tensor
self: sgn_backward(result, grad, self)
- name: sin(Tensor self) -> Tensor
self: grad * self.cos().conj()

View File

@ -3121,6 +3121,20 @@ signbit() -> Tensor
See :func:`torch.signbit`
""")
add_docstr_all('sgn',
r"""
sgn() -> Tensor
See :func:`torch.sgn`
""")
add_docstr_all('sgn_',
r"""
sgn_() -> Tensor
In-place version of :meth:`~Tensor.sgn`
""")
add_docstr_all('sin',
r"""
sin() -> Tensor

View File

@ -6603,6 +6603,31 @@ Example::
tensor([ False, True, False, False])
""".format(**common_args))
add_docstr(torch.sgn,
r"""
sgn(input, *, out=None) -> Tensor
For complex tensors, this function returns a new tensor whose elemants have the same angle as that of the
elements of :attr:`input` and absolute value 1. For a non-complex tensor, this function
returns the signs of the elements of :attr:`input` (see :func:`torch.sign`).
:math:`\text{out}_{i} = 0`, if :math:`|{\text{{input}}_i}| == 0`
:math:`\text{out}_{i} = \frac{{\text{{input}}_i}}{|{\text{{input}}_i}|}`, otherwise
""" + r"""
Args:
{input}
Keyword args:
{out}
Example::
>>> x=torch.tensor([3+4j, 7-24j, 0, 1+2j])
>>> x.sgn()
tensor([0.6000+0.8000j, 0.2800-0.9600j, 0.0000+0.0000j, 0.4472+0.8944j])
""".format(**common_args))
add_docstr(torch.sin,
r"""
sin(input, out=None) -> Tensor

View File

@ -211,6 +211,17 @@ Tensor mvlgamma_backward(Tensor grad, const Tensor & self, int64_t p) {
return grad * args.digamma_().sum(-1);
}
Tensor sgn_backward(Tensor result, Tensor grad, Tensor self) {
if (self.is_complex()) {
auto abs = at::abs(self);
// C -> C
// https://arxiv.org/pdf/1701.00392.pdf Section 4.20
return at::where(abs == 0.0, at::zeros({}, grad.options()), (grad/abs - (at::real(grad/self) * result)));
} else {
return at::zeros_like(grad, at::MemoryFormat::Preserve);
}
}
Tensor mul_tensor_backward(Tensor grad, Tensor other, ScalarType self_st) {
auto result = grad * other.conj();
if (!at::isComplexType(self_st) && result.is_complex()) {
@ -220,6 +231,24 @@ Tensor mul_tensor_backward(Tensor grad, Tensor other, ScalarType self_st) {
return result;
}
Tensor div_tensor_self_backward(Tensor grad, Tensor other, ScalarType self_st) {
auto result = grad / other.conj();
if (!at::isComplexType(self_st) && result.is_complex()) {
// R -> C
result = at::real(result);
}
return result;
}
Tensor div_tensor_other_backward(Tensor grad, Tensor self, Tensor other) {
auto result = -grad * ((self / other) / other).conj();
if (!other.is_complex() && result.is_complex()) {
// R -> C
result = at::real(result);
}
return result;
}
Tensor permute_backwards(const Tensor & grad, IntArrayRef fwd_dims) {
// invert the permutation
auto ndims = fwd_dims.size();

View File

@ -44,6 +44,8 @@ at::Tensor pow_backward_self(at::Tensor grad, const at::Tensor & self, const at:
at::Tensor pow_backward_exponent(at::Tensor grad, const at::Tensor& self, const at::Tensor& exponent, at::Tensor result);
at::Tensor pow_backward_exponent(at::Tensor grad, const at::Scalar & base, const at::Tensor& exponent, at::Tensor result);
at::Tensor mul_tensor_backward(Tensor grad, Tensor other, ScalarType self_st);
at::Tensor div_tensor_self_backward(Tensor grad, Tensor other, ScalarType self_st);
at::Tensor div_tensor_other_backward(Tensor grad, Tensor self, Tensor other);
at::Tensor mvlgamma_backward(at::Tensor grad, const at::Tensor & self, int64_t p);
at::Tensor permute_backwards(const at::Tensor & grad, at::IntArrayRef fwd_dims);
at::Tensor rad2deg_backward(const at::Tensor& grad);
@ -74,6 +76,7 @@ at::Tensor sum_tensorlist(at::TensorList tl);
at::Tensor repeat_backward(at::Tensor grad, int64_t input_dims, at::IntArrayRef repeats);
at::Tensor _fused_dropout_backward(at::Tensor grad, at::Tensor mask, double p1m);
at::Tensor evenly_distribute_backward(at::Tensor grad, const at::Tensor & input, const at::Tensor & value);
at::Tensor sgn_backward(Tensor result, Tensor grad, Tensor self);
at::Tensor var_backward(const at::Tensor & grad, const at::Tensor & self, bool unbiased);
at::Tensor var_backward(at::Tensor grad, const at::Tensor & self, at::IntArrayRef dim, bool unbiased, bool keepdim);
at::Tensor std_backward(const at::Tensor & result, const at::Tensor & grad, const at::Tensor & self, bool unbiased);

View File

@ -701,6 +701,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.sigmoid: lambda input, out=None: -1,
torch.sign: lambda input, out=None: -1,
torch.signbit: lambda input, out=None: -1,
torch.sgn: lambda input, out=None: -1,
torch.sin: lambda input, out=None: -1,
torch.sinh: lambda input, out=None: -1,
torch.slogdet: lambda input: -1,

View File

@ -696,6 +696,8 @@ def method_tests():
('round', (), NO_ARGS, 'scalar', (True,)),
('sign', (S, S, S), NO_ARGS),
('sign', (), NO_ARGS, 'scalar'),
('sgn', (S, S, S), NO_ARGS),
('sgn', (), NO_ARGS, 'scalar'),
('trunc', (S, S, S), NO_ARGS, '', (True,)),
('trunc', (), NO_ARGS, 'scalar', (True,)),
('floor', (S, S, S), NO_ARGS, '', (True,)),