mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Conjugate View (#54987)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/54987 Based off of ezyang (https://github.com/pytorch/pytorch/pull/44799) and bdhirsh (https://github.com/pytorch/pytorch/pull/43702) 's prototype: Here's a summary of the changes in this PR: This PR adds a new dispatch key called Conjugate. This enables us to make conjugate operation a view and leverage the specialized library functions that fast path with the hermitian operation (conj + transpose). 1. Conjugate operation will now return a view with conj bit (1) for complex tensors and returns self for non-complex tensors as before. This also means `torch.view_as_real` will no longer be a view on conjugated complex tensors and is hence disabled. To fill the gap, we have added `torch.view_as_real_physical` which would return the real tensor agnostic of the conjugate bit on the input complex tensor. The information about conjugation on the old tensor can be obtained by calling `.is_conj()` on the new tensor. 2. NEW API: a) `.conj()` -- now returning a view. b) `.conj_physical()` -- does the physical conjugate operation. If the conj bit for input was set, you'd get `self.clone()`, else you'll get a new tensor with conjugated value in its memory. c) `.conj_physical_()`, and `out=` variant d) `.resolve_conj()` -- materializes the conjugation. returns self if the conj bit is unset, else returns a new tensor with conjugated values and conj bit set to 0. e) `.resolve_conj_()` in-place version of (d) f) `view_as_real_physical` -- as described in (1), it's functionally same as `view_as_real`, just that it doesn't error out on conjugated tensors. g) `view_as_real` -- existing function, but now errors out on conjugated tensors. 3. Conjugate Fallback a) Vast majority of PyTorch functions would currently use this fallback when they are called on a conjugated tensor. b) This fallback is well equipped to handle the following cases: - functional operation e.g., `torch.sin(input)` - Mutable inputs and in-place operations e.g., `tensor.add_(2)` - out-of-place operation e.g., `torch.sin(input, out=out)` - Tensorlist input args - NOTE: Meta tensors don't work with conjugate fallback. 4. Autograd a) `resolve_conj()` is an identity function w.r.t. autograd b) Everything else works as expected. 5. Testing: a) All method_tests run with conjugate view tensors. b) OpInfo tests that run with conjugate views - test_variant_consistency_eager/jit - gradcheck, gradgradcheck - test_conj_views (that only run for `torch.cfloat` dtype) NOTE: functions like `empty_like`, `zero_like`, `randn_like`, `clone` don't propagate the conjugate bit. Follow up work: 1. conjugate view RFC 2. Add neg bit to re-enable view operation on conjugated tensors 3. Update linalg functions to call into specialized functions that fast path with the hermitian operation. Test Plan: Imported from OSS Reviewed By: VitalyFedyunin Differential Revision: D28227315 Pulled By: anjali411 fbshipit-source-id: acab9402b9d6a970c6d512809b627a290c8def5f
This commit is contained in:
parent
19985d6f84
commit
3607478ecd
|
|
@ -1031,7 +1031,6 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
|
|||
|
||||
m.impl("sum.dim_IntList", sum_batching_rule);
|
||||
m.impl("is_complex", native::is_complex);
|
||||
m.impl("conj", native::conj);
|
||||
|
||||
// inplace operations
|
||||
m.impl("fill_.Scalar", fill_inplace_scalar_batching_rule);
|
||||
|
|
@ -1085,7 +1084,7 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
|
|||
UNARY_POINTWISE(ceil);
|
||||
UNARY_POINTWISE(cos);
|
||||
UNARY_POINTWISE(cosh);
|
||||
UNARY_POINTWISE(_conj);
|
||||
UNARY_POINTWISE(conj_physical);
|
||||
UNARY_POINTWISE(digamma);
|
||||
UNARY_POINTWISE(exp);
|
||||
UNARY_POINTWISE(expm1);
|
||||
|
|
@ -1181,6 +1180,10 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
|
|||
TRIVIAL_OP(imag)
|
||||
TRIVIAL_OP(real);
|
||||
TRIVIAL_OP(view_as_real);
|
||||
TRIVIAL_OP(_view_as_real_physical);
|
||||
TRIVIAL_OP(conj);
|
||||
TRIVIAL_OP(_conj);
|
||||
TRIVIAL_OP(resolve_conj);
|
||||
m.impl("view_as_complex", view_as_complex_batching_rule);
|
||||
#undef TRIVIAL
|
||||
|
||||
|
|
|
|||
152
aten/src/ATen/ConjugateFallback.cpp
Normal file
152
aten/src/ATen/ConjugateFallback.cpp
Normal file
|
|
@ -0,0 +1,152 @@
|
|||
#include <ATen/ATen.h>
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
#include <torch/library.h>
|
||||
#include <ATen/core/dispatch/Dispatcher.h>
|
||||
#include <ATen/native/UnaryOps.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
|
||||
namespace at {
|
||||
|
||||
void conjugateFallback(const c10::OperatorHandle& op, DispatchKeySet dispatch_keys, torch::jit::Stack* stack) {
|
||||
// Situations to handle:
|
||||
// 1. Out-of-place operation. Easy: materialize all inputs and
|
||||
// call it a day.
|
||||
// 2. Inplace operation. Desugar x.add_(2) into x.conj_().add_(2).conj_().
|
||||
// Materialize other inputs as in (1).
|
||||
// 3. out= operation. Desugar add(x, 2, out=y) into y.copy_(add(x, 2))
|
||||
// Materialize other inputs as in (1).
|
||||
//
|
||||
// It is important to be able to tell if we READ from an argument and if we
|
||||
// WRITE from an argument. Conservative approach is to assume that we always
|
||||
// READ from an argument, but in out-of-place operations you can skip
|
||||
// conjugating inputs on entry that never get used. In current schema we
|
||||
// can't easily tell if inplace situation has happened, so don't do it.
|
||||
|
||||
const auto& arguments = op.schema().arguments();
|
||||
const auto num_arguments = arguments.size();
|
||||
const auto stack_start = stack->size() - num_arguments;
|
||||
|
||||
c10::optional<bool> is_write;
|
||||
for (int64_t i = 0; i < num_arguments; ++i) {
|
||||
const auto& alias_info = arguments[i].alias_info();
|
||||
// Three possible states:
|
||||
// 1. alias_info has no value --> out-of-place operation
|
||||
// 2. alias_info does have a value, alias_info->is_write=True --> in-place or out= operation
|
||||
// 3. alias_info does have a value, alias_info->is_write=False --> view operation
|
||||
if (alias_info.has_value()) {
|
||||
if (is_write.has_value()) {
|
||||
TORCH_CHECK(*is_write == alias_info->isWrite(),
|
||||
"Unsupported operator for conjugate fallback: ", op.schema().name(),
|
||||
"Conjugate fallback doesn't work for operators with a mix "
|
||||
"mutable and non-mutable inputs that alias with outputs, "
|
||||
"this must be implemented manually. "
|
||||
"If you got this error on a core op, please report a bug to PyTorch.");
|
||||
} else {
|
||||
is_write = alias_info->isWrite();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (is_write.has_value() && !*is_write) {
|
||||
// We assume that view operators automatically handle conjugation
|
||||
// correctly by propagating the Conjugate dispatch key in key_set.
|
||||
// This is not necessarily always right, so you should test these cases.
|
||||
op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, DispatchKey::Conjugate), stack);
|
||||
return;
|
||||
}
|
||||
|
||||
// Mutable inputs to be tracked separately
|
||||
std::vector<Tensor> mutable_inputs;
|
||||
|
||||
for (int64_t i = 0; i < num_arguments; ++i) {
|
||||
auto& ivalue = (*stack)[stack_start + i];
|
||||
if (!(ivalue.isTensor() || ivalue.isTensorList())) {
|
||||
continue;
|
||||
}
|
||||
const auto& argument = arguments[i];
|
||||
bool mut_arg = false;
|
||||
if (argument.alias_info()) {
|
||||
// View operations were already filtered above, so only in-place/out= operations should get here.
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(argument.alias_info()->isWrite());
|
||||
mut_arg = true;
|
||||
}
|
||||
if (ivalue.isTensor()) {
|
||||
auto* impl = ivalue.unsafeToTensorImpl();
|
||||
if (!impl->is_conj()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto tensor = std::move(ivalue).toTensor();
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(!tensor.is_meta(), "Conjugate Fallback does not support meta tensors.");
|
||||
if (mut_arg) {
|
||||
// TODO: This is a waste if the argument is write only
|
||||
tensor._set_conj(false);
|
||||
at::conj_physical_(tensor);
|
||||
mutable_inputs.emplace_back(tensor);
|
||||
} else {
|
||||
tensor = at::resolve_conj(tensor);
|
||||
}
|
||||
(*stack)[stack_start + i] = std::move(tensor);
|
||||
} else if (ivalue.isTensorList()) {
|
||||
auto tensors = std::move(ivalue).toTensorList();
|
||||
if (mut_arg) {
|
||||
for(const auto j : c10::irange(tensors.size())) {
|
||||
Tensor t = tensors[j];
|
||||
t._set_conj(false);
|
||||
at::conj_physical_(t);
|
||||
mutable_inputs.emplace_back(t);
|
||||
}
|
||||
} else {
|
||||
for(const auto j : c10::irange(tensors.size())) {
|
||||
tensors[j] = at::resolve_conj(tensors[j]);
|
||||
}
|
||||
}
|
||||
(*stack)[stack_start + i] = std::move(tensors);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, DispatchKey::Conjugate), stack);
|
||||
|
||||
for (auto& mutable_input : mutable_inputs) {
|
||||
at::conj_physical_(mutable_input);
|
||||
mutable_input._set_conj(true);
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(_, Conjugate, m) {
|
||||
m.fallback(torch::CppFunction::makeFromBoxedFunction<&conjugateFallback>());
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(aten, Conjugate, m) {
|
||||
m.impl("requires_grad_", torch::CppFunction::makeFallthrough());
|
||||
m.impl("set_.source_Storage_storage_offset", torch::CppFunction::makeFallthrough());
|
||||
m.impl("set_.source_Tensor", torch::CppFunction::makeFallthrough());
|
||||
m.impl("set_", torch::CppFunction::makeFallthrough());
|
||||
m.impl("copy_", torch::CppFunction::makeFallthrough());
|
||||
m.impl("clone", torch::CppFunction::makeFallthrough());
|
||||
m.impl("conj", torch::CppFunction::makeFallthrough());
|
||||
m.impl("_conj", torch::CppFunction::makeFallthrough());
|
||||
m.impl("_conj_physical", torch::CppFunction::makeFallthrough());
|
||||
m.impl("conj_physical", torch::CppFunction::makeFallthrough());
|
||||
m.impl("conj_physical_", torch::CppFunction::makeFallthrough());
|
||||
m.impl("resolve_conj", torch::CppFunction::makeFallthrough());
|
||||
m.impl("empty_like", torch::CppFunction::makeFallthrough());
|
||||
m.impl("empty.memory_format", torch::CppFunction::makeFallthrough());
|
||||
m.impl("empty.out", torch::CppFunction::makeFallthrough());
|
||||
m.impl("empty_strided", torch::CppFunction::makeFallthrough());
|
||||
m.impl("full_like", torch::CppFunction::makeFallthrough());
|
||||
m.impl("stride.int", torch::CppFunction::makeFallthrough());
|
||||
m.impl("stride.Dimname", torch::CppFunction::makeFallthrough());
|
||||
m.impl("size.int", torch::CppFunction::makeFallthrough());
|
||||
m.impl("size.Dimname", torch::CppFunction::makeFallthrough());
|
||||
m.impl("is_complex", torch::CppFunction::makeFallthrough());
|
||||
m.impl("_view_as_real_physical", torch::CppFunction::makeFallthrough());
|
||||
m.impl("view_as_real", torch::CppFunction::makeFallthrough());
|
||||
m.impl("imag", torch::CppFunction::makeFallthrough());
|
||||
m.impl("real", torch::CppFunction::makeFallthrough());
|
||||
m.impl("view", torch::CppFunction::makeFallthrough());
|
||||
m.impl("reshape", torch::CppFunction::makeFallthrough());
|
||||
}
|
||||
|
||||
} // namespace at
|
||||
|
|
@ -117,7 +117,6 @@ TORCH_LIBRARY_IMPL(aten, Named, m) {
|
|||
m.impl("clamp_min_.Tensor", CppFunction::makeFallthrough());
|
||||
m.impl("clone", CppFunction::makeFallthrough());
|
||||
m.impl("conj", CppFunction::makeFallthrough());
|
||||
m.impl("conj.out", CppFunction::makeFallthrough());
|
||||
m.impl("contiguous", CppFunction::makeFallthrough());
|
||||
m.impl("copy_", CppFunction::makeFallthrough());
|
||||
m.impl("cos", CppFunction::makeFallthrough());
|
||||
|
|
|
|||
|
|
@ -238,6 +238,9 @@ _(aten, coalesce) \
|
|||
_(aten, combinations) \
|
||||
_(aten, _conj) \
|
||||
_(aten, conj) \
|
||||
_(aten, conj_physical) \
|
||||
_(aten, conj_physical_) \
|
||||
_(aten, resolve_conj) \
|
||||
_(aten, complex) \
|
||||
_(aten, copysign) \
|
||||
_(aten, polar) \
|
||||
|
|
@ -764,6 +767,7 @@ _(aten, zeros_like) \
|
|||
_(aten, real) \
|
||||
_(aten, imag) \
|
||||
_(aten, view_as_real) \
|
||||
_(aten, _view_as_real_physical) \
|
||||
_(aten, view_as_complex) \
|
||||
/* nothing */
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,9 @@
|
|||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
// WARNING: this header contains non-inline functions and should be only
|
||||
// included from ONE cpp file
|
||||
|
||||
namespace at { namespace native {
|
||||
|
||||
// View tensor with new dtype, storage offset, sizes and strides
|
||||
|
|
@ -9,8 +12,9 @@ inline Tensor view_tensor(
|
|||
const Tensor &tensor, ScalarType dtype,
|
||||
int64_t offset, IntArrayRef sizes, IntArrayRef strides) {
|
||||
Storage storage = tensor.storage();
|
||||
auto key_set = tensor.key_set().remove(DispatchKey::Conjugate);
|
||||
auto new_tensor = detail::make_tensor<TensorImpl>(
|
||||
c10::TensorImpl::VIEW, std::move(storage), tensor.key_set(), scalarTypeToTypeMeta(dtype));
|
||||
c10::TensorImpl::VIEW, std::move(storage), key_set, scalarTypeToTypeMeta(dtype));
|
||||
auto * impl = new_tensor.unsafeGetTensorImpl();
|
||||
impl->set_storage_offset(offset);
|
||||
impl->set_sizes_and_strides(sizes, strides);
|
||||
|
|
@ -30,7 +34,12 @@ inline DimVector computeStrideForViewAsReal(IntArrayRef oldstride) {
|
|||
// with corresponding real dtype containing the complex values
|
||||
// in the last two dimensions
|
||||
Tensor view_as_real(const Tensor& self) {
|
||||
TORCH_CHECK(self.is_complex(), "view_as_real is only supported for complex tensors");
|
||||
TORCH_CHECK(!self.is_conj(), "view_as_real doesn't work on unresolved conjugated tensors. To resolve the conjugate tensor so you can view it as real, use self.resolve_conj(); however, be warned that the resulting tensor will NOT alias the original.");
|
||||
return native::_view_as_real_physical(self);
|
||||
}
|
||||
|
||||
Tensor _view_as_real_physical(const Tensor& self) {
|
||||
TORCH_CHECK(self.is_complex(), "view_as_real_physical is only supported for complex tensors");
|
||||
auto old_sizes = self.sizes();
|
||||
DimVector new_sizes(old_sizes.size() + 1);
|
||||
std::copy(old_sizes.begin(), old_sizes.end(), new_sizes.begin());
|
||||
|
|
@ -39,7 +48,8 @@ Tensor view_as_real(const Tensor& self) {
|
|||
auto new_strides = computeStrideForViewAsReal(self.strides());
|
||||
auto new_storage_offset = 2 * self.storage_offset();
|
||||
const auto float_type = c10::toValueType(self.scalar_type());
|
||||
return view_tensor(self, float_type, new_storage_offset, new_sizes, new_strides);
|
||||
auto real_tensor = view_tensor(self, float_type, new_storage_offset, new_sizes, new_strides);
|
||||
return real_tensor;
|
||||
}
|
||||
|
||||
inline DimVector computeStrideForViewAsComplex(IntArrayRef oldstride) {
|
||||
|
|
|
|||
|
|
@ -150,7 +150,7 @@ Tensor fft_r2c(c10::string_view function_name,
|
|||
|
||||
if (!forward) {
|
||||
// FIXME: _fft_r2c doesn't support native r2c IFFT
|
||||
return out.defined() ? at::conj_out(out, ret) : at::conj(ret);
|
||||
return out.defined() ? at::conj_physical_out(out, ret) : at::conj(ret);
|
||||
} else {
|
||||
return ret;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -349,6 +349,8 @@ Tensor empty_like(
|
|||
namedinference::propagate_names(result, self.names());
|
||||
}
|
||||
|
||||
// never propagate Conjugate key
|
||||
result._set_conj(false);
|
||||
return result;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -30,6 +30,10 @@ bool is_signed(const Tensor &self) {
|
|||
return self.is_signed();
|
||||
}
|
||||
|
||||
bool is_conj(const Tensor& self) {
|
||||
return self.is_conj();
|
||||
}
|
||||
|
||||
bool is_sparse(const Tensor& self) {
|
||||
return self.is_sparse();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -28,7 +28,8 @@ namespace at {
|
|||
namespace meta {
|
||||
|
||||
// Unary float operations always produce floating point
|
||||
// outputs, even if their inputs are integer
|
||||
// outputs for floating point and integral types
|
||||
// For complex inputs, the output type should be the same as input type.
|
||||
#define CREATE_UNARY_FLOAT_META_FUNC(func) \
|
||||
TORCH_META_FUNC(func) (const Tensor& self) { \
|
||||
build_unary_float_op(maybe_get_output(), self); \
|
||||
|
|
@ -363,7 +364,8 @@ Tensor angle(const Tensor& self) {
|
|||
|
||||
Tensor real(const Tensor& self) {
|
||||
if (self.is_complex()) {
|
||||
auto real_tensor = at::view_as_real(self);
|
||||
// real is never affected by conjugate bit, safe to use physical version
|
||||
auto real_tensor = at::_view_as_real_physical(self);
|
||||
return at::select(real_tensor, real_tensor.dim() - 1, 0);
|
||||
} else {
|
||||
TORCH_CHECK(false, "real is not implemented for tensors with non-complex dtypes.");
|
||||
|
|
@ -379,17 +381,44 @@ Tensor imag(const Tensor& self) {
|
|||
}
|
||||
}
|
||||
|
||||
Tensor& conj_out(const Tensor& self, Tensor& result) {
|
||||
return unary_op_impl_out(result, self, conj_stub);
|
||||
Tensor& conj_physical_out(const Tensor& self, Tensor& result) {
|
||||
return unary_op_impl_out(result, self, conj_physical_stub);
|
||||
}
|
||||
|
||||
Tensor _conj(const Tensor& self) { return unary_op_impl(self, at::conj_out); }
|
||||
Tensor _conj_physical(const Tensor& self) {
|
||||
if (self.is_conj()) {
|
||||
return self.conj().clone();
|
||||
}
|
||||
return unary_op_impl(self, at::conj_physical_out);
|
||||
}
|
||||
|
||||
Tensor conj_physical(const Tensor& self) {
|
||||
if (!self.is_complex()) return self;
|
||||
return at::_conj_physical(self);
|
||||
}
|
||||
|
||||
Tensor& conj_physical_(Tensor& self) {
|
||||
if (!self.is_complex()) return self;
|
||||
return unary_op_impl_out(self, self, conj_physical_stub);
|
||||
}
|
||||
|
||||
Tensor resolve_conj(const Tensor& self) {
|
||||
if (!self.is_conj()) { return self; }
|
||||
// conjugation is handled in `copy_()` that clone ultimately calls into
|
||||
return self.clone(self.suggest_memory_format());
|
||||
}
|
||||
|
||||
Tensor _conj(const Tensor& self) {
|
||||
Tensor self_ = self.alias();
|
||||
self_._set_conj(!self.is_conj());
|
||||
namedinference::propagate_names(self_, self);
|
||||
return self_;
|
||||
}
|
||||
|
||||
Tensor conj(const Tensor& self) {
|
||||
if (!self.is_complex()) {
|
||||
return self;
|
||||
}
|
||||
return at::_conj(self);
|
||||
// This might look like an infinite recursion but it's not.
|
||||
// This actually calls into `conj()` defined in the Tensor class.
|
||||
return self.conj();
|
||||
}
|
||||
|
||||
// special_exp2, alias for exp2
|
||||
|
|
@ -689,7 +718,7 @@ DEFINE_DISPATCH(abs_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-va
|
|||
DEFINE_DISPATCH(angle_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
DEFINE_DISPATCH(real_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
DEFINE_DISPATCH(imag_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
DEFINE_DISPATCH(conj_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
DEFINE_DISPATCH(conj_physical_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
DEFINE_DISPATCH(acos_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
DEFINE_DISPATCH(acosh_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
DEFINE_DISPATCH(asinh_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ DECLARE_DISPATCH(unary_fn, abs_stub);
|
|||
DECLARE_DISPATCH(unary_fn, angle_stub);
|
||||
DECLARE_DISPATCH(unary_fn, real_stub);
|
||||
DECLARE_DISPATCH(unary_fn, imag_stub);
|
||||
DECLARE_DISPATCH(unary_fn, conj_stub);
|
||||
DECLARE_DISPATCH(unary_fn, conj_physical_stub);
|
||||
DECLARE_DISPATCH(unary_fn, acos_stub);
|
||||
DECLARE_DISPATCH(unary_fn, acosh_stub);
|
||||
DECLARE_DISPATCH(unary_fn, asinh_stub);
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
#include <ATen/native/TensorIterator.h>
|
||||
#include <ATen/native/cpu/Loops.h>
|
||||
#include <c10/util/TypeCast.h>
|
||||
#include <ATen/native/cpu/zmath.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
|
@ -13,6 +14,15 @@ namespace {
|
|||
static void copy_kernel(TensorIterator& iter, bool non_blocking) {
|
||||
ScalarType dtype = iter.dtype(0);
|
||||
if (dtype == iter.dtype(1)) {
|
||||
// TODO: as the majority of these operations can be done treating
|
||||
// their datatypes as opaque bit patterns, we don't actually need
|
||||
// separate instantiations per dtype; we only need a separate
|
||||
// instantiation per dtype size. This would probably save us a
|
||||
// little bit of code size here
|
||||
// TODO: not sure if optimizer is able to compile two levels of
|
||||
// conditionals into a single jump table. We should have a
|
||||
// single jump table here; might be worth just writing out the
|
||||
// dispatch statement by hand instead of using AT_DISPATCH
|
||||
if (dtype == ScalarType::Half) {
|
||||
cpu_kernel(iter, [=](at::Half a) -> at::Half { return a; });
|
||||
} else if (dtype == ScalarType::ComplexHalf) {
|
||||
|
|
@ -25,12 +35,21 @@ static void copy_kernel(TensorIterator& iter, bool non_blocking) {
|
|||
[=](Vectorized<scalar_t> a) -> Vectorized<scalar_t> { return a; });
|
||||
});
|
||||
} else if (isComplexType(dtype)) {
|
||||
if (iter.tensor(0).is_conj() == iter.tensor(1).is_conj()) {
|
||||
AT_DISPATCH_COMPLEX_TYPES(dtype, "copy_kernel", [&] {
|
||||
cpu_kernel_vec(
|
||||
iter,
|
||||
[=](scalar_t a) -> scalar_t { return a; },
|
||||
[=](Vectorized<scalar_t> a) -> Vectorized<scalar_t> { return a; });
|
||||
});
|
||||
} else {
|
||||
AT_DISPATCH_COMPLEX_TYPES(dtype, "conj_kernel", [&] {
|
||||
cpu_kernel_vec(
|
||||
iter,
|
||||
[=](scalar_t a) -> scalar_t { return conj_impl(a); },
|
||||
[=](Vectorized<scalar_t> a) -> Vectorized<scalar_t> { return a.conj(); });
|
||||
});
|
||||
}
|
||||
} else {
|
||||
AT_DISPATCH_ALL_TYPES_AND2(
|
||||
ScalarType::Bool, ScalarType::BFloat16,dtype, "copy_kernel", [&] {
|
||||
|
|
@ -62,6 +81,9 @@ static void copy_kernel(TensorIterator& iter, bool non_blocking) {
|
|||
return c10::static_cast_with_inter_type<dest_t, scalar_t>::apply(src); });
|
||||
});
|
||||
});
|
||||
if (iter.tensor(0).is_conj() != iter.tensor(1).is_conj()) {
|
||||
iter.tensor(0).conj_physical_();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -193,6 +193,7 @@ static void imag_kernel(TensorIteratorBase& iter) {
|
|||
});
|
||||
}
|
||||
|
||||
// NB: Ignores the negative bit on tensors
|
||||
static void conj_kernel(TensorIteratorBase& iter) {
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
|
||||
kBool, kBFloat16, kHalf, iter.common_dtype(), "conj_cpu", [&]() {
|
||||
|
|
@ -716,7 +717,7 @@ REGISTER_DISPATCH(real_stub, &CPU_CAPABILITY::real_kernel);
|
|||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
REGISTER_DISPATCH(imag_stub, &CPU_CAPABILITY::imag_kernel);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
REGISTER_DISPATCH(conj_stub, &CPU_CAPABILITY::conj_kernel);
|
||||
REGISTER_DISPATCH(conj_physical_stub, &CPU_CAPABILITY::conj_kernel);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
REGISTER_DISPATCH(exp2_stub, &CPU_CAPABILITY::exp2_kernel);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
|
|
|
|||
|
|
@ -25,7 +25,8 @@ void copy_device_to_device(TensorIterator& iter, bool non_blocking) {
|
|||
// We can memcpy the memory if both tensors have the same type AND both
|
||||
// tensors are contiguous after dimension coalescing and reordering.
|
||||
bool same_type = iter.dtype(0) == iter.dtype(1);
|
||||
bool memcpy_eligible = same_type && iter.is_contiguous();
|
||||
bool same_conj = iter.tensor(0).is_conj() == iter.tensor(1).is_conj();
|
||||
bool memcpy_eligible = same_type && same_conj && iter.is_contiguous();
|
||||
|
||||
Device dst_device = iter.device(0);
|
||||
Device src_device = iter.device(1);
|
||||
|
|
@ -70,6 +71,12 @@ void copy_device_to_device(TensorIterator& iter, bool non_blocking) {
|
|||
AT_DISPATCH_QINT_TYPES(dtype, "copy_", [&] {
|
||||
gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) { return x; });
|
||||
});
|
||||
} else {
|
||||
if (!same_conj && same_type) {
|
||||
AT_DISPATCH_COMPLEX_TYPES(
|
||||
dtype, "copy_conj_", [&] {
|
||||
gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) { return std::conj(x); });
|
||||
});
|
||||
} else {
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
|
||||
kHalf, kBool, kBFloat16, dtype, "copy_", [&] {
|
||||
|
|
@ -77,6 +84,7 @@ void copy_device_to_device(TensorIterator& iter, bool non_blocking) {
|
|||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (src_device != dst_device) {
|
||||
// dst waits on src barrier (dst already waits on dst). We cannot
|
||||
|
|
@ -152,6 +160,10 @@ static void copy_kernel_cuda(TensorIterator& iter, bool non_blocking) {
|
|||
src_contig = iter.tensor(1).expand_as(dst).contiguous();
|
||||
}
|
||||
|
||||
// propagate the correct conjugate bit
|
||||
dst_contig._set_conj(dst.is_conj());
|
||||
src_contig._set_conj(iter.tensor(1).is_conj());
|
||||
|
||||
// perform a same-dtype copy on contiguous tensors
|
||||
TORCH_INTERNAL_ASSERT(dst_contig.sizes().equals(src_contig.sizes()));
|
||||
TORCH_INTERNAL_ASSERT(dst_contig.scalar_type() == src_contig.scalar_type());
|
||||
|
|
@ -201,6 +213,10 @@ static void copy_kernel_cuda(TensorIterator& iter, bool non_blocking) {
|
|||
AT_CUDA_CHECK(cudaStreamSynchronize(stream));
|
||||
#endif
|
||||
}
|
||||
|
||||
if (iter.tensor(0).is_conj() != iter.tensor(1).is_conj()) {
|
||||
iter.tensor(0).conj_physical_();
|
||||
}
|
||||
}
|
||||
|
||||
REGISTER_DISPATCH(copy_stub, ©_kernel_cuda);
|
||||
|
|
|
|||
|
|
@ -80,6 +80,7 @@ __host__ __device__ static inline c10::complex<T> conj_wrapper(c10::complex<T> v
|
|||
return std::conj(v);
|
||||
}
|
||||
|
||||
// NB: Ignores the negative bit on tensors
|
||||
void conj_kernel_cuda(TensorIteratorBase& iter) {
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
|
||||
kBool, kBFloat16, kHalf, iter.common_dtype(), "conj_cuda", [&]() {
|
||||
|
|
@ -92,6 +93,6 @@ void conj_kernel_cuda(TensorIteratorBase& iter) {
|
|||
REGISTER_DISPATCH(angle_stub, &angle_kernel_cuda);
|
||||
REGISTER_DISPATCH(real_stub, &real_kernel_cuda);
|
||||
REGISTER_DISPATCH(imag_stub, &imag_kernel_cuda);
|
||||
REGISTER_DISPATCH(conj_stub, &conj_kernel_cuda);
|
||||
REGISTER_DISPATCH(conj_physical_stub, &conj_kernel_cuda);
|
||||
|
||||
}} // namespace at::native
|
||||
|
|
|
|||
|
|
@ -271,6 +271,11 @@
|
|||
dispatch:
|
||||
CPU, CUDA: view_as_real
|
||||
|
||||
- func: _view_as_real_physical(Tensor(a) self) -> Tensor(a)
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU, CUDA: _view_as_real_physical
|
||||
|
||||
- func: view_as_complex(Tensor(a) self) -> Tensor(a)
|
||||
variants: function
|
||||
dispatch:
|
||||
|
|
@ -298,21 +303,36 @@
|
|||
device_check: NoCheck # TensorIterator
|
||||
variants: function
|
||||
|
||||
- func: conj(Tensor(a) self) -> Tensor(a)
|
||||
device_check: NoCheck # TensorIterator
|
||||
- func: _conj(Tensor(a) self) -> Tensor(a)
|
||||
variants: function, method
|
||||
|
||||
- func: conj.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
|
||||
device_check: NoCheck # TensorIterator
|
||||
dispatch:
|
||||
CPU, CUDA: conj_out
|
||||
SparseCPU, SparseCUDA: conj_out_sparse
|
||||
|
||||
- func: _conj(Tensor self) -> Tensor
|
||||
variants: function
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: _conj
|
||||
|
||||
- func: conj(Tensor(a) self) -> Tensor(a)
|
||||
variants: function, method
|
||||
manual_cpp_binding: True
|
||||
|
||||
- func: _conj_physical(Tensor self) -> Tensor
|
||||
variants: function, method
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: _conj_physical
|
||||
|
||||
- func: conj_physical(Tensor self) -> Tensor
|
||||
variants: function, method
|
||||
|
||||
- func: conj_physical.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
|
||||
dispatch:
|
||||
CPU, CUDA: conj_physical_out
|
||||
SparseCPU, SparseCUDA: conj_physical_out_sparse
|
||||
|
||||
- func: conj_physical_(Tensor(a!) self) -> Tensor(a!)
|
||||
variants: function, method
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: conj_physical_
|
||||
|
||||
- func: resolve_conj(Tensor(a) self) -> Tensor(a)
|
||||
variants: function, method
|
||||
|
||||
- func: acos(Tensor self) -> Tensor
|
||||
device_check: NoCheck # TensorIterator
|
||||
variants: function, method
|
||||
|
|
@ -2249,6 +2269,11 @@
|
|||
device_guard: False
|
||||
manual_cpp_binding: True
|
||||
|
||||
- func: is_conj(Tensor self) -> bool
|
||||
variants: function, method
|
||||
device_guard: False
|
||||
manual_cpp_binding: True
|
||||
|
||||
- func: isreal(Tensor self) -> Tensor
|
||||
variants: function, method
|
||||
|
||||
|
|
|
|||
|
|
@ -1645,15 +1645,15 @@ Tensor& bmm_out_sparse_cpu(const SparseTensor& self, const Tensor& mat2, Tensor&
|
|||
return result;
|
||||
}
|
||||
|
||||
Tensor conj_sparse(const Tensor& input) {
|
||||
if (!input.is_complex()) {
|
||||
return input;
|
||||
}
|
||||
Tensor result = at::native::empty_like(input);
|
||||
return conj_out_sparse(input, result);
|
||||
}
|
||||
// Tensor conj_physical_sparse(const Tensor& input) {
|
||||
// if (!input.is_complex()) {
|
||||
// return input;
|
||||
// }
|
||||
// Tensor result = at::native::empty_like(input);
|
||||
// return conj_physical_out_sparse(input, result);
|
||||
// }
|
||||
|
||||
Tensor& conj_out_sparse(const Tensor& input, Tensor& result) {
|
||||
Tensor& conj_physical_out_sparse(const Tensor& input, Tensor& result) {
|
||||
TORCH_INTERNAL_ASSERT(input.is_sparse());
|
||||
if (input.numel() == 0) {
|
||||
return result;
|
||||
|
|
@ -1665,7 +1665,7 @@ Tensor& conj_out_sparse(const Tensor& input, Tensor& result) {
|
|||
return result;
|
||||
}
|
||||
Tensor result_values = result._values();
|
||||
at::conj_out(result_values, input._values());
|
||||
at::conj_physical_out(result_values, input._values());
|
||||
return result;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -216,4 +216,12 @@ inline bool is_inference(const Tensor& tensor) {
|
|||
return tensor.is_inference();
|
||||
}
|
||||
|
||||
inline bool is_conj(const Tensor& tensor) {
|
||||
return tensor.is_conj();
|
||||
}
|
||||
|
||||
inline Tensor conj(const Tensor& tensor) {
|
||||
return tensor.conj();
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -136,6 +136,17 @@ class TORCH_API Tensor {
|
|||
}
|
||||
}
|
||||
|
||||
Tensor conj() const {
|
||||
if (!this->is_complex()) {
|
||||
return *this;
|
||||
} else {
|
||||
if (this->is_sparse()) {
|
||||
return this->conj_physical();
|
||||
}
|
||||
return this->_conj();
|
||||
}
|
||||
}
|
||||
|
||||
/// Should be used if *this can reasonably be expected to be contiguous and
|
||||
/// performance is important.
|
||||
/// Compared to contiguous, it saves a reference count
|
||||
|
|
@ -363,6 +374,18 @@ class TORCH_API Tensor {
|
|||
return !at::impl::variable_excluded_from_dispatch();
|
||||
}
|
||||
|
||||
inline bool is_conj() const {
|
||||
return impl_->is_conj();
|
||||
}
|
||||
|
||||
// sets the conjugate bit of a tensor.
|
||||
// NOTE: Conjugate bit is supposed to be a read-only field. Only change this, if you are extremely sure
|
||||
// that's what you want. Changing this might lead to incorrect behavior since conjugation is
|
||||
// a lazy operation and we rely on this bit to determine if a conjugation needs to be materialized.
|
||||
inline void _set_conj(bool conjugate) const {
|
||||
impl_->_set_conj(conjugate);
|
||||
}
|
||||
|
||||
/// Returns a `Tensor`'s layout.
|
||||
Layout layout() const noexcept {
|
||||
return impl_->layout();
|
||||
|
|
|
|||
|
|
@ -65,6 +65,8 @@ const char* toString(DispatchKey t) {
|
|||
case DispatchKey::PrivateUse3:
|
||||
return "PrivateUse3";
|
||||
|
||||
case DispatchKey::Conjugate:
|
||||
return "Conjugate";
|
||||
case DispatchKey::Meta:
|
||||
return "Meta";
|
||||
|
||||
|
|
|
|||
|
|
@ -142,6 +142,11 @@ enum class DispatchKey : uint8_t {
|
|||
// constituent parts.
|
||||
Named,
|
||||
|
||||
// The Conjugate dispatch key is set for any tensors that need to perform
|
||||
// conjugation
|
||||
// This is implemented at a dispatch level right before any backends run
|
||||
Conjugate,
|
||||
|
||||
// See Note [Out-of-tree vmap+grad prototype]. The purpose of this key
|
||||
// is to insert code after the "autograd subsystem" runs, so this key should
|
||||
// be directly after ADInplaceOrView and all of the autograd keys.
|
||||
|
|
|
|||
|
|
@ -938,6 +938,26 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
|||
*/
|
||||
const at::Tensor& grad() const;
|
||||
|
||||
/**
|
||||
* Whether or not the imaginary part of the tensor should be negated
|
||||
*/
|
||||
inline bool is_conj() const {
|
||||
return key_set_.has(DispatchKey::Conjugate);
|
||||
}
|
||||
|
||||
/**
|
||||
* Set whether or not to take the conjugate of the tensor (flip the imaginary
|
||||
* bit).
|
||||
*/
|
||||
void _set_conj(bool value) {
|
||||
if (value) {
|
||||
key_set_ = key_set_.add(DispatchKey::Conjugate);
|
||||
TORCH_INTERNAL_ASSERT(isComplexType(typeMetaToScalarType(dtype())));
|
||||
} else {
|
||||
key_set_ = key_set_.remove(DispatchKey::Conjugate);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the accumulated gradient of a tensor. This gradient is computed
|
||||
* using forward mode AD.
|
||||
|
|
|
|||
|
|
@ -64,7 +64,6 @@ For reference, here’s a full list of view ops in PyTorch:
|
|||
- :attr:`~torch.Tensor.real`
|
||||
- :attr:`~torch.Tensor.imag`
|
||||
- :meth:`~torch.Tensor.view_as_real`
|
||||
- :meth:`~torch.Tensor.view_as_imag`
|
||||
- :meth:`~torch.Tensor.unflatten`
|
||||
- :meth:`~torch.Tensor.unfold`
|
||||
- :meth:`~torch.Tensor.unsqueeze`
|
||||
|
|
|
|||
|
|
@ -276,6 +276,9 @@ Tensor class reference
|
|||
Tensor.contiguous
|
||||
Tensor.copy_
|
||||
Tensor.conj
|
||||
Tensor.conj_physical
|
||||
Tensor.conj_physical_
|
||||
Tensor.resolve_conj
|
||||
Tensor.copysign
|
||||
Tensor.copysign_
|
||||
Tensor.cos
|
||||
|
|
@ -410,6 +413,7 @@ Tensor class reference
|
|||
Tensor.isnan
|
||||
Tensor.is_contiguous
|
||||
Tensor.is_complex
|
||||
Tensor.is_conj
|
||||
Tensor.is_floating_point
|
||||
Tensor.is_inference
|
||||
Tensor.is_leaf
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ Tensors
|
|||
is_tensor
|
||||
is_storage
|
||||
is_complex
|
||||
is_conj
|
||||
is_floating_point
|
||||
is_nonzero
|
||||
set_default_dtype
|
||||
|
|
@ -86,6 +87,7 @@ Indexing, Slicing, Joining, Mutating Ops
|
|||
:nosignatures:
|
||||
|
||||
cat
|
||||
conj
|
||||
chunk
|
||||
dsplit
|
||||
column_stack
|
||||
|
|
@ -293,7 +295,7 @@ Pointwise Ops
|
|||
ceil
|
||||
clamp
|
||||
clip
|
||||
conj
|
||||
conj_physical
|
||||
copysign
|
||||
cos
|
||||
cosh
|
||||
|
|
@ -511,6 +513,7 @@ Other Operations
|
|||
vander
|
||||
view_as_real
|
||||
view_as_complex
|
||||
resolve_conj
|
||||
|
||||
|
||||
BLAS and LAPACK Operations
|
||||
|
|
|
|||
|
|
@ -91,6 +91,9 @@ allow_list = [
|
|||
("aten::linalg_vector_norm", datetime.date(2021, 5, 15)),
|
||||
("aten::repeat_interleave", datetime.date(2021, 6, 26)),
|
||||
("aten::one_hot", datetime.date(2021, 6, 15)),
|
||||
("aten::conj", datetime.date(2021, 8, 1)),
|
||||
("aten::_conj", datetime.date(2021, 8, 1)),
|
||||
("aten::conj.out", datetime.date(2021, 8, 1)),
|
||||
]
|
||||
|
||||
def allow_listed(schema, allow_list):
|
||||
|
|
|
|||
|
|
@ -4279,7 +4279,7 @@ class TestAutograd(TestCase):
|
|||
gradcheck(bad_fn, (x, y), check_forward_ad=True, fast_mode=fast_mode)
|
||||
|
||||
def basic_mul(x):
|
||||
return torch.view_as_real(x * 1j)
|
||||
return torch.view_as_real(torch.resolve_conj(x * 1j))
|
||||
gradcheck(basic_mul, x, check_forward_ad=True, fast_mode=fast_mode)
|
||||
|
||||
# Test for one input and one output being complex
|
||||
|
|
@ -5459,7 +5459,7 @@ complex_list = ['t', 'view', 'reshape', 'reshape_as', 'view_as', 'roll', 'clone'
|
|||
'permute', 'squeeze', 'unsqueeze', 'resize', 'resize_as', 'tril', 'triu',
|
||||
'chunk', 'split', 'split_with_sizes', 'zero_',
|
||||
'__radd__', 'mul', '__rmul__', 'diagonal', 'fill_', 'sub', 'narrow',
|
||||
'swapaxes', 'swapdims', 'tensor_split'] + separate_complex_tests
|
||||
'swapaxes', 'swapdims', 'tensor_split', 'select', 'clone'] + separate_complex_tests
|
||||
|
||||
# deny list for batched grad computation
|
||||
EXCLUDE_BATCHED_GRAD_TESTS = set([
|
||||
|
|
@ -5591,7 +5591,9 @@ def add_test(
|
|||
output_variable = getattr(self_variable, name)(*args_variable, **kwargs_variable)
|
||||
if not isinstance(output_variable, tuple):
|
||||
output_variable = (output_variable,)
|
||||
|
||||
inplace_self_variable = deepcopy(self_variable)
|
||||
self.assertEqual(inplace_self_variable, self_variable)
|
||||
inplace_self_variable_copy = tuple(i.clone() if isinstance(i, torch.Tensor) else i
|
||||
for i in (inplace_self_variable,))
|
||||
inplace_args_variable = deepcopy(args_variable)
|
||||
|
|
|
|||
|
|
@ -1946,6 +1946,7 @@ known_failures = [
|
|||
|
||||
# If your OpInfo test causes this test to fail, add it here
|
||||
skip_ops = [
|
||||
'conj'
|
||||
]
|
||||
|
||||
def get_name(op):
|
||||
|
|
|
|||
109
test/test_ops.py
109
test/test_ops.py
|
|
@ -17,7 +17,7 @@ from torch.testing._internal.common_jit import JitCommonTestCase, check_against_
|
|||
from torch.testing._internal.jit_metaprogramming_utils import create_script_fn, create_traced_fn, \
|
||||
check_alias_annotation
|
||||
from torch.testing._internal.jit_utils import disable_autodiff_subgraph_inlining
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
# Get names of all the operators which have entry in `method_tests` (legacy testing infra)
|
||||
method_tested_operators = set(map(lambda test_details: test_details[0], method_tests()))
|
||||
|
|
@ -110,7 +110,9 @@ class TestGradients(TestCase):
|
|||
return variant.__wrapped__ is op.get_inplace()
|
||||
return variant is op.get_inplace()
|
||||
|
||||
samples = op.sample_inputs(device, dtype, requires_grad=True)
|
||||
include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex
|
||||
samples = op.sample_inputs(device, dtype, requires_grad=True, include_conjugated_inputs=include_conjugated_inputs)
|
||||
|
||||
for sample in samples:
|
||||
if sample.broadcasts_input and is_inplace(variant):
|
||||
continue
|
||||
|
|
@ -280,7 +282,8 @@ class TestCommon(JitCommonTestCase):
|
|||
_requires_grad = (op.supports_autograd and
|
||||
(dtype.is_floating_point or op.supports_complex_autograd(torch.device(device).type)))
|
||||
|
||||
samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad)
|
||||
include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex
|
||||
samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad, include_conjugated_inputs=include_conjugated_inputs)
|
||||
|
||||
def _test_consistency_helper(samples, variants):
|
||||
for sample in samples:
|
||||
|
|
@ -381,7 +384,8 @@ class TestCommon(JitCommonTestCase):
|
|||
_requires_grad = op.supports_autograd and (dtype.is_floating_point or
|
||||
op.supports_complex_autograd(torch.device(device).type))
|
||||
|
||||
samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad)
|
||||
include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex
|
||||
samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad, include_conjugated_inputs=include_conjugated_inputs)
|
||||
|
||||
for sample in samples:
|
||||
# Acquires variants to test
|
||||
|
|
@ -734,6 +738,103 @@ class TestCommon(JitCommonTestCase):
|
|||
with self.assertRaises(RuntimeError, msg=msg_fail):
|
||||
op_out(out=out)
|
||||
|
||||
# Tests that
|
||||
# 1. The operator's output for physically conjugated tensors and conjugate view tensors
|
||||
# produces the same value
|
||||
# 2. The gradients are same in both cases mentioned in (1)
|
||||
# 3. If the operator's inplace variant is supported, tests that the inplace operation
|
||||
# produces the correct value when called on a conjugate view tensor and that the output
|
||||
# has its conj bit set to true
|
||||
# This test only runs for C -> R and C -> C functions
|
||||
# TODO: add tests for `R->C` functions
|
||||
# Note: This test runs for functions that take both tensors and tensorlists as input.
|
||||
@ops(op_db, allowed_dtypes=(torch.cfloat,))
|
||||
def test_conj_view(self, device, dtype, op):
|
||||
if not op.test_conjugated_samples:
|
||||
self.skipTest("Operation doesn't support conjugated inputs.")
|
||||
_requires_grad = (op.supports_autograd and op.supports_complex_autograd(torch.device(device).type))
|
||||
samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad)
|
||||
inplace_variant = op.inplace_variant
|
||||
|
||||
# helper function to physically conjugate the tensor
|
||||
def conjugate_physical(input):
|
||||
if isinstance(input, torch.Tensor):
|
||||
tensor_requires_grad = input.requires_grad
|
||||
with torch.no_grad():
|
||||
input = input.conj_physical()
|
||||
return input.requires_grad_(tensor_requires_grad)
|
||||
|
||||
if isinstance(input, Sequence):
|
||||
out = list(map(clone_input_helper, input))
|
||||
out[0] = conjugate_physical(out[0])
|
||||
return tuple(out)
|
||||
|
||||
# helper function to clone and conjugate the input if its a tensor
|
||||
# else clone the sequence and conjugate the first element in the sequence
|
||||
# If a requires_grad argument is provided the tensor being conjugated will
|
||||
# have its requires_grad set to that value.
|
||||
def clone_conj_input_helper(input, **kwargs):
|
||||
if isinstance(input, torch.Tensor):
|
||||
requires_grad = kwargs.get('requires_grad', input.requires_grad)
|
||||
with torch.no_grad():
|
||||
input = input.clone()
|
||||
# Note: .conj() is not called under no_grad mode since it's not allowed to modify a
|
||||
# view created in no_grad mode. Here it's ok to do so, so as a workaround we call conj
|
||||
# before resetting the requires_grad field for input
|
||||
input = input.conj()
|
||||
assert input.is_leaf
|
||||
return input.requires_grad_(requires_grad)
|
||||
|
||||
if isinstance(input, Sequence):
|
||||
out = list(map(clone_input_helper, input))
|
||||
out[0] = clone_conj_input_helper(out[0])
|
||||
return tuple(out)
|
||||
|
||||
for sample in samples:
|
||||
tensor = sample.input if isinstance(sample.input, torch.Tensor) else sample.input[0]
|
||||
cloned1 = clone_conj_input_helper(sample.input)
|
||||
sample.input = conjugate_physical(sample.input)
|
||||
|
||||
# Computes function forward value with a physically conjugated tensor and
|
||||
# a conj view tensor and verifies that the output in both case are equal.
|
||||
expected_forward = op(sample.input, *sample.args, **sample.kwargs)
|
||||
forward_with_conjview = op(cloned1, *sample.args, **sample.kwargs)
|
||||
self.assertEqual(expected_forward, forward_with_conjview)
|
||||
|
||||
# If the op has an inplace variant, and the input doesn't require broadcasting
|
||||
# and has the same dtype as output, verify that the inplace operation on a conjugated
|
||||
# input produces correct output, and the output tensor has the conj bit set to True
|
||||
if inplace_variant is not None and not sample.broadcasts_input:
|
||||
cloned2 = clone_conj_input_helper(tensor, requires_grad=False)
|
||||
if (isinstance(expected_forward, torch.Tensor) and
|
||||
expected_forward.dtype is tensor.dtype):
|
||||
inplace_forward = inplace_variant(cloned2, *sample.args, **sample.kwargs)
|
||||
self.assertTrue(inplace_forward.is_conj())
|
||||
self.assertEqual(inplace_forward, expected_forward)
|
||||
|
||||
# TODO: backward consistency only supported for single tensor outputs
|
||||
# TODO: backward consistency only checked on sample.input, not all
|
||||
# tensor inputs
|
||||
# TODO: update to handle checking grads of all tensor inputs as
|
||||
# derived from each tensor output
|
||||
if isinstance(expected_forward, torch.Tensor) and expected_forward.requires_grad:
|
||||
tensor = sample.input if isinstance(sample.input, torch.Tensor) else sample.input[0]
|
||||
expected_forward.sum().backward(retain_graph=True)
|
||||
forward_with_conjview.sum().backward(retain_graph=True)
|
||||
if tensor.grad is not None:
|
||||
cloned1_tensor = cloned1 if isinstance(cloned1, torch.Tensor) else cloned1[0]
|
||||
self.assertEqual(tensor.grad, cloned1_tensor.grad)
|
||||
|
||||
tensor.grad, cloned1_tensor.grad = None, None
|
||||
|
||||
# a repeat of the above test if output is not complex valued
|
||||
if (expected_forward.is_complex()):
|
||||
grad = torch.randn_like(expected_forward)
|
||||
expected_forward.backward(grad.conj_physical())
|
||||
forward_with_conjview.backward(grad.conj())
|
||||
|
||||
self.assertEqual(tensor.grad, cloned1_tensor.grad)
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestOpInfo, globals())
|
||||
instantiate_device_type_tests(TestGradients, globals())
|
||||
|
|
|
|||
|
|
@ -346,6 +346,14 @@ class TestViewOps(TestCase):
|
|||
self.assertEqual(a[5:].real, a.real[5:])
|
||||
self.assertEqual(a[5:].imag, a.imag[5:])
|
||||
|
||||
@onlyOnCPUAndCUDA
|
||||
@dtypes(*torch.testing.get_all_complex_dtypes())
|
||||
def test_conj_view(self, device, dtype) -> None:
|
||||
t = _make_tensor((4, 5,), dtype, device)
|
||||
v = t.conj()
|
||||
self.assertTrue(self.is_view_of(t, v))
|
||||
self.assertEqual(v, torch.from_numpy(t.cpu().numpy().conj()).to(device=device))
|
||||
|
||||
@onlyOnCPUAndCUDA
|
||||
@dtypes(*product(torch.testing.get_all_complex_dtypes(), torch.testing.get_all_dtypes()))
|
||||
@suppress_warnings
|
||||
|
|
|
|||
|
|
@ -380,16 +380,23 @@
|
|||
|
||||
- name: complex(Tensor real, Tensor imag) -> Tensor
|
||||
real: at::real(grad)
|
||||
imag: at::imag(grad)
|
||||
imag: at::imag(grad.resolve_conj())
|
||||
result: at::complex(real_t, imag_t)
|
||||
|
||||
- name: polar(Tensor abs, Tensor angle) -> Tensor
|
||||
abs, angle: polar_backward(grad, result)
|
||||
|
||||
- name: _conj(Tensor self) -> Tensor
|
||||
- name: _conj(Tensor(a) self) -> Tensor(a)
|
||||
self: grad.conj()
|
||||
result: self_t.conj()
|
||||
|
||||
- name: _conj_physical(Tensor self) -> Tensor
|
||||
self: grad.conj_physical()
|
||||
result: self_t.conj_physical()
|
||||
|
||||
- name: conj_physical_(Tensor(a!) self) -> Tensor(a!)
|
||||
self: grad.conj_physical()
|
||||
|
||||
- name: copysign.Tensor(Tensor self, Tensor other) -> Tensor
|
||||
self: copysign_tensor_self_backward(grad, self, result)
|
||||
other: zeros_like(other)
|
||||
|
|
@ -1306,12 +1313,16 @@
|
|||
- name: view.dtype(Tensor(a) self, ScalarType dtype) -> Tensor(a)
|
||||
output_differentiability: [False]
|
||||
|
||||
# If the conj bit for self is set, this is effectively a fused conj() and view_as_real
|
||||
- name: _view_as_real_physical(Tensor(a) self) -> Tensor(a)
|
||||
self: "self.is_conj() ? at::view_as_complex(grad.contiguous()) : at::view_as_complex(grad.contiguous()).conj()"
|
||||
|
||||
- name: view_as_real(Tensor(a) self) -> Tensor(a)
|
||||
self: at::view_as_complex(grad.contiguous()) # gx0 + 1j * gx1
|
||||
result: at::view_as_real(self_t)
|
||||
|
||||
- name: view_as_complex(Tensor(a) self) -> Tensor(a)
|
||||
self: at::view_as_real(grad.contiguous()) # [gx, gy]
|
||||
self: at::view_as_real(grad.contiguous().resolve_conj()) # [gx, gy]
|
||||
result: at::view_as_complex(self_t)
|
||||
|
||||
- name: _s_where(Tensor condition, Tensor self, Tensor other) -> Tensor
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ from .gen_trace_type import (
|
|||
#
|
||||
# A map: function name => name of the argument that all outputs are view of
|
||||
|
||||
VIEW_FUNCTIONS_WITH_METADATA_CHANGE = ['view_as_real', 'view_as_complex']
|
||||
VIEW_FUNCTIONS_WITH_METADATA_CHANGE = ['view_as_complex', '_view_as_real_physical', 'view_as_real', '_conj']
|
||||
|
||||
VIEW_FUNCTIONS = {
|
||||
'numpy_T': 'self',
|
||||
|
|
|
|||
|
|
@ -100,7 +100,8 @@ GRADIENT_IMPLEMENTED_FOR_COMPLEX = {
|
|||
'replication_pad1d_backward', 'replication_pad2d_backward', 'replication_pad3d_backward',
|
||||
'diag', 'masked_scatter', 'masked_select', 'index_fill', 'trace', 'polar', 'cumsum', 'rsub',
|
||||
'eig', 'lerp', 'linalg_vector_norm', 'cumprod', 'prod', 'index_copy', 'lu', 'unfold', 'unfold_backward',
|
||||
'index', 'masked_fill', 'cross', 'lu_unpack', 'renorm',
|
||||
'index', 'masked_fill', 'cross', 'lu_unpack', 'renorm', '_view_as_real_physical', '_conj_physical',
|
||||
'conj_physical_'
|
||||
}
|
||||
|
||||
GRADIENT_IMPLEMENTED_FOR_SPARSE_COMPLEX = {
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ import yaml
|
|||
from tools.codegen.api.autograd import (Derivative, DifferentiabilityInfo,
|
||||
SavedAttribute, ForwardDerivative)
|
||||
from tools.codegen.api.types import (Binding, CppSignatureGroup, NamedCType, BaseCType, VectorCType,
|
||||
intArrayRefT, tensorOptionsT, typeAndSizeT, intT,
|
||||
intArrayRefT, tensorOptionsT, typeAndSizeT, intT, boolT,
|
||||
tensorGeometryT, scalarTypeT, SpecialArgName,
|
||||
OptionalCType, stringT)
|
||||
from tools.codegen.api import cpp
|
||||
|
|
@ -498,6 +498,11 @@ def saved_variables(
|
|||
'nctype': lambda name: NamedCType(name, BaseCType(intArrayRefT)),
|
||||
'expr': stride_expr,
|
||||
}),
|
||||
# replace self.is_conj() with self_conjugate
|
||||
(r'{}.is_conj\(\)', {
|
||||
'suffix': '_conjugate',
|
||||
'nctype': lambda name: NamedCType(name, BaseCType(boolT)),
|
||||
})
|
||||
]
|
||||
|
||||
# find which arguments need to be saved
|
||||
|
|
|
|||
|
|
@ -86,8 +86,10 @@ class Tensor(torch._C._TensorBase):
|
|||
self.requires_grad,
|
||||
self._backward_hooks)
|
||||
else:
|
||||
new_tensor = self.new()
|
||||
new_tensor = self.new_empty([])
|
||||
new_tensor.set_(new_storage, self.storage_offset(), self.size(), self.stride())
|
||||
if self.is_conj():
|
||||
new_tensor = new_tensor.conj_physical()
|
||||
new_tensor.requires_grad = self.requires_grad
|
||||
if self.grad is not None:
|
||||
new_tensor.grad = self.grad.__deepcopy__(memo)
|
||||
|
|
|
|||
|
|
@ -893,6 +893,27 @@ conj() -> Tensor
|
|||
See :func:`torch.conj`
|
||||
""")
|
||||
|
||||
add_docstr_all('conj_physical',
|
||||
r"""
|
||||
conj_physical() -> Tensor
|
||||
|
||||
See :func:`torch.conj_physical`
|
||||
""")
|
||||
|
||||
add_docstr_all('conj_physical_',
|
||||
r"""
|
||||
conj_physical_() -> Tensor
|
||||
|
||||
In-place version of :meth:`~Tensor.conj_physical`
|
||||
""")
|
||||
|
||||
add_docstr_all('resolve_conj',
|
||||
r"""
|
||||
resolve_conj() -> Tensor
|
||||
|
||||
See :func:`torch.resolve_conj`
|
||||
""")
|
||||
|
||||
add_docstr_all('copysign',
|
||||
r"""
|
||||
copysign(other) -> Tensor
|
||||
|
|
@ -1977,6 +1998,13 @@ is_inference() -> bool
|
|||
See :func:`torch.is_inference`
|
||||
""")
|
||||
|
||||
add_docstr_all('is_conj',
|
||||
r"""
|
||||
is_conj() -> bool
|
||||
|
||||
Returns True if the conjugate bit of :attr:`self` is set to true.
|
||||
""")
|
||||
|
||||
add_docstr_all('is_signed',
|
||||
r"""
|
||||
is_signed() -> bool
|
||||
|
|
|
|||
|
|
@ -153,11 +153,12 @@ class _Formatter(object):
|
|||
def _scalar_str(self, formatter1, formatter2=None):
|
||||
if formatter2 is not None:
|
||||
real_str = _scalar_str(self.real, formatter1)
|
||||
imag_str = _scalar_str(self.imag, formatter2) + "j"
|
||||
if self.imag < 0:
|
||||
return real_str + imag_str.lstrip()
|
||||
imag_str = (_scalar_str(self.imag, formatter2) + "j").lstrip()
|
||||
# handles negative numbers, +0.0, -0.0
|
||||
if imag_str[0] == '+' or imag_str[0] == '-':
|
||||
return real_str + imag_str
|
||||
else:
|
||||
return real_str + "+" + imag_str.lstrip()
|
||||
return real_str + "+" + imag_str
|
||||
else:
|
||||
return formatter1.format(self.item())
|
||||
|
||||
|
|
@ -174,11 +175,12 @@ def _vector_str(self, indent, summarize, formatter1, formatter2=None):
|
|||
def _val_formatter(val, formatter1=formatter1, formatter2=formatter2):
|
||||
if formatter2 is not None:
|
||||
real_str = formatter1.format(val.real)
|
||||
imag_str = formatter2.format(val.imag) + "j"
|
||||
if val.imag < 0:
|
||||
return real_str + imag_str.lstrip()
|
||||
imag_str = (formatter2.format(val.imag) + "j").lstrip()
|
||||
# handles negative numbers, +0.0, -0.0
|
||||
if imag_str[0] == '+' or imag_str[0] == '-':
|
||||
return real_str + imag_str
|
||||
else:
|
||||
return real_str + "+" + imag_str.lstrip()
|
||||
return real_str + "+" + imag_str
|
||||
else:
|
||||
return formatter1.format(val)
|
||||
|
||||
|
|
@ -235,6 +237,7 @@ def _tensor_str(self, indent):
|
|||
self = self.float()
|
||||
|
||||
if self.dtype.is_complex:
|
||||
self = self.resolve_conj()
|
||||
real_formatter = _Formatter(get_summarized_data(self.real) if summarize else self.real)
|
||||
imag_formatter = _Formatter(get_summarized_data(self.imag) if summarize else self.imag)
|
||||
return _tensor_str_with_formatter(self, indent, summarize, real_formatter, imag_formatter)
|
||||
|
|
|
|||
|
|
@ -2181,15 +2181,18 @@ Example::
|
|||
tensor([(0.0000+1.0000j), (-1.4142-1.4142j)], dtype=torch.complex128)
|
||||
""")
|
||||
|
||||
add_docstr(torch.conj,
|
||||
add_docstr(torch.conj_physical,
|
||||
r"""
|
||||
conj(input, *, out=None) -> Tensor
|
||||
conj_physical(input, *, out=None) -> Tensor
|
||||
|
||||
Computes the element-wise conjugate of the given :attr:`input` tensor. If :attr:`input` has a non-complex dtype,
|
||||
this function just returns :attr:`input`.
|
||||
Computes the element-wise conjugate of the given :attr:`input` tensor.
|
||||
If :attr:`input` has a non-complex dtype, this function just returns :attr:`input`.
|
||||
|
||||
.. warning:: In the future, :func:`torch.conj` may return a non-writeable view for an :attr:`input` of
|
||||
non-complex dtype. It's recommended that programs not modify the tensor returned by :func:`torch.conj`
|
||||
.. note::
|
||||
This performs the conjugate operation regardless of the fact conjugate bit is set or not.
|
||||
|
||||
.. warning:: In the future, :func:`torch.conj_physical` may return a non-writeable view for an :attr:`input` of
|
||||
non-complex dtype. It's recommended that programs not modify the tensor returned by :func:`torch.conj_physical`
|
||||
when :attr:`input` is of non-complex dtype to be compatible with this change.
|
||||
|
||||
.. math::
|
||||
|
|
@ -2203,10 +2206,63 @@ Keyword args:
|
|||
|
||||
Example::
|
||||
|
||||
>>> torch.conj(torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j]))
|
||||
>>> torch.conj_physical(torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j]))
|
||||
tensor([-1 - 1j, -2 - 2j, 3 + 3j])
|
||||
""".format(**common_args))
|
||||
|
||||
add_docstr(torch.conj,
|
||||
r"""
|
||||
conj(input) -> Tensor
|
||||
|
||||
Returns a view of :attr:`input` with a flipped conjugate bit. If :attr:`input` has a non-complex dtype,
|
||||
this function just returns :attr:`input`.
|
||||
|
||||
.. note::
|
||||
:func:`torch.conj` performs a lazy conjugation, but the actual conjugated tensor can be materialized
|
||||
at any time using :func:`torch.resolve_conj`.
|
||||
|
||||
.. warning:: In the future, :func:`torch.conj` may return a non-writeable view for an :attr:`input` of
|
||||
non-complex dtype. It's recommended that programs not modify the tensor returned by :func:`torch.conj_physical`
|
||||
when :attr:`input` is of non-complex dtype to be compatible with this change.
|
||||
|
||||
Args:
|
||||
{input}
|
||||
|
||||
Example::
|
||||
|
||||
>>> x = torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j])
|
||||
>>> x.is_conj()
|
||||
False
|
||||
>>> y = torch.conj(x)
|
||||
>>> y.is_conj()
|
||||
True
|
||||
|
||||
""")
|
||||
|
||||
add_docstr(torch.resolve_conj,
|
||||
r"""
|
||||
resolve_conj(input) -> Tensor
|
||||
|
||||
Returns a new tensor with materialized conjugation if :attr:`input`'s conjugate bit is set to `True`,
|
||||
else returns :attr:`input`. The output tensor will always have its conjugate bit set to `False`.
|
||||
|
||||
Args:
|
||||
{input}
|
||||
|
||||
Example::
|
||||
|
||||
>>> x = torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j])
|
||||
>>> y = x.conj()
|
||||
>>> y.is_conj()
|
||||
True
|
||||
>>> z = y.resolve_conj()
|
||||
>>> z
|
||||
tensor([-1 - 1j, -2 - 2j, 3 + 3j])
|
||||
>>> z.is_conj()
|
||||
False
|
||||
|
||||
""")
|
||||
|
||||
add_docstr(torch.copysign,
|
||||
r"""
|
||||
copysign(input, other, *, out=None) -> Tensor
|
||||
|
|
@ -4201,6 +4257,15 @@ Args:
|
|||
{input}
|
||||
""".format(**common_args))
|
||||
|
||||
add_docstr(torch.is_conj, r"""
|
||||
is_conj(input) -> (bool)
|
||||
|
||||
Returns True if the :attr:`input` is a conjugated tensor, i.e. its conjugate bit is set to `True`.
|
||||
|
||||
Args:
|
||||
{input}
|
||||
""".format(**common_args))
|
||||
|
||||
add_docstr(torch.is_nonzero, r"""
|
||||
is_nonzero(input) -> (bool)
|
||||
|
||||
|
|
|
|||
|
|
@ -545,7 +545,7 @@ def _get_analytical_vJu_backward_mode(inputs, outputs, nondet_tol, check_grad_dt
|
|||
# the error checking logic from slow mode
|
||||
vJ = vJ.T.squeeze(0)
|
||||
if vJ.is_complex(): # C -> R
|
||||
tv = torch.view_as_real(vJ)
|
||||
tv = torch.view_as_real(vJ.resolve_conj())
|
||||
tr = tv.select(-1, 0)
|
||||
ti = tv.select(-1, 1)
|
||||
jacobian_scalars.append(tr.dot(u[0]) + 1j * ti.dot(u[1]))
|
||||
|
|
@ -894,7 +894,11 @@ def _real_and_imag_output(fn):
|
|||
outs = _as_tuple(fn(*inputs))
|
||||
return tuple(fn_to_apply(o) if o.is_complex() else o for o in outs)
|
||||
return wrapped_fn
|
||||
return apply_to_c_outs(fn, torch.real), apply_to_c_outs(fn, torch.imag)
|
||||
|
||||
# TODO(@anjali411): remove this workaround once neg bit is added.
|
||||
def torch_imag(x):
|
||||
return x.resolve_conj().imag
|
||||
return apply_to_c_outs(fn, torch.real), apply_to_c_outs(fn, torch_imag)
|
||||
|
||||
def _real_and_imag_input(fn, complex_inp_indices):
|
||||
# returns new functions that take real inputs instead of complex inputs and compute fn(x + 0 * 1j)
|
||||
|
|
@ -935,7 +939,7 @@ def _gradcheck_real_imag(gradcheck_fn, func, func_out, tupled_inputs, outputs, e
|
|||
if complex_inp_indices:
|
||||
real_fn, imag_fn = _real_and_imag_input(func, complex_inp_indices)
|
||||
|
||||
imag_inputs = [inp.imag if is_tensor_like(inp) and inp.is_complex() else inp for inp in tupled_inputs]
|
||||
imag_inputs = [inp.resolve_conj().imag if is_tensor_like(inp) and inp.is_complex() else inp for inp in tupled_inputs]
|
||||
imag_func_out = imag_fn(*imag_inputs)
|
||||
diff_imag_func_out = _differentiable_outputs(imag_func_out)
|
||||
gradcheck_fn(imag_fn, imag_func_out, imag_inputs, diff_imag_func_out, eps,
|
||||
|
|
|
|||
|
|
@ -248,7 +248,7 @@ Tensor pow_backward_self(Tensor grad, const Tensor & self, const Tensor & expone
|
|||
Tensor pow_backward_exponent(Tensor grad, const Tensor& self, const Tensor& exponent, Tensor result) {
|
||||
Tensor cond;
|
||||
if (exponent.is_complex()) {
|
||||
auto is_real_exp = at::logical_and(at::imag(exponent) == 0, at::real(exponent) >= 0);
|
||||
auto is_real_exp = at::logical_and(at::imag(exponent.resolve_conj()) == 0, at::real(exponent) >= 0);
|
||||
cond = at::logical_and(self == 0, is_real_exp);
|
||||
} else {
|
||||
cond = at::logical_and(self == 0, exponent >= 0);
|
||||
|
|
@ -264,7 +264,7 @@ Tensor pow_backward_exponent(Tensor grad, const Scalar & base, const Tensor& exp
|
|||
if (base.equal(0.0)) {
|
||||
auto cond = [](auto exp) {
|
||||
if (exp.is_complex()) {
|
||||
return at::logical_and(at::imag(exp) == 0, at::real(exp) >= 0);
|
||||
return at::logical_and(at::imag(exp.resolve_conj()) == 0, at::real(exp) >= 0);
|
||||
} else {
|
||||
return exp >=0;
|
||||
}
|
||||
|
|
@ -2154,7 +2154,7 @@ Tensor svd_backward(const std::vector<torch::autograd::Variable> &grads, const T
|
|||
if (self.is_complex() && gu.defined()) {
|
||||
Tensor L = at::matmul(uh, gu).diagonal(0, -2, -1);
|
||||
at::real(L).zero_();
|
||||
at::imag(L).mul_(sigma_inv);
|
||||
at::imag(L.resolve_conj()).mul_(sigma_inv);
|
||||
Tensor imag_term = at::matmul(u * L.unsqueeze(-2), vh);
|
||||
return u_term + sigma_term + v_term + imag_term;
|
||||
}
|
||||
|
|
@ -2195,7 +2195,7 @@ Tensor eig_backward(const std::vector<torch::autograd::Variable> &grads, const T
|
|||
}
|
||||
// path for torch.linalg.eig with always a complex tensor of eigenvalues
|
||||
else {
|
||||
is_imag_eigvals_zero = (at::imag(D) == 0.0).min().item<bool>();
|
||||
is_imag_eigvals_zero = (at::imag(D.resolve_conj()) == 0.0).min().item<bool>();
|
||||
// insert an additional dimension to be compatible with torch.eig.
|
||||
// Recall that it produces 2D tensors.
|
||||
// We extract only the real parts as there is no support for
|
||||
|
|
|
|||
|
|
@ -83,23 +83,6 @@ void nnc_aten_sgn(
|
|||
} catch (...) {
|
||||
}
|
||||
}
|
||||
void nnc_aten_conj(
|
||||
int64_t bufs_num,
|
||||
void** buf_data,
|
||||
int64_t* buf_ranks,
|
||||
int64_t* buf_dims,
|
||||
int8_t* buf_dtypes,
|
||||
int64_t args_num,
|
||||
int64_t* extra_args) {
|
||||
std::vector<at::Tensor> tensors =
|
||||
constructTensors(bufs_num, buf_data, buf_ranks, buf_dims, buf_dtypes);
|
||||
at::Tensor& r = tensors[0];
|
||||
const at::Tensor& self = tensors[1];
|
||||
try {
|
||||
at::conj_out(r, self);
|
||||
} catch (...) {
|
||||
}
|
||||
}
|
||||
void nnc_aten_acos(
|
||||
int64_t bufs_num,
|
||||
void** buf_data,
|
||||
|
|
@ -2758,9 +2741,6 @@ const static RegisterNNCExternalFunction nnc_angle(
|
|||
"nnc_aten_angle",
|
||||
nnc_aten_angle);
|
||||
const static RegisterNNCExternalFunction nnc_sgn("nnc_aten_sgn", nnc_aten_sgn);
|
||||
const static RegisterNNCExternalFunction nnc_conj(
|
||||
"nnc_aten_conj",
|
||||
nnc_aten_conj);
|
||||
const static RegisterNNCExternalFunction nnc_acos(
|
||||
"nnc_aten_acos",
|
||||
nnc_aten_acos);
|
||||
|
|
|
|||
|
|
@ -228,6 +228,8 @@ def get_ignored_functions() -> Set[Callable]:
|
|||
Tensor.to_sparse_csr,
|
||||
Tensor._reduce_ex_internal,
|
||||
Tensor._fix_weakref,
|
||||
Tensor._conj,
|
||||
Tensor._conj_physical,
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -349,6 +351,8 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
|||
torch.polar: lambda abs, ang: -1,
|
||||
torch.linalg.cond: lambda input, ord=None: -1,
|
||||
torch.conj: lambda input, out=None: -1,
|
||||
torch.conj_physical: lambda input, out=None: -1,
|
||||
torch.resolve_conj: lambda input, out=None: -1,
|
||||
torch.constant_pad_nd: lambda input, pad, value=0: -1,
|
||||
torch.conv1d: lambda input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1: -1,
|
||||
torch.conv2d: lambda input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1: -1,
|
||||
|
|
@ -500,6 +504,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
|||
torch.linalg.inv: lambda input, out=None: -1,
|
||||
torch.linalg.inv_ex: lambda input, check_errors=False, out=None: -1,
|
||||
torch.is_complex: lambda input: -1,
|
||||
torch.is_conj: lambda input: -1,
|
||||
torch.is_distributed: lambda input: -1,
|
||||
torch.is_inference: lambda input: -1,
|
||||
torch.is_floating_point: lambda input: -1,
|
||||
|
|
@ -797,6 +802,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
|||
torch.real: lambda input, out=None: -1,
|
||||
torch.vdot: lambda input, other, out=None: -1,
|
||||
torch.view_as_real: lambda input: -1,
|
||||
torch._view_as_real_physical: lambda input: -1,
|
||||
torch.view_as_complex: lambda input: -1,
|
||||
torch.reciprocal: lambda input, out=None: -1,
|
||||
torch.relu: lambda input, inplace=False: -1,
|
||||
|
|
|
|||
|
|
@ -134,6 +134,8 @@ def _compare_tensors_internal(a: torch.Tensor, b: torch.Tensor, *, rtol, atol, e
|
|||
# Compares complex tensors' real and imaginary parts separately.
|
||||
# (see NOTE Test Framework Tensor "Equality")
|
||||
if a.is_complex():
|
||||
a = a.resolve_conj()
|
||||
b = b.resolve_conj()
|
||||
if equal_nan == "relaxed":
|
||||
a = a.clone()
|
||||
b = b.clone()
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@ import random
|
|||
import torch
|
||||
import numpy as np
|
||||
from torch._six import inf
|
||||
from torch.autograd import Variable
|
||||
import collections.abc
|
||||
|
||||
from typing import List, Sequence, Tuple, Dict, Any, Union
|
||||
|
|
@ -231,6 +230,7 @@ class OpInfo(object):
|
|||
# function around gradcheck (testing._internal.common_utils.gradcheck)
|
||||
inplace_variant=_NOTHING, # explicitly pass the inplace variant of the operator if required
|
||||
method_variant=_NOTHING, # explicitly pass the method variant of the operator if required
|
||||
test_conjugated_samples=True,
|
||||
):
|
||||
|
||||
# Validates the dtypes are generated from the dispatch-related functions
|
||||
|
|
@ -300,6 +300,8 @@ class OpInfo(object):
|
|||
if aliases is not None:
|
||||
self.aliases = tuple(AliasInfo(a) for a in aliases) # type: ignore[assignment]
|
||||
|
||||
self.test_conjugated_samples = test_conjugated_samples
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""Calls the function variant of the operator."""
|
||||
return self.op(*args, **kwargs)
|
||||
|
|
@ -326,6 +328,37 @@ class OpInfo(object):
|
|||
"""
|
||||
return self.operator_variant
|
||||
|
||||
def conjugate_sample_inputs(self, device, dtype, requires_grad=False, **kwargs):
|
||||
"""Returns an iterable of SampleInputs but with the tensor input or first
|
||||
tensor in a sequence input conjugated.
|
||||
"""
|
||||
|
||||
# TODO: Remove the try/except once all operators have sample_inputs_func with
|
||||
# **kwargs in their signature.
|
||||
try:
|
||||
samples = self.sample_inputs_func(self, device, dtype, requires_grad, **kwargs)
|
||||
except TypeError:
|
||||
samples = self.sample_inputs_func(self, device, dtype, requires_grad)
|
||||
|
||||
conj_samples = list(samples)
|
||||
|
||||
def conjugate(tensor):
|
||||
_requires_grad = tensor.requires_grad
|
||||
with torch.no_grad():
|
||||
tensor = tensor.conj()
|
||||
return tensor.requires_grad_(_requires_grad)
|
||||
|
||||
for i in range(len(samples)):
|
||||
sample = conj_samples[i]
|
||||
# Note: it is assumed that the input here is either a tensor or tensorlist
|
||||
if isinstance(sample.input, torch.Tensor):
|
||||
sample.input = conjugate(sample.input)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
sample.input[0] = conjugate(sample.input[0])
|
||||
|
||||
return tuple(conj_samples)
|
||||
|
||||
def sample_inputs(self, device, dtype, requires_grad=False, **kwargs):
|
||||
"""Returns an iterable of SampleInputs.
|
||||
|
||||
|
|
@ -339,6 +372,13 @@ class OpInfo(object):
|
|||
samples = self.sample_inputs_func(self, device, dtype, requires_grad, **kwargs)
|
||||
except TypeError:
|
||||
samples = self.sample_inputs_func(self, device, dtype, requires_grad)
|
||||
|
||||
if 'include_conjugated_inputs' in kwargs and kwargs.get('include_conjugated_inputs'):
|
||||
conj_samples = self.conjugate_sample_inputs(device, dtype, requires_grad, **kwargs)
|
||||
samples_list = list(samples)
|
||||
samples_list.extend(conj_samples)
|
||||
samples = tuple(samples_list)
|
||||
|
||||
return samples
|
||||
|
||||
# Returns True if the test should be skipped and False otherwise
|
||||
|
|
@ -1001,21 +1041,21 @@ def sample_inputs_addbmm(op_info, device, dtype, requires_grad, **kwargs):
|
|||
return tuple(sample_inputs)
|
||||
|
||||
def sample_inputs_addcmul_addcdiv(op_info, device, dtype, requires_grad, **kwargs):
|
||||
test_cases = [((S, S), (S, S), (S, S)),
|
||||
((S, S), (S, 1), (1, S)),
|
||||
((1,), (S, S, 1), (1, S)),
|
||||
((), (), ()),
|
||||
((S, S), (), ()),
|
||||
((), (S, S, 1), (1, S)),
|
||||
test_cases = [(((S, S), (S, S), (S, S)), False),
|
||||
(((S, S), (S, 1), (1, S)), False),
|
||||
(((1,), (S, S, 1), (1, S)), True),
|
||||
(((), (), ()), False),
|
||||
(((S, S), (), ()), True),
|
||||
(((), (S, S, 1), (1, S)), True)
|
||||
]
|
||||
|
||||
sample_inputs = []
|
||||
for input_args in test_cases:
|
||||
for input_args, broadcasts_input in test_cases:
|
||||
args = tuple(make_tensor(arg, device, dtype, requires_grad=requires_grad) if isinstance(arg, tuple) else arg
|
||||
for arg in input_args)
|
||||
sample_inputs.append(SampleInput(args[0], args=args[1:]))
|
||||
sample_inputs.append(SampleInput(args[0], args=args[1:], broadcasts_input=broadcasts_input))
|
||||
|
||||
sample_inputs.append(SampleInput(args[0], args=args[1:], kwargs=dict(value=3.14)))
|
||||
sample_inputs.append(SampleInput(args[0], args=args[1:], kwargs=dict(value=3.14), broadcasts_input=broadcasts_input))
|
||||
|
||||
return tuple(sample_inputs)
|
||||
|
||||
|
|
@ -1057,7 +1097,8 @@ def sample_inputs_addr(op_info, device, dtype, requires_grad, **kwargs):
|
|||
make_tensor((), device, dtype, low=None, high=None, requires_grad=requires_grad),
|
||||
args=(
|
||||
make_tensor((S, ), device, dtype, low=None, high=None, requires_grad=requires_grad),
|
||||
make_tensor((M, ), device, dtype, low=None, high=None, requires_grad=requires_grad)))
|
||||
make_tensor((M, ), device, dtype, low=None, high=None, requires_grad=requires_grad)),
|
||||
broadcasts_input=True)
|
||||
|
||||
if dtype.is_complex:
|
||||
alpha, beta = 0.1 + 0.3j, 0.4 + 0.6j
|
||||
|
|
@ -1078,7 +1119,8 @@ def sample_inputs_addr(op_info, device, dtype, requires_grad, **kwargs):
|
|||
args=(
|
||||
make_tensor((S, ), device, dtype, low=None, high=None, requires_grad=requires_grad),
|
||||
make_tensor((M, ), device, dtype, low=None, high=None, requires_grad=requires_grad)),
|
||||
kwargs=dict(beta=beta, alpha=alpha))
|
||||
kwargs=dict(beta=beta, alpha=alpha),
|
||||
broadcasts_input=True)
|
||||
|
||||
return (input1, input2, input3, input4)
|
||||
|
||||
|
|
@ -2196,36 +2238,7 @@ def np_unary_ufunc_integer_promotion_wrapper(fn):
|
|||
|
||||
return wrapped_fn
|
||||
|
||||
|
||||
# Metadata class for Fast Fourier Transforms in torch.fft.
|
||||
class SpectralFuncInfo(OpInfo):
|
||||
"""Operator information for torch.fft transforms. """
|
||||
|
||||
def __init__(self,
|
||||
name, # the string name of the function
|
||||
*,
|
||||
ref=None, # Reference implementation (probably in np.fft namespace)
|
||||
dtypes=floating_and_complex_types(),
|
||||
ndimensional: bool, # Whether dim argument can be a tuple
|
||||
decorators=None,
|
||||
**kwargs):
|
||||
decorators = list(decorators) if decorators is not None else []
|
||||
decorators += [
|
||||
skipCPUIfNoMkl,
|
||||
skipCUDAIfRocm,
|
||||
# gradgrad is quite slow
|
||||
DecorateInfo(slowTest, 'TestGradients', 'test_fn_gradgrad'),
|
||||
]
|
||||
|
||||
super().__init__(name=name,
|
||||
dtypes=dtypes,
|
||||
decorators=decorators,
|
||||
**kwargs)
|
||||
self.ref = ref if ref is not None else _getattr_qual(np, name)
|
||||
self.ndimensional = ndimensional
|
||||
|
||||
|
||||
def sample_inputs(self, device, dtype, requires_grad=False, **kwargs):
|
||||
def sample_inputs_spectral_ops(self, device, dtype, requires_grad=False, **kwargs):
|
||||
nd_tensor = make_tensor((S, S + 1, S + 2), device, dtype, low=None, high=None,
|
||||
requires_grad=requires_grad)
|
||||
tensor = make_tensor((31,), device, dtype, low=None, high=None,
|
||||
|
|
@ -2252,6 +2265,35 @@ class SpectralFuncInfo(OpInfo):
|
|||
for dim in [-1, -2, -3]),
|
||||
]
|
||||
|
||||
# Metadata class for Fast Fourier Transforms in torch.fft.
|
||||
class SpectralFuncInfo(OpInfo):
|
||||
"""Operator information for torch.fft transforms. """
|
||||
|
||||
def __init__(self,
|
||||
name, # the string name of the function
|
||||
*,
|
||||
ref=None, # Reference implementation (probably in np.fft namespace)
|
||||
dtypes=floating_and_complex_types(),
|
||||
ndimensional: bool, # Whether dim argument can be a tuple
|
||||
sample_inputs_func=sample_inputs_spectral_ops,
|
||||
decorators=None,
|
||||
**kwargs):
|
||||
decorators = list(decorators) if decorators is not None else []
|
||||
decorators += [
|
||||
skipCPUIfNoMkl,
|
||||
skipCUDAIfRocm,
|
||||
# gradgrad is quite slow
|
||||
DecorateInfo(slowTest, 'TestGradients', 'test_fn_gradgrad'),
|
||||
]
|
||||
|
||||
super().__init__(name=name,
|
||||
dtypes=dtypes,
|
||||
decorators=decorators,
|
||||
sample_inputs_func=sample_inputs_func,
|
||||
**kwargs)
|
||||
self.ref = ref if ref is not None else _getattr_qual(np, name)
|
||||
self.ndimensional = ndimensional
|
||||
|
||||
|
||||
class ShapeFuncInfo(OpInfo):
|
||||
"""Early version of a specialized OpInfo for Shape manipulating operations like tile and roll"""
|
||||
|
|
@ -2737,13 +2779,13 @@ def sample_inputs_pow(op_info, device, dtype, requires_grad, **kwargs):
|
|||
test_cases = (
|
||||
((2, 2), 0, 5, 1e-3, requires_grad, (2, 2), 0, 1, 0.1, requires_grad, False),
|
||||
((2, 2), 0, 5, 1e-3, requires_grad, (1,), 0, 1, 0.1, requires_grad, False),
|
||||
((), 1e-3, 1e-3 + 1, 0, True, (), 0.1, 1.1, 0, False, False),
|
||||
((), 1e-3, 1e-3 + 1, 0, requires_grad, (), 0.1, 1.1, 0, False, False),
|
||||
((2, 2), 0, 5, 1e-3, requires_grad, (), 0.1, 1.1, 1, False, False),
|
||||
)
|
||||
tests_require_resizing = (
|
||||
((1,), 0, 5, 1e-3, requires_grad, (2, 2), 0, 1, 0.1, requires_grad, True),
|
||||
((2, 1, 2), 0, 5, 1e-3, requires_grad, (1, 2, 1), 0, 1, 0.1, requires_grad, True),
|
||||
((), 1e-3, 1e-3 + 1, 0, True, (1, S, 1), 0, 1, 0.1, requires_grad, True),
|
||||
((1,), 0, 5, 1e-3, requires_grad, (2, 2), 0, 1, 0.1, requires_grad, requires_grad),
|
||||
((2, 1, 2), 0, 5, 1e-3, requires_grad, (1, 2, 1), 0, 1, 0.1, requires_grad, requires_grad),
|
||||
((), 1e-3, 1e-3 + 1, 0, requires_grad, (1, S, 1), 0, 1, 0.1, requires_grad, requires_grad),
|
||||
)
|
||||
cases = test_cases + tests_require_resizing
|
||||
samples = list(SampleInput(make_tensor(shape_b, low=low_b, high=high_b,
|
||||
|
|
@ -2757,7 +2799,7 @@ def sample_inputs_pow(op_info, device, dtype, requires_grad, **kwargs):
|
|||
high_e, additive_e, e_grad, broadcasts_input in cases)
|
||||
tensor_scalar_inputs = (
|
||||
((2, 2), 0, 5, 1e-3, requires_grad, (3.14,)),
|
||||
((), 1e-3, 1e-3 + 1, 0, True, (3.14,))
|
||||
((), 1e-3, 1e-3 + 1, 0, requires_grad, (3.14,))
|
||||
)
|
||||
more_samples = list(SampleInput(make_tensor(shape, dtype=dtype, device=device,
|
||||
high=high, low=low,
|
||||
|
|
@ -2768,8 +2810,8 @@ def sample_inputs_pow(op_info, device, dtype, requires_grad, **kwargs):
|
|||
elif dtype in [torch.complex64, torch.complex128]:
|
||||
args_tuple = (
|
||||
((2, 2), 0, 5, requires_grad, (3.14,)),
|
||||
((), 0, 1, True, (3.14,)),
|
||||
((), 0, 1, True, (3.14j,))
|
||||
((), 0, 1, requires_grad, (3.14,)),
|
||||
((), 0, 1, requires_grad, (3.14j,))
|
||||
)
|
||||
samples = list(SampleInput(make_tensor(shape, dtype=dtype, device=device,
|
||||
high=high, low=low,
|
||||
|
|
@ -4298,7 +4340,7 @@ op_db: List[OpInfo] = [
|
|||
SkipInfo('TestGradients', 'test_method_grad',
|
||||
device_type='cuda', dtypes=[torch.cdouble], active_if=IS_WINDOWS),
|
||||
SkipInfo('TestGradients', 'test_forward_mode_AD',
|
||||
dtypes=[torch.cdouble], active_if=IS_WINDOWS),
|
||||
dtypes=[torch.cdouble]),
|
||||
)),
|
||||
OpInfo('add',
|
||||
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
|
||||
|
|
@ -4670,28 +4712,30 @@ op_db: List[OpInfo] = [
|
|||
supports_forward_ad=True,
|
||||
),
|
||||
UnaryUfuncInfo('conj',
|
||||
ref=np.conj,
|
||||
dtypes=all_types_and_complex_and(torch.bool,
|
||||
torch.bfloat16, torch.half),
|
||||
supports_forward_ad=True,
|
||||
supports_out=False),
|
||||
UnaryUfuncInfo('conj_physical',
|
||||
ref=np.conj,
|
||||
dtypes=all_types_and_complex_and(torch.bool,
|
||||
torch.bfloat16, torch.half),
|
||||
supports_forward_ad=True,
|
||||
skips=(
|
||||
# File "test_unary_ufuncs.py", line 289, in test_reference_numerics
|
||||
# if not torch.can_cast(numpy_to_torch_dtype_dict[expected.dtype.type], dtype):
|
||||
# KeyError: <class 'numpy.intc'>
|
||||
# Following error in Windows CI
|
||||
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics_normal',
|
||||
dtypes=[torch.int],
|
||||
active_if=IS_WINDOWS),
|
||||
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics_hard',
|
||||
dtypes=[torch.int],
|
||||
active_if=IS_WINDOWS),
|
||||
# TODO fix the formula for complex forward AD
|
||||
SkipInfo('TestGradients', 'test_forward_mode_AD'),
|
||||
SkipInfo('TestCommon', 'test_variant_consistency_jit', dtypes=(torch.float32, )),
|
||||
)),
|
||||
OpInfo('resolve_conj',
|
||||
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
|
||||
sample_inputs_func=sample_inputs_view_as_real,
|
||||
supports_forward_ad=True,
|
||||
supports_out=False,
|
||||
),
|
||||
OpInfo('view_as_real',
|
||||
dtypes=complex_types(),
|
||||
supports_forward_ad=True,
|
||||
sample_inputs_func=sample_inputs_view_as_real,
|
||||
test_conjugated_samples=False,
|
||||
),
|
||||
OpInfo('view_as_complex',
|
||||
dtypes=floating_types_and(torch.half),
|
||||
|
|
@ -5120,7 +5164,9 @@ op_db: List[OpInfo] = [
|
|||
ref=np.imag,
|
||||
dtypes=complex_types(),
|
||||
supports_out=False,
|
||||
supports_autograd=False,
|
||||
supports_forward_ad=True,
|
||||
# TODO(@anjali411): Test this once neg bit is added.
|
||||
test_conjugated_samples=False,
|
||||
skips=(
|
||||
# Skip since real and imag don't have out variants.
|
||||
SkipInfo('TestUnaryUfuncs', 'test_out_arg_all_dtypes'),
|
||||
|
|
@ -5251,8 +5297,6 @@ op_db: List[OpInfo] = [
|
|||
supports_autograd=False,
|
||||
decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack],
|
||||
skips=(
|
||||
# skip because `linalg_lstsq` is not differentiable
|
||||
SkipInfo('TestGradients', 'test_fn_grad'),
|
||||
SkipInfo('TestCommon', 'test_variant_consistency_jit'),
|
||||
)),
|
||||
OpInfo('linalg.matrix_power',
|
||||
|
|
@ -5471,7 +5515,8 @@ op_db: List[OpInfo] = [
|
|||
# "RuntimeError: CUDA error: CUBLAS_STATUS_NOT_SUPPORTED when
|
||||
# calling cublasGemmStridedBatchedExFix."
|
||||
SkipInfo('TestOpInfo', 'test_supported_backward',
|
||||
device_type='cuda', dtypes=(torch.bfloat16,)),)),
|
||||
device_type='cuda', dtypes=(torch.bfloat16,)),
|
||||
SkipInfo('TestCommon', 'test_conj_view', device_type='cpu'),),),
|
||||
OpInfo('max',
|
||||
op=torch.max,
|
||||
variant_test_name='binary',
|
||||
|
|
@ -5699,7 +5744,9 @@ op_db: List[OpInfo] = [
|
|||
assert_autodiffed=True),
|
||||
OpInfo('float_power',
|
||||
dtypes=all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool),
|
||||
sample_inputs_func=sample_inputs_pow),
|
||||
sample_inputs_func=sample_inputs_pow,
|
||||
skips=(
|
||||
SkipInfo('TestCommon', 'test_conj_view', device_type='cuda'),),),
|
||||
OpInfo('prod',
|
||||
dtypes=all_types_and_complex_and(torch.bool),
|
||||
dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
|
||||
|
|
@ -5739,7 +5786,6 @@ op_db: List[OpInfo] = [
|
|||
ref=np.real,
|
||||
dtypes=complex_types(),
|
||||
supports_out=False,
|
||||
supports_autograd=False,
|
||||
skips=(
|
||||
# Skip since real and imag don't have out variants.
|
||||
SkipInfo('TestUnaryUfuncs', 'test_out_arg_all_dtypes'),
|
||||
|
|
@ -7068,23 +7114,26 @@ def create_input(call_args, requires_grad=True, non_contiguous=False, call_kwarg
|
|||
def maybe_non_contig(tensor):
|
||||
return tensor if not non_contiguous else make_non_contiguous(tensor)
|
||||
|
||||
def conjugate(tensor):
|
||||
return tensor.conj()
|
||||
|
||||
if isinstance(arg, torch.Size) or isinstance(arg, dont_convert):
|
||||
return arg
|
||||
elif isinstance(arg, tuple) and len(arg) == 0:
|
||||
var = torch.randn((), dtype=dtype, device=device)
|
||||
var = conjugate(torch.randn((), dtype=dtype, device=device))
|
||||
var.requires_grad = requires_grad
|
||||
return var
|
||||
elif isinstance(arg, tuple) and not isinstance(arg[0], torch.Tensor):
|
||||
return Variable(maybe_non_contig(torch.randn(*arg, dtype=dtype, device=device)), requires_grad=requires_grad)
|
||||
return conjugate(maybe_non_contig(torch.randn(*arg, dtype=dtype, device=device))).requires_grad_(requires_grad)
|
||||
# double check casting
|
||||
elif isinstance(arg, non_differentiable):
|
||||
if isinstance(arg.tensor, torch.Tensor):
|
||||
if arg.tensor.dtype == torch.float:
|
||||
return maybe_non_contig(arg.tensor.to(dtype=torch.double, device=device))
|
||||
if arg.tensor.dtype == torch.cfloat:
|
||||
return maybe_non_contig(arg.tensor.to(dtype=torch.cdouble, device=device))
|
||||
return maybe_non_contig(arg.tensor.to(device=device))
|
||||
return maybe_non_contig(arg.tensor.to(device=device))
|
||||
return conjugate(maybe_non_contig(arg.tensor.to(dtype=torch.cdouble, device=device)))
|
||||
return conjugate(maybe_non_contig(arg.tensor.to(device=device)))
|
||||
return conjugate(maybe_non_contig(arg.tensor.to(device=device)))
|
||||
elif isinstance(arg, torch.Tensor):
|
||||
if arg.dtype == torch.float:
|
||||
arg = arg.double()
|
||||
|
|
@ -7094,7 +7143,7 @@ def create_input(call_args, requires_grad=True, non_contiguous=False, call_kwarg
|
|||
raise RuntimeError("User provided tensor is real for a test that runs with complex dtype, ",
|
||||
"which is not supported for now")
|
||||
# NOTE: We do clone() after detach() here because we need to be able to change size/storage of v afterwards
|
||||
v = maybe_non_contig(arg).detach().to(device=device).clone()
|
||||
v = conjugate(maybe_non_contig(arg)).detach().to(device=device).clone()
|
||||
v.requires_grad = requires_grad and (v.is_floating_point() or v.is_complex())
|
||||
return v
|
||||
elif callable(arg):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user