Complex gradcheck logic (#43208)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/43208

This PR adds gradcheck for complex. The logic used for complex gradcheck is described in Section 3.5.3 here: https://arxiv.org/pdf/1701.00392.pdf

More concretely, this PR introduces the following changes:
1. Updates get_numerical_jacobian to take as input a scalar value for vector (v). Adds gradcheck logic for C -> C, C-> R, R -> C. For R -> C functions, only the real value of gradient is propagated.
2. Adds backward definition for `torch.complex` and also adds a test to verify the definition added.
3. Updates backward for `mul`, `sin`, `cos`, `sinh`, `cosh`.
4. Adds tests for all `torch.real`, `torch.imag`, `torch.view_as_real`, `torch.view_as_complex`, `torch.conj`.

Follow up tasks:
1. Add more thorough tests for R -> C cases. Specifically, add R->C test variants for functions. for e.g., `torch.mul(complex_tensor, real_tensor)`
2. Add back commented test in `common_methods_invocation.py`.
3. Add more special case checking for complex gradcheck to make debugging easier.
4. Update complex autograd note.
5. disable complex autograd for operators not tested for complex.

Test Plan: Imported from OSS

Reviewed By: zou3519

Differential Revision: D23655088

Pulled By: anjali411

fbshipit-source-id: caa75e09864b5f6ead0f988f6368dce64cf15deb
This commit is contained in:
anjali411 2020-09-20 22:01:58 -07:00 committed by Facebook GitHub Bot
parent da7863f46b
commit 9f67176b82
9 changed files with 231 additions and 135 deletions

View File

@ -4680,18 +4680,22 @@ def run_functional_checks(test_case, test_name, name, apply_fn, run_grad_checks,
# the tests for these ops which do not have 'complex' in variant should not run for complex
# and only run for floating point
separate_complex_tests = ['log', 'log10', 'log1p', 'log2', 'reciprocal', 'tan']
# TODO(@anjali411): add the commented tests back after updating the formula based on tensorflow definition
separate_complex_tests = ['view_as_real', 'real', 'imag', 'asin', 'acos'] # ['log', 'log10', 'log1p', 'log2', 'reciprocal', 'tan']
# NOTE: Some non-holomorphic are separately tested in TestAutogradComplex until gradcheck works properly
# for non-holomorphic functions
# allow list for complex
complex_list = ['t', 'view', 'reshape', 'reshape_as', 'view_as', 'zero_', 'clone',
'tril', 'triu', 'fill_', 'eq_', 'ne_', 'permute', 'squeeze', 'unsqueeze',
'chunk', 'split', 'split_with_sizes', 'resize', 'resize_as', 'sin', 'cos',
'__rmul__', '__rdiv__', 'sum', 'transpose', 'round', 'add', 'roll',
'__radd__', 'repeat', 'expand', 'mul', 'tanh', 'flip', 'fliplr', 'flipud',
'rot90'] + separate_complex_tests
complex_list = ['t', 'view', 'reshape', 'reshape_as', 'view_as', 'roll', 'clone',
'repeat', 'expand', 'flip', 'fliplr', 'flipud', 'rot90', 'transpose',
'permute', 'squeeze', 'unsqueeze', 'resize', 'resize_as', 'tril', 'triu',
'chunk', 'split', 'split_with_sizes', 'repeat', 'expand', 'zero_', 'round',
'eq_', 'ne_', 'add', '__radd__', 'sum', 'conj', 'sin', 'cos', 'mul', 'sinh',
'cosh', '__rmul__'] + separate_complex_tests
# TODO(@anjali411): add the commented tests back after updating the formula based on tensorflow definition - @anjali411
# complex_list += ['fill_', 't', '__rdiv__', 'tanh']
def add_test(
name,
@ -4721,7 +4725,7 @@ def add_test(
if dtype.is_complex:
# TODO: remove this. this is temporary while we ramp up the complex support.
if name in complex_list and 'scalar' not in test_name and 'constant' not in test_name:
if name in complex_list:
if name in separate_complex_tests and 'complex' not in variant_name:
continue
if not run_only_complex:
@ -4787,7 +4791,13 @@ def add_test(
self_variable = create_input((self_size,), requires_grad=True, dtype=dtype)[0][0]
args_variable, kwargs_variable = create_input(args, requires_grad=False, call_kwargs=kwargs, dtype=dtype)
if hasattr(self_variable, name):
output_variable = getattr(self_variable, name)(*args_variable, **kwargs_variable)
attribute_result = getattr(self_variable, name)
if callable(attribute_result):
output_variable = attribute_result(*args_variable, **kwargs_variable)
else:
self.assertTrue(len(args_variable) == 0)
self.assertTrue(len(kwargs_variable) == 0)
output_variable = attribute_result
else:
self_and_args_variable = (self_variable,) + args_variable
output_variable = torch_fn(*self_and_args_variable, **kwargs_variable)
@ -4865,30 +4875,6 @@ def add_test(
setattr(TestAutogradDeviceType, test_name, do_test)
class TestAutogradComplex(TestCase):
# remove this test after gradcheck support is added for non-holomorphic functions
def test_real(self):
x = torch.randn(3, 4, 5, dtype=torch.cdouble, requires_grad=True)
x.real.sum().backward()
self.assertEqual(x.grad, torch.ones_like(x))
# remove this test after gradcheck support is added for non-holomorphic functions
def test_imag(self):
x = torch.randn(3, 4, 5, dtype=torch.cdouble, requires_grad=True)
x.imag.sum().backward()
self.assertEqual(x.grad, -1j * torch.ones_like(x))
# remove this test after gradcheck support is added for non-holomorphic functions
def test_view_as_real(self):
x = torch.randn(10, dtype=torch.cdouble, requires_grad=True)
torch.view_as_real(x).sum().backward()
self.assertEqual(x.grad, torch.full_like(x, 1 - 1j))
# remove this test after gradcheck support is added for non-holomorphic functions
def test_view_as_complex(self):
x = torch.randn(10, 2, dtype=torch.double, requires_grad=True)
torch.view_as_complex(x).sum().backward()
self.assertEqual(x.grad, torch.tensor([1, 0], dtype=torch.double).expand_as(x))
def test_view_func_for_complex_views(self):
# case 1: both parent and child have view_func
x = torch.randn(2, 2, 2, dtype=torch.double, requires_grad=True)

View File

@ -15583,7 +15583,7 @@ def add_autograd_test(
# Disable complex tests
# TODO: Add complex support for jit
if 'complex' in variant_name:
if 'complex' in variant_name or name in ['view_as_complex', 'complex']:
return
# Skips aliases, which are tested in test_op_aliases.py

View File

@ -89,35 +89,41 @@ class TestGradients(TestCase):
return self._check_helper(device, dtype, op, variant, 'gradgradcheck')
# Tests that gradients are computed correctly
@dtypes(torch.double, torch.cdouble)
# TODO(@anjali411) enable this for torch.cdouble.
@dtypes(torch.double)
@ops(op_db)
def test_fn_grad(self, device, dtype, op):
self._grad_test_helper(device, dtype, op, op.get_op())
@dtypes(torch.double, torch.cdouble)
# TODO(@anjali411) enable this for torch.cdouble.
@dtypes(torch.double)
@ops(op_db)
def test_method_grad(self, device, dtype, op):
self._grad_test_helper(device, dtype, op, op.get_method())
@dtypes(torch.double, torch.cdouble)
# TODO(@anjali411) enable this for torch.cdouble.
@dtypes(torch.double)
@ops(op_db)
def test_inplace_grad(self, device, dtype, op):
if not op.test_inplace_grad:
self.skipTest("Skipped! Inplace gradcheck marked to skip.")
self._grad_test_helper(device, dtype, op, self._get_safe_inplace(op.get_inplace()))
# TODO(@anjali411) enable this for torch.cdouble.
# Test that gradients of gradients are computed correctly
@dtypes(torch.double, torch.cdouble)
@dtypes(torch.double)
@ops(op_db)
def test_fn_gradgrad(self, device, dtype, op):
self._gradgrad_test_helper(device, dtype, op, op.get_op())
@dtypes(torch.double, torch.cdouble)
# TODO(@anjali411) enable this for torch.cdouble.
@dtypes(torch.double)
@ops(op_db)
def test_method_gradgrad(self, device, dtype, op):
self._gradgrad_test_helper(device, dtype, op, op.get_method())
@dtypes(torch.double, torch.cdouble)
# TODO(@anjali411) enable this for torch.cdouble.
@dtypes(torch.double)
@ops(op_db)
def test_inplace_gradgrad(self, device, dtype, op):
if not op.test_inplace_grad:

View File

@ -694,6 +694,9 @@ class Wrapper:
def __add__(self, other):
return self.__torch_function__(torch.add, (Wrapper,), (self, other))
def __mul__(self, other):
return self.__torch_function__(torch.mul, (Wrapper,), (self, other))
def __sub__(self, other):
return self.__torch_function__(torch.sub, (Wrapper,), (self, other))
@ -757,51 +760,51 @@ class TestEinsumOverride(TestCase):
self.assertTrue(torch.allclose(torch.einsum('ik,jkl,il->ij', [a, b, c]),
torch.nn.functional.bilinear(a, c, b)))
# TODO(@anjali411): re-enable this test
# class TestGradCheckOverride(TestCase):
# "Test that wrappers work with gradcheck."
# def test_gradcheck(self):
# from torch.autograd import gradcheck
class TestGradCheckOverride(TestCase):
"Test that wrappers work with gradcheck."
def test_gradcheck(self):
from torch.autograd import gradcheck
# a = wrap(torch.tensor(5.0, dtype=torch.double))
# b = wrap(torch.tensor(6.0, dtype=torch.double))
a = wrap(torch.tensor(5.0, dtype=torch.double))
b = wrap(torch.tensor(6.0, dtype=torch.double))
# a.requires_grad = True
# b.requires_grad = True
a.requires_grad = True
b.requires_grad = True
# gradcheck(torch.add, (a, b), raise_exception=False)
gradcheck(torch.add, (a, b), raise_exception=False)
# total_used_attrs = a.used_attrs.union(b.used_attrs)
# total_used_calls = a.used_calls.union(b.used_calls)
total_used_attrs = a.used_attrs.union(b.used_attrs)
total_used_calls = a.used_calls.union(b.used_calls)
# # These attributes (and the functions below) may change
# # if the gradcheck implementation changes. It's best to
# # aim for attributes that may be commonly present on other
# # Tensor-likes.
# self.assertEqual(total_used_attrs, {
# 'data',
# 'dtype',
# 'is_floating_point',
# 'is_sparse',
# 'layout',
# 'nelement',
# 'new_zeros',
# 'requires_grad',
# 'retain_grad',
# 'size',
# 'stride',
# })
# These attributes (and the functions below) may change
# if the gradcheck implementation changes. It's best to
# aim for attributes that may be commonly present on other
# Tensor-likes.
self.assertEqual(total_used_attrs, {
'data',
'dtype',
'is_floating_point',
'is_sparse',
'layout',
'nelement',
'new_zeros',
'requires_grad',
'retain_grad',
'size',
'stride',
})
self.assertEqual(total_used_calls, {
torch.Tensor.new_zeros,
torch.Tensor.size,
torch.Tensor.is_floating_point,
torch.Tensor.nelement,
torch.Tensor.retain_grad,
torch.Tensor.stride,
torch.autograd.grad,
torch.add,
})
# self.assertEqual(total_used_calls, {
# torch.Tensor.new_zeros,
# torch.Tensor.size,
# torch.Tensor.is_floating_point,
# torch.Tensor.nelement,
# torch.Tensor.retain_grad,
# torch.Tensor.stride,
# torch.autograd.grad,
# torch.add,
# })
if __name__ == '__main__':

View File

@ -330,8 +330,8 @@
self: grad
- name: complex(Tensor real, Tensor imag) -> Tensor
real: not_implemented("complex real")
imag: not_implemented("complex imag")
real: at::real(grad)
imag: at::imag(grad)
- name: polar(Tensor abs, Tensor angle) -> Tensor
abs: not_implemented("polar abs")
@ -341,10 +341,10 @@
self: grad.conj()
- name: cos(Tensor self) -> Tensor
self: grad * -self.sin()
self: grad * -self.sin().conj()
- name: cosh(Tensor self) -> Tensor
self: grad * self.sinh()
self: grad * self.sinh().conj()
- name: count_nonzero.dim_IntList(Tensor self, int[] dim) -> Tensor
self: not_implemented("count_nonzero")
@ -736,11 +736,11 @@
self: value_selecting_reduction_backward(grad, dim, indices, self.sizes(), keepdim)
- name: mul.Tensor(Tensor self, Tensor other) -> Tensor
self: grad * other
other: grad * self
self: mul_tensor_backward(grad, other, self.scalar_type())
other: mul_tensor_backward(grad, self, other.scalar_type())
- name: mul.Scalar(Tensor self, Scalar other) -> Tensor
self: grad * other
self: mul_tensor_backward(grad, at::scalar_to_tensor(other), self.scalar_type())
- name: mv(Tensor self, Tensor vec) -> Tensor
self: grad.ger(vec)
@ -929,10 +929,10 @@
self: zeros_like(grad)
- name: sin(Tensor self) -> Tensor
self: grad * self.cos()
self: grad * self.cos().conj()
- name: sinh(Tensor self) -> Tensor
self: grad * self.cosh()
self: grad * self.cosh().conj()
- name: slice.Tensor(Tensor(a) self, int dim=0, int start=0, int end=9223372036854775807, int step=1) -> Tensor(a)
self: slice_backward(grad, self.sizes(), dim, start, end, step)
@ -1104,10 +1104,10 @@
self: grad.reshape(self.sizes())
- name: view_as_real(Tensor(a) self) -> Tensor(a)
self: at::view_as_complex(grad.contiguous()).conj() # gx0 - i gx1
self: at::view_as_complex(grad.contiguous()) # gx0 + 1j * gx1
- name: view_as_complex(Tensor(a) self) -> Tensor(a)
self: at::view_as_real(grad.contiguous().conj()) # [gx, -gy]
self: at::view_as_real(grad.contiguous()) # [gx, gy]
- name: _s_where(Tensor condition, Tensor self, Tensor other) -> Tensor
condition: non_differentiable

View File

@ -43,10 +43,11 @@ def iter_tensors(x, only_requiring_grad=False):
for result in iter_tensors(elem, only_requiring_grad):
yield result
def get_numerical_jacobian(fn, input, target=None, eps=1e-3):
def get_numerical_jacobian(fn, input, target=None, eps=1e-3, grad_out=1.0):
"""
input: input to `fn`
target: the Tensors wrt whom Jacobians are calculated (default=`input`)
grad_out: grad output value used to calculate gradients.
Note that `target` may not even be part of `input` to `fn`, so please be
**very careful** in this to not clone `target`.
@ -62,7 +63,13 @@ def get_numerical_jacobian(fn, input, target=None, eps=1e-3):
x_tensors = iter_tensors(target, True)
j_tensors = iter_tensors(jacobian)
def compute_gradient(x, idx, is_mkldnn=False):
def update_jacobians(x, idx, d, d_idx, is_mkldnn=False):
# compute_jacobian only works for pure real
# or pure imaginary delta
def compute_gradient(delta):
# we currently assume that the norm of delta equals eps
assert(delta == eps or delta == (eps * 1j))
def fn_out():
if not is_mkldnn:
@ -73,19 +80,40 @@ def get_numerical_jacobian(fn, input, target=None, eps=1e-3):
return fn([x.to_mkldnn()])
orig = x[idx].item()
x[idx] = orig - eps
x[idx] = orig - delta
outa = fn_out()
x[idx] = orig + eps
x[idx] = orig + delta
outb = fn_out()
x[idx] = orig
r = (outb - outa) / (2 * eps)
return r.detach().reshape(-1)
# for details on the algorithm used here, refer:
# Section 3.5.3 https://arxiv.org/pdf/1701.00392.pdf
# s = fn(z) where z = x for real valued input
# and z = x + yj for complex valued input
ds_dx = compute_gradient(eps)
if x.is_complex(): # C -> C, C -> R
ds_dy = compute_gradient(eps * 1j)
# conjugate wirtinger derivative
conj_w_d = 0.5 * (ds_dx + ds_dy * 1j)
# wirtinger derivative
w_d = 0.5 * (ds_dx - ds_dy * 1j)
d[d_idx] = grad_out.conjugate() * conj_w_d + grad_out * w_d.conj()
elif ds_dx.is_complex(): # R -> C
# w_d = conj_w_d = 0.5 * ds_dx
dL_dz_conj = 0.5 * (grad_out.conjugate() * ds_dx + grad_out * ds_dx.conj())
# The above formula is derived for a C -> C function that's a part of
# bigger function with real valued output. From separate calculations,
# it can be verified that the gradient for R -> C function
# equals to real value of the result obtained from the generic formula for
# C -> C functions used above.
d[d_idx] = torch.real(dL_dz_conj)
else: # R -> R
d[d_idx] = ds_dx * grad_out
# TODO: compare structure
for x_tensor, d_tensor in zip(x_tensors, j_tensors):
is_complex = x_tensor.dtype.is_complex
if is_complex:
eps *= (1 + 1j)
if x_tensor.is_sparse:
def get_stride(size):
dim = len(size)
@ -110,7 +138,7 @@ def get_numerical_jacobian(fn, input, target=None, eps=1e-3):
for x_idx in product(*[range(m) for m in x_values.size()[1:]]):
indices = x_indices[i].tolist() + list(x_idx)
d_idx = sum(indices[k] * x_stride[k] for k in range(len(x_size)))
d_tensor[d_idx] = compute_gradient(x_value, x_idx)
update_jacobians(x_value, x_idx, d_tensor, d_idx)
elif x_tensor.layout == torch._mkldnn:
# Use .data here to get around the version check
x_tensor = x_tensor.data
@ -121,17 +149,17 @@ def get_numerical_jacobian(fn, input, target=None, eps=1e-3):
# this is really inefficient, but without indexing implemented, there's
# not really a better way than converting back and forth
x_tensor_dense = x_tensor.to_dense()
d_tensor[d_idx] = compute_gradient(x_tensor_dense, x_idx, is_mkldnn=True)
update_jacobians(x_tensor_dense, x_idx, d_tensor, d_idx, is_mkldnn=True)
else:
# Use .data here to get around the version check
x_tensor = x_tensor.data
for d_idx, x_idx in enumerate(product(*[range(m) for m in x_tensor.size()])):
d_tensor[d_idx] = compute_gradient(x_tensor, x_idx)
update_jacobians(x_tensor, x_idx, d_tensor, d_idx)
return jacobian
def get_analytical_jacobian(input, output, nondet_tol=0.0):
def get_analytical_jacobian(input, output, nondet_tol=0.0, grad_out=1.0):
# it is easier to call to_dense() on the sparse output than
# to modify analytical jacobian
if output.is_sparse:
@ -151,7 +179,7 @@ def get_analytical_jacobian(input, output, nondet_tol=0.0):
for i in range(flat_grad_output.numel()):
flat_grad_output.zero_()
flat_grad_output[i] = 1
flat_grad_output[i] = grad_out
for jacobian_c in (jacobian, jacobian_reentrant):
grads_input = torch.autograd.grad(output, diff_input_list, grad_output,
retain_graph=True, allow_unused=True)
@ -215,6 +243,13 @@ def gradcheck(
The check between numerical and analytical gradients uses :func:`~torch.allclose`.
For complex functions, no notion of Jacobian exists. Gradcheck verifies if the numerical and
analytical values of Wirtinger and Conjugate Wirtinger derivative are consistent. The gradient
computation is done under the assumption that the overall function has a real valued output.
For functions with complex output, gradcheck compares the numerical and analytical gradients
for two values of :attr:`grad_output`: 1 and 1j. For more details, check out
:ref:`complex_autograd-doc`.
.. note::
The default values are designed for :attr:`input` of double precision.
This check will likely fail if :attr:`input` is of less precision, e.g.,
@ -309,23 +344,63 @@ def gradcheck(
nondet_tol=nondet_tol)
numerical = get_numerical_jacobian(fn, tupled_inputs, eps=eps)
if not correct_grad_sizes:
return fail_test('Analytical gradient has incorrect size')
out_is_complex = o.is_complex()
if out_is_complex:
# analytical vjp with grad_out = 1.0j
analytical_with_imag_grad_out, reentrant_with_imag_grad_out, \
correct_grad_sizes_with_imag_grad_out, correct_grad_types_with_imag_grad_out \
= get_analytical_jacobian(tupled_inputs, o, nondet_tol=nondet_tol, grad_out=1j)
numerical_with_imag_grad_out = get_numerical_jacobian(fn, tupled_inputs, eps=eps, grad_out=1j)
if not correct_grad_types and check_grad_dtypes:
return fail_test('Gradient has dtype mismatch')
for j, (a, n) in enumerate(zip(analytical, numerical)):
if a.numel() != 0 or n.numel() != 0:
if out_is_complex and not correct_grad_types_with_imag_grad_out and check_grad_dtypes:
return fail_test('Gradient (calculated using complex valued grad output) has dtype mismatch')
if not correct_grad_sizes:
return fail_test('Analytical gradient has incorrect size')
if out_is_complex and not correct_grad_sizes_with_imag_grad_out:
return fail_test('Analytical gradient (calculated using complex valued grad output) has incorrect size')
def checkIfNumericalAnalyticAreClose(a, n, j, error_str=''):
if not torch.allclose(a, n, rtol, atol):
return fail_test('Jacobian mismatch for output %d with respect to input %d,\n'
return fail_test(error_str + 'Jacobian mismatch for output %d with respect to input %d,\n'
'numerical:%s\nanalytical:%s\n' % (i, j, n, a))
inp_tensors = iter_tensors(tupled_inputs, True)
for j, (a, n, inp) in enumerate(zip(analytical, numerical, inp_tensors)):
if a.numel() != 0 or n.numel() != 0:
if o.is_complex():
# C -> C, R -> C
a_with_imag_grad_out = analytical_with_imag_grad_out[j]
n_with_imag_grad_out = numerical_with_imag_grad_out[j]
checkIfNumericalAnalyticAreClose(a_with_imag_grad_out, n_with_imag_grad_out, j,
"Gradients failed to compare equal for grad output = 1j. ")
if inp.is_complex():
# C -> R, C -> C
checkIfNumericalAnalyticAreClose(a, n, j,
"Gradients failed to compare equal for grad output = 1. ")
else:
# R -> R, R -> C
checkIfNumericalAnalyticAreClose(a, n, j)
def not_reentrant_error(error_str=''):
error_msg = "Backward" + error_str + " is not reentrant, i.e., running backward with same \
input and grad_output multiple times gives different values, \
although analytical gradient matches numerical gradient. \
The tolerance for nondeterminism was {}.".format(nondet_tol)
return fail_test(error_msg)
if not reentrant:
return fail_test('Backward is not reentrant, i.e., running backward with same '
'input and grad_output multiple times gives different values, '
'although analytical gradient matches numerical gradient. '
'The tolerance for nondeterminism was {}.'.format(nondet_tol))
return not_reentrant_error()
if out_is_complex and not reentrant_with_imag_grad_out:
return not_reentrant_error(' (calculated using complex valued grad output)')
# check if the backward multiplies by grad_output
output = _differentiable_outputs(func(*tupled_inputs))

View File

@ -211,6 +211,15 @@ Tensor mvlgamma_backward(Tensor grad, const Tensor & self, int64_t p) {
return grad * args.digamma_().sum(-1);
}
Tensor mul_tensor_backward(Tensor grad, Tensor other, ScalarType self_st) {
auto result = grad * other.conj();
if (!at::isComplexType(self_st) && result.is_complex()) {
// R -> C
result = at::real(result);
}
return result;
}
Tensor permute_backwards(const Tensor & grad, IntArrayRef fwd_dims) {
// invert the permutation
auto ndims = fwd_dims.size();

View File

@ -43,6 +43,7 @@ at::Tensor pow_backward(at::Tensor grad, const at::Tensor & self, const at::Scal
at::Tensor pow_backward_self(at::Tensor grad, const at::Tensor & self, const at::Tensor & exponent);
at::Tensor pow_backward_exponent(at::Tensor grad, const at::Tensor& self, const at::Tensor& exponent, at::Tensor result);
at::Tensor pow_backward_exponent(at::Tensor grad, const at::Scalar & base, const at::Tensor& exponent, at::Tensor result);
at::Tensor mul_tensor_backward(Tensor grad, Tensor other, ScalarType self_st);
at::Tensor mvlgamma_backward(at::Tensor grad, const at::Tensor & self, int64_t p);
at::Tensor permute_backwards(const at::Tensor & grad, at::IntArrayRef fwd_dims);
at::Tensor rad2deg_backward(const at::Tensor& grad);

View File

@ -154,7 +154,6 @@ class UnaryUfuncInfo(OpInfo):
- they typically have method and inplace variants
- they typically support the out kwarg
- they typically have NumPy or SciPy references
See NumPy's universal function documentation
(https://numpy.org/doc/1.18/reference/ufuncs.html) for more details
about the concept of ufuncs.
@ -522,6 +521,9 @@ def method_tests():
('mul', (), ((S, S, S),), 'scalar_broadcast_lhs', (True,)),
('mul', (S, S, S), (3.14,), 'constant', (True,)),
('mul', (), (3.14,), 'scalar_constant', (True,)),
# TODO(@anjali411): enable these tests
# ('mul', (S, S, S), (3.14j,), 'imaginary_constant', (True,)),
# ('mul', (), (3.14j,), 'imaginary_scalar_constant', (True,)),
('__rmul__', (S, S, S), (3.14,), 'constant', (True, 'aten::mul')),
('__rmul__', (), (3.14,), 'scalar_constant', (True, 'aten::mul')),
('div', (S, S, S), (torch.rand(S, S, S) + 0.1,), '', (True,)),
@ -626,12 +628,13 @@ def method_tests():
('log1p', uniform_scalar(requires_grad=True), NO_ARGS, 'scalar', (True,)),
('log2', torch.rand(S, S, S) + 1e-2, NO_ARGS, '', (True,)),
('log2', uniform_scalar(1e-2, requires_grad=True), NO_ARGS, 'scalar', (True,)),
('log', torch.randn(S, S, S, dtype=torch.cfloat) + 1e-2, NO_ARGS, 'complex', (True,)),
('log', uniform_scalar(1e-2j, requires_grad=True), NO_ARGS, 'complex_scalar', (True,)),
('log10', torch.randn(S, S, S, dtype=torch.cfloat) + 1e-2, NO_ARGS, 'complex', (True,)),
('log10', uniform_scalar(1e-2j, requires_grad=True), NO_ARGS, 'complex_scalar', (True,)),
('log2', torch.randn(S, S, S, dtype=torch.cfloat) + 1e-2, NO_ARGS, 'complex', (True,)),
('log2', uniform_scalar(1e-2j, requires_grad=True), NO_ARGS, 'complex_scalar', (True,)),
# TODO(@anjali411): add the commented tests back after updating the formula based on tensorflow definition.
# ('log', torch.randn(S, S, S, dtype=torch.cfloat) + 1e-2, NO_ARGS, 'complex', (True,)),
# ('log', uniform_scalar(1e-2j, requires_grad=True), NO_ARGS, 'complex_scalar', (True,)),
# ('log10', torch.randn(S, S, S, dtype=torch.cfloat) + 1e-2, NO_ARGS, 'complex', (True,)),
# ('log10', uniform_scalar(1e-2j, requires_grad=True), NO_ARGS, 'complex_scalar', (True,)),
# ('log2', torch.randn(S, S, S, dtype=torch.cfloat) + 1e-2, NO_ARGS, 'complex', (True,)),
# ('log2', uniform_scalar(1e-2j, requires_grad=True), NO_ARGS, 'complex_scalar', (True,)),
('tanh', (S, S, S), NO_ARGS, '', (True,)),
('tanh', (), NO_ARGS, 'scalar', (True,)),
('sigmoid', (S, S, S), NO_ARGS, '', (True,)),
@ -644,6 +647,12 @@ def method_tests():
('sinh', (), NO_ARGS, 'scalar', (True,)),
('cosh', (S, S, S), NO_ARGS, '', (True,)),
('cosh', (), NO_ARGS, 'scalar', (True,)),
('conj', (S, S, S), NO_ARGS),
('real', (S, S, S), NO_ARGS, 'complex'),
('imag', (S, S, S), NO_ARGS, 'complex'),
('view_as_real', (S, S, S), NO_ARGS, 'complex'),
('view_as_complex', (S, S, 2), NO_ARGS),
('complex', (S, S, S), ((S, S, S),), ''),
('abs', (S, S, S), NO_ARGS, '', (True,)),
('abs', (), NO_ARGS, 'scalar', (True,)),
('clamp', (S, S, S), (0, 1), '', (True,)),
@ -660,7 +669,8 @@ def method_tests():
('cos', (S, S, S), NO_ARGS, '', (True,)),
('cos', (), NO_ARGS, 'scalar', (True,)),
('tan', torch.randn(S, S, S).clamp(-1, 1), NO_ARGS, '', (True,)),
('tan', (S, S, S), NO_ARGS, 'complex', (True,)),
# TODO(@anjali411): add the commented test back after updating the formula based on tensorflow definition.
# ('tan', (S, S, S), NO_ARGS, 'complex', (True,)),
('asin', torch.randn(S, S, S).clamp(-0.9, 0.9), NO_ARGS, '', (True,)),
('acos', torch.randn(S, S, S).clamp(-0.9, 0.9), NO_ARGS, '', (True,)),
('atan', (S, S, S), NO_ARGS, '', (True,)),
@ -672,8 +682,9 @@ def method_tests():
('atan2', (S, 1, S), ((S, S),), 'broadcast_all'),
('reciprocal', torch.rand(S, S, S) + 0.1, NO_ARGS, '', (True,)),
('reciprocal', uniform_scalar(0.1, requires_grad=True), NO_ARGS, 'scalar', (True,)),
('reciprocal', torch.randn(S, S, S, dtype=torch.cdouble) + 0.1, NO_ARGS, 'complex', (True,)),
('reciprocal', uniform_scalar(0.1j), NO_ARGS, 'complex_scalar', (True,)),
# TODO(@anjali411): add the commented tests back after updating the formula based on tensorflow definition.
# ('reciprocal', torch.randn(S, S, S, dtype=torch.cdouble) + 0.1, NO_ARGS, 'complex', (True,)),
# ('reciprocal', uniform_scalar(0.1j), NO_ARGS, 'complex_scalar', (True,)),
('round', (S, S, S), NO_ARGS, '', (True,)),
('round', (), NO_ARGS, 'scalar', (True,)),
('sign', (S, S, S), NO_ARGS),
@ -1643,6 +1654,11 @@ def exclude_tensor_method(name, test_name):
'test_std_mean_dim_1d',
'test_std_mean_dim',
'test_std_mean',
'test_view_as_complex',
'test_view_as_real_complex',
'test_real_complex',
'test_imag_complex',
'test_complex'
}
# there are no out-of-place tensor equivalents for these
exclude_outplace_tensor_method = {