Conjugate View (#54987)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/54987

Based off of ezyang (https://github.com/pytorch/pytorch/pull/44799) and bdhirsh (https://github.com/pytorch/pytorch/pull/43702) 's prototype:

Here's a summary of the changes in this PR:
This PR adds a new dispatch key called Conjugate. This enables us to make conjugate operation a view and leverage the specialized library functions that fast path with the hermitian operation (conj + transpose).

1. Conjugate operation will now return a view with conj bit (1) for complex tensors and returns self for non-complex tensors as before. This also means `torch.view_as_real` will no longer be a view on conjugated complex tensors and is hence disabled. To fill the gap, we have added `torch.view_as_real_physical` which would return the real tensor agnostic of the conjugate bit on the input complex tensor. The information about conjugation on the old tensor can be obtained by calling `.is_conj()` on the new tensor.
2. NEW API:
    a) `.conj()` -- now returning a view.
    b) `.conj_physical()` -- does the physical conjugate operation. If the conj bit for input was set, you'd get `self.clone()`, else you'll get a new tensor with conjugated value in its memory.
    c) `.conj_physical_()`, and `out=` variant
    d) `.resolve_conj()`  -- materializes the conjugation. returns self if the conj bit is unset, else returns a new tensor with conjugated values and conj bit set to 0.
    e) `.resolve_conj_()` in-place version of (d)
    f) `view_as_real_physical` -- as described in (1), it's functionally same as `view_as_real`, just that it doesn't error out on conjugated tensors.
    g) `view_as_real` -- existing function, but now errors out on conjugated tensors.
3. Conjugate Fallback
    a) Vast majority of PyTorch functions would currently use this fallback when they are called on a conjugated tensor.
    b) This fallback is well equipped to handle the following cases:
        - functional operation e.g., `torch.sin(input)`
        - Mutable inputs and in-place operations e.g., `tensor.add_(2)`
        - out-of-place operation e.g., `torch.sin(input, out=out)`
        - Tensorlist input args
        - NOTE: Meta tensors don't work with conjugate fallback.
4. Autograd
    a) `resolve_conj()` is an identity function w.r.t. autograd
    b) Everything else works as expected.
5. Testing:
    a) All method_tests run with conjugate view tensors.
    b) OpInfo tests that run with conjugate views
        - test_variant_consistency_eager/jit
        - gradcheck, gradgradcheck
        - test_conj_views (that only run for `torch.cfloat` dtype)

NOTE: functions like `empty_like`, `zero_like`, `randn_like`, `clone` don't propagate the conjugate bit.

Follow up work:
1. conjugate view RFC
2. Add neg bit to re-enable view operation on conjugated tensors
3. Update linalg functions to call into specialized functions that fast path with the hermitian operation.

Test Plan: Imported from OSS

Reviewed By: VitalyFedyunin

Differential Revision: D28227315

Pulled By: anjali411

fbshipit-source-id: acab9402b9d6a970c6d512809b627a290c8def5f
This commit is contained in:
anjali411 2021-06-04 14:11:23 -07:00 committed by Facebook GitHub Bot
parent 19985d6f84
commit 3607478ecd
43 changed files with 783 additions and 180 deletions

View File

@ -1031,7 +1031,6 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
m.impl("sum.dim_IntList", sum_batching_rule); m.impl("sum.dim_IntList", sum_batching_rule);
m.impl("is_complex", native::is_complex); m.impl("is_complex", native::is_complex);
m.impl("conj", native::conj);
// inplace operations // inplace operations
m.impl("fill_.Scalar", fill_inplace_scalar_batching_rule); m.impl("fill_.Scalar", fill_inplace_scalar_batching_rule);
@ -1085,7 +1084,7 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
UNARY_POINTWISE(ceil); UNARY_POINTWISE(ceil);
UNARY_POINTWISE(cos); UNARY_POINTWISE(cos);
UNARY_POINTWISE(cosh); UNARY_POINTWISE(cosh);
UNARY_POINTWISE(_conj); UNARY_POINTWISE(conj_physical);
UNARY_POINTWISE(digamma); UNARY_POINTWISE(digamma);
UNARY_POINTWISE(exp); UNARY_POINTWISE(exp);
UNARY_POINTWISE(expm1); UNARY_POINTWISE(expm1);
@ -1181,6 +1180,10 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
TRIVIAL_OP(imag) TRIVIAL_OP(imag)
TRIVIAL_OP(real); TRIVIAL_OP(real);
TRIVIAL_OP(view_as_real); TRIVIAL_OP(view_as_real);
TRIVIAL_OP(_view_as_real_physical);
TRIVIAL_OP(conj);
TRIVIAL_OP(_conj);
TRIVIAL_OP(resolve_conj);
m.impl("view_as_complex", view_as_complex_batching_rule); m.impl("view_as_complex", view_as_complex_batching_rule);
#undef TRIVIAL #undef TRIVIAL

View File

@ -0,0 +1,152 @@
#include <ATen/ATen.h>
#include <ATen/core/op_registration/op_registration.h>
#include <torch/library.h>
#include <ATen/core/dispatch/Dispatcher.h>
#include <ATen/native/UnaryOps.h>
#include <ATen/NativeFunctions.h>
namespace at {
void conjugateFallback(const c10::OperatorHandle& op, DispatchKeySet dispatch_keys, torch::jit::Stack* stack) {
// Situations to handle:
// 1. Out-of-place operation. Easy: materialize all inputs and
// call it a day.
// 2. Inplace operation. Desugar x.add_(2) into x.conj_().add_(2).conj_().
// Materialize other inputs as in (1).
// 3. out= operation. Desugar add(x, 2, out=y) into y.copy_(add(x, 2))
// Materialize other inputs as in (1).
//
// It is important to be able to tell if we READ from an argument and if we
// WRITE from an argument. Conservative approach is to assume that we always
// READ from an argument, but in out-of-place operations you can skip
// conjugating inputs on entry that never get used. In current schema we
// can't easily tell if inplace situation has happened, so don't do it.
const auto& arguments = op.schema().arguments();
const auto num_arguments = arguments.size();
const auto stack_start = stack->size() - num_arguments;
c10::optional<bool> is_write;
for (int64_t i = 0; i < num_arguments; ++i) {
const auto& alias_info = arguments[i].alias_info();
// Three possible states:
// 1. alias_info has no value --> out-of-place operation
// 2. alias_info does have a value, alias_info->is_write=True --> in-place or out= operation
// 3. alias_info does have a value, alias_info->is_write=False --> view operation
if (alias_info.has_value()) {
if (is_write.has_value()) {
TORCH_CHECK(*is_write == alias_info->isWrite(),
"Unsupported operator for conjugate fallback: ", op.schema().name(),
"Conjugate fallback doesn't work for operators with a mix "
"mutable and non-mutable inputs that alias with outputs, "
"this must be implemented manually. "
"If you got this error on a core op, please report a bug to PyTorch.");
} else {
is_write = alias_info->isWrite();
}
}
}
if (is_write.has_value() && !*is_write) {
// We assume that view operators automatically handle conjugation
// correctly by propagating the Conjugate dispatch key in key_set.
// This is not necessarily always right, so you should test these cases.
op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, DispatchKey::Conjugate), stack);
return;
}
// Mutable inputs to be tracked separately
std::vector<Tensor> mutable_inputs;
for (int64_t i = 0; i < num_arguments; ++i) {
auto& ivalue = (*stack)[stack_start + i];
if (!(ivalue.isTensor() || ivalue.isTensorList())) {
continue;
}
const auto& argument = arguments[i];
bool mut_arg = false;
if (argument.alias_info()) {
// View operations were already filtered above, so only in-place/out= operations should get here.
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(argument.alias_info()->isWrite());
mut_arg = true;
}
if (ivalue.isTensor()) {
auto* impl = ivalue.unsafeToTensorImpl();
if (!impl->is_conj()) {
continue;
}
auto tensor = std::move(ivalue).toTensor();
TORCH_CHECK_NOT_IMPLEMENTED(!tensor.is_meta(), "Conjugate Fallback does not support meta tensors.");
if (mut_arg) {
// TODO: This is a waste if the argument is write only
tensor._set_conj(false);
at::conj_physical_(tensor);
mutable_inputs.emplace_back(tensor);
} else {
tensor = at::resolve_conj(tensor);
}
(*stack)[stack_start + i] = std::move(tensor);
} else if (ivalue.isTensorList()) {
auto tensors = std::move(ivalue).toTensorList();
if (mut_arg) {
for(const auto j : c10::irange(tensors.size())) {
Tensor t = tensors[j];
t._set_conj(false);
at::conj_physical_(t);
mutable_inputs.emplace_back(t);
}
} else {
for(const auto j : c10::irange(tensors.size())) {
tensors[j] = at::resolve_conj(tensors[j]);
}
}
(*stack)[stack_start + i] = std::move(tensors);
}
}
op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, DispatchKey::Conjugate), stack);
for (auto& mutable_input : mutable_inputs) {
at::conj_physical_(mutable_input);
mutable_input._set_conj(true);
}
}
TORCH_LIBRARY_IMPL(_, Conjugate, m) {
m.fallback(torch::CppFunction::makeFromBoxedFunction<&conjugateFallback>());
}
TORCH_LIBRARY_IMPL(aten, Conjugate, m) {
m.impl("requires_grad_", torch::CppFunction::makeFallthrough());
m.impl("set_.source_Storage_storage_offset", torch::CppFunction::makeFallthrough());
m.impl("set_.source_Tensor", torch::CppFunction::makeFallthrough());
m.impl("set_", torch::CppFunction::makeFallthrough());
m.impl("copy_", torch::CppFunction::makeFallthrough());
m.impl("clone", torch::CppFunction::makeFallthrough());
m.impl("conj", torch::CppFunction::makeFallthrough());
m.impl("_conj", torch::CppFunction::makeFallthrough());
m.impl("_conj_physical", torch::CppFunction::makeFallthrough());
m.impl("conj_physical", torch::CppFunction::makeFallthrough());
m.impl("conj_physical_", torch::CppFunction::makeFallthrough());
m.impl("resolve_conj", torch::CppFunction::makeFallthrough());
m.impl("empty_like", torch::CppFunction::makeFallthrough());
m.impl("empty.memory_format", torch::CppFunction::makeFallthrough());
m.impl("empty.out", torch::CppFunction::makeFallthrough());
m.impl("empty_strided", torch::CppFunction::makeFallthrough());
m.impl("full_like", torch::CppFunction::makeFallthrough());
m.impl("stride.int", torch::CppFunction::makeFallthrough());
m.impl("stride.Dimname", torch::CppFunction::makeFallthrough());
m.impl("size.int", torch::CppFunction::makeFallthrough());
m.impl("size.Dimname", torch::CppFunction::makeFallthrough());
m.impl("is_complex", torch::CppFunction::makeFallthrough());
m.impl("_view_as_real_physical", torch::CppFunction::makeFallthrough());
m.impl("view_as_real", torch::CppFunction::makeFallthrough());
m.impl("imag", torch::CppFunction::makeFallthrough());
m.impl("real", torch::CppFunction::makeFallthrough());
m.impl("view", torch::CppFunction::makeFallthrough());
m.impl("reshape", torch::CppFunction::makeFallthrough());
}
} // namespace at

View File

@ -117,7 +117,6 @@ TORCH_LIBRARY_IMPL(aten, Named, m) {
m.impl("clamp_min_.Tensor", CppFunction::makeFallthrough()); m.impl("clamp_min_.Tensor", CppFunction::makeFallthrough());
m.impl("clone", CppFunction::makeFallthrough()); m.impl("clone", CppFunction::makeFallthrough());
m.impl("conj", CppFunction::makeFallthrough()); m.impl("conj", CppFunction::makeFallthrough());
m.impl("conj.out", CppFunction::makeFallthrough());
m.impl("contiguous", CppFunction::makeFallthrough()); m.impl("contiguous", CppFunction::makeFallthrough());
m.impl("copy_", CppFunction::makeFallthrough()); m.impl("copy_", CppFunction::makeFallthrough());
m.impl("cos", CppFunction::makeFallthrough()); m.impl("cos", CppFunction::makeFallthrough());

View File

@ -238,6 +238,9 @@ _(aten, coalesce) \
_(aten, combinations) \ _(aten, combinations) \
_(aten, _conj) \ _(aten, _conj) \
_(aten, conj) \ _(aten, conj) \
_(aten, conj_physical) \
_(aten, conj_physical_) \
_(aten, resolve_conj) \
_(aten, complex) \ _(aten, complex) \
_(aten, copysign) \ _(aten, copysign) \
_(aten, polar) \ _(aten, polar) \
@ -764,6 +767,7 @@ _(aten, zeros_like) \
_(aten, real) \ _(aten, real) \
_(aten, imag) \ _(aten, imag) \
_(aten, view_as_real) \ _(aten, view_as_real) \
_(aten, _view_as_real_physical) \
_(aten, view_as_complex) \ _(aten, view_as_complex) \
/* nothing */ /* nothing */

View File

@ -2,6 +2,9 @@
#include <ATen/ATen.h> #include <ATen/ATen.h>
// WARNING: this header contains non-inline functions and should be only
// included from ONE cpp file
namespace at { namespace native { namespace at { namespace native {
// View tensor with new dtype, storage offset, sizes and strides // View tensor with new dtype, storage offset, sizes and strides
@ -9,8 +12,9 @@ inline Tensor view_tensor(
const Tensor &tensor, ScalarType dtype, const Tensor &tensor, ScalarType dtype,
int64_t offset, IntArrayRef sizes, IntArrayRef strides) { int64_t offset, IntArrayRef sizes, IntArrayRef strides) {
Storage storage = tensor.storage(); Storage storage = tensor.storage();
auto key_set = tensor.key_set().remove(DispatchKey::Conjugate);
auto new_tensor = detail::make_tensor<TensorImpl>( auto new_tensor = detail::make_tensor<TensorImpl>(
c10::TensorImpl::VIEW, std::move(storage), tensor.key_set(), scalarTypeToTypeMeta(dtype)); c10::TensorImpl::VIEW, std::move(storage), key_set, scalarTypeToTypeMeta(dtype));
auto * impl = new_tensor.unsafeGetTensorImpl(); auto * impl = new_tensor.unsafeGetTensorImpl();
impl->set_storage_offset(offset); impl->set_storage_offset(offset);
impl->set_sizes_and_strides(sizes, strides); impl->set_sizes_and_strides(sizes, strides);
@ -30,7 +34,12 @@ inline DimVector computeStrideForViewAsReal(IntArrayRef oldstride) {
// with corresponding real dtype containing the complex values // with corresponding real dtype containing the complex values
// in the last two dimensions // in the last two dimensions
Tensor view_as_real(const Tensor& self) { Tensor view_as_real(const Tensor& self) {
TORCH_CHECK(self.is_complex(), "view_as_real is only supported for complex tensors"); TORCH_CHECK(!self.is_conj(), "view_as_real doesn't work on unresolved conjugated tensors. To resolve the conjugate tensor so you can view it as real, use self.resolve_conj(); however, be warned that the resulting tensor will NOT alias the original.");
return native::_view_as_real_physical(self);
}
Tensor _view_as_real_physical(const Tensor& self) {
TORCH_CHECK(self.is_complex(), "view_as_real_physical is only supported for complex tensors");
auto old_sizes = self.sizes(); auto old_sizes = self.sizes();
DimVector new_sizes(old_sizes.size() + 1); DimVector new_sizes(old_sizes.size() + 1);
std::copy(old_sizes.begin(), old_sizes.end(), new_sizes.begin()); std::copy(old_sizes.begin(), old_sizes.end(), new_sizes.begin());
@ -39,7 +48,8 @@ Tensor view_as_real(const Tensor& self) {
auto new_strides = computeStrideForViewAsReal(self.strides()); auto new_strides = computeStrideForViewAsReal(self.strides());
auto new_storage_offset = 2 * self.storage_offset(); auto new_storage_offset = 2 * self.storage_offset();
const auto float_type = c10::toValueType(self.scalar_type()); const auto float_type = c10::toValueType(self.scalar_type());
return view_tensor(self, float_type, new_storage_offset, new_sizes, new_strides); auto real_tensor = view_tensor(self, float_type, new_storage_offset, new_sizes, new_strides);
return real_tensor;
} }
inline DimVector computeStrideForViewAsComplex(IntArrayRef oldstride) { inline DimVector computeStrideForViewAsComplex(IntArrayRef oldstride) {

View File

@ -150,7 +150,7 @@ Tensor fft_r2c(c10::string_view function_name,
if (!forward) { if (!forward) {
// FIXME: _fft_r2c doesn't support native r2c IFFT // FIXME: _fft_r2c doesn't support native r2c IFFT
return out.defined() ? at::conj_out(out, ret) : at::conj(ret); return out.defined() ? at::conj_physical_out(out, ret) : at::conj(ret);
} else { } else {
return ret; return ret;
} }

View File

@ -349,6 +349,8 @@ Tensor empty_like(
namedinference::propagate_names(result, self.names()); namedinference::propagate_names(result, self.names());
} }
// never propagate Conjugate key
result._set_conj(false);
return result; return result;
} }

View File

@ -30,6 +30,10 @@ bool is_signed(const Tensor &self) {
return self.is_signed(); return self.is_signed();
} }
bool is_conj(const Tensor& self) {
return self.is_conj();
}
bool is_sparse(const Tensor& self) { bool is_sparse(const Tensor& self) {
return self.is_sparse(); return self.is_sparse();
} }

View File

@ -28,7 +28,8 @@ namespace at {
namespace meta { namespace meta {
// Unary float operations always produce floating point // Unary float operations always produce floating point
// outputs, even if their inputs are integer // outputs for floating point and integral types
// For complex inputs, the output type should be the same as input type.
#define CREATE_UNARY_FLOAT_META_FUNC(func) \ #define CREATE_UNARY_FLOAT_META_FUNC(func) \
TORCH_META_FUNC(func) (const Tensor& self) { \ TORCH_META_FUNC(func) (const Tensor& self) { \
build_unary_float_op(maybe_get_output(), self); \ build_unary_float_op(maybe_get_output(), self); \
@ -363,7 +364,8 @@ Tensor angle(const Tensor& self) {
Tensor real(const Tensor& self) { Tensor real(const Tensor& self) {
if (self.is_complex()) { if (self.is_complex()) {
auto real_tensor = at::view_as_real(self); // real is never affected by conjugate bit, safe to use physical version
auto real_tensor = at::_view_as_real_physical(self);
return at::select(real_tensor, real_tensor.dim() - 1, 0); return at::select(real_tensor, real_tensor.dim() - 1, 0);
} else { } else {
TORCH_CHECK(false, "real is not implemented for tensors with non-complex dtypes."); TORCH_CHECK(false, "real is not implemented for tensors with non-complex dtypes.");
@ -379,17 +381,44 @@ Tensor imag(const Tensor& self) {
} }
} }
Tensor& conj_out(const Tensor& self, Tensor& result) { Tensor& conj_physical_out(const Tensor& self, Tensor& result) {
return unary_op_impl_out(result, self, conj_stub); return unary_op_impl_out(result, self, conj_physical_stub);
} }
Tensor _conj(const Tensor& self) { return unary_op_impl(self, at::conj_out); } Tensor _conj_physical(const Tensor& self) {
if (self.is_conj()) {
return self.conj().clone();
}
return unary_op_impl(self, at::conj_physical_out);
}
Tensor conj_physical(const Tensor& self) {
if (!self.is_complex()) return self;
return at::_conj_physical(self);
}
Tensor& conj_physical_(Tensor& self) {
if (!self.is_complex()) return self;
return unary_op_impl_out(self, self, conj_physical_stub);
}
Tensor resolve_conj(const Tensor& self) {
if (!self.is_conj()) { return self; }
// conjugation is handled in `copy_()` that clone ultimately calls into
return self.clone(self.suggest_memory_format());
}
Tensor _conj(const Tensor& self) {
Tensor self_ = self.alias();
self_._set_conj(!self.is_conj());
namedinference::propagate_names(self_, self);
return self_;
}
Tensor conj(const Tensor& self) { Tensor conj(const Tensor& self) {
if (!self.is_complex()) { // This might look like an infinite recursion but it's not.
return self; // This actually calls into `conj()` defined in the Tensor class.
} return self.conj();
return at::_conj(self);
} }
// special_exp2, alias for exp2 // special_exp2, alias for exp2
@ -689,7 +718,7 @@ DEFINE_DISPATCH(abs_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-va
DEFINE_DISPATCH(angle_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) DEFINE_DISPATCH(angle_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(real_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) DEFINE_DISPATCH(real_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(imag_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) DEFINE_DISPATCH(imag_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(conj_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) DEFINE_DISPATCH(conj_physical_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(acos_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) DEFINE_DISPATCH(acos_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(acosh_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) DEFINE_DISPATCH(acosh_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(asinh_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) DEFINE_DISPATCH(asinh_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)

View File

@ -14,7 +14,7 @@ DECLARE_DISPATCH(unary_fn, abs_stub);
DECLARE_DISPATCH(unary_fn, angle_stub); DECLARE_DISPATCH(unary_fn, angle_stub);
DECLARE_DISPATCH(unary_fn, real_stub); DECLARE_DISPATCH(unary_fn, real_stub);
DECLARE_DISPATCH(unary_fn, imag_stub); DECLARE_DISPATCH(unary_fn, imag_stub);
DECLARE_DISPATCH(unary_fn, conj_stub); DECLARE_DISPATCH(unary_fn, conj_physical_stub);
DECLARE_DISPATCH(unary_fn, acos_stub); DECLARE_DISPATCH(unary_fn, acos_stub);
DECLARE_DISPATCH(unary_fn, acosh_stub); DECLARE_DISPATCH(unary_fn, acosh_stub);
DECLARE_DISPATCH(unary_fn, asinh_stub); DECLARE_DISPATCH(unary_fn, asinh_stub);

View File

@ -5,6 +5,7 @@
#include <ATen/native/TensorIterator.h> #include <ATen/native/TensorIterator.h>
#include <ATen/native/cpu/Loops.h> #include <ATen/native/cpu/Loops.h>
#include <c10/util/TypeCast.h> #include <c10/util/TypeCast.h>
#include <ATen/native/cpu/zmath.h>
namespace at { namespace at {
namespace native { namespace native {
@ -13,6 +14,15 @@ namespace {
static void copy_kernel(TensorIterator& iter, bool non_blocking) { static void copy_kernel(TensorIterator& iter, bool non_blocking) {
ScalarType dtype = iter.dtype(0); ScalarType dtype = iter.dtype(0);
if (dtype == iter.dtype(1)) { if (dtype == iter.dtype(1)) {
// TODO: as the majority of these operations can be done treating
// their datatypes as opaque bit patterns, we don't actually need
// separate instantiations per dtype; we only need a separate
// instantiation per dtype size. This would probably save us a
// little bit of code size here
// TODO: not sure if optimizer is able to compile two levels of
// conditionals into a single jump table. We should have a
// single jump table here; might be worth just writing out the
// dispatch statement by hand instead of using AT_DISPATCH
if (dtype == ScalarType::Half) { if (dtype == ScalarType::Half) {
cpu_kernel(iter, [=](at::Half a) -> at::Half { return a; }); cpu_kernel(iter, [=](at::Half a) -> at::Half { return a; });
} else if (dtype == ScalarType::ComplexHalf) { } else if (dtype == ScalarType::ComplexHalf) {
@ -25,12 +35,21 @@ static void copy_kernel(TensorIterator& iter, bool non_blocking) {
[=](Vectorized<scalar_t> a) -> Vectorized<scalar_t> { return a; }); [=](Vectorized<scalar_t> a) -> Vectorized<scalar_t> { return a; });
}); });
} else if (isComplexType(dtype)) { } else if (isComplexType(dtype)) {
AT_DISPATCH_COMPLEX_TYPES(dtype, "copy_kernel", [&] { if (iter.tensor(0).is_conj() == iter.tensor(1).is_conj()) {
cpu_kernel_vec( AT_DISPATCH_COMPLEX_TYPES(dtype, "copy_kernel", [&] {
iter, cpu_kernel_vec(
[=](scalar_t a) -> scalar_t { return a; }, iter,
[=](Vectorized<scalar_t> a) -> Vectorized<scalar_t> { return a; }); [=](scalar_t a) -> scalar_t { return a; },
}); [=](Vectorized<scalar_t> a) -> Vectorized<scalar_t> { return a; });
});
} else {
AT_DISPATCH_COMPLEX_TYPES(dtype, "conj_kernel", [&] {
cpu_kernel_vec(
iter,
[=](scalar_t a) -> scalar_t { return conj_impl(a); },
[=](Vectorized<scalar_t> a) -> Vectorized<scalar_t> { return a.conj(); });
});
}
} else { } else {
AT_DISPATCH_ALL_TYPES_AND2( AT_DISPATCH_ALL_TYPES_AND2(
ScalarType::Bool, ScalarType::BFloat16,dtype, "copy_kernel", [&] { ScalarType::Bool, ScalarType::BFloat16,dtype, "copy_kernel", [&] {
@ -62,6 +81,9 @@ static void copy_kernel(TensorIterator& iter, bool non_blocking) {
return c10::static_cast_with_inter_type<dest_t, scalar_t>::apply(src); }); return c10::static_cast_with_inter_type<dest_t, scalar_t>::apply(src); });
}); });
}); });
if (iter.tensor(0).is_conj() != iter.tensor(1).is_conj()) {
iter.tensor(0).conj_physical_();
}
} }
} }

View File

@ -193,6 +193,7 @@ static void imag_kernel(TensorIteratorBase& iter) {
}); });
} }
// NB: Ignores the negative bit on tensors
static void conj_kernel(TensorIteratorBase& iter) { static void conj_kernel(TensorIteratorBase& iter) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
kBool, kBFloat16, kHalf, iter.common_dtype(), "conj_cpu", [&]() { kBool, kBFloat16, kHalf, iter.common_dtype(), "conj_cpu", [&]() {
@ -716,7 +717,7 @@ REGISTER_DISPATCH(real_stub, &CPU_CAPABILITY::real_kernel);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_DISPATCH(imag_stub, &CPU_CAPABILITY::imag_kernel); REGISTER_DISPATCH(imag_stub, &CPU_CAPABILITY::imag_kernel);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_DISPATCH(conj_stub, &CPU_CAPABILITY::conj_kernel); REGISTER_DISPATCH(conj_physical_stub, &CPU_CAPABILITY::conj_kernel);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_DISPATCH(exp2_stub, &CPU_CAPABILITY::exp2_kernel); REGISTER_DISPATCH(exp2_stub, &CPU_CAPABILITY::exp2_kernel);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)

View File

@ -25,7 +25,8 @@ void copy_device_to_device(TensorIterator& iter, bool non_blocking) {
// We can memcpy the memory if both tensors have the same type AND both // We can memcpy the memory if both tensors have the same type AND both
// tensors are contiguous after dimension coalescing and reordering. // tensors are contiguous after dimension coalescing and reordering.
bool same_type = iter.dtype(0) == iter.dtype(1); bool same_type = iter.dtype(0) == iter.dtype(1);
bool memcpy_eligible = same_type && iter.is_contiguous(); bool same_conj = iter.tensor(0).is_conj() == iter.tensor(1).is_conj();
bool memcpy_eligible = same_type && same_conj && iter.is_contiguous();
Device dst_device = iter.device(0); Device dst_device = iter.device(0);
Device src_device = iter.device(1); Device src_device = iter.device(1);
@ -71,10 +72,17 @@ void copy_device_to_device(TensorIterator& iter, bool non_blocking) {
gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) { return x; }); gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) { return x; });
}); });
} else { } else {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( if (!same_conj && same_type) {
kHalf, kBool, kBFloat16, dtype, "copy_", [&] { AT_DISPATCH_COMPLEX_TYPES(
gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) { return x; }); dtype, "copy_conj_", [&] {
}); gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) { return std::conj(x); });
});
} else {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
kHalf, kBool, kBFloat16, dtype, "copy_", [&] {
gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) { return x; });
});
}
} }
} }
@ -152,6 +160,10 @@ static void copy_kernel_cuda(TensorIterator& iter, bool non_blocking) {
src_contig = iter.tensor(1).expand_as(dst).contiguous(); src_contig = iter.tensor(1).expand_as(dst).contiguous();
} }
// propagate the correct conjugate bit
dst_contig._set_conj(dst.is_conj());
src_contig._set_conj(iter.tensor(1).is_conj());
// perform a same-dtype copy on contiguous tensors // perform a same-dtype copy on contiguous tensors
TORCH_INTERNAL_ASSERT(dst_contig.sizes().equals(src_contig.sizes())); TORCH_INTERNAL_ASSERT(dst_contig.sizes().equals(src_contig.sizes()));
TORCH_INTERNAL_ASSERT(dst_contig.scalar_type() == src_contig.scalar_type()); TORCH_INTERNAL_ASSERT(dst_contig.scalar_type() == src_contig.scalar_type());
@ -201,6 +213,10 @@ static void copy_kernel_cuda(TensorIterator& iter, bool non_blocking) {
AT_CUDA_CHECK(cudaStreamSynchronize(stream)); AT_CUDA_CHECK(cudaStreamSynchronize(stream));
#endif #endif
} }
if (iter.tensor(0).is_conj() != iter.tensor(1).is_conj()) {
iter.tensor(0).conj_physical_();
}
} }
REGISTER_DISPATCH(copy_stub, &copy_kernel_cuda); REGISTER_DISPATCH(copy_stub, &copy_kernel_cuda);

View File

@ -80,6 +80,7 @@ __host__ __device__ static inline c10::complex<T> conj_wrapper(c10::complex<T> v
return std::conj(v); return std::conj(v);
} }
// NB: Ignores the negative bit on tensors
void conj_kernel_cuda(TensorIteratorBase& iter) { void conj_kernel_cuda(TensorIteratorBase& iter) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
kBool, kBFloat16, kHalf, iter.common_dtype(), "conj_cuda", [&]() { kBool, kBFloat16, kHalf, iter.common_dtype(), "conj_cuda", [&]() {
@ -92,6 +93,6 @@ void conj_kernel_cuda(TensorIteratorBase& iter) {
REGISTER_DISPATCH(angle_stub, &angle_kernel_cuda); REGISTER_DISPATCH(angle_stub, &angle_kernel_cuda);
REGISTER_DISPATCH(real_stub, &real_kernel_cuda); REGISTER_DISPATCH(real_stub, &real_kernel_cuda);
REGISTER_DISPATCH(imag_stub, &imag_kernel_cuda); REGISTER_DISPATCH(imag_stub, &imag_kernel_cuda);
REGISTER_DISPATCH(conj_stub, &conj_kernel_cuda); REGISTER_DISPATCH(conj_physical_stub, &conj_kernel_cuda);
}} // namespace at::native }} // namespace at::native

View File

@ -271,6 +271,11 @@
dispatch: dispatch:
CPU, CUDA: view_as_real CPU, CUDA: view_as_real
- func: _view_as_real_physical(Tensor(a) self) -> Tensor(a)
variants: function
dispatch:
CPU, CUDA: _view_as_real_physical
- func: view_as_complex(Tensor(a) self) -> Tensor(a) - func: view_as_complex(Tensor(a) self) -> Tensor(a)
variants: function variants: function
dispatch: dispatch:
@ -298,21 +303,36 @@
device_check: NoCheck # TensorIterator device_check: NoCheck # TensorIterator
variants: function variants: function
- func: conj(Tensor(a) self) -> Tensor(a) - func: _conj(Tensor(a) self) -> Tensor(a)
device_check: NoCheck # TensorIterator
variants: function, method variants: function, method
- func: conj.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
dispatch:
CPU, CUDA: conj_out
SparseCPU, SparseCUDA: conj_out_sparse
- func: _conj(Tensor self) -> Tensor
variants: function
dispatch: dispatch:
CompositeExplicitAutograd: _conj CompositeExplicitAutograd: _conj
- func: conj(Tensor(a) self) -> Tensor(a)
variants: function, method
manual_cpp_binding: True
- func: _conj_physical(Tensor self) -> Tensor
variants: function, method
dispatch:
CompositeExplicitAutograd: _conj_physical
- func: conj_physical(Tensor self) -> Tensor
variants: function, method
- func: conj_physical.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU, CUDA: conj_physical_out
SparseCPU, SparseCUDA: conj_physical_out_sparse
- func: conj_physical_(Tensor(a!) self) -> Tensor(a!)
variants: function, method
dispatch:
CompositeExplicitAutograd: conj_physical_
- func: resolve_conj(Tensor(a) self) -> Tensor(a)
variants: function, method
- func: acos(Tensor self) -> Tensor - func: acos(Tensor self) -> Tensor
device_check: NoCheck # TensorIterator device_check: NoCheck # TensorIterator
variants: function, method variants: function, method
@ -2249,6 +2269,11 @@
device_guard: False device_guard: False
manual_cpp_binding: True manual_cpp_binding: True
- func: is_conj(Tensor self) -> bool
variants: function, method
device_guard: False
manual_cpp_binding: True
- func: isreal(Tensor self) -> Tensor - func: isreal(Tensor self) -> Tensor
variants: function, method variants: function, method

View File

@ -1645,15 +1645,15 @@ Tensor& bmm_out_sparse_cpu(const SparseTensor& self, const Tensor& mat2, Tensor&
return result; return result;
} }
Tensor conj_sparse(const Tensor& input) { // Tensor conj_physical_sparse(const Tensor& input) {
if (!input.is_complex()) { // if (!input.is_complex()) {
return input; // return input;
} // }
Tensor result = at::native::empty_like(input); // Tensor result = at::native::empty_like(input);
return conj_out_sparse(input, result); // return conj_physical_out_sparse(input, result);
} // }
Tensor& conj_out_sparse(const Tensor& input, Tensor& result) { Tensor& conj_physical_out_sparse(const Tensor& input, Tensor& result) {
TORCH_INTERNAL_ASSERT(input.is_sparse()); TORCH_INTERNAL_ASSERT(input.is_sparse());
if (input.numel() == 0) { if (input.numel() == 0) {
return result; return result;
@ -1665,7 +1665,7 @@ Tensor& conj_out_sparse(const Tensor& input, Tensor& result) {
return result; return result;
} }
Tensor result_values = result._values(); Tensor result_values = result._values();
at::conj_out(result_values, input._values()); at::conj_physical_out(result_values, input._values());
return result; return result;
} }

View File

@ -216,4 +216,12 @@ inline bool is_inference(const Tensor& tensor) {
return tensor.is_inference(); return tensor.is_inference();
} }
inline bool is_conj(const Tensor& tensor) {
return tensor.is_conj();
}
inline Tensor conj(const Tensor& tensor) {
return tensor.conj();
}
} }

View File

@ -136,6 +136,17 @@ class TORCH_API Tensor {
} }
} }
Tensor conj() const {
if (!this->is_complex()) {
return *this;
} else {
if (this->is_sparse()) {
return this->conj_physical();
}
return this->_conj();
}
}
/// Should be used if *this can reasonably be expected to be contiguous and /// Should be used if *this can reasonably be expected to be contiguous and
/// performance is important. /// performance is important.
/// Compared to contiguous, it saves a reference count /// Compared to contiguous, it saves a reference count
@ -363,6 +374,18 @@ class TORCH_API Tensor {
return !at::impl::variable_excluded_from_dispatch(); return !at::impl::variable_excluded_from_dispatch();
} }
inline bool is_conj() const {
return impl_->is_conj();
}
// sets the conjugate bit of a tensor.
// NOTE: Conjugate bit is supposed to be a read-only field. Only change this, if you are extremely sure
// that's what you want. Changing this might lead to incorrect behavior since conjugation is
// a lazy operation and we rely on this bit to determine if a conjugation needs to be materialized.
inline void _set_conj(bool conjugate) const {
impl_->_set_conj(conjugate);
}
/// Returns a `Tensor`'s layout. /// Returns a `Tensor`'s layout.
Layout layout() const noexcept { Layout layout() const noexcept {
return impl_->layout(); return impl_->layout();

View File

@ -65,6 +65,8 @@ const char* toString(DispatchKey t) {
case DispatchKey::PrivateUse3: case DispatchKey::PrivateUse3:
return "PrivateUse3"; return "PrivateUse3";
case DispatchKey::Conjugate:
return "Conjugate";
case DispatchKey::Meta: case DispatchKey::Meta:
return "Meta"; return "Meta";

View File

@ -142,6 +142,11 @@ enum class DispatchKey : uint8_t {
// constituent parts. // constituent parts.
Named, Named,
// The Conjugate dispatch key is set for any tensors that need to perform
// conjugation
// This is implemented at a dispatch level right before any backends run
Conjugate,
// See Note [Out-of-tree vmap+grad prototype]. The purpose of this key // See Note [Out-of-tree vmap+grad prototype]. The purpose of this key
// is to insert code after the "autograd subsystem" runs, so this key should // is to insert code after the "autograd subsystem" runs, so this key should
// be directly after ADInplaceOrView and all of the autograd keys. // be directly after ADInplaceOrView and all of the autograd keys.

View File

@ -938,6 +938,26 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
*/ */
const at::Tensor& grad() const; const at::Tensor& grad() const;
/**
* Whether or not the imaginary part of the tensor should be negated
*/
inline bool is_conj() const {
return key_set_.has(DispatchKey::Conjugate);
}
/**
* Set whether or not to take the conjugate of the tensor (flip the imaginary
* bit).
*/
void _set_conj(bool value) {
if (value) {
key_set_ = key_set_.add(DispatchKey::Conjugate);
TORCH_INTERNAL_ASSERT(isComplexType(typeMetaToScalarType(dtype())));
} else {
key_set_ = key_set_.remove(DispatchKey::Conjugate);
}
}
/** /**
* Return the accumulated gradient of a tensor. This gradient is computed * Return the accumulated gradient of a tensor. This gradient is computed
* using forward mode AD. * using forward mode AD.

View File

@ -64,7 +64,6 @@ For reference, heres a full list of view ops in PyTorch:
- :attr:`~torch.Tensor.real` - :attr:`~torch.Tensor.real`
- :attr:`~torch.Tensor.imag` - :attr:`~torch.Tensor.imag`
- :meth:`~torch.Tensor.view_as_real` - :meth:`~torch.Tensor.view_as_real`
- :meth:`~torch.Tensor.view_as_imag`
- :meth:`~torch.Tensor.unflatten` - :meth:`~torch.Tensor.unflatten`
- :meth:`~torch.Tensor.unfold` - :meth:`~torch.Tensor.unfold`
- :meth:`~torch.Tensor.unsqueeze` - :meth:`~torch.Tensor.unsqueeze`

View File

@ -276,6 +276,9 @@ Tensor class reference
Tensor.contiguous Tensor.contiguous
Tensor.copy_ Tensor.copy_
Tensor.conj Tensor.conj
Tensor.conj_physical
Tensor.conj_physical_
Tensor.resolve_conj
Tensor.copysign Tensor.copysign
Tensor.copysign_ Tensor.copysign_
Tensor.cos Tensor.cos
@ -410,6 +413,7 @@ Tensor class reference
Tensor.isnan Tensor.isnan
Tensor.is_contiguous Tensor.is_contiguous
Tensor.is_complex Tensor.is_complex
Tensor.is_conj
Tensor.is_floating_point Tensor.is_floating_point
Tensor.is_inference Tensor.is_inference
Tensor.is_leaf Tensor.is_leaf

View File

@ -19,6 +19,7 @@ Tensors
is_tensor is_tensor
is_storage is_storage
is_complex is_complex
is_conj
is_floating_point is_floating_point
is_nonzero is_nonzero
set_default_dtype set_default_dtype
@ -86,6 +87,7 @@ Indexing, Slicing, Joining, Mutating Ops
:nosignatures: :nosignatures:
cat cat
conj
chunk chunk
dsplit dsplit
column_stack column_stack
@ -293,7 +295,7 @@ Pointwise Ops
ceil ceil
clamp clamp
clip clip
conj conj_physical
copysign copysign
cos cos
cosh cosh
@ -511,6 +513,7 @@ Other Operations
vander vander
view_as_real view_as_real
view_as_complex view_as_complex
resolve_conj
BLAS and LAPACK Operations BLAS and LAPACK Operations

View File

@ -91,6 +91,9 @@ allow_list = [
("aten::linalg_vector_norm", datetime.date(2021, 5, 15)), ("aten::linalg_vector_norm", datetime.date(2021, 5, 15)),
("aten::repeat_interleave", datetime.date(2021, 6, 26)), ("aten::repeat_interleave", datetime.date(2021, 6, 26)),
("aten::one_hot", datetime.date(2021, 6, 15)), ("aten::one_hot", datetime.date(2021, 6, 15)),
("aten::conj", datetime.date(2021, 8, 1)),
("aten::_conj", datetime.date(2021, 8, 1)),
("aten::conj.out", datetime.date(2021, 8, 1)),
] ]
def allow_listed(schema, allow_list): def allow_listed(schema, allow_list):

View File

@ -4279,7 +4279,7 @@ class TestAutograd(TestCase):
gradcheck(bad_fn, (x, y), check_forward_ad=True, fast_mode=fast_mode) gradcheck(bad_fn, (x, y), check_forward_ad=True, fast_mode=fast_mode)
def basic_mul(x): def basic_mul(x):
return torch.view_as_real(x * 1j) return torch.view_as_real(torch.resolve_conj(x * 1j))
gradcheck(basic_mul, x, check_forward_ad=True, fast_mode=fast_mode) gradcheck(basic_mul, x, check_forward_ad=True, fast_mode=fast_mode)
# Test for one input and one output being complex # Test for one input and one output being complex
@ -5459,7 +5459,7 @@ complex_list = ['t', 'view', 'reshape', 'reshape_as', 'view_as', 'roll', 'clone'
'permute', 'squeeze', 'unsqueeze', 'resize', 'resize_as', 'tril', 'triu', 'permute', 'squeeze', 'unsqueeze', 'resize', 'resize_as', 'tril', 'triu',
'chunk', 'split', 'split_with_sizes', 'zero_', 'chunk', 'split', 'split_with_sizes', 'zero_',
'__radd__', 'mul', '__rmul__', 'diagonal', 'fill_', 'sub', 'narrow', '__radd__', 'mul', '__rmul__', 'diagonal', 'fill_', 'sub', 'narrow',
'swapaxes', 'swapdims', 'tensor_split'] + separate_complex_tests 'swapaxes', 'swapdims', 'tensor_split', 'select', 'clone'] + separate_complex_tests
# deny list for batched grad computation # deny list for batched grad computation
EXCLUDE_BATCHED_GRAD_TESTS = set([ EXCLUDE_BATCHED_GRAD_TESTS = set([
@ -5591,7 +5591,9 @@ def add_test(
output_variable = getattr(self_variable, name)(*args_variable, **kwargs_variable) output_variable = getattr(self_variable, name)(*args_variable, **kwargs_variable)
if not isinstance(output_variable, tuple): if not isinstance(output_variable, tuple):
output_variable = (output_variable,) output_variable = (output_variable,)
inplace_self_variable = deepcopy(self_variable) inplace_self_variable = deepcopy(self_variable)
self.assertEqual(inplace_self_variable, self_variable)
inplace_self_variable_copy = tuple(i.clone() if isinstance(i, torch.Tensor) else i inplace_self_variable_copy = tuple(i.clone() if isinstance(i, torch.Tensor) else i
for i in (inplace_self_variable,)) for i in (inplace_self_variable,))
inplace_args_variable = deepcopy(args_variable) inplace_args_variable = deepcopy(args_variable)

View File

@ -1946,6 +1946,7 @@ known_failures = [
# If your OpInfo test causes this test to fail, add it here # If your OpInfo test causes this test to fail, add it here
skip_ops = [ skip_ops = [
'conj'
] ]
def get_name(op): def get_name(op):

View File

@ -17,7 +17,7 @@ from torch.testing._internal.common_jit import JitCommonTestCase, check_against_
from torch.testing._internal.jit_metaprogramming_utils import create_script_fn, create_traced_fn, \ from torch.testing._internal.jit_metaprogramming_utils import create_script_fn, create_traced_fn, \
check_alias_annotation check_alias_annotation
from torch.testing._internal.jit_utils import disable_autodiff_subgraph_inlining from torch.testing._internal.jit_utils import disable_autodiff_subgraph_inlining
from collections.abc import Sequence
# Get names of all the operators which have entry in `method_tests` (legacy testing infra) # Get names of all the operators which have entry in `method_tests` (legacy testing infra)
method_tested_operators = set(map(lambda test_details: test_details[0], method_tests())) method_tested_operators = set(map(lambda test_details: test_details[0], method_tests()))
@ -110,7 +110,9 @@ class TestGradients(TestCase):
return variant.__wrapped__ is op.get_inplace() return variant.__wrapped__ is op.get_inplace()
return variant is op.get_inplace() return variant is op.get_inplace()
samples = op.sample_inputs(device, dtype, requires_grad=True) include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex
samples = op.sample_inputs(device, dtype, requires_grad=True, include_conjugated_inputs=include_conjugated_inputs)
for sample in samples: for sample in samples:
if sample.broadcasts_input and is_inplace(variant): if sample.broadcasts_input and is_inplace(variant):
continue continue
@ -280,7 +282,8 @@ class TestCommon(JitCommonTestCase):
_requires_grad = (op.supports_autograd and _requires_grad = (op.supports_autograd and
(dtype.is_floating_point or op.supports_complex_autograd(torch.device(device).type))) (dtype.is_floating_point or op.supports_complex_autograd(torch.device(device).type)))
samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad) include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex
samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad, include_conjugated_inputs=include_conjugated_inputs)
def _test_consistency_helper(samples, variants): def _test_consistency_helper(samples, variants):
for sample in samples: for sample in samples:
@ -381,7 +384,8 @@ class TestCommon(JitCommonTestCase):
_requires_grad = op.supports_autograd and (dtype.is_floating_point or _requires_grad = op.supports_autograd and (dtype.is_floating_point or
op.supports_complex_autograd(torch.device(device).type)) op.supports_complex_autograd(torch.device(device).type))
samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad) include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex
samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad, include_conjugated_inputs=include_conjugated_inputs)
for sample in samples: for sample in samples:
# Acquires variants to test # Acquires variants to test
@ -734,6 +738,103 @@ class TestCommon(JitCommonTestCase):
with self.assertRaises(RuntimeError, msg=msg_fail): with self.assertRaises(RuntimeError, msg=msg_fail):
op_out(out=out) op_out(out=out)
# Tests that
# 1. The operator's output for physically conjugated tensors and conjugate view tensors
# produces the same value
# 2. The gradients are same in both cases mentioned in (1)
# 3. If the operator's inplace variant is supported, tests that the inplace operation
# produces the correct value when called on a conjugate view tensor and that the output
# has its conj bit set to true
# This test only runs for C -> R and C -> C functions
# TODO: add tests for `R->C` functions
# Note: This test runs for functions that take both tensors and tensorlists as input.
@ops(op_db, allowed_dtypes=(torch.cfloat,))
def test_conj_view(self, device, dtype, op):
if not op.test_conjugated_samples:
self.skipTest("Operation doesn't support conjugated inputs.")
_requires_grad = (op.supports_autograd and op.supports_complex_autograd(torch.device(device).type))
samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad)
inplace_variant = op.inplace_variant
# helper function to physically conjugate the tensor
def conjugate_physical(input):
if isinstance(input, torch.Tensor):
tensor_requires_grad = input.requires_grad
with torch.no_grad():
input = input.conj_physical()
return input.requires_grad_(tensor_requires_grad)
if isinstance(input, Sequence):
out = list(map(clone_input_helper, input))
out[0] = conjugate_physical(out[0])
return tuple(out)
# helper function to clone and conjugate the input if its a tensor
# else clone the sequence and conjugate the first element in the sequence
# If a requires_grad argument is provided the tensor being conjugated will
# have its requires_grad set to that value.
def clone_conj_input_helper(input, **kwargs):
if isinstance(input, torch.Tensor):
requires_grad = kwargs.get('requires_grad', input.requires_grad)
with torch.no_grad():
input = input.clone()
# Note: .conj() is not called under no_grad mode since it's not allowed to modify a
# view created in no_grad mode. Here it's ok to do so, so as a workaround we call conj
# before resetting the requires_grad field for input
input = input.conj()
assert input.is_leaf
return input.requires_grad_(requires_grad)
if isinstance(input, Sequence):
out = list(map(clone_input_helper, input))
out[0] = clone_conj_input_helper(out[0])
return tuple(out)
for sample in samples:
tensor = sample.input if isinstance(sample.input, torch.Tensor) else sample.input[0]
cloned1 = clone_conj_input_helper(sample.input)
sample.input = conjugate_physical(sample.input)
# Computes function forward value with a physically conjugated tensor and
# a conj view tensor and verifies that the output in both case are equal.
expected_forward = op(sample.input, *sample.args, **sample.kwargs)
forward_with_conjview = op(cloned1, *sample.args, **sample.kwargs)
self.assertEqual(expected_forward, forward_with_conjview)
# If the op has an inplace variant, and the input doesn't require broadcasting
# and has the same dtype as output, verify that the inplace operation on a conjugated
# input produces correct output, and the output tensor has the conj bit set to True
if inplace_variant is not None and not sample.broadcasts_input:
cloned2 = clone_conj_input_helper(tensor, requires_grad=False)
if (isinstance(expected_forward, torch.Tensor) and
expected_forward.dtype is tensor.dtype):
inplace_forward = inplace_variant(cloned2, *sample.args, **sample.kwargs)
self.assertTrue(inplace_forward.is_conj())
self.assertEqual(inplace_forward, expected_forward)
# TODO: backward consistency only supported for single tensor outputs
# TODO: backward consistency only checked on sample.input, not all
# tensor inputs
# TODO: update to handle checking grads of all tensor inputs as
# derived from each tensor output
if isinstance(expected_forward, torch.Tensor) and expected_forward.requires_grad:
tensor = sample.input if isinstance(sample.input, torch.Tensor) else sample.input[0]
expected_forward.sum().backward(retain_graph=True)
forward_with_conjview.sum().backward(retain_graph=True)
if tensor.grad is not None:
cloned1_tensor = cloned1 if isinstance(cloned1, torch.Tensor) else cloned1[0]
self.assertEqual(tensor.grad, cloned1_tensor.grad)
tensor.grad, cloned1_tensor.grad = None, None
# a repeat of the above test if output is not complex valued
if (expected_forward.is_complex()):
grad = torch.randn_like(expected_forward)
expected_forward.backward(grad.conj_physical())
forward_with_conjview.backward(grad.conj())
self.assertEqual(tensor.grad, cloned1_tensor.grad)
instantiate_device_type_tests(TestOpInfo, globals()) instantiate_device_type_tests(TestOpInfo, globals())
instantiate_device_type_tests(TestGradients, globals()) instantiate_device_type_tests(TestGradients, globals())

View File

@ -346,6 +346,14 @@ class TestViewOps(TestCase):
self.assertEqual(a[5:].real, a.real[5:]) self.assertEqual(a[5:].real, a.real[5:])
self.assertEqual(a[5:].imag, a.imag[5:]) self.assertEqual(a[5:].imag, a.imag[5:])
@onlyOnCPUAndCUDA
@dtypes(*torch.testing.get_all_complex_dtypes())
def test_conj_view(self, device, dtype) -> None:
t = _make_tensor((4, 5,), dtype, device)
v = t.conj()
self.assertTrue(self.is_view_of(t, v))
self.assertEqual(v, torch.from_numpy(t.cpu().numpy().conj()).to(device=device))
@onlyOnCPUAndCUDA @onlyOnCPUAndCUDA
@dtypes(*product(torch.testing.get_all_complex_dtypes(), torch.testing.get_all_dtypes())) @dtypes(*product(torch.testing.get_all_complex_dtypes(), torch.testing.get_all_dtypes()))
@suppress_warnings @suppress_warnings

View File

@ -380,16 +380,23 @@
- name: complex(Tensor real, Tensor imag) -> Tensor - name: complex(Tensor real, Tensor imag) -> Tensor
real: at::real(grad) real: at::real(grad)
imag: at::imag(grad) imag: at::imag(grad.resolve_conj())
result: at::complex(real_t, imag_t) result: at::complex(real_t, imag_t)
- name: polar(Tensor abs, Tensor angle) -> Tensor - name: polar(Tensor abs, Tensor angle) -> Tensor
abs, angle: polar_backward(grad, result) abs, angle: polar_backward(grad, result)
- name: _conj(Tensor self) -> Tensor - name: _conj(Tensor(a) self) -> Tensor(a)
self: grad.conj() self: grad.conj()
result: self_t.conj() result: self_t.conj()
- name: _conj_physical(Tensor self) -> Tensor
self: grad.conj_physical()
result: self_t.conj_physical()
- name: conj_physical_(Tensor(a!) self) -> Tensor(a!)
self: grad.conj_physical()
- name: copysign.Tensor(Tensor self, Tensor other) -> Tensor - name: copysign.Tensor(Tensor self, Tensor other) -> Tensor
self: copysign_tensor_self_backward(grad, self, result) self: copysign_tensor_self_backward(grad, self, result)
other: zeros_like(other) other: zeros_like(other)
@ -1306,12 +1313,16 @@
- name: view.dtype(Tensor(a) self, ScalarType dtype) -> Tensor(a) - name: view.dtype(Tensor(a) self, ScalarType dtype) -> Tensor(a)
output_differentiability: [False] output_differentiability: [False]
# If the conj bit for self is set, this is effectively a fused conj() and view_as_real
- name: _view_as_real_physical(Tensor(a) self) -> Tensor(a)
self: "self.is_conj() ? at::view_as_complex(grad.contiguous()) : at::view_as_complex(grad.contiguous()).conj()"
- name: view_as_real(Tensor(a) self) -> Tensor(a) - name: view_as_real(Tensor(a) self) -> Tensor(a)
self: at::view_as_complex(grad.contiguous()) # gx0 + 1j * gx1 self: at::view_as_complex(grad.contiguous()) # gx0 + 1j * gx1
result: at::view_as_real(self_t) result: at::view_as_real(self_t)
- name: view_as_complex(Tensor(a) self) -> Tensor(a) - name: view_as_complex(Tensor(a) self) -> Tensor(a)
self: at::view_as_real(grad.contiguous()) # [gx, gy] self: at::view_as_real(grad.contiguous().resolve_conj()) # [gx, gy]
result: at::view_as_complex(self_t) result: at::view_as_complex(self_t)
- name: _s_where(Tensor condition, Tensor self, Tensor other) -> Tensor - name: _s_where(Tensor condition, Tensor self, Tensor other) -> Tensor

View File

@ -28,7 +28,7 @@ from .gen_trace_type import (
# #
# A map: function name => name of the argument that all outputs are view of # A map: function name => name of the argument that all outputs are view of
VIEW_FUNCTIONS_WITH_METADATA_CHANGE = ['view_as_real', 'view_as_complex'] VIEW_FUNCTIONS_WITH_METADATA_CHANGE = ['view_as_complex', '_view_as_real_physical', 'view_as_real', '_conj']
VIEW_FUNCTIONS = { VIEW_FUNCTIONS = {
'numpy_T': 'self', 'numpy_T': 'self',

View File

@ -100,7 +100,8 @@ GRADIENT_IMPLEMENTED_FOR_COMPLEX = {
'replication_pad1d_backward', 'replication_pad2d_backward', 'replication_pad3d_backward', 'replication_pad1d_backward', 'replication_pad2d_backward', 'replication_pad3d_backward',
'diag', 'masked_scatter', 'masked_select', 'index_fill', 'trace', 'polar', 'cumsum', 'rsub', 'diag', 'masked_scatter', 'masked_select', 'index_fill', 'trace', 'polar', 'cumsum', 'rsub',
'eig', 'lerp', 'linalg_vector_norm', 'cumprod', 'prod', 'index_copy', 'lu', 'unfold', 'unfold_backward', 'eig', 'lerp', 'linalg_vector_norm', 'cumprod', 'prod', 'index_copy', 'lu', 'unfold', 'unfold_backward',
'index', 'masked_fill', 'cross', 'lu_unpack', 'renorm', 'index', 'masked_fill', 'cross', 'lu_unpack', 'renorm', '_view_as_real_physical', '_conj_physical',
'conj_physical_'
} }
GRADIENT_IMPLEMENTED_FOR_SPARSE_COMPLEX = { GRADIENT_IMPLEMENTED_FOR_SPARSE_COMPLEX = {

View File

@ -10,7 +10,7 @@ import yaml
from tools.codegen.api.autograd import (Derivative, DifferentiabilityInfo, from tools.codegen.api.autograd import (Derivative, DifferentiabilityInfo,
SavedAttribute, ForwardDerivative) SavedAttribute, ForwardDerivative)
from tools.codegen.api.types import (Binding, CppSignatureGroup, NamedCType, BaseCType, VectorCType, from tools.codegen.api.types import (Binding, CppSignatureGroup, NamedCType, BaseCType, VectorCType,
intArrayRefT, tensorOptionsT, typeAndSizeT, intT, intArrayRefT, tensorOptionsT, typeAndSizeT, intT, boolT,
tensorGeometryT, scalarTypeT, SpecialArgName, tensorGeometryT, scalarTypeT, SpecialArgName,
OptionalCType, stringT) OptionalCType, stringT)
from tools.codegen.api import cpp from tools.codegen.api import cpp
@ -498,6 +498,11 @@ def saved_variables(
'nctype': lambda name: NamedCType(name, BaseCType(intArrayRefT)), 'nctype': lambda name: NamedCType(name, BaseCType(intArrayRefT)),
'expr': stride_expr, 'expr': stride_expr,
}), }),
# replace self.is_conj() with self_conjugate
(r'{}.is_conj\(\)', {
'suffix': '_conjugate',
'nctype': lambda name: NamedCType(name, BaseCType(boolT)),
})
] ]
# find which arguments need to be saved # find which arguments need to be saved

View File

@ -86,8 +86,10 @@ class Tensor(torch._C._TensorBase):
self.requires_grad, self.requires_grad,
self._backward_hooks) self._backward_hooks)
else: else:
new_tensor = self.new() new_tensor = self.new_empty([])
new_tensor.set_(new_storage, self.storage_offset(), self.size(), self.stride()) new_tensor.set_(new_storage, self.storage_offset(), self.size(), self.stride())
if self.is_conj():
new_tensor = new_tensor.conj_physical()
new_tensor.requires_grad = self.requires_grad new_tensor.requires_grad = self.requires_grad
if self.grad is not None: if self.grad is not None:
new_tensor.grad = self.grad.__deepcopy__(memo) new_tensor.grad = self.grad.__deepcopy__(memo)

View File

@ -893,6 +893,27 @@ conj() -> Tensor
See :func:`torch.conj` See :func:`torch.conj`
""") """)
add_docstr_all('conj_physical',
r"""
conj_physical() -> Tensor
See :func:`torch.conj_physical`
""")
add_docstr_all('conj_physical_',
r"""
conj_physical_() -> Tensor
In-place version of :meth:`~Tensor.conj_physical`
""")
add_docstr_all('resolve_conj',
r"""
resolve_conj() -> Tensor
See :func:`torch.resolve_conj`
""")
add_docstr_all('copysign', add_docstr_all('copysign',
r""" r"""
copysign(other) -> Tensor copysign(other) -> Tensor
@ -1977,6 +1998,13 @@ is_inference() -> bool
See :func:`torch.is_inference` See :func:`torch.is_inference`
""") """)
add_docstr_all('is_conj',
r"""
is_conj() -> bool
Returns True if the conjugate bit of :attr:`self` is set to true.
""")
add_docstr_all('is_signed', add_docstr_all('is_signed',
r""" r"""
is_signed() -> bool is_signed() -> bool

View File

@ -153,11 +153,12 @@ class _Formatter(object):
def _scalar_str(self, formatter1, formatter2=None): def _scalar_str(self, formatter1, formatter2=None):
if formatter2 is not None: if formatter2 is not None:
real_str = _scalar_str(self.real, formatter1) real_str = _scalar_str(self.real, formatter1)
imag_str = _scalar_str(self.imag, formatter2) + "j" imag_str = (_scalar_str(self.imag, formatter2) + "j").lstrip()
if self.imag < 0: # handles negative numbers, +0.0, -0.0
return real_str + imag_str.lstrip() if imag_str[0] == '+' or imag_str[0] == '-':
return real_str + imag_str
else: else:
return real_str + "+" + imag_str.lstrip() return real_str + "+" + imag_str
else: else:
return formatter1.format(self.item()) return formatter1.format(self.item())
@ -174,11 +175,12 @@ def _vector_str(self, indent, summarize, formatter1, formatter2=None):
def _val_formatter(val, formatter1=formatter1, formatter2=formatter2): def _val_formatter(val, formatter1=formatter1, formatter2=formatter2):
if formatter2 is not None: if formatter2 is not None:
real_str = formatter1.format(val.real) real_str = formatter1.format(val.real)
imag_str = formatter2.format(val.imag) + "j" imag_str = (formatter2.format(val.imag) + "j").lstrip()
if val.imag < 0: # handles negative numbers, +0.0, -0.0
return real_str + imag_str.lstrip() if imag_str[0] == '+' or imag_str[0] == '-':
return real_str + imag_str
else: else:
return real_str + "+" + imag_str.lstrip() return real_str + "+" + imag_str
else: else:
return formatter1.format(val) return formatter1.format(val)
@ -235,6 +237,7 @@ def _tensor_str(self, indent):
self = self.float() self = self.float()
if self.dtype.is_complex: if self.dtype.is_complex:
self = self.resolve_conj()
real_formatter = _Formatter(get_summarized_data(self.real) if summarize else self.real) real_formatter = _Formatter(get_summarized_data(self.real) if summarize else self.real)
imag_formatter = _Formatter(get_summarized_data(self.imag) if summarize else self.imag) imag_formatter = _Formatter(get_summarized_data(self.imag) if summarize else self.imag)
return _tensor_str_with_formatter(self, indent, summarize, real_formatter, imag_formatter) return _tensor_str_with_formatter(self, indent, summarize, real_formatter, imag_formatter)

View File

@ -2181,15 +2181,18 @@ Example::
tensor([(0.0000+1.0000j), (-1.4142-1.4142j)], dtype=torch.complex128) tensor([(0.0000+1.0000j), (-1.4142-1.4142j)], dtype=torch.complex128)
""") """)
add_docstr(torch.conj, add_docstr(torch.conj_physical,
r""" r"""
conj(input, *, out=None) -> Tensor conj_physical(input, *, out=None) -> Tensor
Computes the element-wise conjugate of the given :attr:`input` tensor. If :attr:`input` has a non-complex dtype, Computes the element-wise conjugate of the given :attr:`input` tensor.
this function just returns :attr:`input`. If :attr:`input` has a non-complex dtype, this function just returns :attr:`input`.
.. warning:: In the future, :func:`torch.conj` may return a non-writeable view for an :attr:`input` of .. note::
non-complex dtype. It's recommended that programs not modify the tensor returned by :func:`torch.conj` This performs the conjugate operation regardless of the fact conjugate bit is set or not.
.. warning:: In the future, :func:`torch.conj_physical` may return a non-writeable view for an :attr:`input` of
non-complex dtype. It's recommended that programs not modify the tensor returned by :func:`torch.conj_physical`
when :attr:`input` is of non-complex dtype to be compatible with this change. when :attr:`input` is of non-complex dtype to be compatible with this change.
.. math:: .. math::
@ -2203,10 +2206,63 @@ Keyword args:
Example:: Example::
>>> torch.conj(torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j])) >>> torch.conj_physical(torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j]))
tensor([-1 - 1j, -2 - 2j, 3 + 3j]) tensor([-1 - 1j, -2 - 2j, 3 + 3j])
""".format(**common_args)) """.format(**common_args))
add_docstr(torch.conj,
r"""
conj(input) -> Tensor
Returns a view of :attr:`input` with a flipped conjugate bit. If :attr:`input` has a non-complex dtype,
this function just returns :attr:`input`.
.. note::
:func:`torch.conj` performs a lazy conjugation, but the actual conjugated tensor can be materialized
at any time using :func:`torch.resolve_conj`.
.. warning:: In the future, :func:`torch.conj` may return a non-writeable view for an :attr:`input` of
non-complex dtype. It's recommended that programs not modify the tensor returned by :func:`torch.conj_physical`
when :attr:`input` is of non-complex dtype to be compatible with this change.
Args:
{input}
Example::
>>> x = torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j])
>>> x.is_conj()
False
>>> y = torch.conj(x)
>>> y.is_conj()
True
""")
add_docstr(torch.resolve_conj,
r"""
resolve_conj(input) -> Tensor
Returns a new tensor with materialized conjugation if :attr:`input`'s conjugate bit is set to `True`,
else returns :attr:`input`. The output tensor will always have its conjugate bit set to `False`.
Args:
{input}
Example::
>>> x = torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j])
>>> y = x.conj()
>>> y.is_conj()
True
>>> z = y.resolve_conj()
>>> z
tensor([-1 - 1j, -2 - 2j, 3 + 3j])
>>> z.is_conj()
False
""")
add_docstr(torch.copysign, add_docstr(torch.copysign,
r""" r"""
copysign(input, other, *, out=None) -> Tensor copysign(input, other, *, out=None) -> Tensor
@ -4201,6 +4257,15 @@ Args:
{input} {input}
""".format(**common_args)) """.format(**common_args))
add_docstr(torch.is_conj, r"""
is_conj(input) -> (bool)
Returns True if the :attr:`input` is a conjugated tensor, i.e. its conjugate bit is set to `True`.
Args:
{input}
""".format(**common_args))
add_docstr(torch.is_nonzero, r""" add_docstr(torch.is_nonzero, r"""
is_nonzero(input) -> (bool) is_nonzero(input) -> (bool)

View File

@ -545,7 +545,7 @@ def _get_analytical_vJu_backward_mode(inputs, outputs, nondet_tol, check_grad_dt
# the error checking logic from slow mode # the error checking logic from slow mode
vJ = vJ.T.squeeze(0) vJ = vJ.T.squeeze(0)
if vJ.is_complex(): # C -> R if vJ.is_complex(): # C -> R
tv = torch.view_as_real(vJ) tv = torch.view_as_real(vJ.resolve_conj())
tr = tv.select(-1, 0) tr = tv.select(-1, 0)
ti = tv.select(-1, 1) ti = tv.select(-1, 1)
jacobian_scalars.append(tr.dot(u[0]) + 1j * ti.dot(u[1])) jacobian_scalars.append(tr.dot(u[0]) + 1j * ti.dot(u[1]))
@ -894,7 +894,11 @@ def _real_and_imag_output(fn):
outs = _as_tuple(fn(*inputs)) outs = _as_tuple(fn(*inputs))
return tuple(fn_to_apply(o) if o.is_complex() else o for o in outs) return tuple(fn_to_apply(o) if o.is_complex() else o for o in outs)
return wrapped_fn return wrapped_fn
return apply_to_c_outs(fn, torch.real), apply_to_c_outs(fn, torch.imag)
# TODO(@anjali411): remove this workaround once neg bit is added.
def torch_imag(x):
return x.resolve_conj().imag
return apply_to_c_outs(fn, torch.real), apply_to_c_outs(fn, torch_imag)
def _real_and_imag_input(fn, complex_inp_indices): def _real_and_imag_input(fn, complex_inp_indices):
# returns new functions that take real inputs instead of complex inputs and compute fn(x + 0 * 1j) # returns new functions that take real inputs instead of complex inputs and compute fn(x + 0 * 1j)
@ -935,7 +939,7 @@ def _gradcheck_real_imag(gradcheck_fn, func, func_out, tupled_inputs, outputs, e
if complex_inp_indices: if complex_inp_indices:
real_fn, imag_fn = _real_and_imag_input(func, complex_inp_indices) real_fn, imag_fn = _real_and_imag_input(func, complex_inp_indices)
imag_inputs = [inp.imag if is_tensor_like(inp) and inp.is_complex() else inp for inp in tupled_inputs] imag_inputs = [inp.resolve_conj().imag if is_tensor_like(inp) and inp.is_complex() else inp for inp in tupled_inputs]
imag_func_out = imag_fn(*imag_inputs) imag_func_out = imag_fn(*imag_inputs)
diff_imag_func_out = _differentiable_outputs(imag_func_out) diff_imag_func_out = _differentiable_outputs(imag_func_out)
gradcheck_fn(imag_fn, imag_func_out, imag_inputs, diff_imag_func_out, eps, gradcheck_fn(imag_fn, imag_func_out, imag_inputs, diff_imag_func_out, eps,

View File

@ -248,7 +248,7 @@ Tensor pow_backward_self(Tensor grad, const Tensor & self, const Tensor & expone
Tensor pow_backward_exponent(Tensor grad, const Tensor& self, const Tensor& exponent, Tensor result) { Tensor pow_backward_exponent(Tensor grad, const Tensor& self, const Tensor& exponent, Tensor result) {
Tensor cond; Tensor cond;
if (exponent.is_complex()) { if (exponent.is_complex()) {
auto is_real_exp = at::logical_and(at::imag(exponent) == 0, at::real(exponent) >= 0); auto is_real_exp = at::logical_and(at::imag(exponent.resolve_conj()) == 0, at::real(exponent) >= 0);
cond = at::logical_and(self == 0, is_real_exp); cond = at::logical_and(self == 0, is_real_exp);
} else { } else {
cond = at::logical_and(self == 0, exponent >= 0); cond = at::logical_and(self == 0, exponent >= 0);
@ -264,7 +264,7 @@ Tensor pow_backward_exponent(Tensor grad, const Scalar & base, const Tensor& exp
if (base.equal(0.0)) { if (base.equal(0.0)) {
auto cond = [](auto exp) { auto cond = [](auto exp) {
if (exp.is_complex()) { if (exp.is_complex()) {
return at::logical_and(at::imag(exp) == 0, at::real(exp) >= 0); return at::logical_and(at::imag(exp.resolve_conj()) == 0, at::real(exp) >= 0);
} else { } else {
return exp >=0; return exp >=0;
} }
@ -2154,7 +2154,7 @@ Tensor svd_backward(const std::vector<torch::autograd::Variable> &grads, const T
if (self.is_complex() && gu.defined()) { if (self.is_complex() && gu.defined()) {
Tensor L = at::matmul(uh, gu).diagonal(0, -2, -1); Tensor L = at::matmul(uh, gu).diagonal(0, -2, -1);
at::real(L).zero_(); at::real(L).zero_();
at::imag(L).mul_(sigma_inv); at::imag(L.resolve_conj()).mul_(sigma_inv);
Tensor imag_term = at::matmul(u * L.unsqueeze(-2), vh); Tensor imag_term = at::matmul(u * L.unsqueeze(-2), vh);
return u_term + sigma_term + v_term + imag_term; return u_term + sigma_term + v_term + imag_term;
} }
@ -2195,7 +2195,7 @@ Tensor eig_backward(const std::vector<torch::autograd::Variable> &grads, const T
} }
// path for torch.linalg.eig with always a complex tensor of eigenvalues // path for torch.linalg.eig with always a complex tensor of eigenvalues
else { else {
is_imag_eigvals_zero = (at::imag(D) == 0.0).min().item<bool>(); is_imag_eigvals_zero = (at::imag(D.resolve_conj()) == 0.0).min().item<bool>();
// insert an additional dimension to be compatible with torch.eig. // insert an additional dimension to be compatible with torch.eig.
// Recall that it produces 2D tensors. // Recall that it produces 2D tensors.
// We extract only the real parts as there is no support for // We extract only the real parts as there is no support for

View File

@ -83,23 +83,6 @@ void nnc_aten_sgn(
} catch (...) { } catch (...) {
} }
} }
void nnc_aten_conj(
int64_t bufs_num,
void** buf_data,
int64_t* buf_ranks,
int64_t* buf_dims,
int8_t* buf_dtypes,
int64_t args_num,
int64_t* extra_args) {
std::vector<at::Tensor> tensors =
constructTensors(bufs_num, buf_data, buf_ranks, buf_dims, buf_dtypes);
at::Tensor& r = tensors[0];
const at::Tensor& self = tensors[1];
try {
at::conj_out(r, self);
} catch (...) {
}
}
void nnc_aten_acos( void nnc_aten_acos(
int64_t bufs_num, int64_t bufs_num,
void** buf_data, void** buf_data,
@ -2758,9 +2741,6 @@ const static RegisterNNCExternalFunction nnc_angle(
"nnc_aten_angle", "nnc_aten_angle",
nnc_aten_angle); nnc_aten_angle);
const static RegisterNNCExternalFunction nnc_sgn("nnc_aten_sgn", nnc_aten_sgn); const static RegisterNNCExternalFunction nnc_sgn("nnc_aten_sgn", nnc_aten_sgn);
const static RegisterNNCExternalFunction nnc_conj(
"nnc_aten_conj",
nnc_aten_conj);
const static RegisterNNCExternalFunction nnc_acos( const static RegisterNNCExternalFunction nnc_acos(
"nnc_aten_acos", "nnc_aten_acos",
nnc_aten_acos); nnc_aten_acos);

View File

@ -228,6 +228,8 @@ def get_ignored_functions() -> Set[Callable]:
Tensor.to_sparse_csr, Tensor.to_sparse_csr,
Tensor._reduce_ex_internal, Tensor._reduce_ex_internal,
Tensor._fix_weakref, Tensor._fix_weakref,
Tensor._conj,
Tensor._conj_physical,
} }
@ -349,6 +351,8 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.polar: lambda abs, ang: -1, torch.polar: lambda abs, ang: -1,
torch.linalg.cond: lambda input, ord=None: -1, torch.linalg.cond: lambda input, ord=None: -1,
torch.conj: lambda input, out=None: -1, torch.conj: lambda input, out=None: -1,
torch.conj_physical: lambda input, out=None: -1,
torch.resolve_conj: lambda input, out=None: -1,
torch.constant_pad_nd: lambda input, pad, value=0: -1, torch.constant_pad_nd: lambda input, pad, value=0: -1,
torch.conv1d: lambda input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1: -1, torch.conv1d: lambda input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1: -1,
torch.conv2d: lambda input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1: -1, torch.conv2d: lambda input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1: -1,
@ -500,6 +504,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.linalg.inv: lambda input, out=None: -1, torch.linalg.inv: lambda input, out=None: -1,
torch.linalg.inv_ex: lambda input, check_errors=False, out=None: -1, torch.linalg.inv_ex: lambda input, check_errors=False, out=None: -1,
torch.is_complex: lambda input: -1, torch.is_complex: lambda input: -1,
torch.is_conj: lambda input: -1,
torch.is_distributed: lambda input: -1, torch.is_distributed: lambda input: -1,
torch.is_inference: lambda input: -1, torch.is_inference: lambda input: -1,
torch.is_floating_point: lambda input: -1, torch.is_floating_point: lambda input: -1,
@ -797,6 +802,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.real: lambda input, out=None: -1, torch.real: lambda input, out=None: -1,
torch.vdot: lambda input, other, out=None: -1, torch.vdot: lambda input, other, out=None: -1,
torch.view_as_real: lambda input: -1, torch.view_as_real: lambda input: -1,
torch._view_as_real_physical: lambda input: -1,
torch.view_as_complex: lambda input: -1, torch.view_as_complex: lambda input: -1,
torch.reciprocal: lambda input, out=None: -1, torch.reciprocal: lambda input, out=None: -1,
torch.relu: lambda input, inplace=False: -1, torch.relu: lambda input, inplace=False: -1,

View File

@ -134,6 +134,8 @@ def _compare_tensors_internal(a: torch.Tensor, b: torch.Tensor, *, rtol, atol, e
# Compares complex tensors' real and imaginary parts separately. # Compares complex tensors' real and imaginary parts separately.
# (see NOTE Test Framework Tensor "Equality") # (see NOTE Test Framework Tensor "Equality")
if a.is_complex(): if a.is_complex():
a = a.resolve_conj()
b = b.resolve_conj()
if equal_nan == "relaxed": if equal_nan == "relaxed":
a = a.clone() a = a.clone()
b = b.clone() b = b.clone()

View File

@ -10,7 +10,6 @@ import random
import torch import torch
import numpy as np import numpy as np
from torch._six import inf from torch._six import inf
from torch.autograd import Variable
import collections.abc import collections.abc
from typing import List, Sequence, Tuple, Dict, Any, Union from typing import List, Sequence, Tuple, Dict, Any, Union
@ -231,6 +230,7 @@ class OpInfo(object):
# function around gradcheck (testing._internal.common_utils.gradcheck) # function around gradcheck (testing._internal.common_utils.gradcheck)
inplace_variant=_NOTHING, # explicitly pass the inplace variant of the operator if required inplace_variant=_NOTHING, # explicitly pass the inplace variant of the operator if required
method_variant=_NOTHING, # explicitly pass the method variant of the operator if required method_variant=_NOTHING, # explicitly pass the method variant of the operator if required
test_conjugated_samples=True,
): ):
# Validates the dtypes are generated from the dispatch-related functions # Validates the dtypes are generated from the dispatch-related functions
@ -300,6 +300,8 @@ class OpInfo(object):
if aliases is not None: if aliases is not None:
self.aliases = tuple(AliasInfo(a) for a in aliases) # type: ignore[assignment] self.aliases = tuple(AliasInfo(a) for a in aliases) # type: ignore[assignment]
self.test_conjugated_samples = test_conjugated_samples
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
"""Calls the function variant of the operator.""" """Calls the function variant of the operator."""
return self.op(*args, **kwargs) return self.op(*args, **kwargs)
@ -326,6 +328,37 @@ class OpInfo(object):
""" """
return self.operator_variant return self.operator_variant
def conjugate_sample_inputs(self, device, dtype, requires_grad=False, **kwargs):
"""Returns an iterable of SampleInputs but with the tensor input or first
tensor in a sequence input conjugated.
"""
# TODO: Remove the try/except once all operators have sample_inputs_func with
# **kwargs in their signature.
try:
samples = self.sample_inputs_func(self, device, dtype, requires_grad, **kwargs)
except TypeError:
samples = self.sample_inputs_func(self, device, dtype, requires_grad)
conj_samples = list(samples)
def conjugate(tensor):
_requires_grad = tensor.requires_grad
with torch.no_grad():
tensor = tensor.conj()
return tensor.requires_grad_(_requires_grad)
for i in range(len(samples)):
sample = conj_samples[i]
# Note: it is assumed that the input here is either a tensor or tensorlist
if isinstance(sample.input, torch.Tensor):
sample.input = conjugate(sample.input)
else:
with torch.no_grad():
sample.input[0] = conjugate(sample.input[0])
return tuple(conj_samples)
def sample_inputs(self, device, dtype, requires_grad=False, **kwargs): def sample_inputs(self, device, dtype, requires_grad=False, **kwargs):
"""Returns an iterable of SampleInputs. """Returns an iterable of SampleInputs.
@ -339,6 +372,13 @@ class OpInfo(object):
samples = self.sample_inputs_func(self, device, dtype, requires_grad, **kwargs) samples = self.sample_inputs_func(self, device, dtype, requires_grad, **kwargs)
except TypeError: except TypeError:
samples = self.sample_inputs_func(self, device, dtype, requires_grad) samples = self.sample_inputs_func(self, device, dtype, requires_grad)
if 'include_conjugated_inputs' in kwargs and kwargs.get('include_conjugated_inputs'):
conj_samples = self.conjugate_sample_inputs(device, dtype, requires_grad, **kwargs)
samples_list = list(samples)
samples_list.extend(conj_samples)
samples = tuple(samples_list)
return samples return samples
# Returns True if the test should be skipped and False otherwise # Returns True if the test should be skipped and False otherwise
@ -1001,21 +1041,21 @@ def sample_inputs_addbmm(op_info, device, dtype, requires_grad, **kwargs):
return tuple(sample_inputs) return tuple(sample_inputs)
def sample_inputs_addcmul_addcdiv(op_info, device, dtype, requires_grad, **kwargs): def sample_inputs_addcmul_addcdiv(op_info, device, dtype, requires_grad, **kwargs):
test_cases = [((S, S), (S, S), (S, S)), test_cases = [(((S, S), (S, S), (S, S)), False),
((S, S), (S, 1), (1, S)), (((S, S), (S, 1), (1, S)), False),
((1,), (S, S, 1), (1, S)), (((1,), (S, S, 1), (1, S)), True),
((), (), ()), (((), (), ()), False),
((S, S), (), ()), (((S, S), (), ()), True),
((), (S, S, 1), (1, S)), (((), (S, S, 1), (1, S)), True)
] ]
sample_inputs = [] sample_inputs = []
for input_args in test_cases: for input_args, broadcasts_input in test_cases:
args = tuple(make_tensor(arg, device, dtype, requires_grad=requires_grad) if isinstance(arg, tuple) else arg args = tuple(make_tensor(arg, device, dtype, requires_grad=requires_grad) if isinstance(arg, tuple) else arg
for arg in input_args) for arg in input_args)
sample_inputs.append(SampleInput(args[0], args=args[1:])) sample_inputs.append(SampleInput(args[0], args=args[1:], broadcasts_input=broadcasts_input))
sample_inputs.append(SampleInput(args[0], args=args[1:], kwargs=dict(value=3.14))) sample_inputs.append(SampleInput(args[0], args=args[1:], kwargs=dict(value=3.14), broadcasts_input=broadcasts_input))
return tuple(sample_inputs) return tuple(sample_inputs)
@ -1057,7 +1097,8 @@ def sample_inputs_addr(op_info, device, dtype, requires_grad, **kwargs):
make_tensor((), device, dtype, low=None, high=None, requires_grad=requires_grad), make_tensor((), device, dtype, low=None, high=None, requires_grad=requires_grad),
args=( args=(
make_tensor((S, ), device, dtype, low=None, high=None, requires_grad=requires_grad), make_tensor((S, ), device, dtype, low=None, high=None, requires_grad=requires_grad),
make_tensor((M, ), device, dtype, low=None, high=None, requires_grad=requires_grad))) make_tensor((M, ), device, dtype, low=None, high=None, requires_grad=requires_grad)),
broadcasts_input=True)
if dtype.is_complex: if dtype.is_complex:
alpha, beta = 0.1 + 0.3j, 0.4 + 0.6j alpha, beta = 0.1 + 0.3j, 0.4 + 0.6j
@ -1078,7 +1119,8 @@ def sample_inputs_addr(op_info, device, dtype, requires_grad, **kwargs):
args=( args=(
make_tensor((S, ), device, dtype, low=None, high=None, requires_grad=requires_grad), make_tensor((S, ), device, dtype, low=None, high=None, requires_grad=requires_grad),
make_tensor((M, ), device, dtype, low=None, high=None, requires_grad=requires_grad)), make_tensor((M, ), device, dtype, low=None, high=None, requires_grad=requires_grad)),
kwargs=dict(beta=beta, alpha=alpha)) kwargs=dict(beta=beta, alpha=alpha),
broadcasts_input=True)
return (input1, input2, input3, input4) return (input1, input2, input3, input4)
@ -2196,6 +2238,32 @@ def np_unary_ufunc_integer_promotion_wrapper(fn):
return wrapped_fn return wrapped_fn
def sample_inputs_spectral_ops(self, device, dtype, requires_grad=False, **kwargs):
nd_tensor = make_tensor((S, S + 1, S + 2), device, dtype, low=None, high=None,
requires_grad=requires_grad)
tensor = make_tensor((31,), device, dtype, low=None, high=None,
requires_grad=requires_grad)
if self.ndimensional:
return [
SampleInput(nd_tensor, kwargs=dict(s=(3, 10), dim=(1, 2), norm='ortho')),
SampleInput(nd_tensor, kwargs=dict(norm='ortho')),
SampleInput(nd_tensor, kwargs=dict(s=(8,))),
SampleInput(tensor),
*(SampleInput(nd_tensor, kwargs=dict(dim=dim))
for dim in [-1, -2, -3, (0, -1)]),
]
else:
return [
SampleInput(nd_tensor, kwargs=dict(n=10, dim=1, norm='ortho')),
SampleInput(nd_tensor, kwargs=dict(norm='ortho')),
SampleInput(nd_tensor, kwargs=dict(n=7)),
SampleInput(tensor),
*(SampleInput(nd_tensor, kwargs=dict(dim=dim))
for dim in [-1, -2, -3]),
]
# Metadata class for Fast Fourier Transforms in torch.fft. # Metadata class for Fast Fourier Transforms in torch.fft.
class SpectralFuncInfo(OpInfo): class SpectralFuncInfo(OpInfo):
@ -2207,6 +2275,7 @@ class SpectralFuncInfo(OpInfo):
ref=None, # Reference implementation (probably in np.fft namespace) ref=None, # Reference implementation (probably in np.fft namespace)
dtypes=floating_and_complex_types(), dtypes=floating_and_complex_types(),
ndimensional: bool, # Whether dim argument can be a tuple ndimensional: bool, # Whether dim argument can be a tuple
sample_inputs_func=sample_inputs_spectral_ops,
decorators=None, decorators=None,
**kwargs): **kwargs):
decorators = list(decorators) if decorators is not None else [] decorators = list(decorators) if decorators is not None else []
@ -2220,39 +2289,12 @@ class SpectralFuncInfo(OpInfo):
super().__init__(name=name, super().__init__(name=name,
dtypes=dtypes, dtypes=dtypes,
decorators=decorators, decorators=decorators,
sample_inputs_func=sample_inputs_func,
**kwargs) **kwargs)
self.ref = ref if ref is not None else _getattr_qual(np, name) self.ref = ref if ref is not None else _getattr_qual(np, name)
self.ndimensional = ndimensional self.ndimensional = ndimensional
def sample_inputs(self, device, dtype, requires_grad=False, **kwargs):
nd_tensor = make_tensor((S, S + 1, S + 2), device, dtype, low=None, high=None,
requires_grad=requires_grad)
tensor = make_tensor((31,), device, dtype, low=None, high=None,
requires_grad=requires_grad)
if self.ndimensional:
return [
SampleInput(nd_tensor, kwargs=dict(s=(3, 10), dim=(1, 2), norm='ortho')),
SampleInput(nd_tensor, kwargs=dict(norm='ortho')),
SampleInput(nd_tensor, kwargs=dict(s=(8,))),
SampleInput(tensor),
*(SampleInput(nd_tensor, kwargs=dict(dim=dim))
for dim in [-1, -2, -3, (0, -1)]),
]
else:
return [
SampleInput(nd_tensor, kwargs=dict(n=10, dim=1, norm='ortho')),
SampleInput(nd_tensor, kwargs=dict(norm='ortho')),
SampleInput(nd_tensor, kwargs=dict(n=7)),
SampleInput(tensor),
*(SampleInput(nd_tensor, kwargs=dict(dim=dim))
for dim in [-1, -2, -3]),
]
class ShapeFuncInfo(OpInfo): class ShapeFuncInfo(OpInfo):
"""Early version of a specialized OpInfo for Shape manipulating operations like tile and roll""" """Early version of a specialized OpInfo for Shape manipulating operations like tile and roll"""
def __init__(self, def __init__(self,
@ -2737,13 +2779,13 @@ def sample_inputs_pow(op_info, device, dtype, requires_grad, **kwargs):
test_cases = ( test_cases = (
((2, 2), 0, 5, 1e-3, requires_grad, (2, 2), 0, 1, 0.1, requires_grad, False), ((2, 2), 0, 5, 1e-3, requires_grad, (2, 2), 0, 1, 0.1, requires_grad, False),
((2, 2), 0, 5, 1e-3, requires_grad, (1,), 0, 1, 0.1, requires_grad, False), ((2, 2), 0, 5, 1e-3, requires_grad, (1,), 0, 1, 0.1, requires_grad, False),
((), 1e-3, 1e-3 + 1, 0, True, (), 0.1, 1.1, 0, False, False), ((), 1e-3, 1e-3 + 1, 0, requires_grad, (), 0.1, 1.1, 0, False, False),
((2, 2), 0, 5, 1e-3, requires_grad, (), 0.1, 1.1, 1, False, False), ((2, 2), 0, 5, 1e-3, requires_grad, (), 0.1, 1.1, 1, False, False),
) )
tests_require_resizing = ( tests_require_resizing = (
((1,), 0, 5, 1e-3, requires_grad, (2, 2), 0, 1, 0.1, requires_grad, True), ((1,), 0, 5, 1e-3, requires_grad, (2, 2), 0, 1, 0.1, requires_grad, requires_grad),
((2, 1, 2), 0, 5, 1e-3, requires_grad, (1, 2, 1), 0, 1, 0.1, requires_grad, True), ((2, 1, 2), 0, 5, 1e-3, requires_grad, (1, 2, 1), 0, 1, 0.1, requires_grad, requires_grad),
((), 1e-3, 1e-3 + 1, 0, True, (1, S, 1), 0, 1, 0.1, requires_grad, True), ((), 1e-3, 1e-3 + 1, 0, requires_grad, (1, S, 1), 0, 1, 0.1, requires_grad, requires_grad),
) )
cases = test_cases + tests_require_resizing cases = test_cases + tests_require_resizing
samples = list(SampleInput(make_tensor(shape_b, low=low_b, high=high_b, samples = list(SampleInput(make_tensor(shape_b, low=low_b, high=high_b,
@ -2757,7 +2799,7 @@ def sample_inputs_pow(op_info, device, dtype, requires_grad, **kwargs):
high_e, additive_e, e_grad, broadcasts_input in cases) high_e, additive_e, e_grad, broadcasts_input in cases)
tensor_scalar_inputs = ( tensor_scalar_inputs = (
((2, 2), 0, 5, 1e-3, requires_grad, (3.14,)), ((2, 2), 0, 5, 1e-3, requires_grad, (3.14,)),
((), 1e-3, 1e-3 + 1, 0, True, (3.14,)) ((), 1e-3, 1e-3 + 1, 0, requires_grad, (3.14,))
) )
more_samples = list(SampleInput(make_tensor(shape, dtype=dtype, device=device, more_samples = list(SampleInput(make_tensor(shape, dtype=dtype, device=device,
high=high, low=low, high=high, low=low,
@ -2768,8 +2810,8 @@ def sample_inputs_pow(op_info, device, dtype, requires_grad, **kwargs):
elif dtype in [torch.complex64, torch.complex128]: elif dtype in [torch.complex64, torch.complex128]:
args_tuple = ( args_tuple = (
((2, 2), 0, 5, requires_grad, (3.14,)), ((2, 2), 0, 5, requires_grad, (3.14,)),
((), 0, 1, True, (3.14,)), ((), 0, 1, requires_grad, (3.14,)),
((), 0, 1, True, (3.14j,)) ((), 0, 1, requires_grad, (3.14j,))
) )
samples = list(SampleInput(make_tensor(shape, dtype=dtype, device=device, samples = list(SampleInput(make_tensor(shape, dtype=dtype, device=device,
high=high, low=low, high=high, low=low,
@ -4298,7 +4340,7 @@ op_db: List[OpInfo] = [
SkipInfo('TestGradients', 'test_method_grad', SkipInfo('TestGradients', 'test_method_grad',
device_type='cuda', dtypes=[torch.cdouble], active_if=IS_WINDOWS), device_type='cuda', dtypes=[torch.cdouble], active_if=IS_WINDOWS),
SkipInfo('TestGradients', 'test_forward_mode_AD', SkipInfo('TestGradients', 'test_forward_mode_AD',
dtypes=[torch.cdouble], active_if=IS_WINDOWS), dtypes=[torch.cdouble]),
)), )),
OpInfo('add', OpInfo('add',
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
@ -4670,28 +4712,30 @@ op_db: List[OpInfo] = [
supports_forward_ad=True, supports_forward_ad=True,
), ),
UnaryUfuncInfo('conj', UnaryUfuncInfo('conj',
ref=np.conj,
dtypes=all_types_and_complex_and(torch.bool,
torch.bfloat16, torch.half),
supports_forward_ad=True,
supports_out=False),
UnaryUfuncInfo('conj_physical',
ref=np.conj, ref=np.conj,
dtypes=all_types_and_complex_and(torch.bool, dtypes=all_types_and_complex_and(torch.bool,
torch.bfloat16, torch.half), torch.bfloat16, torch.half),
supports_forward_ad=True, supports_forward_ad=True,
skips=( skips=(
# File "test_unary_ufuncs.py", line 289, in test_reference_numerics SkipInfo('TestCommon', 'test_variant_consistency_jit', dtypes=(torch.float32, )),
# if not torch.can_cast(numpy_to_torch_dtype_dict[expected.dtype.type], dtype):
# KeyError: <class 'numpy.intc'>
# Following error in Windows CI
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics_normal',
dtypes=[torch.int],
active_if=IS_WINDOWS),
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics_hard',
dtypes=[torch.int],
active_if=IS_WINDOWS),
# TODO fix the formula for complex forward AD
SkipInfo('TestGradients', 'test_forward_mode_AD'),
)), )),
OpInfo('resolve_conj',
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
sample_inputs_func=sample_inputs_view_as_real,
supports_forward_ad=True,
supports_out=False,
),
OpInfo('view_as_real', OpInfo('view_as_real',
dtypes=complex_types(), dtypes=complex_types(),
supports_forward_ad=True, supports_forward_ad=True,
sample_inputs_func=sample_inputs_view_as_real, sample_inputs_func=sample_inputs_view_as_real,
test_conjugated_samples=False,
), ),
OpInfo('view_as_complex', OpInfo('view_as_complex',
dtypes=floating_types_and(torch.half), dtypes=floating_types_and(torch.half),
@ -5120,7 +5164,9 @@ op_db: List[OpInfo] = [
ref=np.imag, ref=np.imag,
dtypes=complex_types(), dtypes=complex_types(),
supports_out=False, supports_out=False,
supports_autograd=False, supports_forward_ad=True,
# TODO(@anjali411): Test this once neg bit is added.
test_conjugated_samples=False,
skips=( skips=(
# Skip since real and imag don't have out variants. # Skip since real and imag don't have out variants.
SkipInfo('TestUnaryUfuncs', 'test_out_arg_all_dtypes'), SkipInfo('TestUnaryUfuncs', 'test_out_arg_all_dtypes'),
@ -5251,8 +5297,6 @@ op_db: List[OpInfo] = [
supports_autograd=False, supports_autograd=False,
decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack], decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack],
skips=( skips=(
# skip because `linalg_lstsq` is not differentiable
SkipInfo('TestGradients', 'test_fn_grad'),
SkipInfo('TestCommon', 'test_variant_consistency_jit'), SkipInfo('TestCommon', 'test_variant_consistency_jit'),
)), )),
OpInfo('linalg.matrix_power', OpInfo('linalg.matrix_power',
@ -5471,7 +5515,8 @@ op_db: List[OpInfo] = [
# "RuntimeError: CUDA error: CUBLAS_STATUS_NOT_SUPPORTED when # "RuntimeError: CUDA error: CUBLAS_STATUS_NOT_SUPPORTED when
# calling cublasGemmStridedBatchedExFix." # calling cublasGemmStridedBatchedExFix."
SkipInfo('TestOpInfo', 'test_supported_backward', SkipInfo('TestOpInfo', 'test_supported_backward',
device_type='cuda', dtypes=(torch.bfloat16,)),)), device_type='cuda', dtypes=(torch.bfloat16,)),
SkipInfo('TestCommon', 'test_conj_view', device_type='cpu'),),),
OpInfo('max', OpInfo('max',
op=torch.max, op=torch.max,
variant_test_name='binary', variant_test_name='binary',
@ -5699,7 +5744,9 @@ op_db: List[OpInfo] = [
assert_autodiffed=True), assert_autodiffed=True),
OpInfo('float_power', OpInfo('float_power',
dtypes=all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool), dtypes=all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool),
sample_inputs_func=sample_inputs_pow), sample_inputs_func=sample_inputs_pow,
skips=(
SkipInfo('TestCommon', 'test_conj_view', device_type='cuda'),),),
OpInfo('prod', OpInfo('prod',
dtypes=all_types_and_complex_and(torch.bool), dtypes=all_types_and_complex_and(torch.bool),
dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
@ -5739,7 +5786,6 @@ op_db: List[OpInfo] = [
ref=np.real, ref=np.real,
dtypes=complex_types(), dtypes=complex_types(),
supports_out=False, supports_out=False,
supports_autograd=False,
skips=( skips=(
# Skip since real and imag don't have out variants. # Skip since real and imag don't have out variants.
SkipInfo('TestUnaryUfuncs', 'test_out_arg_all_dtypes'), SkipInfo('TestUnaryUfuncs', 'test_out_arg_all_dtypes'),
@ -7068,23 +7114,26 @@ def create_input(call_args, requires_grad=True, non_contiguous=False, call_kwarg
def maybe_non_contig(tensor): def maybe_non_contig(tensor):
return tensor if not non_contiguous else make_non_contiguous(tensor) return tensor if not non_contiguous else make_non_contiguous(tensor)
def conjugate(tensor):
return tensor.conj()
if isinstance(arg, torch.Size) or isinstance(arg, dont_convert): if isinstance(arg, torch.Size) or isinstance(arg, dont_convert):
return arg return arg
elif isinstance(arg, tuple) and len(arg) == 0: elif isinstance(arg, tuple) and len(arg) == 0:
var = torch.randn((), dtype=dtype, device=device) var = conjugate(torch.randn((), dtype=dtype, device=device))
var.requires_grad = requires_grad var.requires_grad = requires_grad
return var return var
elif isinstance(arg, tuple) and not isinstance(arg[0], torch.Tensor): elif isinstance(arg, tuple) and not isinstance(arg[0], torch.Tensor):
return Variable(maybe_non_contig(torch.randn(*arg, dtype=dtype, device=device)), requires_grad=requires_grad) return conjugate(maybe_non_contig(torch.randn(*arg, dtype=dtype, device=device))).requires_grad_(requires_grad)
# double check casting # double check casting
elif isinstance(arg, non_differentiable): elif isinstance(arg, non_differentiable):
if isinstance(arg.tensor, torch.Tensor): if isinstance(arg.tensor, torch.Tensor):
if arg.tensor.dtype == torch.float: if arg.tensor.dtype == torch.float:
return maybe_non_contig(arg.tensor.to(dtype=torch.double, device=device)) return maybe_non_contig(arg.tensor.to(dtype=torch.double, device=device))
if arg.tensor.dtype == torch.cfloat: if arg.tensor.dtype == torch.cfloat:
return maybe_non_contig(arg.tensor.to(dtype=torch.cdouble, device=device)) return conjugate(maybe_non_contig(arg.tensor.to(dtype=torch.cdouble, device=device)))
return maybe_non_contig(arg.tensor.to(device=device)) return conjugate(maybe_non_contig(arg.tensor.to(device=device)))
return maybe_non_contig(arg.tensor.to(device=device)) return conjugate(maybe_non_contig(arg.tensor.to(device=device)))
elif isinstance(arg, torch.Tensor): elif isinstance(arg, torch.Tensor):
if arg.dtype == torch.float: if arg.dtype == torch.float:
arg = arg.double() arg = arg.double()
@ -7094,7 +7143,7 @@ def create_input(call_args, requires_grad=True, non_contiguous=False, call_kwarg
raise RuntimeError("User provided tensor is real for a test that runs with complex dtype, ", raise RuntimeError("User provided tensor is real for a test that runs with complex dtype, ",
"which is not supported for now") "which is not supported for now")
# NOTE: We do clone() after detach() here because we need to be able to change size/storage of v afterwards # NOTE: We do clone() after detach() here because we need to be able to change size/storage of v afterwards
v = maybe_non_contig(arg).detach().to(device=device).clone() v = conjugate(maybe_non_contig(arg)).detach().to(device=device).clone()
v.requires_grad = requires_grad and (v.is_floating_point() or v.is_complex()) v.requires_grad = requires_grad and (v.is_floating_point() or v.is_complex())
return v return v
elif callable(arg): elif callable(arg):