diff --git a/aten/src/ATen/BatchingRegistrations.cpp b/aten/src/ATen/BatchingRegistrations.cpp index d6d110b55ad..82c59560421 100644 --- a/aten/src/ATen/BatchingRegistrations.cpp +++ b/aten/src/ATen/BatchingRegistrations.cpp @@ -1031,7 +1031,6 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) { m.impl("sum.dim_IntList", sum_batching_rule); m.impl("is_complex", native::is_complex); - m.impl("conj", native::conj); // inplace operations m.impl("fill_.Scalar", fill_inplace_scalar_batching_rule); @@ -1085,7 +1084,7 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) { UNARY_POINTWISE(ceil); UNARY_POINTWISE(cos); UNARY_POINTWISE(cosh); - UNARY_POINTWISE(_conj); + UNARY_POINTWISE(conj_physical); UNARY_POINTWISE(digamma); UNARY_POINTWISE(exp); UNARY_POINTWISE(expm1); @@ -1181,6 +1180,10 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) { TRIVIAL_OP(imag) TRIVIAL_OP(real); TRIVIAL_OP(view_as_real); + TRIVIAL_OP(_view_as_real_physical); + TRIVIAL_OP(conj); + TRIVIAL_OP(_conj); + TRIVIAL_OP(resolve_conj); m.impl("view_as_complex", view_as_complex_batching_rule); #undef TRIVIAL diff --git a/aten/src/ATen/ConjugateFallback.cpp b/aten/src/ATen/ConjugateFallback.cpp new file mode 100644 index 00000000000..e6d96fb133e --- /dev/null +++ b/aten/src/ATen/ConjugateFallback.cpp @@ -0,0 +1,152 @@ +#include +#include +#include +#include +#include +#include + +namespace at { + +void conjugateFallback(const c10::OperatorHandle& op, DispatchKeySet dispatch_keys, torch::jit::Stack* stack) { + // Situations to handle: + // 1. Out-of-place operation. Easy: materialize all inputs and + // call it a day. + // 2. Inplace operation. Desugar x.add_(2) into x.conj_().add_(2).conj_(). + // Materialize other inputs as in (1). + // 3. out= operation. Desugar add(x, 2, out=y) into y.copy_(add(x, 2)) + // Materialize other inputs as in (1). + // + // It is important to be able to tell if we READ from an argument and if we + // WRITE from an argument. Conservative approach is to assume that we always + // READ from an argument, but in out-of-place operations you can skip + // conjugating inputs on entry that never get used. In current schema we + // can't easily tell if inplace situation has happened, so don't do it. + + const auto& arguments = op.schema().arguments(); + const auto num_arguments = arguments.size(); + const auto stack_start = stack->size() - num_arguments; + + c10::optional is_write; + for (int64_t i = 0; i < num_arguments; ++i) { + const auto& alias_info = arguments[i].alias_info(); + // Three possible states: + // 1. alias_info has no value --> out-of-place operation + // 2. alias_info does have a value, alias_info->is_write=True --> in-place or out= operation + // 3. alias_info does have a value, alias_info->is_write=False --> view operation + if (alias_info.has_value()) { + if (is_write.has_value()) { + TORCH_CHECK(*is_write == alias_info->isWrite(), + "Unsupported operator for conjugate fallback: ", op.schema().name(), + "Conjugate fallback doesn't work for operators with a mix " + "mutable and non-mutable inputs that alias with outputs, " + "this must be implemented manually. " + "If you got this error on a core op, please report a bug to PyTorch."); + } else { + is_write = alias_info->isWrite(); + } + } + } + + if (is_write.has_value() && !*is_write) { + // We assume that view operators automatically handle conjugation + // correctly by propagating the Conjugate dispatch key in key_set. + // This is not necessarily always right, so you should test these cases. + op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, DispatchKey::Conjugate), stack); + return; + } + + // Mutable inputs to be tracked separately + std::vector mutable_inputs; + + for (int64_t i = 0; i < num_arguments; ++i) { + auto& ivalue = (*stack)[stack_start + i]; + if (!(ivalue.isTensor() || ivalue.isTensorList())) { + continue; + } + const auto& argument = arguments[i]; + bool mut_arg = false; + if (argument.alias_info()) { + // View operations were already filtered above, so only in-place/out= operations should get here. + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(argument.alias_info()->isWrite()); + mut_arg = true; + } + if (ivalue.isTensor()) { + auto* impl = ivalue.unsafeToTensorImpl(); + if (!impl->is_conj()) { + continue; + } + + auto tensor = std::move(ivalue).toTensor(); + TORCH_CHECK_NOT_IMPLEMENTED(!tensor.is_meta(), "Conjugate Fallback does not support meta tensors."); + if (mut_arg) { + // TODO: This is a waste if the argument is write only + tensor._set_conj(false); + at::conj_physical_(tensor); + mutable_inputs.emplace_back(tensor); + } else { + tensor = at::resolve_conj(tensor); + } + (*stack)[stack_start + i] = std::move(tensor); + } else if (ivalue.isTensorList()) { + auto tensors = std::move(ivalue).toTensorList(); + if (mut_arg) { + for(const auto j : c10::irange(tensors.size())) { + Tensor t = tensors[j]; + t._set_conj(false); + at::conj_physical_(t); + mutable_inputs.emplace_back(t); + } + } else { + for(const auto j : c10::irange(tensors.size())) { + tensors[j] = at::resolve_conj(tensors[j]); + } + } + (*stack)[stack_start + i] = std::move(tensors); + } + } + + + op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, DispatchKey::Conjugate), stack); + + for (auto& mutable_input : mutable_inputs) { + at::conj_physical_(mutable_input); + mutable_input._set_conj(true); + } +} + +TORCH_LIBRARY_IMPL(_, Conjugate, m) { + m.fallback(torch::CppFunction::makeFromBoxedFunction<&conjugateFallback>()); +} + +TORCH_LIBRARY_IMPL(aten, Conjugate, m) { + m.impl("requires_grad_", torch::CppFunction::makeFallthrough()); + m.impl("set_.source_Storage_storage_offset", torch::CppFunction::makeFallthrough()); + m.impl("set_.source_Tensor", torch::CppFunction::makeFallthrough()); + m.impl("set_", torch::CppFunction::makeFallthrough()); + m.impl("copy_", torch::CppFunction::makeFallthrough()); + m.impl("clone", torch::CppFunction::makeFallthrough()); + m.impl("conj", torch::CppFunction::makeFallthrough()); + m.impl("_conj", torch::CppFunction::makeFallthrough()); + m.impl("_conj_physical", torch::CppFunction::makeFallthrough()); + m.impl("conj_physical", torch::CppFunction::makeFallthrough()); + m.impl("conj_physical_", torch::CppFunction::makeFallthrough()); + m.impl("resolve_conj", torch::CppFunction::makeFallthrough()); + m.impl("empty_like", torch::CppFunction::makeFallthrough()); + m.impl("empty.memory_format", torch::CppFunction::makeFallthrough()); + m.impl("empty.out", torch::CppFunction::makeFallthrough()); + m.impl("empty_strided", torch::CppFunction::makeFallthrough()); + m.impl("full_like", torch::CppFunction::makeFallthrough()); + m.impl("stride.int", torch::CppFunction::makeFallthrough()); + m.impl("stride.Dimname", torch::CppFunction::makeFallthrough()); + m.impl("size.int", torch::CppFunction::makeFallthrough()); + m.impl("size.Dimname", torch::CppFunction::makeFallthrough()); + m.impl("is_complex", torch::CppFunction::makeFallthrough()); + m.impl("_view_as_real_physical", torch::CppFunction::makeFallthrough()); + m.impl("view_as_real", torch::CppFunction::makeFallthrough()); + m.impl("imag", torch::CppFunction::makeFallthrough()); + m.impl("real", torch::CppFunction::makeFallthrough()); + m.impl("view", torch::CppFunction::makeFallthrough()); + m.impl("reshape", torch::CppFunction::makeFallthrough()); +} + +} // namespace at diff --git a/aten/src/ATen/core/NamedRegistrations.cpp b/aten/src/ATen/core/NamedRegistrations.cpp index 481a6eda588..ff34fa286e2 100644 --- a/aten/src/ATen/core/NamedRegistrations.cpp +++ b/aten/src/ATen/core/NamedRegistrations.cpp @@ -117,7 +117,6 @@ TORCH_LIBRARY_IMPL(aten, Named, m) { m.impl("clamp_min_.Tensor", CppFunction::makeFallthrough()); m.impl("clone", CppFunction::makeFallthrough()); m.impl("conj", CppFunction::makeFallthrough()); - m.impl("conj.out", CppFunction::makeFallthrough()); m.impl("contiguous", CppFunction::makeFallthrough()); m.impl("copy_", CppFunction::makeFallthrough()); m.impl("cos", CppFunction::makeFallthrough()); diff --git a/aten/src/ATen/core/aten_interned_strings.h b/aten/src/ATen/core/aten_interned_strings.h index 2209710183d..1a7486a019a 100644 --- a/aten/src/ATen/core/aten_interned_strings.h +++ b/aten/src/ATen/core/aten_interned_strings.h @@ -238,6 +238,9 @@ _(aten, coalesce) \ _(aten, combinations) \ _(aten, _conj) \ _(aten, conj) \ +_(aten, conj_physical) \ +_(aten, conj_physical_) \ +_(aten, resolve_conj) \ _(aten, complex) \ _(aten, copysign) \ _(aten, polar) \ @@ -764,6 +767,7 @@ _(aten, zeros_like) \ _(aten, real) \ _(aten, imag) \ _(aten, view_as_real) \ +_(aten, _view_as_real_physical) \ _(aten, view_as_complex) \ /* nothing */ diff --git a/aten/src/ATen/native/ComplexHelper.h b/aten/src/ATen/native/ComplexHelper.h index 7fed2dd4771..518862b16e3 100644 --- a/aten/src/ATen/native/ComplexHelper.h +++ b/aten/src/ATen/native/ComplexHelper.h @@ -2,6 +2,9 @@ #include +// WARNING: this header contains non-inline functions and should be only +// included from ONE cpp file + namespace at { namespace native { // View tensor with new dtype, storage offset, sizes and strides @@ -9,8 +12,9 @@ inline Tensor view_tensor( const Tensor &tensor, ScalarType dtype, int64_t offset, IntArrayRef sizes, IntArrayRef strides) { Storage storage = tensor.storage(); + auto key_set = tensor.key_set().remove(DispatchKey::Conjugate); auto new_tensor = detail::make_tensor( - c10::TensorImpl::VIEW, std::move(storage), tensor.key_set(), scalarTypeToTypeMeta(dtype)); + c10::TensorImpl::VIEW, std::move(storage), key_set, scalarTypeToTypeMeta(dtype)); auto * impl = new_tensor.unsafeGetTensorImpl(); impl->set_storage_offset(offset); impl->set_sizes_and_strides(sizes, strides); @@ -30,7 +34,12 @@ inline DimVector computeStrideForViewAsReal(IntArrayRef oldstride) { // with corresponding real dtype containing the complex values // in the last two dimensions Tensor view_as_real(const Tensor& self) { - TORCH_CHECK(self.is_complex(), "view_as_real is only supported for complex tensors"); + TORCH_CHECK(!self.is_conj(), "view_as_real doesn't work on unresolved conjugated tensors. To resolve the conjugate tensor so you can view it as real, use self.resolve_conj(); however, be warned that the resulting tensor will NOT alias the original."); + return native::_view_as_real_physical(self); +} + +Tensor _view_as_real_physical(const Tensor& self) { + TORCH_CHECK(self.is_complex(), "view_as_real_physical is only supported for complex tensors"); auto old_sizes = self.sizes(); DimVector new_sizes(old_sizes.size() + 1); std::copy(old_sizes.begin(), old_sizes.end(), new_sizes.begin()); @@ -39,7 +48,8 @@ Tensor view_as_real(const Tensor& self) { auto new_strides = computeStrideForViewAsReal(self.strides()); auto new_storage_offset = 2 * self.storage_offset(); const auto float_type = c10::toValueType(self.scalar_type()); - return view_tensor(self, float_type, new_storage_offset, new_sizes, new_strides); + auto real_tensor = view_tensor(self, float_type, new_storage_offset, new_sizes, new_strides); + return real_tensor; } inline DimVector computeStrideForViewAsComplex(IntArrayRef oldstride) { diff --git a/aten/src/ATen/native/SpectralOps.cpp b/aten/src/ATen/native/SpectralOps.cpp index 5d270b8ca92..6294e37ca66 100644 --- a/aten/src/ATen/native/SpectralOps.cpp +++ b/aten/src/ATen/native/SpectralOps.cpp @@ -150,7 +150,7 @@ Tensor fft_r2c(c10::string_view function_name, if (!forward) { // FIXME: _fft_r2c doesn't support native r2c IFFT - return out.defined() ? at::conj_out(out, ret) : at::conj(ret); + return out.defined() ? at::conj_physical_out(out, ret) : at::conj(ret); } else { return ret; } diff --git a/aten/src/ATen/native/TensorFactories.cpp b/aten/src/ATen/native/TensorFactories.cpp index f8186ed30d3..cff7074799b 100644 --- a/aten/src/ATen/native/TensorFactories.cpp +++ b/aten/src/ATen/native/TensorFactories.cpp @@ -349,6 +349,8 @@ Tensor empty_like( namedinference::propagate_names(result, self.names()); } + // never propagate Conjugate key + result._set_conj(false); return result; } diff --git a/aten/src/ATen/native/TypeProperties.cpp b/aten/src/ATen/native/TypeProperties.cpp index db7a244856f..519afdad928 100644 --- a/aten/src/ATen/native/TypeProperties.cpp +++ b/aten/src/ATen/native/TypeProperties.cpp @@ -30,6 +30,10 @@ bool is_signed(const Tensor &self) { return self.is_signed(); } +bool is_conj(const Tensor& self) { + return self.is_conj(); +} + bool is_sparse(const Tensor& self) { return self.is_sparse(); } diff --git a/aten/src/ATen/native/UnaryOps.cpp b/aten/src/ATen/native/UnaryOps.cpp index 07c0fc26383..624881a01db 100644 --- a/aten/src/ATen/native/UnaryOps.cpp +++ b/aten/src/ATen/native/UnaryOps.cpp @@ -28,7 +28,8 @@ namespace at { namespace meta { // Unary float operations always produce floating point -// outputs, even if their inputs are integer +// outputs for floating point and integral types +// For complex inputs, the output type should be the same as input type. #define CREATE_UNARY_FLOAT_META_FUNC(func) \ TORCH_META_FUNC(func) (const Tensor& self) { \ build_unary_float_op(maybe_get_output(), self); \ @@ -363,7 +364,8 @@ Tensor angle(const Tensor& self) { Tensor real(const Tensor& self) { if (self.is_complex()) { - auto real_tensor = at::view_as_real(self); + // real is never affected by conjugate bit, safe to use physical version + auto real_tensor = at::_view_as_real_physical(self); return at::select(real_tensor, real_tensor.dim() - 1, 0); } else { TORCH_CHECK(false, "real is not implemented for tensors with non-complex dtypes."); @@ -379,17 +381,44 @@ Tensor imag(const Tensor& self) { } } -Tensor& conj_out(const Tensor& self, Tensor& result) { - return unary_op_impl_out(result, self, conj_stub); +Tensor& conj_physical_out(const Tensor& self, Tensor& result) { + return unary_op_impl_out(result, self, conj_physical_stub); } -Tensor _conj(const Tensor& self) { return unary_op_impl(self, at::conj_out); } +Tensor _conj_physical(const Tensor& self) { + if (self.is_conj()) { + return self.conj().clone(); + } + return unary_op_impl(self, at::conj_physical_out); +} + +Tensor conj_physical(const Tensor& self) { + if (!self.is_complex()) return self; + return at::_conj_physical(self); +} + +Tensor& conj_physical_(Tensor& self) { + if (!self.is_complex()) return self; + return unary_op_impl_out(self, self, conj_physical_stub); +} + +Tensor resolve_conj(const Tensor& self) { + if (!self.is_conj()) { return self; } + // conjugation is handled in `copy_()` that clone ultimately calls into + return self.clone(self.suggest_memory_format()); +} + +Tensor _conj(const Tensor& self) { + Tensor self_ = self.alias(); + self_._set_conj(!self.is_conj()); + namedinference::propagate_names(self_, self); + return self_; +} Tensor conj(const Tensor& self) { - if (!self.is_complex()) { - return self; - } - return at::_conj(self); + // This might look like an infinite recursion but it's not. + // This actually calls into `conj()` defined in the Tensor class. + return self.conj(); } // special_exp2, alias for exp2 @@ -689,7 +718,7 @@ DEFINE_DISPATCH(abs_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-va DEFINE_DISPATCH(angle_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) DEFINE_DISPATCH(real_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) DEFINE_DISPATCH(imag_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) -DEFINE_DISPATCH(conj_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) +DEFINE_DISPATCH(conj_physical_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) DEFINE_DISPATCH(acos_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) DEFINE_DISPATCH(acosh_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) DEFINE_DISPATCH(asinh_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) diff --git a/aten/src/ATen/native/UnaryOps.h b/aten/src/ATen/native/UnaryOps.h index 9b793064326..27654bc9f81 100644 --- a/aten/src/ATen/native/UnaryOps.h +++ b/aten/src/ATen/native/UnaryOps.h @@ -14,7 +14,7 @@ DECLARE_DISPATCH(unary_fn, abs_stub); DECLARE_DISPATCH(unary_fn, angle_stub); DECLARE_DISPATCH(unary_fn, real_stub); DECLARE_DISPATCH(unary_fn, imag_stub); -DECLARE_DISPATCH(unary_fn, conj_stub); +DECLARE_DISPATCH(unary_fn, conj_physical_stub); DECLARE_DISPATCH(unary_fn, acos_stub); DECLARE_DISPATCH(unary_fn, acosh_stub); DECLARE_DISPATCH(unary_fn, asinh_stub); diff --git a/aten/src/ATen/native/cpu/CopyKernel.cpp b/aten/src/ATen/native/cpu/CopyKernel.cpp index f0280c58474..4bf27ef85a2 100644 --- a/aten/src/ATen/native/cpu/CopyKernel.cpp +++ b/aten/src/ATen/native/cpu/CopyKernel.cpp @@ -5,6 +5,7 @@ #include #include #include +#include namespace at { namespace native { @@ -13,6 +14,15 @@ namespace { static void copy_kernel(TensorIterator& iter, bool non_blocking) { ScalarType dtype = iter.dtype(0); if (dtype == iter.dtype(1)) { + // TODO: as the majority of these operations can be done treating + // their datatypes as opaque bit patterns, we don't actually need + // separate instantiations per dtype; we only need a separate + // instantiation per dtype size. This would probably save us a + // little bit of code size here + // TODO: not sure if optimizer is able to compile two levels of + // conditionals into a single jump table. We should have a + // single jump table here; might be worth just writing out the + // dispatch statement by hand instead of using AT_DISPATCH if (dtype == ScalarType::Half) { cpu_kernel(iter, [=](at::Half a) -> at::Half { return a; }); } else if (dtype == ScalarType::ComplexHalf) { @@ -25,12 +35,21 @@ static void copy_kernel(TensorIterator& iter, bool non_blocking) { [=](Vectorized a) -> Vectorized { return a; }); }); } else if (isComplexType(dtype)) { - AT_DISPATCH_COMPLEX_TYPES(dtype, "copy_kernel", [&] { - cpu_kernel_vec( - iter, - [=](scalar_t a) -> scalar_t { return a; }, - [=](Vectorized a) -> Vectorized { return a; }); - }); + if (iter.tensor(0).is_conj() == iter.tensor(1).is_conj()) { + AT_DISPATCH_COMPLEX_TYPES(dtype, "copy_kernel", [&] { + cpu_kernel_vec( + iter, + [=](scalar_t a) -> scalar_t { return a; }, + [=](Vectorized a) -> Vectorized { return a; }); + }); + } else { + AT_DISPATCH_COMPLEX_TYPES(dtype, "conj_kernel", [&] { + cpu_kernel_vec( + iter, + [=](scalar_t a) -> scalar_t { return conj_impl(a); }, + [=](Vectorized a) -> Vectorized { return a.conj(); }); + }); + } } else { AT_DISPATCH_ALL_TYPES_AND2( ScalarType::Bool, ScalarType::BFloat16,dtype, "copy_kernel", [&] { @@ -62,6 +81,9 @@ static void copy_kernel(TensorIterator& iter, bool non_blocking) { return c10::static_cast_with_inter_type::apply(src); }); }); }); + if (iter.tensor(0).is_conj() != iter.tensor(1).is_conj()) { + iter.tensor(0).conj_physical_(); + } } } diff --git a/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp b/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp index 16ffd5b8c06..19016021904 100644 --- a/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp @@ -193,6 +193,7 @@ static void imag_kernel(TensorIteratorBase& iter) { }); } +// NB: Ignores the negative bit on tensors static void conj_kernel(TensorIteratorBase& iter) { AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( kBool, kBFloat16, kHalf, iter.common_dtype(), "conj_cpu", [&]() { @@ -716,7 +717,7 @@ REGISTER_DISPATCH(real_stub, &CPU_CAPABILITY::real_kernel); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_DISPATCH(imag_stub, &CPU_CAPABILITY::imag_kernel); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) -REGISTER_DISPATCH(conj_stub, &CPU_CAPABILITY::conj_kernel); +REGISTER_DISPATCH(conj_physical_stub, &CPU_CAPABILITY::conj_kernel); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_DISPATCH(exp2_stub, &CPU_CAPABILITY::exp2_kernel); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) diff --git a/aten/src/ATen/native/cuda/Copy.cu b/aten/src/ATen/native/cuda/Copy.cu index 49564ff5a95..ffc4e3edc98 100644 --- a/aten/src/ATen/native/cuda/Copy.cu +++ b/aten/src/ATen/native/cuda/Copy.cu @@ -25,7 +25,8 @@ void copy_device_to_device(TensorIterator& iter, bool non_blocking) { // We can memcpy the memory if both tensors have the same type AND both // tensors are contiguous after dimension coalescing and reordering. bool same_type = iter.dtype(0) == iter.dtype(1); - bool memcpy_eligible = same_type && iter.is_contiguous(); + bool same_conj = iter.tensor(0).is_conj() == iter.tensor(1).is_conj(); + bool memcpy_eligible = same_type && same_conj && iter.is_contiguous(); Device dst_device = iter.device(0); Device src_device = iter.device(1); @@ -71,10 +72,17 @@ void copy_device_to_device(TensorIterator& iter, bool non_blocking) { gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) { return x; }); }); } else { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( - kHalf, kBool, kBFloat16, dtype, "copy_", [&] { - gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) { return x; }); - }); + if (!same_conj && same_type) { + AT_DISPATCH_COMPLEX_TYPES( + dtype, "copy_conj_", [&] { + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) { return std::conj(x); }); + }); + } else { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( + kHalf, kBool, kBFloat16, dtype, "copy_", [&] { + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) { return x; }); + }); + } } } @@ -152,6 +160,10 @@ static void copy_kernel_cuda(TensorIterator& iter, bool non_blocking) { src_contig = iter.tensor(1).expand_as(dst).contiguous(); } + // propagate the correct conjugate bit + dst_contig._set_conj(dst.is_conj()); + src_contig._set_conj(iter.tensor(1).is_conj()); + // perform a same-dtype copy on contiguous tensors TORCH_INTERNAL_ASSERT(dst_contig.sizes().equals(src_contig.sizes())); TORCH_INTERNAL_ASSERT(dst_contig.scalar_type() == src_contig.scalar_type()); @@ -201,6 +213,10 @@ static void copy_kernel_cuda(TensorIterator& iter, bool non_blocking) { AT_CUDA_CHECK(cudaStreamSynchronize(stream)); #endif } + + if (iter.tensor(0).is_conj() != iter.tensor(1).is_conj()) { + iter.tensor(0).conj_physical_(); + } } REGISTER_DISPATCH(copy_stub, ©_kernel_cuda); diff --git a/aten/src/ATen/native/cuda/UnaryComplexKernels.cu b/aten/src/ATen/native/cuda/UnaryComplexKernels.cu index 7fb9acd37fc..d054eb9091b 100644 --- a/aten/src/ATen/native/cuda/UnaryComplexKernels.cu +++ b/aten/src/ATen/native/cuda/UnaryComplexKernels.cu @@ -80,6 +80,7 @@ __host__ __device__ static inline c10::complex conj_wrapper(c10::complex v return std::conj(v); } +// NB: Ignores the negative bit on tensors void conj_kernel_cuda(TensorIteratorBase& iter) { AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( kBool, kBFloat16, kHalf, iter.common_dtype(), "conj_cuda", [&]() { @@ -92,6 +93,6 @@ void conj_kernel_cuda(TensorIteratorBase& iter) { REGISTER_DISPATCH(angle_stub, &angle_kernel_cuda); REGISTER_DISPATCH(real_stub, &real_kernel_cuda); REGISTER_DISPATCH(imag_stub, &imag_kernel_cuda); -REGISTER_DISPATCH(conj_stub, &conj_kernel_cuda); +REGISTER_DISPATCH(conj_physical_stub, &conj_kernel_cuda); }} // namespace at::native diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 1d5af7a9971..e453fbc4f3b 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -271,6 +271,11 @@ dispatch: CPU, CUDA: view_as_real +- func: _view_as_real_physical(Tensor(a) self) -> Tensor(a) + variants: function + dispatch: + CPU, CUDA: _view_as_real_physical + - func: view_as_complex(Tensor(a) self) -> Tensor(a) variants: function dispatch: @@ -298,21 +303,36 @@ device_check: NoCheck # TensorIterator variants: function -- func: conj(Tensor(a) self) -> Tensor(a) - device_check: NoCheck # TensorIterator +- func: _conj(Tensor(a) self) -> Tensor(a) variants: function, method - -- func: conj.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) - device_check: NoCheck # TensorIterator - dispatch: - CPU, CUDA: conj_out - SparseCPU, SparseCUDA: conj_out_sparse - -- func: _conj(Tensor self) -> Tensor - variants: function dispatch: CompositeExplicitAutograd: _conj +- func: conj(Tensor(a) self) -> Tensor(a) + variants: function, method + manual_cpp_binding: True + +- func: _conj_physical(Tensor self) -> Tensor + variants: function, method + dispatch: + CompositeExplicitAutograd: _conj_physical + +- func: conj_physical(Tensor self) -> Tensor + variants: function, method + +- func: conj_physical.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CPU, CUDA: conj_physical_out + SparseCPU, SparseCUDA: conj_physical_out_sparse + +- func: conj_physical_(Tensor(a!) self) -> Tensor(a!) + variants: function, method + dispatch: + CompositeExplicitAutograd: conj_physical_ + +- func: resolve_conj(Tensor(a) self) -> Tensor(a) + variants: function, method + - func: acos(Tensor self) -> Tensor device_check: NoCheck # TensorIterator variants: function, method @@ -2249,6 +2269,11 @@ device_guard: False manual_cpp_binding: True +- func: is_conj(Tensor self) -> bool + variants: function, method + device_guard: False + manual_cpp_binding: True + - func: isreal(Tensor self) -> Tensor variants: function, method diff --git a/aten/src/ATen/native/sparse/SparseTensorMath.cpp b/aten/src/ATen/native/sparse/SparseTensorMath.cpp index 546d897d190..705d6d0bb3f 100644 --- a/aten/src/ATen/native/sparse/SparseTensorMath.cpp +++ b/aten/src/ATen/native/sparse/SparseTensorMath.cpp @@ -1645,15 +1645,15 @@ Tensor& bmm_out_sparse_cpu(const SparseTensor& self, const Tensor& mat2, Tensor& return result; } -Tensor conj_sparse(const Tensor& input) { - if (!input.is_complex()) { - return input; - } - Tensor result = at::native::empty_like(input); - return conj_out_sparse(input, result); -} +// Tensor conj_physical_sparse(const Tensor& input) { +// if (!input.is_complex()) { +// return input; +// } +// Tensor result = at::native::empty_like(input); +// return conj_physical_out_sparse(input, result); +// } -Tensor& conj_out_sparse(const Tensor& input, Tensor& result) { +Tensor& conj_physical_out_sparse(const Tensor& input, Tensor& result) { TORCH_INTERNAL_ASSERT(input.is_sparse()); if (input.numel() == 0) { return result; @@ -1665,7 +1665,7 @@ Tensor& conj_out_sparse(const Tensor& input, Tensor& result) { return result; } Tensor result_values = result._values(); - at::conj_out(result_values, input._values()); + at::conj_physical_out(result_values, input._values()); return result; } diff --git a/aten/src/ATen/templates/Functions.h b/aten/src/ATen/templates/Functions.h index 4cfc4758737..18e7c00a0bb 100644 --- a/aten/src/ATen/templates/Functions.h +++ b/aten/src/ATen/templates/Functions.h @@ -216,4 +216,12 @@ inline bool is_inference(const Tensor& tensor) { return tensor.is_inference(); } +inline bool is_conj(const Tensor& tensor) { + return tensor.is_conj(); +} + +inline Tensor conj(const Tensor& tensor) { + return tensor.conj(); +} + } diff --git a/aten/src/ATen/templates/TensorBody.h b/aten/src/ATen/templates/TensorBody.h index 21eeb2a51e9..cf2bef96b33 100644 --- a/aten/src/ATen/templates/TensorBody.h +++ b/aten/src/ATen/templates/TensorBody.h @@ -136,6 +136,17 @@ class TORCH_API Tensor { } } + Tensor conj() const { + if (!this->is_complex()) { + return *this; + } else { + if (this->is_sparse()) { + return this->conj_physical(); + } + return this->_conj(); + } + } + /// Should be used if *this can reasonably be expected to be contiguous and /// performance is important. /// Compared to contiguous, it saves a reference count @@ -363,6 +374,18 @@ class TORCH_API Tensor { return !at::impl::variable_excluded_from_dispatch(); } + inline bool is_conj() const { + return impl_->is_conj(); + } + + // sets the conjugate bit of a tensor. + // NOTE: Conjugate bit is supposed to be a read-only field. Only change this, if you are extremely sure + // that's what you want. Changing this might lead to incorrect behavior since conjugation is + // a lazy operation and we rely on this bit to determine if a conjugation needs to be materialized. + inline void _set_conj(bool conjugate) const { + impl_->_set_conj(conjugate); + } + /// Returns a `Tensor`'s layout. Layout layout() const noexcept { return impl_->layout(); diff --git a/c10/core/DispatchKey.cpp b/c10/core/DispatchKey.cpp index eaf2facc224..e22e4845440 100644 --- a/c10/core/DispatchKey.cpp +++ b/c10/core/DispatchKey.cpp @@ -65,6 +65,8 @@ const char* toString(DispatchKey t) { case DispatchKey::PrivateUse3: return "PrivateUse3"; + case DispatchKey::Conjugate: + return "Conjugate"; case DispatchKey::Meta: return "Meta"; diff --git a/c10/core/DispatchKey.h b/c10/core/DispatchKey.h index b3d31531694..ed597ba81c2 100644 --- a/c10/core/DispatchKey.h +++ b/c10/core/DispatchKey.h @@ -142,6 +142,11 @@ enum class DispatchKey : uint8_t { // constituent parts. Named, + // The Conjugate dispatch key is set for any tensors that need to perform + // conjugation + // This is implemented at a dispatch level right before any backends run + Conjugate, + // See Note [Out-of-tree vmap+grad prototype]. The purpose of this key // is to insert code after the "autograd subsystem" runs, so this key should // be directly after ADInplaceOrView and all of the autograd keys. diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h index ceac9ea25eb..0ff3411bb91 100644 --- a/c10/core/TensorImpl.h +++ b/c10/core/TensorImpl.h @@ -938,6 +938,26 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { */ const at::Tensor& grad() const; + /** + * Whether or not the imaginary part of the tensor should be negated + */ + inline bool is_conj() const { + return key_set_.has(DispatchKey::Conjugate); + } + + /** + * Set whether or not to take the conjugate of the tensor (flip the imaginary + * bit). + */ + void _set_conj(bool value) { + if (value) { + key_set_ = key_set_.add(DispatchKey::Conjugate); + TORCH_INTERNAL_ASSERT(isComplexType(typeMetaToScalarType(dtype()))); + } else { + key_set_ = key_set_.remove(DispatchKey::Conjugate); + } + } + /** * Return the accumulated gradient of a tensor. This gradient is computed * using forward mode AD. diff --git a/docs/source/tensor_view.rst b/docs/source/tensor_view.rst index 059a76e2b28..a75a31836ad 100644 --- a/docs/source/tensor_view.rst +++ b/docs/source/tensor_view.rst @@ -64,7 +64,6 @@ For reference, here’s a full list of view ops in PyTorch: - :attr:`~torch.Tensor.real` - :attr:`~torch.Tensor.imag` - :meth:`~torch.Tensor.view_as_real` -- :meth:`~torch.Tensor.view_as_imag` - :meth:`~torch.Tensor.unflatten` - :meth:`~torch.Tensor.unfold` - :meth:`~torch.Tensor.unsqueeze` diff --git a/docs/source/tensors.rst b/docs/source/tensors.rst index f025de09568..8c05a6733d4 100644 --- a/docs/source/tensors.rst +++ b/docs/source/tensors.rst @@ -276,6 +276,9 @@ Tensor class reference Tensor.contiguous Tensor.copy_ Tensor.conj + Tensor.conj_physical + Tensor.conj_physical_ + Tensor.resolve_conj Tensor.copysign Tensor.copysign_ Tensor.cos @@ -410,6 +413,7 @@ Tensor class reference Tensor.isnan Tensor.is_contiguous Tensor.is_complex + Tensor.is_conj Tensor.is_floating_point Tensor.is_inference Tensor.is_leaf diff --git a/docs/source/torch.rst b/docs/source/torch.rst index 51fc9b8fd5b..94b288920ca 100644 --- a/docs/source/torch.rst +++ b/docs/source/torch.rst @@ -19,6 +19,7 @@ Tensors is_tensor is_storage is_complex + is_conj is_floating_point is_nonzero set_default_dtype @@ -86,6 +87,7 @@ Indexing, Slicing, Joining, Mutating Ops :nosignatures: cat + conj chunk dsplit column_stack @@ -293,7 +295,7 @@ Pointwise Ops ceil clamp clip - conj + conj_physical copysign cos cosh @@ -511,6 +513,7 @@ Other Operations vander view_as_real view_as_complex + resolve_conj BLAS and LAPACK Operations diff --git a/test/backward_compatibility/check_backward_compatibility.py b/test/backward_compatibility/check_backward_compatibility.py index c6be5f3fb1a..ee7b55c312f 100644 --- a/test/backward_compatibility/check_backward_compatibility.py +++ b/test/backward_compatibility/check_backward_compatibility.py @@ -91,6 +91,9 @@ allow_list = [ ("aten::linalg_vector_norm", datetime.date(2021, 5, 15)), ("aten::repeat_interleave", datetime.date(2021, 6, 26)), ("aten::one_hot", datetime.date(2021, 6, 15)), + ("aten::conj", datetime.date(2021, 8, 1)), + ("aten::_conj", datetime.date(2021, 8, 1)), + ("aten::conj.out", datetime.date(2021, 8, 1)), ] def allow_listed(schema, allow_list): diff --git a/test/test_autograd.py b/test/test_autograd.py index a08c98b58bc..5a385360a55 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -4279,7 +4279,7 @@ class TestAutograd(TestCase): gradcheck(bad_fn, (x, y), check_forward_ad=True, fast_mode=fast_mode) def basic_mul(x): - return torch.view_as_real(x * 1j) + return torch.view_as_real(torch.resolve_conj(x * 1j)) gradcheck(basic_mul, x, check_forward_ad=True, fast_mode=fast_mode) # Test for one input and one output being complex @@ -5459,7 +5459,7 @@ complex_list = ['t', 'view', 'reshape', 'reshape_as', 'view_as', 'roll', 'clone' 'permute', 'squeeze', 'unsqueeze', 'resize', 'resize_as', 'tril', 'triu', 'chunk', 'split', 'split_with_sizes', 'zero_', '__radd__', 'mul', '__rmul__', 'diagonal', 'fill_', 'sub', 'narrow', - 'swapaxes', 'swapdims', 'tensor_split'] + separate_complex_tests + 'swapaxes', 'swapdims', 'tensor_split', 'select', 'clone'] + separate_complex_tests # deny list for batched grad computation EXCLUDE_BATCHED_GRAD_TESTS = set([ @@ -5591,7 +5591,9 @@ def add_test( output_variable = getattr(self_variable, name)(*args_variable, **kwargs_variable) if not isinstance(output_variable, tuple): output_variable = (output_variable,) + inplace_self_variable = deepcopy(self_variable) + self.assertEqual(inplace_self_variable, self_variable) inplace_self_variable_copy = tuple(i.clone() if isinstance(i, torch.Tensor) else i for i in (inplace_self_variable,)) inplace_args_variable = deepcopy(args_variable) diff --git a/test/test_jit_fuser_te.py b/test/test_jit_fuser_te.py index 920e7dbd1ea..1f0c914601f 100644 --- a/test/test_jit_fuser_te.py +++ b/test/test_jit_fuser_te.py @@ -1946,6 +1946,7 @@ known_failures = [ # If your OpInfo test causes this test to fail, add it here skip_ops = [ + 'conj' ] def get_name(op): diff --git a/test/test_ops.py b/test/test_ops.py index ea160b65ae1..ea84407add7 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -17,7 +17,7 @@ from torch.testing._internal.common_jit import JitCommonTestCase, check_against_ from torch.testing._internal.jit_metaprogramming_utils import create_script_fn, create_traced_fn, \ check_alias_annotation from torch.testing._internal.jit_utils import disable_autodiff_subgraph_inlining - +from collections.abc import Sequence # Get names of all the operators which have entry in `method_tests` (legacy testing infra) method_tested_operators = set(map(lambda test_details: test_details[0], method_tests())) @@ -110,7 +110,9 @@ class TestGradients(TestCase): return variant.__wrapped__ is op.get_inplace() return variant is op.get_inplace() - samples = op.sample_inputs(device, dtype, requires_grad=True) + include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex + samples = op.sample_inputs(device, dtype, requires_grad=True, include_conjugated_inputs=include_conjugated_inputs) + for sample in samples: if sample.broadcasts_input and is_inplace(variant): continue @@ -280,7 +282,8 @@ class TestCommon(JitCommonTestCase): _requires_grad = (op.supports_autograd and (dtype.is_floating_point or op.supports_complex_autograd(torch.device(device).type))) - samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad) + include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex + samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad, include_conjugated_inputs=include_conjugated_inputs) def _test_consistency_helper(samples, variants): for sample in samples: @@ -381,7 +384,8 @@ class TestCommon(JitCommonTestCase): _requires_grad = op.supports_autograd and (dtype.is_floating_point or op.supports_complex_autograd(torch.device(device).type)) - samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad) + include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex + samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad, include_conjugated_inputs=include_conjugated_inputs) for sample in samples: # Acquires variants to test @@ -734,6 +738,103 @@ class TestCommon(JitCommonTestCase): with self.assertRaises(RuntimeError, msg=msg_fail): op_out(out=out) + # Tests that + # 1. The operator's output for physically conjugated tensors and conjugate view tensors + # produces the same value + # 2. The gradients are same in both cases mentioned in (1) + # 3. If the operator's inplace variant is supported, tests that the inplace operation + # produces the correct value when called on a conjugate view tensor and that the output + # has its conj bit set to true + # This test only runs for C -> R and C -> C functions + # TODO: add tests for `R->C` functions + # Note: This test runs for functions that take both tensors and tensorlists as input. + @ops(op_db, allowed_dtypes=(torch.cfloat,)) + def test_conj_view(self, device, dtype, op): + if not op.test_conjugated_samples: + self.skipTest("Operation doesn't support conjugated inputs.") + _requires_grad = (op.supports_autograd and op.supports_complex_autograd(torch.device(device).type)) + samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad) + inplace_variant = op.inplace_variant + + # helper function to physically conjugate the tensor + def conjugate_physical(input): + if isinstance(input, torch.Tensor): + tensor_requires_grad = input.requires_grad + with torch.no_grad(): + input = input.conj_physical() + return input.requires_grad_(tensor_requires_grad) + + if isinstance(input, Sequence): + out = list(map(clone_input_helper, input)) + out[0] = conjugate_physical(out[0]) + return tuple(out) + + # helper function to clone and conjugate the input if its a tensor + # else clone the sequence and conjugate the first element in the sequence + # If a requires_grad argument is provided the tensor being conjugated will + # have its requires_grad set to that value. + def clone_conj_input_helper(input, **kwargs): + if isinstance(input, torch.Tensor): + requires_grad = kwargs.get('requires_grad', input.requires_grad) + with torch.no_grad(): + input = input.clone() + # Note: .conj() is not called under no_grad mode since it's not allowed to modify a + # view created in no_grad mode. Here it's ok to do so, so as a workaround we call conj + # before resetting the requires_grad field for input + input = input.conj() + assert input.is_leaf + return input.requires_grad_(requires_grad) + + if isinstance(input, Sequence): + out = list(map(clone_input_helper, input)) + out[0] = clone_conj_input_helper(out[0]) + return tuple(out) + + for sample in samples: + tensor = sample.input if isinstance(sample.input, torch.Tensor) else sample.input[0] + cloned1 = clone_conj_input_helper(sample.input) + sample.input = conjugate_physical(sample.input) + + # Computes function forward value with a physically conjugated tensor and + # a conj view tensor and verifies that the output in both case are equal. + expected_forward = op(sample.input, *sample.args, **sample.kwargs) + forward_with_conjview = op(cloned1, *sample.args, **sample.kwargs) + self.assertEqual(expected_forward, forward_with_conjview) + + # If the op has an inplace variant, and the input doesn't require broadcasting + # and has the same dtype as output, verify that the inplace operation on a conjugated + # input produces correct output, and the output tensor has the conj bit set to True + if inplace_variant is not None and not sample.broadcasts_input: + cloned2 = clone_conj_input_helper(tensor, requires_grad=False) + if (isinstance(expected_forward, torch.Tensor) and + expected_forward.dtype is tensor.dtype): + inplace_forward = inplace_variant(cloned2, *sample.args, **sample.kwargs) + self.assertTrue(inplace_forward.is_conj()) + self.assertEqual(inplace_forward, expected_forward) + + # TODO: backward consistency only supported for single tensor outputs + # TODO: backward consistency only checked on sample.input, not all + # tensor inputs + # TODO: update to handle checking grads of all tensor inputs as + # derived from each tensor output + if isinstance(expected_forward, torch.Tensor) and expected_forward.requires_grad: + tensor = sample.input if isinstance(sample.input, torch.Tensor) else sample.input[0] + expected_forward.sum().backward(retain_graph=True) + forward_with_conjview.sum().backward(retain_graph=True) + if tensor.grad is not None: + cloned1_tensor = cloned1 if isinstance(cloned1, torch.Tensor) else cloned1[0] + self.assertEqual(tensor.grad, cloned1_tensor.grad) + + tensor.grad, cloned1_tensor.grad = None, None + + # a repeat of the above test if output is not complex valued + if (expected_forward.is_complex()): + grad = torch.randn_like(expected_forward) + expected_forward.backward(grad.conj_physical()) + forward_with_conjview.backward(grad.conj()) + + self.assertEqual(tensor.grad, cloned1_tensor.grad) + instantiate_device_type_tests(TestOpInfo, globals()) instantiate_device_type_tests(TestGradients, globals()) diff --git a/test/test_view_ops.py b/test/test_view_ops.py index 56448ebaf40..6b02c2c9bc6 100644 --- a/test/test_view_ops.py +++ b/test/test_view_ops.py @@ -346,6 +346,14 @@ class TestViewOps(TestCase): self.assertEqual(a[5:].real, a.real[5:]) self.assertEqual(a[5:].imag, a.imag[5:]) + @onlyOnCPUAndCUDA + @dtypes(*torch.testing.get_all_complex_dtypes()) + def test_conj_view(self, device, dtype) -> None: + t = _make_tensor((4, 5,), dtype, device) + v = t.conj() + self.assertTrue(self.is_view_of(t, v)) + self.assertEqual(v, torch.from_numpy(t.cpu().numpy().conj()).to(device=device)) + @onlyOnCPUAndCUDA @dtypes(*product(torch.testing.get_all_complex_dtypes(), torch.testing.get_all_dtypes())) @suppress_warnings diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 5200bd8490f..c76aa4c20cf 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -380,16 +380,23 @@ - name: complex(Tensor real, Tensor imag) -> Tensor real: at::real(grad) - imag: at::imag(grad) + imag: at::imag(grad.resolve_conj()) result: at::complex(real_t, imag_t) - name: polar(Tensor abs, Tensor angle) -> Tensor abs, angle: polar_backward(grad, result) -- name: _conj(Tensor self) -> Tensor +- name: _conj(Tensor(a) self) -> Tensor(a) self: grad.conj() result: self_t.conj() +- name: _conj_physical(Tensor self) -> Tensor + self: grad.conj_physical() + result: self_t.conj_physical() + +- name: conj_physical_(Tensor(a!) self) -> Tensor(a!) + self: grad.conj_physical() + - name: copysign.Tensor(Tensor self, Tensor other) -> Tensor self: copysign_tensor_self_backward(grad, self, result) other: zeros_like(other) @@ -1306,12 +1313,16 @@ - name: view.dtype(Tensor(a) self, ScalarType dtype) -> Tensor(a) output_differentiability: [False] +# If the conj bit for self is set, this is effectively a fused conj() and view_as_real +- name: _view_as_real_physical(Tensor(a) self) -> Tensor(a) + self: "self.is_conj() ? at::view_as_complex(grad.contiguous()) : at::view_as_complex(grad.contiguous()).conj()" + - name: view_as_real(Tensor(a) self) -> Tensor(a) self: at::view_as_complex(grad.contiguous()) # gx0 + 1j * gx1 result: at::view_as_real(self_t) - name: view_as_complex(Tensor(a) self) -> Tensor(a) - self: at::view_as_real(grad.contiguous()) # [gx, gy] + self: at::view_as_real(grad.contiguous().resolve_conj()) # [gx, gy] result: at::view_as_complex(self_t) - name: _s_where(Tensor condition, Tensor self, Tensor other) -> Tensor diff --git a/tools/autograd/gen_inplace_or_view_type.py b/tools/autograd/gen_inplace_or_view_type.py index 27f4956287f..158044b5fd8 100644 --- a/tools/autograd/gen_inplace_or_view_type.py +++ b/tools/autograd/gen_inplace_or_view_type.py @@ -28,7 +28,7 @@ from .gen_trace_type import ( # # A map: function name => name of the argument that all outputs are view of -VIEW_FUNCTIONS_WITH_METADATA_CHANGE = ['view_as_real', 'view_as_complex'] +VIEW_FUNCTIONS_WITH_METADATA_CHANGE = ['view_as_complex', '_view_as_real_physical', 'view_as_real', '_conj'] VIEW_FUNCTIONS = { 'numpy_T': 'self', diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index 5a4300635ca..73ef7538a46 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -100,7 +100,8 @@ GRADIENT_IMPLEMENTED_FOR_COMPLEX = { 'replication_pad1d_backward', 'replication_pad2d_backward', 'replication_pad3d_backward', 'diag', 'masked_scatter', 'masked_select', 'index_fill', 'trace', 'polar', 'cumsum', 'rsub', 'eig', 'lerp', 'linalg_vector_norm', 'cumprod', 'prod', 'index_copy', 'lu', 'unfold', 'unfold_backward', - 'index', 'masked_fill', 'cross', 'lu_unpack', 'renorm', + 'index', 'masked_fill', 'cross', 'lu_unpack', 'renorm', '_view_as_real_physical', '_conj_physical', + 'conj_physical_' } GRADIENT_IMPLEMENTED_FOR_SPARSE_COMPLEX = { diff --git a/tools/autograd/load_derivatives.py b/tools/autograd/load_derivatives.py index 575e751f552..59ba25a196e 100644 --- a/tools/autograd/load_derivatives.py +++ b/tools/autograd/load_derivatives.py @@ -10,7 +10,7 @@ import yaml from tools.codegen.api.autograd import (Derivative, DifferentiabilityInfo, SavedAttribute, ForwardDerivative) from tools.codegen.api.types import (Binding, CppSignatureGroup, NamedCType, BaseCType, VectorCType, - intArrayRefT, tensorOptionsT, typeAndSizeT, intT, + intArrayRefT, tensorOptionsT, typeAndSizeT, intT, boolT, tensorGeometryT, scalarTypeT, SpecialArgName, OptionalCType, stringT) from tools.codegen.api import cpp @@ -498,6 +498,11 @@ def saved_variables( 'nctype': lambda name: NamedCType(name, BaseCType(intArrayRefT)), 'expr': stride_expr, }), + # replace self.is_conj() with self_conjugate + (r'{}.is_conj\(\)', { + 'suffix': '_conjugate', + 'nctype': lambda name: NamedCType(name, BaseCType(boolT)), + }) ] # find which arguments need to be saved diff --git a/torch/_tensor.py b/torch/_tensor.py index 8bc5d06d562..4f723ca7450 100644 --- a/torch/_tensor.py +++ b/torch/_tensor.py @@ -86,8 +86,10 @@ class Tensor(torch._C._TensorBase): self.requires_grad, self._backward_hooks) else: - new_tensor = self.new() + new_tensor = self.new_empty([]) new_tensor.set_(new_storage, self.storage_offset(), self.size(), self.stride()) + if self.is_conj(): + new_tensor = new_tensor.conj_physical() new_tensor.requires_grad = self.requires_grad if self.grad is not None: new_tensor.grad = self.grad.__deepcopy__(memo) diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py index c9975fda86a..38d5e992f60 100644 --- a/torch/_tensor_docs.py +++ b/torch/_tensor_docs.py @@ -893,6 +893,27 @@ conj() -> Tensor See :func:`torch.conj` """) +add_docstr_all('conj_physical', + r""" +conj_physical() -> Tensor + +See :func:`torch.conj_physical` +""") + +add_docstr_all('conj_physical_', + r""" +conj_physical_() -> Tensor + +In-place version of :meth:`~Tensor.conj_physical` +""") + +add_docstr_all('resolve_conj', + r""" +resolve_conj() -> Tensor + +See :func:`torch.resolve_conj` +""") + add_docstr_all('copysign', r""" copysign(other) -> Tensor @@ -1977,6 +1998,13 @@ is_inference() -> bool See :func:`torch.is_inference` """) +add_docstr_all('is_conj', + r""" +is_conj() -> bool + +Returns True if the conjugate bit of :attr:`self` is set to true. +""") + add_docstr_all('is_signed', r""" is_signed() -> bool diff --git a/torch/_tensor_str.py b/torch/_tensor_str.py index c1c95bc77e5..ec736f5ff99 100644 --- a/torch/_tensor_str.py +++ b/torch/_tensor_str.py @@ -153,11 +153,12 @@ class _Formatter(object): def _scalar_str(self, formatter1, formatter2=None): if formatter2 is not None: real_str = _scalar_str(self.real, formatter1) - imag_str = _scalar_str(self.imag, formatter2) + "j" - if self.imag < 0: - return real_str + imag_str.lstrip() + imag_str = (_scalar_str(self.imag, formatter2) + "j").lstrip() + # handles negative numbers, +0.0, -0.0 + if imag_str[0] == '+' or imag_str[0] == '-': + return real_str + imag_str else: - return real_str + "+" + imag_str.lstrip() + return real_str + "+" + imag_str else: return formatter1.format(self.item()) @@ -174,11 +175,12 @@ def _vector_str(self, indent, summarize, formatter1, formatter2=None): def _val_formatter(val, formatter1=formatter1, formatter2=formatter2): if formatter2 is not None: real_str = formatter1.format(val.real) - imag_str = formatter2.format(val.imag) + "j" - if val.imag < 0: - return real_str + imag_str.lstrip() + imag_str = (formatter2.format(val.imag) + "j").lstrip() + # handles negative numbers, +0.0, -0.0 + if imag_str[0] == '+' or imag_str[0] == '-': + return real_str + imag_str else: - return real_str + "+" + imag_str.lstrip() + return real_str + "+" + imag_str else: return formatter1.format(val) @@ -235,6 +237,7 @@ def _tensor_str(self, indent): self = self.float() if self.dtype.is_complex: + self = self.resolve_conj() real_formatter = _Formatter(get_summarized_data(self.real) if summarize else self.real) imag_formatter = _Formatter(get_summarized_data(self.imag) if summarize else self.imag) return _tensor_str_with_formatter(self, indent, summarize, real_formatter, imag_formatter) diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 96094dcf275..e0f09bd764c 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -2181,15 +2181,18 @@ Example:: tensor([(0.0000+1.0000j), (-1.4142-1.4142j)], dtype=torch.complex128) """) -add_docstr(torch.conj, +add_docstr(torch.conj_physical, r""" -conj(input, *, out=None) -> Tensor +conj_physical(input, *, out=None) -> Tensor -Computes the element-wise conjugate of the given :attr:`input` tensor. If :attr:`input` has a non-complex dtype, -this function just returns :attr:`input`. +Computes the element-wise conjugate of the given :attr:`input` tensor. +If :attr:`input` has a non-complex dtype, this function just returns :attr:`input`. -.. warning:: In the future, :func:`torch.conj` may return a non-writeable view for an :attr:`input` of - non-complex dtype. It's recommended that programs not modify the tensor returned by :func:`torch.conj` +.. note:: + This performs the conjugate operation regardless of the fact conjugate bit is set or not. + +.. warning:: In the future, :func:`torch.conj_physical` may return a non-writeable view for an :attr:`input` of + non-complex dtype. It's recommended that programs not modify the tensor returned by :func:`torch.conj_physical` when :attr:`input` is of non-complex dtype to be compatible with this change. .. math:: @@ -2203,10 +2206,63 @@ Keyword args: Example:: - >>> torch.conj(torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j])) + >>> torch.conj_physical(torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j])) tensor([-1 - 1j, -2 - 2j, 3 + 3j]) """.format(**common_args)) +add_docstr(torch.conj, + r""" +conj(input) -> Tensor + +Returns a view of :attr:`input` with a flipped conjugate bit. If :attr:`input` has a non-complex dtype, +this function just returns :attr:`input`. + +.. note:: + :func:`torch.conj` performs a lazy conjugation, but the actual conjugated tensor can be materialized + at any time using :func:`torch.resolve_conj`. + +.. warning:: In the future, :func:`torch.conj` may return a non-writeable view for an :attr:`input` of + non-complex dtype. It's recommended that programs not modify the tensor returned by :func:`torch.conj_physical` + when :attr:`input` is of non-complex dtype to be compatible with this change. + +Args: + {input} + +Example:: + + >>> x = torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j]) + >>> x.is_conj() + False + >>> y = torch.conj(x) + >>> y.is_conj() + True + +""") + +add_docstr(torch.resolve_conj, + r""" +resolve_conj(input) -> Tensor + +Returns a new tensor with materialized conjugation if :attr:`input`'s conjugate bit is set to `True`, +else returns :attr:`input`. The output tensor will always have its conjugate bit set to `False`. + +Args: + {input} + +Example:: + + >>> x = torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j]) + >>> y = x.conj() + >>> y.is_conj() + True + >>> z = y.resolve_conj() + >>> z + tensor([-1 - 1j, -2 - 2j, 3 + 3j]) + >>> z.is_conj() + False + +""") + add_docstr(torch.copysign, r""" copysign(input, other, *, out=None) -> Tensor @@ -4201,6 +4257,15 @@ Args: {input} """.format(**common_args)) +add_docstr(torch.is_conj, r""" +is_conj(input) -> (bool) + +Returns True if the :attr:`input` is a conjugated tensor, i.e. its conjugate bit is set to `True`. + +Args: + {input} +""".format(**common_args)) + add_docstr(torch.is_nonzero, r""" is_nonzero(input) -> (bool) diff --git a/torch/autograd/gradcheck.py b/torch/autograd/gradcheck.py index b6ccd9be115..db6318fdceb 100644 --- a/torch/autograd/gradcheck.py +++ b/torch/autograd/gradcheck.py @@ -545,7 +545,7 @@ def _get_analytical_vJu_backward_mode(inputs, outputs, nondet_tol, check_grad_dt # the error checking logic from slow mode vJ = vJ.T.squeeze(0) if vJ.is_complex(): # C -> R - tv = torch.view_as_real(vJ) + tv = torch.view_as_real(vJ.resolve_conj()) tr = tv.select(-1, 0) ti = tv.select(-1, 1) jacobian_scalars.append(tr.dot(u[0]) + 1j * ti.dot(u[1])) @@ -894,7 +894,11 @@ def _real_and_imag_output(fn): outs = _as_tuple(fn(*inputs)) return tuple(fn_to_apply(o) if o.is_complex() else o for o in outs) return wrapped_fn - return apply_to_c_outs(fn, torch.real), apply_to_c_outs(fn, torch.imag) + + # TODO(@anjali411): remove this workaround once neg bit is added. + def torch_imag(x): + return x.resolve_conj().imag + return apply_to_c_outs(fn, torch.real), apply_to_c_outs(fn, torch_imag) def _real_and_imag_input(fn, complex_inp_indices): # returns new functions that take real inputs instead of complex inputs and compute fn(x + 0 * 1j) @@ -935,7 +939,7 @@ def _gradcheck_real_imag(gradcheck_fn, func, func_out, tupled_inputs, outputs, e if complex_inp_indices: real_fn, imag_fn = _real_and_imag_input(func, complex_inp_indices) - imag_inputs = [inp.imag if is_tensor_like(inp) and inp.is_complex() else inp for inp in tupled_inputs] + imag_inputs = [inp.resolve_conj().imag if is_tensor_like(inp) and inp.is_complex() else inp for inp in tupled_inputs] imag_func_out = imag_fn(*imag_inputs) diff_imag_func_out = _differentiable_outputs(imag_func_out) gradcheck_fn(imag_fn, imag_func_out, imag_inputs, diff_imag_func_out, eps, diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index 5836b3da12a..d083ceb1008 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -248,7 +248,7 @@ Tensor pow_backward_self(Tensor grad, const Tensor & self, const Tensor & expone Tensor pow_backward_exponent(Tensor grad, const Tensor& self, const Tensor& exponent, Tensor result) { Tensor cond; if (exponent.is_complex()) { - auto is_real_exp = at::logical_and(at::imag(exponent) == 0, at::real(exponent) >= 0); + auto is_real_exp = at::logical_and(at::imag(exponent.resolve_conj()) == 0, at::real(exponent) >= 0); cond = at::logical_and(self == 0, is_real_exp); } else { cond = at::logical_and(self == 0, exponent >= 0); @@ -264,7 +264,7 @@ Tensor pow_backward_exponent(Tensor grad, const Scalar & base, const Tensor& exp if (base.equal(0.0)) { auto cond = [](auto exp) { if (exp.is_complex()) { - return at::logical_and(at::imag(exp) == 0, at::real(exp) >= 0); + return at::logical_and(at::imag(exp.resolve_conj()) == 0, at::real(exp) >= 0); } else { return exp >=0; } @@ -2154,7 +2154,7 @@ Tensor svd_backward(const std::vector &grads, const T if (self.is_complex() && gu.defined()) { Tensor L = at::matmul(uh, gu).diagonal(0, -2, -1); at::real(L).zero_(); - at::imag(L).mul_(sigma_inv); + at::imag(L.resolve_conj()).mul_(sigma_inv); Tensor imag_term = at::matmul(u * L.unsqueeze(-2), vh); return u_term + sigma_term + v_term + imag_term; } @@ -2195,7 +2195,7 @@ Tensor eig_backward(const std::vector &grads, const T } // path for torch.linalg.eig with always a complex tensor of eigenvalues else { - is_imag_eigvals_zero = (at::imag(D) == 0.0).min().item(); + is_imag_eigvals_zero = (at::imag(D.resolve_conj()) == 0.0).min().item(); // insert an additional dimension to be compatible with torch.eig. // Recall that it produces 2D tensors. // We extract only the real parts as there is no support for diff --git a/torch/csrc/jit/tensorexpr/external_functions_codegen.cpp b/torch/csrc/jit/tensorexpr/external_functions_codegen.cpp index 2446cacf587..b2c034a2703 100644 --- a/torch/csrc/jit/tensorexpr/external_functions_codegen.cpp +++ b/torch/csrc/jit/tensorexpr/external_functions_codegen.cpp @@ -83,23 +83,6 @@ void nnc_aten_sgn( } catch (...) { } } -void nnc_aten_conj( - int64_t bufs_num, - void** buf_data, - int64_t* buf_ranks, - int64_t* buf_dims, - int8_t* buf_dtypes, - int64_t args_num, - int64_t* extra_args) { - std::vector tensors = - constructTensors(bufs_num, buf_data, buf_ranks, buf_dims, buf_dtypes); - at::Tensor& r = tensors[0]; - const at::Tensor& self = tensors[1]; - try { - at::conj_out(r, self); - } catch (...) { - } -} void nnc_aten_acos( int64_t bufs_num, void** buf_data, @@ -2758,9 +2741,6 @@ const static RegisterNNCExternalFunction nnc_angle( "nnc_aten_angle", nnc_aten_angle); const static RegisterNNCExternalFunction nnc_sgn("nnc_aten_sgn", nnc_aten_sgn); -const static RegisterNNCExternalFunction nnc_conj( - "nnc_aten_conj", - nnc_aten_conj); const static RegisterNNCExternalFunction nnc_acos( "nnc_aten_acos", nnc_aten_acos); diff --git a/torch/overrides.py b/torch/overrides.py index 3414d866db3..750c5ff3b5d 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -228,6 +228,8 @@ def get_ignored_functions() -> Set[Callable]: Tensor.to_sparse_csr, Tensor._reduce_ex_internal, Tensor._fix_weakref, + Tensor._conj, + Tensor._conj_physical, } @@ -349,6 +351,8 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.polar: lambda abs, ang: -1, torch.linalg.cond: lambda input, ord=None: -1, torch.conj: lambda input, out=None: -1, + torch.conj_physical: lambda input, out=None: -1, + torch.resolve_conj: lambda input, out=None: -1, torch.constant_pad_nd: lambda input, pad, value=0: -1, torch.conv1d: lambda input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1: -1, torch.conv2d: lambda input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1: -1, @@ -500,6 +504,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.linalg.inv: lambda input, out=None: -1, torch.linalg.inv_ex: lambda input, check_errors=False, out=None: -1, torch.is_complex: lambda input: -1, + torch.is_conj: lambda input: -1, torch.is_distributed: lambda input: -1, torch.is_inference: lambda input: -1, torch.is_floating_point: lambda input: -1, @@ -797,6 +802,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.real: lambda input, out=None: -1, torch.vdot: lambda input, other, out=None: -1, torch.view_as_real: lambda input: -1, + torch._view_as_real_physical: lambda input: -1, torch.view_as_complex: lambda input: -1, torch.reciprocal: lambda input, out=None: -1, torch.relu: lambda input, inplace=False: -1, diff --git a/torch/testing/_core.py b/torch/testing/_core.py index 02f165ce839..1522c16aa6c 100644 --- a/torch/testing/_core.py +++ b/torch/testing/_core.py @@ -134,6 +134,8 @@ def _compare_tensors_internal(a: torch.Tensor, b: torch.Tensor, *, rtol, atol, e # Compares complex tensors' real and imaginary parts separately. # (see NOTE Test Framework Tensor "Equality") if a.is_complex(): + a = a.resolve_conj() + b = b.resolve_conj() if equal_nan == "relaxed": a = a.clone() b = b.clone() diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 0e46ae266da..c829e3300f7 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -10,7 +10,6 @@ import random import torch import numpy as np from torch._six import inf -from torch.autograd import Variable import collections.abc from typing import List, Sequence, Tuple, Dict, Any, Union @@ -231,6 +230,7 @@ class OpInfo(object): # function around gradcheck (testing._internal.common_utils.gradcheck) inplace_variant=_NOTHING, # explicitly pass the inplace variant of the operator if required method_variant=_NOTHING, # explicitly pass the method variant of the operator if required + test_conjugated_samples=True, ): # Validates the dtypes are generated from the dispatch-related functions @@ -300,6 +300,8 @@ class OpInfo(object): if aliases is not None: self.aliases = tuple(AliasInfo(a) for a in aliases) # type: ignore[assignment] + self.test_conjugated_samples = test_conjugated_samples + def __call__(self, *args, **kwargs): """Calls the function variant of the operator.""" return self.op(*args, **kwargs) @@ -326,6 +328,37 @@ class OpInfo(object): """ return self.operator_variant + def conjugate_sample_inputs(self, device, dtype, requires_grad=False, **kwargs): + """Returns an iterable of SampleInputs but with the tensor input or first + tensor in a sequence input conjugated. + """ + + # TODO: Remove the try/except once all operators have sample_inputs_func with + # **kwargs in their signature. + try: + samples = self.sample_inputs_func(self, device, dtype, requires_grad, **kwargs) + except TypeError: + samples = self.sample_inputs_func(self, device, dtype, requires_grad) + + conj_samples = list(samples) + + def conjugate(tensor): + _requires_grad = tensor.requires_grad + with torch.no_grad(): + tensor = tensor.conj() + return tensor.requires_grad_(_requires_grad) + + for i in range(len(samples)): + sample = conj_samples[i] + # Note: it is assumed that the input here is either a tensor or tensorlist + if isinstance(sample.input, torch.Tensor): + sample.input = conjugate(sample.input) + else: + with torch.no_grad(): + sample.input[0] = conjugate(sample.input[0]) + + return tuple(conj_samples) + def sample_inputs(self, device, dtype, requires_grad=False, **kwargs): """Returns an iterable of SampleInputs. @@ -339,6 +372,13 @@ class OpInfo(object): samples = self.sample_inputs_func(self, device, dtype, requires_grad, **kwargs) except TypeError: samples = self.sample_inputs_func(self, device, dtype, requires_grad) + + if 'include_conjugated_inputs' in kwargs and kwargs.get('include_conjugated_inputs'): + conj_samples = self.conjugate_sample_inputs(device, dtype, requires_grad, **kwargs) + samples_list = list(samples) + samples_list.extend(conj_samples) + samples = tuple(samples_list) + return samples # Returns True if the test should be skipped and False otherwise @@ -1001,21 +1041,21 @@ def sample_inputs_addbmm(op_info, device, dtype, requires_grad, **kwargs): return tuple(sample_inputs) def sample_inputs_addcmul_addcdiv(op_info, device, dtype, requires_grad, **kwargs): - test_cases = [((S, S), (S, S), (S, S)), - ((S, S), (S, 1), (1, S)), - ((1,), (S, S, 1), (1, S)), - ((), (), ()), - ((S, S), (), ()), - ((), (S, S, 1), (1, S)), + test_cases = [(((S, S), (S, S), (S, S)), False), + (((S, S), (S, 1), (1, S)), False), + (((1,), (S, S, 1), (1, S)), True), + (((), (), ()), False), + (((S, S), (), ()), True), + (((), (S, S, 1), (1, S)), True) ] sample_inputs = [] - for input_args in test_cases: + for input_args, broadcasts_input in test_cases: args = tuple(make_tensor(arg, device, dtype, requires_grad=requires_grad) if isinstance(arg, tuple) else arg for arg in input_args) - sample_inputs.append(SampleInput(args[0], args=args[1:])) + sample_inputs.append(SampleInput(args[0], args=args[1:], broadcasts_input=broadcasts_input)) - sample_inputs.append(SampleInput(args[0], args=args[1:], kwargs=dict(value=3.14))) + sample_inputs.append(SampleInput(args[0], args=args[1:], kwargs=dict(value=3.14), broadcasts_input=broadcasts_input)) return tuple(sample_inputs) @@ -1057,7 +1097,8 @@ def sample_inputs_addr(op_info, device, dtype, requires_grad, **kwargs): make_tensor((), device, dtype, low=None, high=None, requires_grad=requires_grad), args=( make_tensor((S, ), device, dtype, low=None, high=None, requires_grad=requires_grad), - make_tensor((M, ), device, dtype, low=None, high=None, requires_grad=requires_grad))) + make_tensor((M, ), device, dtype, low=None, high=None, requires_grad=requires_grad)), + broadcasts_input=True) if dtype.is_complex: alpha, beta = 0.1 + 0.3j, 0.4 + 0.6j @@ -1078,7 +1119,8 @@ def sample_inputs_addr(op_info, device, dtype, requires_grad, **kwargs): args=( make_tensor((S, ), device, dtype, low=None, high=None, requires_grad=requires_grad), make_tensor((M, ), device, dtype, low=None, high=None, requires_grad=requires_grad)), - kwargs=dict(beta=beta, alpha=alpha)) + kwargs=dict(beta=beta, alpha=alpha), + broadcasts_input=True) return (input1, input2, input3, input4) @@ -2196,6 +2238,32 @@ def np_unary_ufunc_integer_promotion_wrapper(fn): return wrapped_fn +def sample_inputs_spectral_ops(self, device, dtype, requires_grad=False, **kwargs): + nd_tensor = make_tensor((S, S + 1, S + 2), device, dtype, low=None, high=None, + requires_grad=requires_grad) + tensor = make_tensor((31,), device, dtype, low=None, high=None, + requires_grad=requires_grad) + + if self.ndimensional: + return [ + SampleInput(nd_tensor, kwargs=dict(s=(3, 10), dim=(1, 2), norm='ortho')), + SampleInput(nd_tensor, kwargs=dict(norm='ortho')), + SampleInput(nd_tensor, kwargs=dict(s=(8,))), + SampleInput(tensor), + + *(SampleInput(nd_tensor, kwargs=dict(dim=dim)) + for dim in [-1, -2, -3, (0, -1)]), + ] + else: + return [ + SampleInput(nd_tensor, kwargs=dict(n=10, dim=1, norm='ortho')), + SampleInput(nd_tensor, kwargs=dict(norm='ortho')), + SampleInput(nd_tensor, kwargs=dict(n=7)), + SampleInput(tensor), + + *(SampleInput(nd_tensor, kwargs=dict(dim=dim)) + for dim in [-1, -2, -3]), + ] # Metadata class for Fast Fourier Transforms in torch.fft. class SpectralFuncInfo(OpInfo): @@ -2207,6 +2275,7 @@ class SpectralFuncInfo(OpInfo): ref=None, # Reference implementation (probably in np.fft namespace) dtypes=floating_and_complex_types(), ndimensional: bool, # Whether dim argument can be a tuple + sample_inputs_func=sample_inputs_spectral_ops, decorators=None, **kwargs): decorators = list(decorators) if decorators is not None else [] @@ -2220,39 +2289,12 @@ class SpectralFuncInfo(OpInfo): super().__init__(name=name, dtypes=dtypes, decorators=decorators, + sample_inputs_func=sample_inputs_func, **kwargs) self.ref = ref if ref is not None else _getattr_qual(np, name) self.ndimensional = ndimensional - def sample_inputs(self, device, dtype, requires_grad=False, **kwargs): - nd_tensor = make_tensor((S, S + 1, S + 2), device, dtype, low=None, high=None, - requires_grad=requires_grad) - tensor = make_tensor((31,), device, dtype, low=None, high=None, - requires_grad=requires_grad) - - if self.ndimensional: - return [ - SampleInput(nd_tensor, kwargs=dict(s=(3, 10), dim=(1, 2), norm='ortho')), - SampleInput(nd_tensor, kwargs=dict(norm='ortho')), - SampleInput(nd_tensor, kwargs=dict(s=(8,))), - SampleInput(tensor), - - *(SampleInput(nd_tensor, kwargs=dict(dim=dim)) - for dim in [-1, -2, -3, (0, -1)]), - ] - else: - return [ - SampleInput(nd_tensor, kwargs=dict(n=10, dim=1, norm='ortho')), - SampleInput(nd_tensor, kwargs=dict(norm='ortho')), - SampleInput(nd_tensor, kwargs=dict(n=7)), - SampleInput(tensor), - - *(SampleInput(nd_tensor, kwargs=dict(dim=dim)) - for dim in [-1, -2, -3]), - ] - - class ShapeFuncInfo(OpInfo): """Early version of a specialized OpInfo for Shape manipulating operations like tile and roll""" def __init__(self, @@ -2737,13 +2779,13 @@ def sample_inputs_pow(op_info, device, dtype, requires_grad, **kwargs): test_cases = ( ((2, 2), 0, 5, 1e-3, requires_grad, (2, 2), 0, 1, 0.1, requires_grad, False), ((2, 2), 0, 5, 1e-3, requires_grad, (1,), 0, 1, 0.1, requires_grad, False), - ((), 1e-3, 1e-3 + 1, 0, True, (), 0.1, 1.1, 0, False, False), + ((), 1e-3, 1e-3 + 1, 0, requires_grad, (), 0.1, 1.1, 0, False, False), ((2, 2), 0, 5, 1e-3, requires_grad, (), 0.1, 1.1, 1, False, False), ) tests_require_resizing = ( - ((1,), 0, 5, 1e-3, requires_grad, (2, 2), 0, 1, 0.1, requires_grad, True), - ((2, 1, 2), 0, 5, 1e-3, requires_grad, (1, 2, 1), 0, 1, 0.1, requires_grad, True), - ((), 1e-3, 1e-3 + 1, 0, True, (1, S, 1), 0, 1, 0.1, requires_grad, True), + ((1,), 0, 5, 1e-3, requires_grad, (2, 2), 0, 1, 0.1, requires_grad, requires_grad), + ((2, 1, 2), 0, 5, 1e-3, requires_grad, (1, 2, 1), 0, 1, 0.1, requires_grad, requires_grad), + ((), 1e-3, 1e-3 + 1, 0, requires_grad, (1, S, 1), 0, 1, 0.1, requires_grad, requires_grad), ) cases = test_cases + tests_require_resizing samples = list(SampleInput(make_tensor(shape_b, low=low_b, high=high_b, @@ -2757,7 +2799,7 @@ def sample_inputs_pow(op_info, device, dtype, requires_grad, **kwargs): high_e, additive_e, e_grad, broadcasts_input in cases) tensor_scalar_inputs = ( ((2, 2), 0, 5, 1e-3, requires_grad, (3.14,)), - ((), 1e-3, 1e-3 + 1, 0, True, (3.14,)) + ((), 1e-3, 1e-3 + 1, 0, requires_grad, (3.14,)) ) more_samples = list(SampleInput(make_tensor(shape, dtype=dtype, device=device, high=high, low=low, @@ -2768,8 +2810,8 @@ def sample_inputs_pow(op_info, device, dtype, requires_grad, **kwargs): elif dtype in [torch.complex64, torch.complex128]: args_tuple = ( ((2, 2), 0, 5, requires_grad, (3.14,)), - ((), 0, 1, True, (3.14,)), - ((), 0, 1, True, (3.14j,)) + ((), 0, 1, requires_grad, (3.14,)), + ((), 0, 1, requires_grad, (3.14j,)) ) samples = list(SampleInput(make_tensor(shape, dtype=dtype, device=device, high=high, low=low, @@ -4298,7 +4340,7 @@ op_db: List[OpInfo] = [ SkipInfo('TestGradients', 'test_method_grad', device_type='cuda', dtypes=[torch.cdouble], active_if=IS_WINDOWS), SkipInfo('TestGradients', 'test_forward_mode_AD', - dtypes=[torch.cdouble], active_if=IS_WINDOWS), + dtypes=[torch.cdouble]), )), OpInfo('add', dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), @@ -4670,28 +4712,30 @@ op_db: List[OpInfo] = [ supports_forward_ad=True, ), UnaryUfuncInfo('conj', + ref=np.conj, + dtypes=all_types_and_complex_and(torch.bool, + torch.bfloat16, torch.half), + supports_forward_ad=True, + supports_out=False), + UnaryUfuncInfo('conj_physical', ref=np.conj, dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half), supports_forward_ad=True, skips=( - # File "test_unary_ufuncs.py", line 289, in test_reference_numerics - # if not torch.can_cast(numpy_to_torch_dtype_dict[expected.dtype.type], dtype): - # KeyError: - # Following error in Windows CI - SkipInfo('TestUnaryUfuncs', 'test_reference_numerics_normal', - dtypes=[torch.int], - active_if=IS_WINDOWS), - SkipInfo('TestUnaryUfuncs', 'test_reference_numerics_hard', - dtypes=[torch.int], - active_if=IS_WINDOWS), - # TODO fix the formula for complex forward AD - SkipInfo('TestGradients', 'test_forward_mode_AD'), + SkipInfo('TestCommon', 'test_variant_consistency_jit', dtypes=(torch.float32, )), )), + OpInfo('resolve_conj', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_view_as_real, + supports_forward_ad=True, + supports_out=False, + ), OpInfo('view_as_real', dtypes=complex_types(), supports_forward_ad=True, sample_inputs_func=sample_inputs_view_as_real, + test_conjugated_samples=False, ), OpInfo('view_as_complex', dtypes=floating_types_and(torch.half), @@ -5120,7 +5164,9 @@ op_db: List[OpInfo] = [ ref=np.imag, dtypes=complex_types(), supports_out=False, - supports_autograd=False, + supports_forward_ad=True, + # TODO(@anjali411): Test this once neg bit is added. + test_conjugated_samples=False, skips=( # Skip since real and imag don't have out variants. SkipInfo('TestUnaryUfuncs', 'test_out_arg_all_dtypes'), @@ -5251,8 +5297,6 @@ op_db: List[OpInfo] = [ supports_autograd=False, decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack], skips=( - # skip because `linalg_lstsq` is not differentiable - SkipInfo('TestGradients', 'test_fn_grad'), SkipInfo('TestCommon', 'test_variant_consistency_jit'), )), OpInfo('linalg.matrix_power', @@ -5471,7 +5515,8 @@ op_db: List[OpInfo] = [ # "RuntimeError: CUDA error: CUBLAS_STATUS_NOT_SUPPORTED when # calling cublasGemmStridedBatchedExFix." SkipInfo('TestOpInfo', 'test_supported_backward', - device_type='cuda', dtypes=(torch.bfloat16,)),)), + device_type='cuda', dtypes=(torch.bfloat16,)), + SkipInfo('TestCommon', 'test_conj_view', device_type='cpu'),),), OpInfo('max', op=torch.max, variant_test_name='binary', @@ -5699,7 +5744,9 @@ op_db: List[OpInfo] = [ assert_autodiffed=True), OpInfo('float_power', dtypes=all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool), - sample_inputs_func=sample_inputs_pow), + sample_inputs_func=sample_inputs_pow, + skips=( + SkipInfo('TestCommon', 'test_conj_view', device_type='cuda'),),), OpInfo('prod', dtypes=all_types_and_complex_and(torch.bool), dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), @@ -5739,7 +5786,6 @@ op_db: List[OpInfo] = [ ref=np.real, dtypes=complex_types(), supports_out=False, - supports_autograd=False, skips=( # Skip since real and imag don't have out variants. SkipInfo('TestUnaryUfuncs', 'test_out_arg_all_dtypes'), @@ -7068,23 +7114,26 @@ def create_input(call_args, requires_grad=True, non_contiguous=False, call_kwarg def maybe_non_contig(tensor): return tensor if not non_contiguous else make_non_contiguous(tensor) + def conjugate(tensor): + return tensor.conj() + if isinstance(arg, torch.Size) or isinstance(arg, dont_convert): return arg elif isinstance(arg, tuple) and len(arg) == 0: - var = torch.randn((), dtype=dtype, device=device) + var = conjugate(torch.randn((), dtype=dtype, device=device)) var.requires_grad = requires_grad return var elif isinstance(arg, tuple) and not isinstance(arg[0], torch.Tensor): - return Variable(maybe_non_contig(torch.randn(*arg, dtype=dtype, device=device)), requires_grad=requires_grad) + return conjugate(maybe_non_contig(torch.randn(*arg, dtype=dtype, device=device))).requires_grad_(requires_grad) # double check casting elif isinstance(arg, non_differentiable): if isinstance(arg.tensor, torch.Tensor): if arg.tensor.dtype == torch.float: return maybe_non_contig(arg.tensor.to(dtype=torch.double, device=device)) if arg.tensor.dtype == torch.cfloat: - return maybe_non_contig(arg.tensor.to(dtype=torch.cdouble, device=device)) - return maybe_non_contig(arg.tensor.to(device=device)) - return maybe_non_contig(arg.tensor.to(device=device)) + return conjugate(maybe_non_contig(arg.tensor.to(dtype=torch.cdouble, device=device))) + return conjugate(maybe_non_contig(arg.tensor.to(device=device))) + return conjugate(maybe_non_contig(arg.tensor.to(device=device))) elif isinstance(arg, torch.Tensor): if arg.dtype == torch.float: arg = arg.double() @@ -7094,7 +7143,7 @@ def create_input(call_args, requires_grad=True, non_contiguous=False, call_kwarg raise RuntimeError("User provided tensor is real for a test that runs with complex dtype, ", "which is not supported for now") # NOTE: We do clone() after detach() here because we need to be able to change size/storage of v afterwards - v = maybe_non_contig(arg).detach().to(device=device).clone() + v = conjugate(maybe_non_contig(arg)).detach().to(device=device).clone() v.requires_grad = requires_grad and (v.is_floating_point() or v.is_complex()) return v elif callable(arg):