Conjugate View (#54987)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/54987

Based off of ezyang (https://github.com/pytorch/pytorch/pull/44799) and bdhirsh (https://github.com/pytorch/pytorch/pull/43702) 's prototype:

Here's a summary of the changes in this PR:
This PR adds a new dispatch key called Conjugate. This enables us to make conjugate operation a view and leverage the specialized library functions that fast path with the hermitian operation (conj + transpose).

1. Conjugate operation will now return a view with conj bit (1) for complex tensors and returns self for non-complex tensors as before. This also means `torch.view_as_real` will no longer be a view on conjugated complex tensors and is hence disabled. To fill the gap, we have added `torch.view_as_real_physical` which would return the real tensor agnostic of the conjugate bit on the input complex tensor. The information about conjugation on the old tensor can be obtained by calling `.is_conj()` on the new tensor.
2. NEW API:
    a) `.conj()` -- now returning a view.
    b) `.conj_physical()` -- does the physical conjugate operation. If the conj bit for input was set, you'd get `self.clone()`, else you'll get a new tensor with conjugated value in its memory.
    c) `.conj_physical_()`, and `out=` variant
    d) `.resolve_conj()`  -- materializes the conjugation. returns self if the conj bit is unset, else returns a new tensor with conjugated values and conj bit set to 0.
    e) `.resolve_conj_()` in-place version of (d)
    f) `view_as_real_physical` -- as described in (1), it's functionally same as `view_as_real`, just that it doesn't error out on conjugated tensors.
    g) `view_as_real` -- existing function, but now errors out on conjugated tensors.
3. Conjugate Fallback
    a) Vast majority of PyTorch functions would currently use this fallback when they are called on a conjugated tensor.
    b) This fallback is well equipped to handle the following cases:
        - functional operation e.g., `torch.sin(input)`
        - Mutable inputs and in-place operations e.g., `tensor.add_(2)`
        - out-of-place operation e.g., `torch.sin(input, out=out)`
        - Tensorlist input args
        - NOTE: Meta tensors don't work with conjugate fallback.
4. Autograd
    a) `resolve_conj()` is an identity function w.r.t. autograd
    b) Everything else works as expected.
5. Testing:
    a) All method_tests run with conjugate view tensors.
    b) OpInfo tests that run with conjugate views
        - test_variant_consistency_eager/jit
        - gradcheck, gradgradcheck
        - test_conj_views (that only run for `torch.cfloat` dtype)

NOTE: functions like `empty_like`, `zero_like`, `randn_like`, `clone` don't propagate the conjugate bit.

Follow up work:
1. conjugate view RFC
2. Add neg bit to re-enable view operation on conjugated tensors
3. Update linalg functions to call into specialized functions that fast path with the hermitian operation.

Test Plan: Imported from OSS

Reviewed By: VitalyFedyunin

Differential Revision: D28227315

Pulled By: anjali411

fbshipit-source-id: acab9402b9d6a970c6d512809b627a290c8def5f
This commit is contained in:
anjali411 2021-06-04 14:11:23 -07:00 committed by Facebook GitHub Bot
parent 19985d6f84
commit 3607478ecd
43 changed files with 783 additions and 180 deletions

View File

@ -1031,7 +1031,6 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
m.impl("sum.dim_IntList", sum_batching_rule);
m.impl("is_complex", native::is_complex);
m.impl("conj", native::conj);
// inplace operations
m.impl("fill_.Scalar", fill_inplace_scalar_batching_rule);
@ -1085,7 +1084,7 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
UNARY_POINTWISE(ceil);
UNARY_POINTWISE(cos);
UNARY_POINTWISE(cosh);
UNARY_POINTWISE(_conj);
UNARY_POINTWISE(conj_physical);
UNARY_POINTWISE(digamma);
UNARY_POINTWISE(exp);
UNARY_POINTWISE(expm1);
@ -1181,6 +1180,10 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
TRIVIAL_OP(imag)
TRIVIAL_OP(real);
TRIVIAL_OP(view_as_real);
TRIVIAL_OP(_view_as_real_physical);
TRIVIAL_OP(conj);
TRIVIAL_OP(_conj);
TRIVIAL_OP(resolve_conj);
m.impl("view_as_complex", view_as_complex_batching_rule);
#undef TRIVIAL

View File

@ -0,0 +1,152 @@
#include <ATen/ATen.h>
#include <ATen/core/op_registration/op_registration.h>
#include <torch/library.h>
#include <ATen/core/dispatch/Dispatcher.h>
#include <ATen/native/UnaryOps.h>
#include <ATen/NativeFunctions.h>
namespace at {
void conjugateFallback(const c10::OperatorHandle& op, DispatchKeySet dispatch_keys, torch::jit::Stack* stack) {
// Situations to handle:
// 1. Out-of-place operation. Easy: materialize all inputs and
// call it a day.
// 2. Inplace operation. Desugar x.add_(2) into x.conj_().add_(2).conj_().
// Materialize other inputs as in (1).
// 3. out= operation. Desugar add(x, 2, out=y) into y.copy_(add(x, 2))
// Materialize other inputs as in (1).
//
// It is important to be able to tell if we READ from an argument and if we
// WRITE from an argument. Conservative approach is to assume that we always
// READ from an argument, but in out-of-place operations you can skip
// conjugating inputs on entry that never get used. In current schema we
// can't easily tell if inplace situation has happened, so don't do it.
const auto& arguments = op.schema().arguments();
const auto num_arguments = arguments.size();
const auto stack_start = stack->size() - num_arguments;
c10::optional<bool> is_write;
for (int64_t i = 0; i < num_arguments; ++i) {
const auto& alias_info = arguments[i].alias_info();
// Three possible states:
// 1. alias_info has no value --> out-of-place operation
// 2. alias_info does have a value, alias_info->is_write=True --> in-place or out= operation
// 3. alias_info does have a value, alias_info->is_write=False --> view operation
if (alias_info.has_value()) {
if (is_write.has_value()) {
TORCH_CHECK(*is_write == alias_info->isWrite(),
"Unsupported operator for conjugate fallback: ", op.schema().name(),
"Conjugate fallback doesn't work for operators with a mix "
"mutable and non-mutable inputs that alias with outputs, "
"this must be implemented manually. "
"If you got this error on a core op, please report a bug to PyTorch.");
} else {
is_write = alias_info->isWrite();
}
}
}
if (is_write.has_value() && !*is_write) {
// We assume that view operators automatically handle conjugation
// correctly by propagating the Conjugate dispatch key in key_set.
// This is not necessarily always right, so you should test these cases.
op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, DispatchKey::Conjugate), stack);
return;
}
// Mutable inputs to be tracked separately
std::vector<Tensor> mutable_inputs;
for (int64_t i = 0; i < num_arguments; ++i) {
auto& ivalue = (*stack)[stack_start + i];
if (!(ivalue.isTensor() || ivalue.isTensorList())) {
continue;
}
const auto& argument = arguments[i];
bool mut_arg = false;
if (argument.alias_info()) {
// View operations were already filtered above, so only in-place/out= operations should get here.
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(argument.alias_info()->isWrite());
mut_arg = true;
}
if (ivalue.isTensor()) {
auto* impl = ivalue.unsafeToTensorImpl();
if (!impl->is_conj()) {
continue;
}
auto tensor = std::move(ivalue).toTensor();
TORCH_CHECK_NOT_IMPLEMENTED(!tensor.is_meta(), "Conjugate Fallback does not support meta tensors.");
if (mut_arg) {
// TODO: This is a waste if the argument is write only
tensor._set_conj(false);
at::conj_physical_(tensor);
mutable_inputs.emplace_back(tensor);
} else {
tensor = at::resolve_conj(tensor);
}
(*stack)[stack_start + i] = std::move(tensor);
} else if (ivalue.isTensorList()) {
auto tensors = std::move(ivalue).toTensorList();
if (mut_arg) {
for(const auto j : c10::irange(tensors.size())) {
Tensor t = tensors[j];
t._set_conj(false);
at::conj_physical_(t);
mutable_inputs.emplace_back(t);
}
} else {
for(const auto j : c10::irange(tensors.size())) {
tensors[j] = at::resolve_conj(tensors[j]);
}
}
(*stack)[stack_start + i] = std::move(tensors);
}
}
op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, DispatchKey::Conjugate), stack);
for (auto& mutable_input : mutable_inputs) {
at::conj_physical_(mutable_input);
mutable_input._set_conj(true);
}
}
TORCH_LIBRARY_IMPL(_, Conjugate, m) {
m.fallback(torch::CppFunction::makeFromBoxedFunction<&conjugateFallback>());
}
TORCH_LIBRARY_IMPL(aten, Conjugate, m) {
m.impl("requires_grad_", torch::CppFunction::makeFallthrough());
m.impl("set_.source_Storage_storage_offset", torch::CppFunction::makeFallthrough());
m.impl("set_.source_Tensor", torch::CppFunction::makeFallthrough());
m.impl("set_", torch::CppFunction::makeFallthrough());
m.impl("copy_", torch::CppFunction::makeFallthrough());
m.impl("clone", torch::CppFunction::makeFallthrough());
m.impl("conj", torch::CppFunction::makeFallthrough());
m.impl("_conj", torch::CppFunction::makeFallthrough());
m.impl("_conj_physical", torch::CppFunction::makeFallthrough());
m.impl("conj_physical", torch::CppFunction::makeFallthrough());
m.impl("conj_physical_", torch::CppFunction::makeFallthrough());
m.impl("resolve_conj", torch::CppFunction::makeFallthrough());
m.impl("empty_like", torch::CppFunction::makeFallthrough());
m.impl("empty.memory_format", torch::CppFunction::makeFallthrough());
m.impl("empty.out", torch::CppFunction::makeFallthrough());
m.impl("empty_strided", torch::CppFunction::makeFallthrough());
m.impl("full_like", torch::CppFunction::makeFallthrough());
m.impl("stride.int", torch::CppFunction::makeFallthrough());
m.impl("stride.Dimname", torch::CppFunction::makeFallthrough());
m.impl("size.int", torch::CppFunction::makeFallthrough());
m.impl("size.Dimname", torch::CppFunction::makeFallthrough());
m.impl("is_complex", torch::CppFunction::makeFallthrough());
m.impl("_view_as_real_physical", torch::CppFunction::makeFallthrough());
m.impl("view_as_real", torch::CppFunction::makeFallthrough());
m.impl("imag", torch::CppFunction::makeFallthrough());
m.impl("real", torch::CppFunction::makeFallthrough());
m.impl("view", torch::CppFunction::makeFallthrough());
m.impl("reshape", torch::CppFunction::makeFallthrough());
}
} // namespace at

View File

@ -117,7 +117,6 @@ TORCH_LIBRARY_IMPL(aten, Named, m) {
m.impl("clamp_min_.Tensor", CppFunction::makeFallthrough());
m.impl("clone", CppFunction::makeFallthrough());
m.impl("conj", CppFunction::makeFallthrough());
m.impl("conj.out", CppFunction::makeFallthrough());
m.impl("contiguous", CppFunction::makeFallthrough());
m.impl("copy_", CppFunction::makeFallthrough());
m.impl("cos", CppFunction::makeFallthrough());

View File

@ -238,6 +238,9 @@ _(aten, coalesce) \
_(aten, combinations) \
_(aten, _conj) \
_(aten, conj) \
_(aten, conj_physical) \
_(aten, conj_physical_) \
_(aten, resolve_conj) \
_(aten, complex) \
_(aten, copysign) \
_(aten, polar) \
@ -764,6 +767,7 @@ _(aten, zeros_like) \
_(aten, real) \
_(aten, imag) \
_(aten, view_as_real) \
_(aten, _view_as_real_physical) \
_(aten, view_as_complex) \
/* nothing */

View File

@ -2,6 +2,9 @@
#include <ATen/ATen.h>
// WARNING: this header contains non-inline functions and should be only
// included from ONE cpp file
namespace at { namespace native {
// View tensor with new dtype, storage offset, sizes and strides
@ -9,8 +12,9 @@ inline Tensor view_tensor(
const Tensor &tensor, ScalarType dtype,
int64_t offset, IntArrayRef sizes, IntArrayRef strides) {
Storage storage = tensor.storage();
auto key_set = tensor.key_set().remove(DispatchKey::Conjugate);
auto new_tensor = detail::make_tensor<TensorImpl>(
c10::TensorImpl::VIEW, std::move(storage), tensor.key_set(), scalarTypeToTypeMeta(dtype));
c10::TensorImpl::VIEW, std::move(storage), key_set, scalarTypeToTypeMeta(dtype));
auto * impl = new_tensor.unsafeGetTensorImpl();
impl->set_storage_offset(offset);
impl->set_sizes_and_strides(sizes, strides);
@ -30,7 +34,12 @@ inline DimVector computeStrideForViewAsReal(IntArrayRef oldstride) {
// with corresponding real dtype containing the complex values
// in the last two dimensions
Tensor view_as_real(const Tensor& self) {
TORCH_CHECK(self.is_complex(), "view_as_real is only supported for complex tensors");
TORCH_CHECK(!self.is_conj(), "view_as_real doesn't work on unresolved conjugated tensors. To resolve the conjugate tensor so you can view it as real, use self.resolve_conj(); however, be warned that the resulting tensor will NOT alias the original.");
return native::_view_as_real_physical(self);
}
Tensor _view_as_real_physical(const Tensor& self) {
TORCH_CHECK(self.is_complex(), "view_as_real_physical is only supported for complex tensors");
auto old_sizes = self.sizes();
DimVector new_sizes(old_sizes.size() + 1);
std::copy(old_sizes.begin(), old_sizes.end(), new_sizes.begin());
@ -39,7 +48,8 @@ Tensor view_as_real(const Tensor& self) {
auto new_strides = computeStrideForViewAsReal(self.strides());
auto new_storage_offset = 2 * self.storage_offset();
const auto float_type = c10::toValueType(self.scalar_type());
return view_tensor(self, float_type, new_storage_offset, new_sizes, new_strides);
auto real_tensor = view_tensor(self, float_type, new_storage_offset, new_sizes, new_strides);
return real_tensor;
}
inline DimVector computeStrideForViewAsComplex(IntArrayRef oldstride) {

View File

@ -150,7 +150,7 @@ Tensor fft_r2c(c10::string_view function_name,
if (!forward) {
// FIXME: _fft_r2c doesn't support native r2c IFFT
return out.defined() ? at::conj_out(out, ret) : at::conj(ret);
return out.defined() ? at::conj_physical_out(out, ret) : at::conj(ret);
} else {
return ret;
}

View File

@ -349,6 +349,8 @@ Tensor empty_like(
namedinference::propagate_names(result, self.names());
}
// never propagate Conjugate key
result._set_conj(false);
return result;
}

View File

@ -30,6 +30,10 @@ bool is_signed(const Tensor &self) {
return self.is_signed();
}
bool is_conj(const Tensor& self) {
return self.is_conj();
}
bool is_sparse(const Tensor& self) {
return self.is_sparse();
}

View File

@ -28,7 +28,8 @@ namespace at {
namespace meta {
// Unary float operations always produce floating point
// outputs, even if their inputs are integer
// outputs for floating point and integral types
// For complex inputs, the output type should be the same as input type.
#define CREATE_UNARY_FLOAT_META_FUNC(func) \
TORCH_META_FUNC(func) (const Tensor& self) { \
build_unary_float_op(maybe_get_output(), self); \
@ -363,7 +364,8 @@ Tensor angle(const Tensor& self) {
Tensor real(const Tensor& self) {
if (self.is_complex()) {
auto real_tensor = at::view_as_real(self);
// real is never affected by conjugate bit, safe to use physical version
auto real_tensor = at::_view_as_real_physical(self);
return at::select(real_tensor, real_tensor.dim() - 1, 0);
} else {
TORCH_CHECK(false, "real is not implemented for tensors with non-complex dtypes.");
@ -379,17 +381,44 @@ Tensor imag(const Tensor& self) {
}
}
Tensor& conj_out(const Tensor& self, Tensor& result) {
return unary_op_impl_out(result, self, conj_stub);
Tensor& conj_physical_out(const Tensor& self, Tensor& result) {
return unary_op_impl_out(result, self, conj_physical_stub);
}
Tensor _conj(const Tensor& self) { return unary_op_impl(self, at::conj_out); }
Tensor _conj_physical(const Tensor& self) {
if (self.is_conj()) {
return self.conj().clone();
}
return unary_op_impl(self, at::conj_physical_out);
}
Tensor conj_physical(const Tensor& self) {
if (!self.is_complex()) return self;
return at::_conj_physical(self);
}
Tensor& conj_physical_(Tensor& self) {
if (!self.is_complex()) return self;
return unary_op_impl_out(self, self, conj_physical_stub);
}
Tensor resolve_conj(const Tensor& self) {
if (!self.is_conj()) { return self; }
// conjugation is handled in `copy_()` that clone ultimately calls into
return self.clone(self.suggest_memory_format());
}
Tensor _conj(const Tensor& self) {
Tensor self_ = self.alias();
self_._set_conj(!self.is_conj());
namedinference::propagate_names(self_, self);
return self_;
}
Tensor conj(const Tensor& self) {
if (!self.is_complex()) {
return self;
}
return at::_conj(self);
// This might look like an infinite recursion but it's not.
// This actually calls into `conj()` defined in the Tensor class.
return self.conj();
}
// special_exp2, alias for exp2
@ -689,7 +718,7 @@ DEFINE_DISPATCH(abs_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-va
DEFINE_DISPATCH(angle_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(real_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(imag_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(conj_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(conj_physical_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(acos_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(acosh_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(asinh_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)

View File

@ -14,7 +14,7 @@ DECLARE_DISPATCH(unary_fn, abs_stub);
DECLARE_DISPATCH(unary_fn, angle_stub);
DECLARE_DISPATCH(unary_fn, real_stub);
DECLARE_DISPATCH(unary_fn, imag_stub);
DECLARE_DISPATCH(unary_fn, conj_stub);
DECLARE_DISPATCH(unary_fn, conj_physical_stub);
DECLARE_DISPATCH(unary_fn, acos_stub);
DECLARE_DISPATCH(unary_fn, acosh_stub);
DECLARE_DISPATCH(unary_fn, asinh_stub);

View File

@ -5,6 +5,7 @@
#include <ATen/native/TensorIterator.h>
#include <ATen/native/cpu/Loops.h>
#include <c10/util/TypeCast.h>
#include <ATen/native/cpu/zmath.h>
namespace at {
namespace native {
@ -13,6 +14,15 @@ namespace {
static void copy_kernel(TensorIterator& iter, bool non_blocking) {
ScalarType dtype = iter.dtype(0);
if (dtype == iter.dtype(1)) {
// TODO: as the majority of these operations can be done treating
// their datatypes as opaque bit patterns, we don't actually need
// separate instantiations per dtype; we only need a separate
// instantiation per dtype size. This would probably save us a
// little bit of code size here
// TODO: not sure if optimizer is able to compile two levels of
// conditionals into a single jump table. We should have a
// single jump table here; might be worth just writing out the
// dispatch statement by hand instead of using AT_DISPATCH
if (dtype == ScalarType::Half) {
cpu_kernel(iter, [=](at::Half a) -> at::Half { return a; });
} else if (dtype == ScalarType::ComplexHalf) {
@ -25,12 +35,21 @@ static void copy_kernel(TensorIterator& iter, bool non_blocking) {
[=](Vectorized<scalar_t> a) -> Vectorized<scalar_t> { return a; });
});
} else if (isComplexType(dtype)) {
AT_DISPATCH_COMPLEX_TYPES(dtype, "copy_kernel", [&] {
cpu_kernel_vec(
iter,
[=](scalar_t a) -> scalar_t { return a; },
[=](Vectorized<scalar_t> a) -> Vectorized<scalar_t> { return a; });
});
if (iter.tensor(0).is_conj() == iter.tensor(1).is_conj()) {
AT_DISPATCH_COMPLEX_TYPES(dtype, "copy_kernel", [&] {
cpu_kernel_vec(
iter,
[=](scalar_t a) -> scalar_t { return a; },
[=](Vectorized<scalar_t> a) -> Vectorized<scalar_t> { return a; });
});
} else {
AT_DISPATCH_COMPLEX_TYPES(dtype, "conj_kernel", [&] {
cpu_kernel_vec(
iter,
[=](scalar_t a) -> scalar_t { return conj_impl(a); },
[=](Vectorized<scalar_t> a) -> Vectorized<scalar_t> { return a.conj(); });
});
}
} else {
AT_DISPATCH_ALL_TYPES_AND2(
ScalarType::Bool, ScalarType::BFloat16,dtype, "copy_kernel", [&] {
@ -62,6 +81,9 @@ static void copy_kernel(TensorIterator& iter, bool non_blocking) {
return c10::static_cast_with_inter_type<dest_t, scalar_t>::apply(src); });
});
});
if (iter.tensor(0).is_conj() != iter.tensor(1).is_conj()) {
iter.tensor(0).conj_physical_();
}
}
}

View File

@ -193,6 +193,7 @@ static void imag_kernel(TensorIteratorBase& iter) {
});
}
// NB: Ignores the negative bit on tensors
static void conj_kernel(TensorIteratorBase& iter) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
kBool, kBFloat16, kHalf, iter.common_dtype(), "conj_cpu", [&]() {
@ -716,7 +717,7 @@ REGISTER_DISPATCH(real_stub, &CPU_CAPABILITY::real_kernel);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_DISPATCH(imag_stub, &CPU_CAPABILITY::imag_kernel);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_DISPATCH(conj_stub, &CPU_CAPABILITY::conj_kernel);
REGISTER_DISPATCH(conj_physical_stub, &CPU_CAPABILITY::conj_kernel);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_DISPATCH(exp2_stub, &CPU_CAPABILITY::exp2_kernel);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)

View File

@ -25,7 +25,8 @@ void copy_device_to_device(TensorIterator& iter, bool non_blocking) {
// We can memcpy the memory if both tensors have the same type AND both
// tensors are contiguous after dimension coalescing and reordering.
bool same_type = iter.dtype(0) == iter.dtype(1);
bool memcpy_eligible = same_type && iter.is_contiguous();
bool same_conj = iter.tensor(0).is_conj() == iter.tensor(1).is_conj();
bool memcpy_eligible = same_type && same_conj && iter.is_contiguous();
Device dst_device = iter.device(0);
Device src_device = iter.device(1);
@ -71,10 +72,17 @@ void copy_device_to_device(TensorIterator& iter, bool non_blocking) {
gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) { return x; });
});
} else {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
kHalf, kBool, kBFloat16, dtype, "copy_", [&] {
gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) { return x; });
});
if (!same_conj && same_type) {
AT_DISPATCH_COMPLEX_TYPES(
dtype, "copy_conj_", [&] {
gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) { return std::conj(x); });
});
} else {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
kHalf, kBool, kBFloat16, dtype, "copy_", [&] {
gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) { return x; });
});
}
}
}
@ -152,6 +160,10 @@ static void copy_kernel_cuda(TensorIterator& iter, bool non_blocking) {
src_contig = iter.tensor(1).expand_as(dst).contiguous();
}
// propagate the correct conjugate bit
dst_contig._set_conj(dst.is_conj());
src_contig._set_conj(iter.tensor(1).is_conj());
// perform a same-dtype copy on contiguous tensors
TORCH_INTERNAL_ASSERT(dst_contig.sizes().equals(src_contig.sizes()));
TORCH_INTERNAL_ASSERT(dst_contig.scalar_type() == src_contig.scalar_type());
@ -201,6 +213,10 @@ static void copy_kernel_cuda(TensorIterator& iter, bool non_blocking) {
AT_CUDA_CHECK(cudaStreamSynchronize(stream));
#endif
}
if (iter.tensor(0).is_conj() != iter.tensor(1).is_conj()) {
iter.tensor(0).conj_physical_();
}
}
REGISTER_DISPATCH(copy_stub, &copy_kernel_cuda);

View File

@ -80,6 +80,7 @@ __host__ __device__ static inline c10::complex<T> conj_wrapper(c10::complex<T> v
return std::conj(v);
}
// NB: Ignores the negative bit on tensors
void conj_kernel_cuda(TensorIteratorBase& iter) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
kBool, kBFloat16, kHalf, iter.common_dtype(), "conj_cuda", [&]() {
@ -92,6 +93,6 @@ void conj_kernel_cuda(TensorIteratorBase& iter) {
REGISTER_DISPATCH(angle_stub, &angle_kernel_cuda);
REGISTER_DISPATCH(real_stub, &real_kernel_cuda);
REGISTER_DISPATCH(imag_stub, &imag_kernel_cuda);
REGISTER_DISPATCH(conj_stub, &conj_kernel_cuda);
REGISTER_DISPATCH(conj_physical_stub, &conj_kernel_cuda);
}} // namespace at::native

View File

@ -271,6 +271,11 @@
dispatch:
CPU, CUDA: view_as_real
- func: _view_as_real_physical(Tensor(a) self) -> Tensor(a)
variants: function
dispatch:
CPU, CUDA: _view_as_real_physical
- func: view_as_complex(Tensor(a) self) -> Tensor(a)
variants: function
dispatch:
@ -298,21 +303,36 @@
device_check: NoCheck # TensorIterator
variants: function
- func: conj(Tensor(a) self) -> Tensor(a)
device_check: NoCheck # TensorIterator
- func: _conj(Tensor(a) self) -> Tensor(a)
variants: function, method
- func: conj.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
dispatch:
CPU, CUDA: conj_out
SparseCPU, SparseCUDA: conj_out_sparse
- func: _conj(Tensor self) -> Tensor
variants: function
dispatch:
CompositeExplicitAutograd: _conj
- func: conj(Tensor(a) self) -> Tensor(a)
variants: function, method
manual_cpp_binding: True
- func: _conj_physical(Tensor self) -> Tensor
variants: function, method
dispatch:
CompositeExplicitAutograd: _conj_physical
- func: conj_physical(Tensor self) -> Tensor
variants: function, method
- func: conj_physical.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU, CUDA: conj_physical_out
SparseCPU, SparseCUDA: conj_physical_out_sparse
- func: conj_physical_(Tensor(a!) self) -> Tensor(a!)
variants: function, method
dispatch:
CompositeExplicitAutograd: conj_physical_
- func: resolve_conj(Tensor(a) self) -> Tensor(a)
variants: function, method
- func: acos(Tensor self) -> Tensor
device_check: NoCheck # TensorIterator
variants: function, method
@ -2249,6 +2269,11 @@
device_guard: False
manual_cpp_binding: True
- func: is_conj(Tensor self) -> bool
variants: function, method
device_guard: False
manual_cpp_binding: True
- func: isreal(Tensor self) -> Tensor
variants: function, method

View File

@ -1645,15 +1645,15 @@ Tensor& bmm_out_sparse_cpu(const SparseTensor& self, const Tensor& mat2, Tensor&
return result;
}
Tensor conj_sparse(const Tensor& input) {
if (!input.is_complex()) {
return input;
}
Tensor result = at::native::empty_like(input);
return conj_out_sparse(input, result);
}
// Tensor conj_physical_sparse(const Tensor& input) {
// if (!input.is_complex()) {
// return input;
// }
// Tensor result = at::native::empty_like(input);
// return conj_physical_out_sparse(input, result);
// }
Tensor& conj_out_sparse(const Tensor& input, Tensor& result) {
Tensor& conj_physical_out_sparse(const Tensor& input, Tensor& result) {
TORCH_INTERNAL_ASSERT(input.is_sparse());
if (input.numel() == 0) {
return result;
@ -1665,7 +1665,7 @@ Tensor& conj_out_sparse(const Tensor& input, Tensor& result) {
return result;
}
Tensor result_values = result._values();
at::conj_out(result_values, input._values());
at::conj_physical_out(result_values, input._values());
return result;
}

View File

@ -216,4 +216,12 @@ inline bool is_inference(const Tensor& tensor) {
return tensor.is_inference();
}
inline bool is_conj(const Tensor& tensor) {
return tensor.is_conj();
}
inline Tensor conj(const Tensor& tensor) {
return tensor.conj();
}
}

View File

@ -136,6 +136,17 @@ class TORCH_API Tensor {
}
}
Tensor conj() const {
if (!this->is_complex()) {
return *this;
} else {
if (this->is_sparse()) {
return this->conj_physical();
}
return this->_conj();
}
}
/// Should be used if *this can reasonably be expected to be contiguous and
/// performance is important.
/// Compared to contiguous, it saves a reference count
@ -363,6 +374,18 @@ class TORCH_API Tensor {
return !at::impl::variable_excluded_from_dispatch();
}
inline bool is_conj() const {
return impl_->is_conj();
}
// sets the conjugate bit of a tensor.
// NOTE: Conjugate bit is supposed to be a read-only field. Only change this, if you are extremely sure
// that's what you want. Changing this might lead to incorrect behavior since conjugation is
// a lazy operation and we rely on this bit to determine if a conjugation needs to be materialized.
inline void _set_conj(bool conjugate) const {
impl_->_set_conj(conjugate);
}
/// Returns a `Tensor`'s layout.
Layout layout() const noexcept {
return impl_->layout();

View File

@ -65,6 +65,8 @@ const char* toString(DispatchKey t) {
case DispatchKey::PrivateUse3:
return "PrivateUse3";
case DispatchKey::Conjugate:
return "Conjugate";
case DispatchKey::Meta:
return "Meta";

View File

@ -142,6 +142,11 @@ enum class DispatchKey : uint8_t {
// constituent parts.
Named,
// The Conjugate dispatch key is set for any tensors that need to perform
// conjugation
// This is implemented at a dispatch level right before any backends run
Conjugate,
// See Note [Out-of-tree vmap+grad prototype]. The purpose of this key
// is to insert code after the "autograd subsystem" runs, so this key should
// be directly after ADInplaceOrView and all of the autograd keys.

View File

@ -938,6 +938,26 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
*/
const at::Tensor& grad() const;
/**
* Whether or not the imaginary part of the tensor should be negated
*/
inline bool is_conj() const {
return key_set_.has(DispatchKey::Conjugate);
}
/**
* Set whether or not to take the conjugate of the tensor (flip the imaginary
* bit).
*/
void _set_conj(bool value) {
if (value) {
key_set_ = key_set_.add(DispatchKey::Conjugate);
TORCH_INTERNAL_ASSERT(isComplexType(typeMetaToScalarType(dtype())));
} else {
key_set_ = key_set_.remove(DispatchKey::Conjugate);
}
}
/**
* Return the accumulated gradient of a tensor. This gradient is computed
* using forward mode AD.

View File

@ -64,7 +64,6 @@ For reference, heres a full list of view ops in PyTorch:
- :attr:`~torch.Tensor.real`
- :attr:`~torch.Tensor.imag`
- :meth:`~torch.Tensor.view_as_real`
- :meth:`~torch.Tensor.view_as_imag`
- :meth:`~torch.Tensor.unflatten`
- :meth:`~torch.Tensor.unfold`
- :meth:`~torch.Tensor.unsqueeze`

View File

@ -276,6 +276,9 @@ Tensor class reference
Tensor.contiguous
Tensor.copy_
Tensor.conj
Tensor.conj_physical
Tensor.conj_physical_
Tensor.resolve_conj
Tensor.copysign
Tensor.copysign_
Tensor.cos
@ -410,6 +413,7 @@ Tensor class reference
Tensor.isnan
Tensor.is_contiguous
Tensor.is_complex
Tensor.is_conj
Tensor.is_floating_point
Tensor.is_inference
Tensor.is_leaf

View File

@ -19,6 +19,7 @@ Tensors
is_tensor
is_storage
is_complex
is_conj
is_floating_point
is_nonzero
set_default_dtype
@ -86,6 +87,7 @@ Indexing, Slicing, Joining, Mutating Ops
:nosignatures:
cat
conj
chunk
dsplit
column_stack
@ -293,7 +295,7 @@ Pointwise Ops
ceil
clamp
clip
conj
conj_physical
copysign
cos
cosh
@ -511,6 +513,7 @@ Other Operations
vander
view_as_real
view_as_complex
resolve_conj
BLAS and LAPACK Operations

View File

@ -91,6 +91,9 @@ allow_list = [
("aten::linalg_vector_norm", datetime.date(2021, 5, 15)),
("aten::repeat_interleave", datetime.date(2021, 6, 26)),
("aten::one_hot", datetime.date(2021, 6, 15)),
("aten::conj", datetime.date(2021, 8, 1)),
("aten::_conj", datetime.date(2021, 8, 1)),
("aten::conj.out", datetime.date(2021, 8, 1)),
]
def allow_listed(schema, allow_list):

View File

@ -4279,7 +4279,7 @@ class TestAutograd(TestCase):
gradcheck(bad_fn, (x, y), check_forward_ad=True, fast_mode=fast_mode)
def basic_mul(x):
return torch.view_as_real(x * 1j)
return torch.view_as_real(torch.resolve_conj(x * 1j))
gradcheck(basic_mul, x, check_forward_ad=True, fast_mode=fast_mode)
# Test for one input and one output being complex
@ -5459,7 +5459,7 @@ complex_list = ['t', 'view', 'reshape', 'reshape_as', 'view_as', 'roll', 'clone'
'permute', 'squeeze', 'unsqueeze', 'resize', 'resize_as', 'tril', 'triu',
'chunk', 'split', 'split_with_sizes', 'zero_',
'__radd__', 'mul', '__rmul__', 'diagonal', 'fill_', 'sub', 'narrow',
'swapaxes', 'swapdims', 'tensor_split'] + separate_complex_tests
'swapaxes', 'swapdims', 'tensor_split', 'select', 'clone'] + separate_complex_tests
# deny list for batched grad computation
EXCLUDE_BATCHED_GRAD_TESTS = set([
@ -5591,7 +5591,9 @@ def add_test(
output_variable = getattr(self_variable, name)(*args_variable, **kwargs_variable)
if not isinstance(output_variable, tuple):
output_variable = (output_variable,)
inplace_self_variable = deepcopy(self_variable)
self.assertEqual(inplace_self_variable, self_variable)
inplace_self_variable_copy = tuple(i.clone() if isinstance(i, torch.Tensor) else i
for i in (inplace_self_variable,))
inplace_args_variable = deepcopy(args_variable)

View File

@ -1946,6 +1946,7 @@ known_failures = [
# If your OpInfo test causes this test to fail, add it here
skip_ops = [
'conj'
]
def get_name(op):

View File

@ -17,7 +17,7 @@ from torch.testing._internal.common_jit import JitCommonTestCase, check_against_
from torch.testing._internal.jit_metaprogramming_utils import create_script_fn, create_traced_fn, \
check_alias_annotation
from torch.testing._internal.jit_utils import disable_autodiff_subgraph_inlining
from collections.abc import Sequence
# Get names of all the operators which have entry in `method_tests` (legacy testing infra)
method_tested_operators = set(map(lambda test_details: test_details[0], method_tests()))
@ -110,7 +110,9 @@ class TestGradients(TestCase):
return variant.__wrapped__ is op.get_inplace()
return variant is op.get_inplace()
samples = op.sample_inputs(device, dtype, requires_grad=True)
include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex
samples = op.sample_inputs(device, dtype, requires_grad=True, include_conjugated_inputs=include_conjugated_inputs)
for sample in samples:
if sample.broadcasts_input and is_inplace(variant):
continue
@ -280,7 +282,8 @@ class TestCommon(JitCommonTestCase):
_requires_grad = (op.supports_autograd and
(dtype.is_floating_point or op.supports_complex_autograd(torch.device(device).type)))
samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad)
include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex
samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad, include_conjugated_inputs=include_conjugated_inputs)
def _test_consistency_helper(samples, variants):
for sample in samples:
@ -381,7 +384,8 @@ class TestCommon(JitCommonTestCase):
_requires_grad = op.supports_autograd and (dtype.is_floating_point or
op.supports_complex_autograd(torch.device(device).type))
samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad)
include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex
samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad, include_conjugated_inputs=include_conjugated_inputs)
for sample in samples:
# Acquires variants to test
@ -734,6 +738,103 @@ class TestCommon(JitCommonTestCase):
with self.assertRaises(RuntimeError, msg=msg_fail):
op_out(out=out)
# Tests that
# 1. The operator's output for physically conjugated tensors and conjugate view tensors
# produces the same value
# 2. The gradients are same in both cases mentioned in (1)
# 3. If the operator's inplace variant is supported, tests that the inplace operation
# produces the correct value when called on a conjugate view tensor and that the output
# has its conj bit set to true
# This test only runs for C -> R and C -> C functions
# TODO: add tests for `R->C` functions
# Note: This test runs for functions that take both tensors and tensorlists as input.
@ops(op_db, allowed_dtypes=(torch.cfloat,))
def test_conj_view(self, device, dtype, op):
if not op.test_conjugated_samples:
self.skipTest("Operation doesn't support conjugated inputs.")
_requires_grad = (op.supports_autograd and op.supports_complex_autograd(torch.device(device).type))
samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad)
inplace_variant = op.inplace_variant
# helper function to physically conjugate the tensor
def conjugate_physical(input):
if isinstance(input, torch.Tensor):
tensor_requires_grad = input.requires_grad
with torch.no_grad():
input = input.conj_physical()
return input.requires_grad_(tensor_requires_grad)
if isinstance(input, Sequence):
out = list(map(clone_input_helper, input))
out[0] = conjugate_physical(out[0])
return tuple(out)
# helper function to clone and conjugate the input if its a tensor
# else clone the sequence and conjugate the first element in the sequence
# If a requires_grad argument is provided the tensor being conjugated will
# have its requires_grad set to that value.
def clone_conj_input_helper(input, **kwargs):
if isinstance(input, torch.Tensor):
requires_grad = kwargs.get('requires_grad', input.requires_grad)
with torch.no_grad():
input = input.clone()
# Note: .conj() is not called under no_grad mode since it's not allowed to modify a
# view created in no_grad mode. Here it's ok to do so, so as a workaround we call conj
# before resetting the requires_grad field for input
input = input.conj()
assert input.is_leaf
return input.requires_grad_(requires_grad)
if isinstance(input, Sequence):
out = list(map(clone_input_helper, input))
out[0] = clone_conj_input_helper(out[0])
return tuple(out)
for sample in samples:
tensor = sample.input if isinstance(sample.input, torch.Tensor) else sample.input[0]
cloned1 = clone_conj_input_helper(sample.input)
sample.input = conjugate_physical(sample.input)
# Computes function forward value with a physically conjugated tensor and
# a conj view tensor and verifies that the output in both case are equal.
expected_forward = op(sample.input, *sample.args, **sample.kwargs)
forward_with_conjview = op(cloned1, *sample.args, **sample.kwargs)
self.assertEqual(expected_forward, forward_with_conjview)
# If the op has an inplace variant, and the input doesn't require broadcasting
# and has the same dtype as output, verify that the inplace operation on a conjugated
# input produces correct output, and the output tensor has the conj bit set to True
if inplace_variant is not None and not sample.broadcasts_input:
cloned2 = clone_conj_input_helper(tensor, requires_grad=False)
if (isinstance(expected_forward, torch.Tensor) and
expected_forward.dtype is tensor.dtype):
inplace_forward = inplace_variant(cloned2, *sample.args, **sample.kwargs)
self.assertTrue(inplace_forward.is_conj())
self.assertEqual(inplace_forward, expected_forward)
# TODO: backward consistency only supported for single tensor outputs
# TODO: backward consistency only checked on sample.input, not all
# tensor inputs
# TODO: update to handle checking grads of all tensor inputs as
# derived from each tensor output
if isinstance(expected_forward, torch.Tensor) and expected_forward.requires_grad:
tensor = sample.input if isinstance(sample.input, torch.Tensor) else sample.input[0]
expected_forward.sum().backward(retain_graph=True)
forward_with_conjview.sum().backward(retain_graph=True)
if tensor.grad is not None:
cloned1_tensor = cloned1 if isinstance(cloned1, torch.Tensor) else cloned1[0]
self.assertEqual(tensor.grad, cloned1_tensor.grad)
tensor.grad, cloned1_tensor.grad = None, None
# a repeat of the above test if output is not complex valued
if (expected_forward.is_complex()):
grad = torch.randn_like(expected_forward)
expected_forward.backward(grad.conj_physical())
forward_with_conjview.backward(grad.conj())
self.assertEqual(tensor.grad, cloned1_tensor.grad)
instantiate_device_type_tests(TestOpInfo, globals())
instantiate_device_type_tests(TestGradients, globals())

View File

@ -346,6 +346,14 @@ class TestViewOps(TestCase):
self.assertEqual(a[5:].real, a.real[5:])
self.assertEqual(a[5:].imag, a.imag[5:])
@onlyOnCPUAndCUDA
@dtypes(*torch.testing.get_all_complex_dtypes())
def test_conj_view(self, device, dtype) -> None:
t = _make_tensor((4, 5,), dtype, device)
v = t.conj()
self.assertTrue(self.is_view_of(t, v))
self.assertEqual(v, torch.from_numpy(t.cpu().numpy().conj()).to(device=device))
@onlyOnCPUAndCUDA
@dtypes(*product(torch.testing.get_all_complex_dtypes(), torch.testing.get_all_dtypes()))
@suppress_warnings

View File

@ -380,16 +380,23 @@
- name: complex(Tensor real, Tensor imag) -> Tensor
real: at::real(grad)
imag: at::imag(grad)
imag: at::imag(grad.resolve_conj())
result: at::complex(real_t, imag_t)
- name: polar(Tensor abs, Tensor angle) -> Tensor
abs, angle: polar_backward(grad, result)
- name: _conj(Tensor self) -> Tensor
- name: _conj(Tensor(a) self) -> Tensor(a)
self: grad.conj()
result: self_t.conj()
- name: _conj_physical(Tensor self) -> Tensor
self: grad.conj_physical()
result: self_t.conj_physical()
- name: conj_physical_(Tensor(a!) self) -> Tensor(a!)
self: grad.conj_physical()
- name: copysign.Tensor(Tensor self, Tensor other) -> Tensor
self: copysign_tensor_self_backward(grad, self, result)
other: zeros_like(other)
@ -1306,12 +1313,16 @@
- name: view.dtype(Tensor(a) self, ScalarType dtype) -> Tensor(a)
output_differentiability: [False]
# If the conj bit for self is set, this is effectively a fused conj() and view_as_real
- name: _view_as_real_physical(Tensor(a) self) -> Tensor(a)
self: "self.is_conj() ? at::view_as_complex(grad.contiguous()) : at::view_as_complex(grad.contiguous()).conj()"
- name: view_as_real(Tensor(a) self) -> Tensor(a)
self: at::view_as_complex(grad.contiguous()) # gx0 + 1j * gx1
result: at::view_as_real(self_t)
- name: view_as_complex(Tensor(a) self) -> Tensor(a)
self: at::view_as_real(grad.contiguous()) # [gx, gy]
self: at::view_as_real(grad.contiguous().resolve_conj()) # [gx, gy]
result: at::view_as_complex(self_t)
- name: _s_where(Tensor condition, Tensor self, Tensor other) -> Tensor

View File

@ -28,7 +28,7 @@ from .gen_trace_type import (
#
# A map: function name => name of the argument that all outputs are view of
VIEW_FUNCTIONS_WITH_METADATA_CHANGE = ['view_as_real', 'view_as_complex']
VIEW_FUNCTIONS_WITH_METADATA_CHANGE = ['view_as_complex', '_view_as_real_physical', 'view_as_real', '_conj']
VIEW_FUNCTIONS = {
'numpy_T': 'self',

View File

@ -100,7 +100,8 @@ GRADIENT_IMPLEMENTED_FOR_COMPLEX = {
'replication_pad1d_backward', 'replication_pad2d_backward', 'replication_pad3d_backward',
'diag', 'masked_scatter', 'masked_select', 'index_fill', 'trace', 'polar', 'cumsum', 'rsub',
'eig', 'lerp', 'linalg_vector_norm', 'cumprod', 'prod', 'index_copy', 'lu', 'unfold', 'unfold_backward',
'index', 'masked_fill', 'cross', 'lu_unpack', 'renorm',
'index', 'masked_fill', 'cross', 'lu_unpack', 'renorm', '_view_as_real_physical', '_conj_physical',
'conj_physical_'
}
GRADIENT_IMPLEMENTED_FOR_SPARSE_COMPLEX = {

View File

@ -10,7 +10,7 @@ import yaml
from tools.codegen.api.autograd import (Derivative, DifferentiabilityInfo,
SavedAttribute, ForwardDerivative)
from tools.codegen.api.types import (Binding, CppSignatureGroup, NamedCType, BaseCType, VectorCType,
intArrayRefT, tensorOptionsT, typeAndSizeT, intT,
intArrayRefT, tensorOptionsT, typeAndSizeT, intT, boolT,
tensorGeometryT, scalarTypeT, SpecialArgName,
OptionalCType, stringT)
from tools.codegen.api import cpp
@ -498,6 +498,11 @@ def saved_variables(
'nctype': lambda name: NamedCType(name, BaseCType(intArrayRefT)),
'expr': stride_expr,
}),
# replace self.is_conj() with self_conjugate
(r'{}.is_conj\(\)', {
'suffix': '_conjugate',
'nctype': lambda name: NamedCType(name, BaseCType(boolT)),
})
]
# find which arguments need to be saved

View File

@ -86,8 +86,10 @@ class Tensor(torch._C._TensorBase):
self.requires_grad,
self._backward_hooks)
else:
new_tensor = self.new()
new_tensor = self.new_empty([])
new_tensor.set_(new_storage, self.storage_offset(), self.size(), self.stride())
if self.is_conj():
new_tensor = new_tensor.conj_physical()
new_tensor.requires_grad = self.requires_grad
if self.grad is not None:
new_tensor.grad = self.grad.__deepcopy__(memo)

View File

@ -893,6 +893,27 @@ conj() -> Tensor
See :func:`torch.conj`
""")
add_docstr_all('conj_physical',
r"""
conj_physical() -> Tensor
See :func:`torch.conj_physical`
""")
add_docstr_all('conj_physical_',
r"""
conj_physical_() -> Tensor
In-place version of :meth:`~Tensor.conj_physical`
""")
add_docstr_all('resolve_conj',
r"""
resolve_conj() -> Tensor
See :func:`torch.resolve_conj`
""")
add_docstr_all('copysign',
r"""
copysign(other) -> Tensor
@ -1977,6 +1998,13 @@ is_inference() -> bool
See :func:`torch.is_inference`
""")
add_docstr_all('is_conj',
r"""
is_conj() -> bool
Returns True if the conjugate bit of :attr:`self` is set to true.
""")
add_docstr_all('is_signed',
r"""
is_signed() -> bool

View File

@ -153,11 +153,12 @@ class _Formatter(object):
def _scalar_str(self, formatter1, formatter2=None):
if formatter2 is not None:
real_str = _scalar_str(self.real, formatter1)
imag_str = _scalar_str(self.imag, formatter2) + "j"
if self.imag < 0:
return real_str + imag_str.lstrip()
imag_str = (_scalar_str(self.imag, formatter2) + "j").lstrip()
# handles negative numbers, +0.0, -0.0
if imag_str[0] == '+' or imag_str[0] == '-':
return real_str + imag_str
else:
return real_str + "+" + imag_str.lstrip()
return real_str + "+" + imag_str
else:
return formatter1.format(self.item())
@ -174,11 +175,12 @@ def _vector_str(self, indent, summarize, formatter1, formatter2=None):
def _val_formatter(val, formatter1=formatter1, formatter2=formatter2):
if formatter2 is not None:
real_str = formatter1.format(val.real)
imag_str = formatter2.format(val.imag) + "j"
if val.imag < 0:
return real_str + imag_str.lstrip()
imag_str = (formatter2.format(val.imag) + "j").lstrip()
# handles negative numbers, +0.0, -0.0
if imag_str[0] == '+' or imag_str[0] == '-':
return real_str + imag_str
else:
return real_str + "+" + imag_str.lstrip()
return real_str + "+" + imag_str
else:
return formatter1.format(val)
@ -235,6 +237,7 @@ def _tensor_str(self, indent):
self = self.float()
if self.dtype.is_complex:
self = self.resolve_conj()
real_formatter = _Formatter(get_summarized_data(self.real) if summarize else self.real)
imag_formatter = _Formatter(get_summarized_data(self.imag) if summarize else self.imag)
return _tensor_str_with_formatter(self, indent, summarize, real_formatter, imag_formatter)

View File

@ -2181,15 +2181,18 @@ Example::
tensor([(0.0000+1.0000j), (-1.4142-1.4142j)], dtype=torch.complex128)
""")
add_docstr(torch.conj,
add_docstr(torch.conj_physical,
r"""
conj(input, *, out=None) -> Tensor
conj_physical(input, *, out=None) -> Tensor
Computes the element-wise conjugate of the given :attr:`input` tensor. If :attr:`input` has a non-complex dtype,
this function just returns :attr:`input`.
Computes the element-wise conjugate of the given :attr:`input` tensor.
If :attr:`input` has a non-complex dtype, this function just returns :attr:`input`.
.. warning:: In the future, :func:`torch.conj` may return a non-writeable view for an :attr:`input` of
non-complex dtype. It's recommended that programs not modify the tensor returned by :func:`torch.conj`
.. note::
This performs the conjugate operation regardless of the fact conjugate bit is set or not.
.. warning:: In the future, :func:`torch.conj_physical` may return a non-writeable view for an :attr:`input` of
non-complex dtype. It's recommended that programs not modify the tensor returned by :func:`torch.conj_physical`
when :attr:`input` is of non-complex dtype to be compatible with this change.
.. math::
@ -2203,10 +2206,63 @@ Keyword args:
Example::
>>> torch.conj(torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j]))
>>> torch.conj_physical(torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j]))
tensor([-1 - 1j, -2 - 2j, 3 + 3j])
""".format(**common_args))
add_docstr(torch.conj,
r"""
conj(input) -> Tensor
Returns a view of :attr:`input` with a flipped conjugate bit. If :attr:`input` has a non-complex dtype,
this function just returns :attr:`input`.
.. note::
:func:`torch.conj` performs a lazy conjugation, but the actual conjugated tensor can be materialized
at any time using :func:`torch.resolve_conj`.
.. warning:: In the future, :func:`torch.conj` may return a non-writeable view for an :attr:`input` of
non-complex dtype. It's recommended that programs not modify the tensor returned by :func:`torch.conj_physical`
when :attr:`input` is of non-complex dtype to be compatible with this change.
Args:
{input}
Example::
>>> x = torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j])
>>> x.is_conj()
False
>>> y = torch.conj(x)
>>> y.is_conj()
True
""")
add_docstr(torch.resolve_conj,
r"""
resolve_conj(input) -> Tensor
Returns a new tensor with materialized conjugation if :attr:`input`'s conjugate bit is set to `True`,
else returns :attr:`input`. The output tensor will always have its conjugate bit set to `False`.
Args:
{input}
Example::
>>> x = torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j])
>>> y = x.conj()
>>> y.is_conj()
True
>>> z = y.resolve_conj()
>>> z
tensor([-1 - 1j, -2 - 2j, 3 + 3j])
>>> z.is_conj()
False
""")
add_docstr(torch.copysign,
r"""
copysign(input, other, *, out=None) -> Tensor
@ -4201,6 +4257,15 @@ Args:
{input}
""".format(**common_args))
add_docstr(torch.is_conj, r"""
is_conj(input) -> (bool)
Returns True if the :attr:`input` is a conjugated tensor, i.e. its conjugate bit is set to `True`.
Args:
{input}
""".format(**common_args))
add_docstr(torch.is_nonzero, r"""
is_nonzero(input) -> (bool)

View File

@ -545,7 +545,7 @@ def _get_analytical_vJu_backward_mode(inputs, outputs, nondet_tol, check_grad_dt
# the error checking logic from slow mode
vJ = vJ.T.squeeze(0)
if vJ.is_complex(): # C -> R
tv = torch.view_as_real(vJ)
tv = torch.view_as_real(vJ.resolve_conj())
tr = tv.select(-1, 0)
ti = tv.select(-1, 1)
jacobian_scalars.append(tr.dot(u[0]) + 1j * ti.dot(u[1]))
@ -894,7 +894,11 @@ def _real_and_imag_output(fn):
outs = _as_tuple(fn(*inputs))
return tuple(fn_to_apply(o) if o.is_complex() else o for o in outs)
return wrapped_fn
return apply_to_c_outs(fn, torch.real), apply_to_c_outs(fn, torch.imag)
# TODO(@anjali411): remove this workaround once neg bit is added.
def torch_imag(x):
return x.resolve_conj().imag
return apply_to_c_outs(fn, torch.real), apply_to_c_outs(fn, torch_imag)
def _real_and_imag_input(fn, complex_inp_indices):
# returns new functions that take real inputs instead of complex inputs and compute fn(x + 0 * 1j)
@ -935,7 +939,7 @@ def _gradcheck_real_imag(gradcheck_fn, func, func_out, tupled_inputs, outputs, e
if complex_inp_indices:
real_fn, imag_fn = _real_and_imag_input(func, complex_inp_indices)
imag_inputs = [inp.imag if is_tensor_like(inp) and inp.is_complex() else inp for inp in tupled_inputs]
imag_inputs = [inp.resolve_conj().imag if is_tensor_like(inp) and inp.is_complex() else inp for inp in tupled_inputs]
imag_func_out = imag_fn(*imag_inputs)
diff_imag_func_out = _differentiable_outputs(imag_func_out)
gradcheck_fn(imag_fn, imag_func_out, imag_inputs, diff_imag_func_out, eps,

View File

@ -248,7 +248,7 @@ Tensor pow_backward_self(Tensor grad, const Tensor & self, const Tensor & expone
Tensor pow_backward_exponent(Tensor grad, const Tensor& self, const Tensor& exponent, Tensor result) {
Tensor cond;
if (exponent.is_complex()) {
auto is_real_exp = at::logical_and(at::imag(exponent) == 0, at::real(exponent) >= 0);
auto is_real_exp = at::logical_and(at::imag(exponent.resolve_conj()) == 0, at::real(exponent) >= 0);
cond = at::logical_and(self == 0, is_real_exp);
} else {
cond = at::logical_and(self == 0, exponent >= 0);
@ -264,7 +264,7 @@ Tensor pow_backward_exponent(Tensor grad, const Scalar & base, const Tensor& exp
if (base.equal(0.0)) {
auto cond = [](auto exp) {
if (exp.is_complex()) {
return at::logical_and(at::imag(exp) == 0, at::real(exp) >= 0);
return at::logical_and(at::imag(exp.resolve_conj()) == 0, at::real(exp) >= 0);
} else {
return exp >=0;
}
@ -2154,7 +2154,7 @@ Tensor svd_backward(const std::vector<torch::autograd::Variable> &grads, const T
if (self.is_complex() && gu.defined()) {
Tensor L = at::matmul(uh, gu).diagonal(0, -2, -1);
at::real(L).zero_();
at::imag(L).mul_(sigma_inv);
at::imag(L.resolve_conj()).mul_(sigma_inv);
Tensor imag_term = at::matmul(u * L.unsqueeze(-2), vh);
return u_term + sigma_term + v_term + imag_term;
}
@ -2195,7 +2195,7 @@ Tensor eig_backward(const std::vector<torch::autograd::Variable> &grads, const T
}
// path for torch.linalg.eig with always a complex tensor of eigenvalues
else {
is_imag_eigvals_zero = (at::imag(D) == 0.0).min().item<bool>();
is_imag_eigvals_zero = (at::imag(D.resolve_conj()) == 0.0).min().item<bool>();
// insert an additional dimension to be compatible with torch.eig.
// Recall that it produces 2D tensors.
// We extract only the real parts as there is no support for

View File

@ -83,23 +83,6 @@ void nnc_aten_sgn(
} catch (...) {
}
}
void nnc_aten_conj(
int64_t bufs_num,
void** buf_data,
int64_t* buf_ranks,
int64_t* buf_dims,
int8_t* buf_dtypes,
int64_t args_num,
int64_t* extra_args) {
std::vector<at::Tensor> tensors =
constructTensors(bufs_num, buf_data, buf_ranks, buf_dims, buf_dtypes);
at::Tensor& r = tensors[0];
const at::Tensor& self = tensors[1];
try {
at::conj_out(r, self);
} catch (...) {
}
}
void nnc_aten_acos(
int64_t bufs_num,
void** buf_data,
@ -2758,9 +2741,6 @@ const static RegisterNNCExternalFunction nnc_angle(
"nnc_aten_angle",
nnc_aten_angle);
const static RegisterNNCExternalFunction nnc_sgn("nnc_aten_sgn", nnc_aten_sgn);
const static RegisterNNCExternalFunction nnc_conj(
"nnc_aten_conj",
nnc_aten_conj);
const static RegisterNNCExternalFunction nnc_acos(
"nnc_aten_acos",
nnc_aten_acos);

View File

@ -228,6 +228,8 @@ def get_ignored_functions() -> Set[Callable]:
Tensor.to_sparse_csr,
Tensor._reduce_ex_internal,
Tensor._fix_weakref,
Tensor._conj,
Tensor._conj_physical,
}
@ -349,6 +351,8 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.polar: lambda abs, ang: -1,
torch.linalg.cond: lambda input, ord=None: -1,
torch.conj: lambda input, out=None: -1,
torch.conj_physical: lambda input, out=None: -1,
torch.resolve_conj: lambda input, out=None: -1,
torch.constant_pad_nd: lambda input, pad, value=0: -1,
torch.conv1d: lambda input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1: -1,
torch.conv2d: lambda input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1: -1,
@ -500,6 +504,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.linalg.inv: lambda input, out=None: -1,
torch.linalg.inv_ex: lambda input, check_errors=False, out=None: -1,
torch.is_complex: lambda input: -1,
torch.is_conj: lambda input: -1,
torch.is_distributed: lambda input: -1,
torch.is_inference: lambda input: -1,
torch.is_floating_point: lambda input: -1,
@ -797,6 +802,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.real: lambda input, out=None: -1,
torch.vdot: lambda input, other, out=None: -1,
torch.view_as_real: lambda input: -1,
torch._view_as_real_physical: lambda input: -1,
torch.view_as_complex: lambda input: -1,
torch.reciprocal: lambda input, out=None: -1,
torch.relu: lambda input, inplace=False: -1,

View File

@ -134,6 +134,8 @@ def _compare_tensors_internal(a: torch.Tensor, b: torch.Tensor, *, rtol, atol, e
# Compares complex tensors' real and imaginary parts separately.
# (see NOTE Test Framework Tensor "Equality")
if a.is_complex():
a = a.resolve_conj()
b = b.resolve_conj()
if equal_nan == "relaxed":
a = a.clone()
b = b.clone()

View File

@ -10,7 +10,6 @@ import random
import torch
import numpy as np
from torch._six import inf
from torch.autograd import Variable
import collections.abc
from typing import List, Sequence, Tuple, Dict, Any, Union
@ -231,6 +230,7 @@ class OpInfo(object):
# function around gradcheck (testing._internal.common_utils.gradcheck)
inplace_variant=_NOTHING, # explicitly pass the inplace variant of the operator if required
method_variant=_NOTHING, # explicitly pass the method variant of the operator if required
test_conjugated_samples=True,
):
# Validates the dtypes are generated from the dispatch-related functions
@ -300,6 +300,8 @@ class OpInfo(object):
if aliases is not None:
self.aliases = tuple(AliasInfo(a) for a in aliases) # type: ignore[assignment]
self.test_conjugated_samples = test_conjugated_samples
def __call__(self, *args, **kwargs):
"""Calls the function variant of the operator."""
return self.op(*args, **kwargs)
@ -326,6 +328,37 @@ class OpInfo(object):
"""
return self.operator_variant
def conjugate_sample_inputs(self, device, dtype, requires_grad=False, **kwargs):
"""Returns an iterable of SampleInputs but with the tensor input or first
tensor in a sequence input conjugated.
"""
# TODO: Remove the try/except once all operators have sample_inputs_func with
# **kwargs in their signature.
try:
samples = self.sample_inputs_func(self, device, dtype, requires_grad, **kwargs)
except TypeError:
samples = self.sample_inputs_func(self, device, dtype, requires_grad)
conj_samples = list(samples)
def conjugate(tensor):
_requires_grad = tensor.requires_grad
with torch.no_grad():
tensor = tensor.conj()
return tensor.requires_grad_(_requires_grad)
for i in range(len(samples)):
sample = conj_samples[i]
# Note: it is assumed that the input here is either a tensor or tensorlist
if isinstance(sample.input, torch.Tensor):
sample.input = conjugate(sample.input)
else:
with torch.no_grad():
sample.input[0] = conjugate(sample.input[0])
return tuple(conj_samples)
def sample_inputs(self, device, dtype, requires_grad=False, **kwargs):
"""Returns an iterable of SampleInputs.
@ -339,6 +372,13 @@ class OpInfo(object):
samples = self.sample_inputs_func(self, device, dtype, requires_grad, **kwargs)
except TypeError:
samples = self.sample_inputs_func(self, device, dtype, requires_grad)
if 'include_conjugated_inputs' in kwargs and kwargs.get('include_conjugated_inputs'):
conj_samples = self.conjugate_sample_inputs(device, dtype, requires_grad, **kwargs)
samples_list = list(samples)
samples_list.extend(conj_samples)
samples = tuple(samples_list)
return samples
# Returns True if the test should be skipped and False otherwise
@ -1001,21 +1041,21 @@ def sample_inputs_addbmm(op_info, device, dtype, requires_grad, **kwargs):
return tuple(sample_inputs)
def sample_inputs_addcmul_addcdiv(op_info, device, dtype, requires_grad, **kwargs):
test_cases = [((S, S), (S, S), (S, S)),
((S, S), (S, 1), (1, S)),
((1,), (S, S, 1), (1, S)),
((), (), ()),
((S, S), (), ()),
((), (S, S, 1), (1, S)),
test_cases = [(((S, S), (S, S), (S, S)), False),
(((S, S), (S, 1), (1, S)), False),
(((1,), (S, S, 1), (1, S)), True),
(((), (), ()), False),
(((S, S), (), ()), True),
(((), (S, S, 1), (1, S)), True)
]
sample_inputs = []
for input_args in test_cases:
for input_args, broadcasts_input in test_cases:
args = tuple(make_tensor(arg, device, dtype, requires_grad=requires_grad) if isinstance(arg, tuple) else arg
for arg in input_args)
sample_inputs.append(SampleInput(args[0], args=args[1:]))
sample_inputs.append(SampleInput(args[0], args=args[1:], broadcasts_input=broadcasts_input))
sample_inputs.append(SampleInput(args[0], args=args[1:], kwargs=dict(value=3.14)))
sample_inputs.append(SampleInput(args[0], args=args[1:], kwargs=dict(value=3.14), broadcasts_input=broadcasts_input))
return tuple(sample_inputs)
@ -1057,7 +1097,8 @@ def sample_inputs_addr(op_info, device, dtype, requires_grad, **kwargs):
make_tensor((), device, dtype, low=None, high=None, requires_grad=requires_grad),
args=(
make_tensor((S, ), device, dtype, low=None, high=None, requires_grad=requires_grad),
make_tensor((M, ), device, dtype, low=None, high=None, requires_grad=requires_grad)))
make_tensor((M, ), device, dtype, low=None, high=None, requires_grad=requires_grad)),
broadcasts_input=True)
if dtype.is_complex:
alpha, beta = 0.1 + 0.3j, 0.4 + 0.6j
@ -1078,7 +1119,8 @@ def sample_inputs_addr(op_info, device, dtype, requires_grad, **kwargs):
args=(
make_tensor((S, ), device, dtype, low=None, high=None, requires_grad=requires_grad),
make_tensor((M, ), device, dtype, low=None, high=None, requires_grad=requires_grad)),
kwargs=dict(beta=beta, alpha=alpha))
kwargs=dict(beta=beta, alpha=alpha),
broadcasts_input=True)
return (input1, input2, input3, input4)
@ -2196,6 +2238,32 @@ def np_unary_ufunc_integer_promotion_wrapper(fn):
return wrapped_fn
def sample_inputs_spectral_ops(self, device, dtype, requires_grad=False, **kwargs):
nd_tensor = make_tensor((S, S + 1, S + 2), device, dtype, low=None, high=None,
requires_grad=requires_grad)
tensor = make_tensor((31,), device, dtype, low=None, high=None,
requires_grad=requires_grad)
if self.ndimensional:
return [
SampleInput(nd_tensor, kwargs=dict(s=(3, 10), dim=(1, 2), norm='ortho')),
SampleInput(nd_tensor, kwargs=dict(norm='ortho')),
SampleInput(nd_tensor, kwargs=dict(s=(8,))),
SampleInput(tensor),
*(SampleInput(nd_tensor, kwargs=dict(dim=dim))
for dim in [-1, -2, -3, (0, -1)]),
]
else:
return [
SampleInput(nd_tensor, kwargs=dict(n=10, dim=1, norm='ortho')),
SampleInput(nd_tensor, kwargs=dict(norm='ortho')),
SampleInput(nd_tensor, kwargs=dict(n=7)),
SampleInput(tensor),
*(SampleInput(nd_tensor, kwargs=dict(dim=dim))
for dim in [-1, -2, -3]),
]
# Metadata class for Fast Fourier Transforms in torch.fft.
class SpectralFuncInfo(OpInfo):
@ -2207,6 +2275,7 @@ class SpectralFuncInfo(OpInfo):
ref=None, # Reference implementation (probably in np.fft namespace)
dtypes=floating_and_complex_types(),
ndimensional: bool, # Whether dim argument can be a tuple
sample_inputs_func=sample_inputs_spectral_ops,
decorators=None,
**kwargs):
decorators = list(decorators) if decorators is not None else []
@ -2220,39 +2289,12 @@ class SpectralFuncInfo(OpInfo):
super().__init__(name=name,
dtypes=dtypes,
decorators=decorators,
sample_inputs_func=sample_inputs_func,
**kwargs)
self.ref = ref if ref is not None else _getattr_qual(np, name)
self.ndimensional = ndimensional
def sample_inputs(self, device, dtype, requires_grad=False, **kwargs):
nd_tensor = make_tensor((S, S + 1, S + 2), device, dtype, low=None, high=None,
requires_grad=requires_grad)
tensor = make_tensor((31,), device, dtype, low=None, high=None,
requires_grad=requires_grad)
if self.ndimensional:
return [
SampleInput(nd_tensor, kwargs=dict(s=(3, 10), dim=(1, 2), norm='ortho')),
SampleInput(nd_tensor, kwargs=dict(norm='ortho')),
SampleInput(nd_tensor, kwargs=dict(s=(8,))),
SampleInput(tensor),
*(SampleInput(nd_tensor, kwargs=dict(dim=dim))
for dim in [-1, -2, -3, (0, -1)]),
]
else:
return [
SampleInput(nd_tensor, kwargs=dict(n=10, dim=1, norm='ortho')),
SampleInput(nd_tensor, kwargs=dict(norm='ortho')),
SampleInput(nd_tensor, kwargs=dict(n=7)),
SampleInput(tensor),
*(SampleInput(nd_tensor, kwargs=dict(dim=dim))
for dim in [-1, -2, -3]),
]
class ShapeFuncInfo(OpInfo):
"""Early version of a specialized OpInfo for Shape manipulating operations like tile and roll"""
def __init__(self,
@ -2737,13 +2779,13 @@ def sample_inputs_pow(op_info, device, dtype, requires_grad, **kwargs):
test_cases = (
((2, 2), 0, 5, 1e-3, requires_grad, (2, 2), 0, 1, 0.1, requires_grad, False),
((2, 2), 0, 5, 1e-3, requires_grad, (1,), 0, 1, 0.1, requires_grad, False),
((), 1e-3, 1e-3 + 1, 0, True, (), 0.1, 1.1, 0, False, False),
((), 1e-3, 1e-3 + 1, 0, requires_grad, (), 0.1, 1.1, 0, False, False),
((2, 2), 0, 5, 1e-3, requires_grad, (), 0.1, 1.1, 1, False, False),
)
tests_require_resizing = (
((1,), 0, 5, 1e-3, requires_grad, (2, 2), 0, 1, 0.1, requires_grad, True),
((2, 1, 2), 0, 5, 1e-3, requires_grad, (1, 2, 1), 0, 1, 0.1, requires_grad, True),
((), 1e-3, 1e-3 + 1, 0, True, (1, S, 1), 0, 1, 0.1, requires_grad, True),
((1,), 0, 5, 1e-3, requires_grad, (2, 2), 0, 1, 0.1, requires_grad, requires_grad),
((2, 1, 2), 0, 5, 1e-3, requires_grad, (1, 2, 1), 0, 1, 0.1, requires_grad, requires_grad),
((), 1e-3, 1e-3 + 1, 0, requires_grad, (1, S, 1), 0, 1, 0.1, requires_grad, requires_grad),
)
cases = test_cases + tests_require_resizing
samples = list(SampleInput(make_tensor(shape_b, low=low_b, high=high_b,
@ -2757,7 +2799,7 @@ def sample_inputs_pow(op_info, device, dtype, requires_grad, **kwargs):
high_e, additive_e, e_grad, broadcasts_input in cases)
tensor_scalar_inputs = (
((2, 2), 0, 5, 1e-3, requires_grad, (3.14,)),
((), 1e-3, 1e-3 + 1, 0, True, (3.14,))
((), 1e-3, 1e-3 + 1, 0, requires_grad, (3.14,))
)
more_samples = list(SampleInput(make_tensor(shape, dtype=dtype, device=device,
high=high, low=low,
@ -2768,8 +2810,8 @@ def sample_inputs_pow(op_info, device, dtype, requires_grad, **kwargs):
elif dtype in [torch.complex64, torch.complex128]:
args_tuple = (
((2, 2), 0, 5, requires_grad, (3.14,)),
((), 0, 1, True, (3.14,)),
((), 0, 1, True, (3.14j,))
((), 0, 1, requires_grad, (3.14,)),
((), 0, 1, requires_grad, (3.14j,))
)
samples = list(SampleInput(make_tensor(shape, dtype=dtype, device=device,
high=high, low=low,
@ -4298,7 +4340,7 @@ op_db: List[OpInfo] = [
SkipInfo('TestGradients', 'test_method_grad',
device_type='cuda', dtypes=[torch.cdouble], active_if=IS_WINDOWS),
SkipInfo('TestGradients', 'test_forward_mode_AD',
dtypes=[torch.cdouble], active_if=IS_WINDOWS),
dtypes=[torch.cdouble]),
)),
OpInfo('add',
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
@ -4670,28 +4712,30 @@ op_db: List[OpInfo] = [
supports_forward_ad=True,
),
UnaryUfuncInfo('conj',
ref=np.conj,
dtypes=all_types_and_complex_and(torch.bool,
torch.bfloat16, torch.half),
supports_forward_ad=True,
supports_out=False),
UnaryUfuncInfo('conj_physical',
ref=np.conj,
dtypes=all_types_and_complex_and(torch.bool,
torch.bfloat16, torch.half),
supports_forward_ad=True,
skips=(
# File "test_unary_ufuncs.py", line 289, in test_reference_numerics
# if not torch.can_cast(numpy_to_torch_dtype_dict[expected.dtype.type], dtype):
# KeyError: <class 'numpy.intc'>
# Following error in Windows CI
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics_normal',
dtypes=[torch.int],
active_if=IS_WINDOWS),
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics_hard',
dtypes=[torch.int],
active_if=IS_WINDOWS),
# TODO fix the formula for complex forward AD
SkipInfo('TestGradients', 'test_forward_mode_AD'),
SkipInfo('TestCommon', 'test_variant_consistency_jit', dtypes=(torch.float32, )),
)),
OpInfo('resolve_conj',
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
sample_inputs_func=sample_inputs_view_as_real,
supports_forward_ad=True,
supports_out=False,
),
OpInfo('view_as_real',
dtypes=complex_types(),
supports_forward_ad=True,
sample_inputs_func=sample_inputs_view_as_real,
test_conjugated_samples=False,
),
OpInfo('view_as_complex',
dtypes=floating_types_and(torch.half),
@ -5120,7 +5164,9 @@ op_db: List[OpInfo] = [
ref=np.imag,
dtypes=complex_types(),
supports_out=False,
supports_autograd=False,
supports_forward_ad=True,
# TODO(@anjali411): Test this once neg bit is added.
test_conjugated_samples=False,
skips=(
# Skip since real and imag don't have out variants.
SkipInfo('TestUnaryUfuncs', 'test_out_arg_all_dtypes'),
@ -5251,8 +5297,6 @@ op_db: List[OpInfo] = [
supports_autograd=False,
decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack],
skips=(
# skip because `linalg_lstsq` is not differentiable
SkipInfo('TestGradients', 'test_fn_grad'),
SkipInfo('TestCommon', 'test_variant_consistency_jit'),
)),
OpInfo('linalg.matrix_power',
@ -5471,7 +5515,8 @@ op_db: List[OpInfo] = [
# "RuntimeError: CUDA error: CUBLAS_STATUS_NOT_SUPPORTED when
# calling cublasGemmStridedBatchedExFix."
SkipInfo('TestOpInfo', 'test_supported_backward',
device_type='cuda', dtypes=(torch.bfloat16,)),)),
device_type='cuda', dtypes=(torch.bfloat16,)),
SkipInfo('TestCommon', 'test_conj_view', device_type='cpu'),),),
OpInfo('max',
op=torch.max,
variant_test_name='binary',
@ -5699,7 +5744,9 @@ op_db: List[OpInfo] = [
assert_autodiffed=True),
OpInfo('float_power',
dtypes=all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool),
sample_inputs_func=sample_inputs_pow),
sample_inputs_func=sample_inputs_pow,
skips=(
SkipInfo('TestCommon', 'test_conj_view', device_type='cuda'),),),
OpInfo('prod',
dtypes=all_types_and_complex_and(torch.bool),
dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
@ -5739,7 +5786,6 @@ op_db: List[OpInfo] = [
ref=np.real,
dtypes=complex_types(),
supports_out=False,
supports_autograd=False,
skips=(
# Skip since real and imag don't have out variants.
SkipInfo('TestUnaryUfuncs', 'test_out_arg_all_dtypes'),
@ -7068,23 +7114,26 @@ def create_input(call_args, requires_grad=True, non_contiguous=False, call_kwarg
def maybe_non_contig(tensor):
return tensor if not non_contiguous else make_non_contiguous(tensor)
def conjugate(tensor):
return tensor.conj()
if isinstance(arg, torch.Size) or isinstance(arg, dont_convert):
return arg
elif isinstance(arg, tuple) and len(arg) == 0:
var = torch.randn((), dtype=dtype, device=device)
var = conjugate(torch.randn((), dtype=dtype, device=device))
var.requires_grad = requires_grad
return var
elif isinstance(arg, tuple) and not isinstance(arg[0], torch.Tensor):
return Variable(maybe_non_contig(torch.randn(*arg, dtype=dtype, device=device)), requires_grad=requires_grad)
return conjugate(maybe_non_contig(torch.randn(*arg, dtype=dtype, device=device))).requires_grad_(requires_grad)
# double check casting
elif isinstance(arg, non_differentiable):
if isinstance(arg.tensor, torch.Tensor):
if arg.tensor.dtype == torch.float:
return maybe_non_contig(arg.tensor.to(dtype=torch.double, device=device))
if arg.tensor.dtype == torch.cfloat:
return maybe_non_contig(arg.tensor.to(dtype=torch.cdouble, device=device))
return maybe_non_contig(arg.tensor.to(device=device))
return maybe_non_contig(arg.tensor.to(device=device))
return conjugate(maybe_non_contig(arg.tensor.to(dtype=torch.cdouble, device=device)))
return conjugate(maybe_non_contig(arg.tensor.to(device=device)))
return conjugate(maybe_non_contig(arg.tensor.to(device=device)))
elif isinstance(arg, torch.Tensor):
if arg.dtype == torch.float:
arg = arg.double()
@ -7094,7 +7143,7 @@ def create_input(call_args, requires_grad=True, non_contiguous=False, call_kwarg
raise RuntimeError("User provided tensor is real for a test that runs with complex dtype, ",
"which is not supported for now")
# NOTE: We do clone() after detach() here because we need to be able to change size/storage of v afterwards
v = maybe_non_contig(arg).detach().to(device=device).clone()
v = conjugate(maybe_non_contig(arg)).detach().to(device=device).clone()
v.requires_grad = requires_grad and (v.is_floating_point() or v.is_complex())
return v
elif callable(arg):