mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Conjugate View (#54987)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/54987 Based off of ezyang (https://github.com/pytorch/pytorch/pull/44799) and bdhirsh (https://github.com/pytorch/pytorch/pull/43702) 's prototype: Here's a summary of the changes in this PR: This PR adds a new dispatch key called Conjugate. This enables us to make conjugate operation a view and leverage the specialized library functions that fast path with the hermitian operation (conj + transpose). 1. Conjugate operation will now return a view with conj bit (1) for complex tensors and returns self for non-complex tensors as before. This also means `torch.view_as_real` will no longer be a view on conjugated complex tensors and is hence disabled. To fill the gap, we have added `torch.view_as_real_physical` which would return the real tensor agnostic of the conjugate bit on the input complex tensor. The information about conjugation on the old tensor can be obtained by calling `.is_conj()` on the new tensor. 2. NEW API: a) `.conj()` -- now returning a view. b) `.conj_physical()` -- does the physical conjugate operation. If the conj bit for input was set, you'd get `self.clone()`, else you'll get a new tensor with conjugated value in its memory. c) `.conj_physical_()`, and `out=` variant d) `.resolve_conj()` -- materializes the conjugation. returns self if the conj bit is unset, else returns a new tensor with conjugated values and conj bit set to 0. e) `.resolve_conj_()` in-place version of (d) f) `view_as_real_physical` -- as described in (1), it's functionally same as `view_as_real`, just that it doesn't error out on conjugated tensors. g) `view_as_real` -- existing function, but now errors out on conjugated tensors. 3. Conjugate Fallback a) Vast majority of PyTorch functions would currently use this fallback when they are called on a conjugated tensor. b) This fallback is well equipped to handle the following cases: - functional operation e.g., `torch.sin(input)` - Mutable inputs and in-place operations e.g., `tensor.add_(2)` - out-of-place operation e.g., `torch.sin(input, out=out)` - Tensorlist input args - NOTE: Meta tensors don't work with conjugate fallback. 4. Autograd a) `resolve_conj()` is an identity function w.r.t. autograd b) Everything else works as expected. 5. Testing: a) All method_tests run with conjugate view tensors. b) OpInfo tests that run with conjugate views - test_variant_consistency_eager/jit - gradcheck, gradgradcheck - test_conj_views (that only run for `torch.cfloat` dtype) NOTE: functions like `empty_like`, `zero_like`, `randn_like`, `clone` don't propagate the conjugate bit. Follow up work: 1. conjugate view RFC 2. Add neg bit to re-enable view operation on conjugated tensors 3. Update linalg functions to call into specialized functions that fast path with the hermitian operation. Test Plan: Imported from OSS Reviewed By: VitalyFedyunin Differential Revision: D28227315 Pulled By: anjali411 fbshipit-source-id: acab9402b9d6a970c6d512809b627a290c8def5f
This commit is contained in:
parent
19985d6f84
commit
3607478ecd
|
|
@ -1031,7 +1031,6 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
|
||||||
|
|
||||||
m.impl("sum.dim_IntList", sum_batching_rule);
|
m.impl("sum.dim_IntList", sum_batching_rule);
|
||||||
m.impl("is_complex", native::is_complex);
|
m.impl("is_complex", native::is_complex);
|
||||||
m.impl("conj", native::conj);
|
|
||||||
|
|
||||||
// inplace operations
|
// inplace operations
|
||||||
m.impl("fill_.Scalar", fill_inplace_scalar_batching_rule);
|
m.impl("fill_.Scalar", fill_inplace_scalar_batching_rule);
|
||||||
|
|
@ -1085,7 +1084,7 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
|
||||||
UNARY_POINTWISE(ceil);
|
UNARY_POINTWISE(ceil);
|
||||||
UNARY_POINTWISE(cos);
|
UNARY_POINTWISE(cos);
|
||||||
UNARY_POINTWISE(cosh);
|
UNARY_POINTWISE(cosh);
|
||||||
UNARY_POINTWISE(_conj);
|
UNARY_POINTWISE(conj_physical);
|
||||||
UNARY_POINTWISE(digamma);
|
UNARY_POINTWISE(digamma);
|
||||||
UNARY_POINTWISE(exp);
|
UNARY_POINTWISE(exp);
|
||||||
UNARY_POINTWISE(expm1);
|
UNARY_POINTWISE(expm1);
|
||||||
|
|
@ -1181,6 +1180,10 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
|
||||||
TRIVIAL_OP(imag)
|
TRIVIAL_OP(imag)
|
||||||
TRIVIAL_OP(real);
|
TRIVIAL_OP(real);
|
||||||
TRIVIAL_OP(view_as_real);
|
TRIVIAL_OP(view_as_real);
|
||||||
|
TRIVIAL_OP(_view_as_real_physical);
|
||||||
|
TRIVIAL_OP(conj);
|
||||||
|
TRIVIAL_OP(_conj);
|
||||||
|
TRIVIAL_OP(resolve_conj);
|
||||||
m.impl("view_as_complex", view_as_complex_batching_rule);
|
m.impl("view_as_complex", view_as_complex_batching_rule);
|
||||||
#undef TRIVIAL
|
#undef TRIVIAL
|
||||||
|
|
||||||
|
|
|
||||||
152
aten/src/ATen/ConjugateFallback.cpp
Normal file
152
aten/src/ATen/ConjugateFallback.cpp
Normal file
|
|
@ -0,0 +1,152 @@
|
||||||
|
#include <ATen/ATen.h>
|
||||||
|
#include <ATen/core/op_registration/op_registration.h>
|
||||||
|
#include <torch/library.h>
|
||||||
|
#include <ATen/core/dispatch/Dispatcher.h>
|
||||||
|
#include <ATen/native/UnaryOps.h>
|
||||||
|
#include <ATen/NativeFunctions.h>
|
||||||
|
|
||||||
|
namespace at {
|
||||||
|
|
||||||
|
void conjugateFallback(const c10::OperatorHandle& op, DispatchKeySet dispatch_keys, torch::jit::Stack* stack) {
|
||||||
|
// Situations to handle:
|
||||||
|
// 1. Out-of-place operation. Easy: materialize all inputs and
|
||||||
|
// call it a day.
|
||||||
|
// 2. Inplace operation. Desugar x.add_(2) into x.conj_().add_(2).conj_().
|
||||||
|
// Materialize other inputs as in (1).
|
||||||
|
// 3. out= operation. Desugar add(x, 2, out=y) into y.copy_(add(x, 2))
|
||||||
|
// Materialize other inputs as in (1).
|
||||||
|
//
|
||||||
|
// It is important to be able to tell if we READ from an argument and if we
|
||||||
|
// WRITE from an argument. Conservative approach is to assume that we always
|
||||||
|
// READ from an argument, but in out-of-place operations you can skip
|
||||||
|
// conjugating inputs on entry that never get used. In current schema we
|
||||||
|
// can't easily tell if inplace situation has happened, so don't do it.
|
||||||
|
|
||||||
|
const auto& arguments = op.schema().arguments();
|
||||||
|
const auto num_arguments = arguments.size();
|
||||||
|
const auto stack_start = stack->size() - num_arguments;
|
||||||
|
|
||||||
|
c10::optional<bool> is_write;
|
||||||
|
for (int64_t i = 0; i < num_arguments; ++i) {
|
||||||
|
const auto& alias_info = arguments[i].alias_info();
|
||||||
|
// Three possible states:
|
||||||
|
// 1. alias_info has no value --> out-of-place operation
|
||||||
|
// 2. alias_info does have a value, alias_info->is_write=True --> in-place or out= operation
|
||||||
|
// 3. alias_info does have a value, alias_info->is_write=False --> view operation
|
||||||
|
if (alias_info.has_value()) {
|
||||||
|
if (is_write.has_value()) {
|
||||||
|
TORCH_CHECK(*is_write == alias_info->isWrite(),
|
||||||
|
"Unsupported operator for conjugate fallback: ", op.schema().name(),
|
||||||
|
"Conjugate fallback doesn't work for operators with a mix "
|
||||||
|
"mutable and non-mutable inputs that alias with outputs, "
|
||||||
|
"this must be implemented manually. "
|
||||||
|
"If you got this error on a core op, please report a bug to PyTorch.");
|
||||||
|
} else {
|
||||||
|
is_write = alias_info->isWrite();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (is_write.has_value() && !*is_write) {
|
||||||
|
// We assume that view operators automatically handle conjugation
|
||||||
|
// correctly by propagating the Conjugate dispatch key in key_set.
|
||||||
|
// This is not necessarily always right, so you should test these cases.
|
||||||
|
op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, DispatchKey::Conjugate), stack);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mutable inputs to be tracked separately
|
||||||
|
std::vector<Tensor> mutable_inputs;
|
||||||
|
|
||||||
|
for (int64_t i = 0; i < num_arguments; ++i) {
|
||||||
|
auto& ivalue = (*stack)[stack_start + i];
|
||||||
|
if (!(ivalue.isTensor() || ivalue.isTensorList())) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
const auto& argument = arguments[i];
|
||||||
|
bool mut_arg = false;
|
||||||
|
if (argument.alias_info()) {
|
||||||
|
// View operations were already filtered above, so only in-place/out= operations should get here.
|
||||||
|
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(argument.alias_info()->isWrite());
|
||||||
|
mut_arg = true;
|
||||||
|
}
|
||||||
|
if (ivalue.isTensor()) {
|
||||||
|
auto* impl = ivalue.unsafeToTensorImpl();
|
||||||
|
if (!impl->is_conj()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto tensor = std::move(ivalue).toTensor();
|
||||||
|
TORCH_CHECK_NOT_IMPLEMENTED(!tensor.is_meta(), "Conjugate Fallback does not support meta tensors.");
|
||||||
|
if (mut_arg) {
|
||||||
|
// TODO: This is a waste if the argument is write only
|
||||||
|
tensor._set_conj(false);
|
||||||
|
at::conj_physical_(tensor);
|
||||||
|
mutable_inputs.emplace_back(tensor);
|
||||||
|
} else {
|
||||||
|
tensor = at::resolve_conj(tensor);
|
||||||
|
}
|
||||||
|
(*stack)[stack_start + i] = std::move(tensor);
|
||||||
|
} else if (ivalue.isTensorList()) {
|
||||||
|
auto tensors = std::move(ivalue).toTensorList();
|
||||||
|
if (mut_arg) {
|
||||||
|
for(const auto j : c10::irange(tensors.size())) {
|
||||||
|
Tensor t = tensors[j];
|
||||||
|
t._set_conj(false);
|
||||||
|
at::conj_physical_(t);
|
||||||
|
mutable_inputs.emplace_back(t);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for(const auto j : c10::irange(tensors.size())) {
|
||||||
|
tensors[j] = at::resolve_conj(tensors[j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(*stack)[stack_start + i] = std::move(tensors);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, DispatchKey::Conjugate), stack);
|
||||||
|
|
||||||
|
for (auto& mutable_input : mutable_inputs) {
|
||||||
|
at::conj_physical_(mutable_input);
|
||||||
|
mutable_input._set_conj(true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TORCH_LIBRARY_IMPL(_, Conjugate, m) {
|
||||||
|
m.fallback(torch::CppFunction::makeFromBoxedFunction<&conjugateFallback>());
|
||||||
|
}
|
||||||
|
|
||||||
|
TORCH_LIBRARY_IMPL(aten, Conjugate, m) {
|
||||||
|
m.impl("requires_grad_", torch::CppFunction::makeFallthrough());
|
||||||
|
m.impl("set_.source_Storage_storage_offset", torch::CppFunction::makeFallthrough());
|
||||||
|
m.impl("set_.source_Tensor", torch::CppFunction::makeFallthrough());
|
||||||
|
m.impl("set_", torch::CppFunction::makeFallthrough());
|
||||||
|
m.impl("copy_", torch::CppFunction::makeFallthrough());
|
||||||
|
m.impl("clone", torch::CppFunction::makeFallthrough());
|
||||||
|
m.impl("conj", torch::CppFunction::makeFallthrough());
|
||||||
|
m.impl("_conj", torch::CppFunction::makeFallthrough());
|
||||||
|
m.impl("_conj_physical", torch::CppFunction::makeFallthrough());
|
||||||
|
m.impl("conj_physical", torch::CppFunction::makeFallthrough());
|
||||||
|
m.impl("conj_physical_", torch::CppFunction::makeFallthrough());
|
||||||
|
m.impl("resolve_conj", torch::CppFunction::makeFallthrough());
|
||||||
|
m.impl("empty_like", torch::CppFunction::makeFallthrough());
|
||||||
|
m.impl("empty.memory_format", torch::CppFunction::makeFallthrough());
|
||||||
|
m.impl("empty.out", torch::CppFunction::makeFallthrough());
|
||||||
|
m.impl("empty_strided", torch::CppFunction::makeFallthrough());
|
||||||
|
m.impl("full_like", torch::CppFunction::makeFallthrough());
|
||||||
|
m.impl("stride.int", torch::CppFunction::makeFallthrough());
|
||||||
|
m.impl("stride.Dimname", torch::CppFunction::makeFallthrough());
|
||||||
|
m.impl("size.int", torch::CppFunction::makeFallthrough());
|
||||||
|
m.impl("size.Dimname", torch::CppFunction::makeFallthrough());
|
||||||
|
m.impl("is_complex", torch::CppFunction::makeFallthrough());
|
||||||
|
m.impl("_view_as_real_physical", torch::CppFunction::makeFallthrough());
|
||||||
|
m.impl("view_as_real", torch::CppFunction::makeFallthrough());
|
||||||
|
m.impl("imag", torch::CppFunction::makeFallthrough());
|
||||||
|
m.impl("real", torch::CppFunction::makeFallthrough());
|
||||||
|
m.impl("view", torch::CppFunction::makeFallthrough());
|
||||||
|
m.impl("reshape", torch::CppFunction::makeFallthrough());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace at
|
||||||
|
|
@ -117,7 +117,6 @@ TORCH_LIBRARY_IMPL(aten, Named, m) {
|
||||||
m.impl("clamp_min_.Tensor", CppFunction::makeFallthrough());
|
m.impl("clamp_min_.Tensor", CppFunction::makeFallthrough());
|
||||||
m.impl("clone", CppFunction::makeFallthrough());
|
m.impl("clone", CppFunction::makeFallthrough());
|
||||||
m.impl("conj", CppFunction::makeFallthrough());
|
m.impl("conj", CppFunction::makeFallthrough());
|
||||||
m.impl("conj.out", CppFunction::makeFallthrough());
|
|
||||||
m.impl("contiguous", CppFunction::makeFallthrough());
|
m.impl("contiguous", CppFunction::makeFallthrough());
|
||||||
m.impl("copy_", CppFunction::makeFallthrough());
|
m.impl("copy_", CppFunction::makeFallthrough());
|
||||||
m.impl("cos", CppFunction::makeFallthrough());
|
m.impl("cos", CppFunction::makeFallthrough());
|
||||||
|
|
|
||||||
|
|
@ -238,6 +238,9 @@ _(aten, coalesce) \
|
||||||
_(aten, combinations) \
|
_(aten, combinations) \
|
||||||
_(aten, _conj) \
|
_(aten, _conj) \
|
||||||
_(aten, conj) \
|
_(aten, conj) \
|
||||||
|
_(aten, conj_physical) \
|
||||||
|
_(aten, conj_physical_) \
|
||||||
|
_(aten, resolve_conj) \
|
||||||
_(aten, complex) \
|
_(aten, complex) \
|
||||||
_(aten, copysign) \
|
_(aten, copysign) \
|
||||||
_(aten, polar) \
|
_(aten, polar) \
|
||||||
|
|
@ -764,6 +767,7 @@ _(aten, zeros_like) \
|
||||||
_(aten, real) \
|
_(aten, real) \
|
||||||
_(aten, imag) \
|
_(aten, imag) \
|
||||||
_(aten, view_as_real) \
|
_(aten, view_as_real) \
|
||||||
|
_(aten, _view_as_real_physical) \
|
||||||
_(aten, view_as_complex) \
|
_(aten, view_as_complex) \
|
||||||
/* nothing */
|
/* nothing */
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,9 @@
|
||||||
|
|
||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
|
|
||||||
|
// WARNING: this header contains non-inline functions and should be only
|
||||||
|
// included from ONE cpp file
|
||||||
|
|
||||||
namespace at { namespace native {
|
namespace at { namespace native {
|
||||||
|
|
||||||
// View tensor with new dtype, storage offset, sizes and strides
|
// View tensor with new dtype, storage offset, sizes and strides
|
||||||
|
|
@ -9,8 +12,9 @@ inline Tensor view_tensor(
|
||||||
const Tensor &tensor, ScalarType dtype,
|
const Tensor &tensor, ScalarType dtype,
|
||||||
int64_t offset, IntArrayRef sizes, IntArrayRef strides) {
|
int64_t offset, IntArrayRef sizes, IntArrayRef strides) {
|
||||||
Storage storage = tensor.storage();
|
Storage storage = tensor.storage();
|
||||||
|
auto key_set = tensor.key_set().remove(DispatchKey::Conjugate);
|
||||||
auto new_tensor = detail::make_tensor<TensorImpl>(
|
auto new_tensor = detail::make_tensor<TensorImpl>(
|
||||||
c10::TensorImpl::VIEW, std::move(storage), tensor.key_set(), scalarTypeToTypeMeta(dtype));
|
c10::TensorImpl::VIEW, std::move(storage), key_set, scalarTypeToTypeMeta(dtype));
|
||||||
auto * impl = new_tensor.unsafeGetTensorImpl();
|
auto * impl = new_tensor.unsafeGetTensorImpl();
|
||||||
impl->set_storage_offset(offset);
|
impl->set_storage_offset(offset);
|
||||||
impl->set_sizes_and_strides(sizes, strides);
|
impl->set_sizes_and_strides(sizes, strides);
|
||||||
|
|
@ -30,7 +34,12 @@ inline DimVector computeStrideForViewAsReal(IntArrayRef oldstride) {
|
||||||
// with corresponding real dtype containing the complex values
|
// with corresponding real dtype containing the complex values
|
||||||
// in the last two dimensions
|
// in the last two dimensions
|
||||||
Tensor view_as_real(const Tensor& self) {
|
Tensor view_as_real(const Tensor& self) {
|
||||||
TORCH_CHECK(self.is_complex(), "view_as_real is only supported for complex tensors");
|
TORCH_CHECK(!self.is_conj(), "view_as_real doesn't work on unresolved conjugated tensors. To resolve the conjugate tensor so you can view it as real, use self.resolve_conj(); however, be warned that the resulting tensor will NOT alias the original.");
|
||||||
|
return native::_view_as_real_physical(self);
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor _view_as_real_physical(const Tensor& self) {
|
||||||
|
TORCH_CHECK(self.is_complex(), "view_as_real_physical is only supported for complex tensors");
|
||||||
auto old_sizes = self.sizes();
|
auto old_sizes = self.sizes();
|
||||||
DimVector new_sizes(old_sizes.size() + 1);
|
DimVector new_sizes(old_sizes.size() + 1);
|
||||||
std::copy(old_sizes.begin(), old_sizes.end(), new_sizes.begin());
|
std::copy(old_sizes.begin(), old_sizes.end(), new_sizes.begin());
|
||||||
|
|
@ -39,7 +48,8 @@ Tensor view_as_real(const Tensor& self) {
|
||||||
auto new_strides = computeStrideForViewAsReal(self.strides());
|
auto new_strides = computeStrideForViewAsReal(self.strides());
|
||||||
auto new_storage_offset = 2 * self.storage_offset();
|
auto new_storage_offset = 2 * self.storage_offset();
|
||||||
const auto float_type = c10::toValueType(self.scalar_type());
|
const auto float_type = c10::toValueType(self.scalar_type());
|
||||||
return view_tensor(self, float_type, new_storage_offset, new_sizes, new_strides);
|
auto real_tensor = view_tensor(self, float_type, new_storage_offset, new_sizes, new_strides);
|
||||||
|
return real_tensor;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline DimVector computeStrideForViewAsComplex(IntArrayRef oldstride) {
|
inline DimVector computeStrideForViewAsComplex(IntArrayRef oldstride) {
|
||||||
|
|
|
||||||
|
|
@ -150,7 +150,7 @@ Tensor fft_r2c(c10::string_view function_name,
|
||||||
|
|
||||||
if (!forward) {
|
if (!forward) {
|
||||||
// FIXME: _fft_r2c doesn't support native r2c IFFT
|
// FIXME: _fft_r2c doesn't support native r2c IFFT
|
||||||
return out.defined() ? at::conj_out(out, ret) : at::conj(ret);
|
return out.defined() ? at::conj_physical_out(out, ret) : at::conj(ret);
|
||||||
} else {
|
} else {
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -349,6 +349,8 @@ Tensor empty_like(
|
||||||
namedinference::propagate_names(result, self.names());
|
namedinference::propagate_names(result, self.names());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// never propagate Conjugate key
|
||||||
|
result._set_conj(false);
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -30,6 +30,10 @@ bool is_signed(const Tensor &self) {
|
||||||
return self.is_signed();
|
return self.is_signed();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool is_conj(const Tensor& self) {
|
||||||
|
return self.is_conj();
|
||||||
|
}
|
||||||
|
|
||||||
bool is_sparse(const Tensor& self) {
|
bool is_sparse(const Tensor& self) {
|
||||||
return self.is_sparse();
|
return self.is_sparse();
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -28,7 +28,8 @@ namespace at {
|
||||||
namespace meta {
|
namespace meta {
|
||||||
|
|
||||||
// Unary float operations always produce floating point
|
// Unary float operations always produce floating point
|
||||||
// outputs, even if their inputs are integer
|
// outputs for floating point and integral types
|
||||||
|
// For complex inputs, the output type should be the same as input type.
|
||||||
#define CREATE_UNARY_FLOAT_META_FUNC(func) \
|
#define CREATE_UNARY_FLOAT_META_FUNC(func) \
|
||||||
TORCH_META_FUNC(func) (const Tensor& self) { \
|
TORCH_META_FUNC(func) (const Tensor& self) { \
|
||||||
build_unary_float_op(maybe_get_output(), self); \
|
build_unary_float_op(maybe_get_output(), self); \
|
||||||
|
|
@ -363,7 +364,8 @@ Tensor angle(const Tensor& self) {
|
||||||
|
|
||||||
Tensor real(const Tensor& self) {
|
Tensor real(const Tensor& self) {
|
||||||
if (self.is_complex()) {
|
if (self.is_complex()) {
|
||||||
auto real_tensor = at::view_as_real(self);
|
// real is never affected by conjugate bit, safe to use physical version
|
||||||
|
auto real_tensor = at::_view_as_real_physical(self);
|
||||||
return at::select(real_tensor, real_tensor.dim() - 1, 0);
|
return at::select(real_tensor, real_tensor.dim() - 1, 0);
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(false, "real is not implemented for tensors with non-complex dtypes.");
|
TORCH_CHECK(false, "real is not implemented for tensors with non-complex dtypes.");
|
||||||
|
|
@ -379,17 +381,44 @@ Tensor imag(const Tensor& self) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor& conj_out(const Tensor& self, Tensor& result) {
|
Tensor& conj_physical_out(const Tensor& self, Tensor& result) {
|
||||||
return unary_op_impl_out(result, self, conj_stub);
|
return unary_op_impl_out(result, self, conj_physical_stub);
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor _conj(const Tensor& self) { return unary_op_impl(self, at::conj_out); }
|
Tensor _conj_physical(const Tensor& self) {
|
||||||
|
if (self.is_conj()) {
|
||||||
|
return self.conj().clone();
|
||||||
|
}
|
||||||
|
return unary_op_impl(self, at::conj_physical_out);
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor conj_physical(const Tensor& self) {
|
||||||
|
if (!self.is_complex()) return self;
|
||||||
|
return at::_conj_physical(self);
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor& conj_physical_(Tensor& self) {
|
||||||
|
if (!self.is_complex()) return self;
|
||||||
|
return unary_op_impl_out(self, self, conj_physical_stub);
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor resolve_conj(const Tensor& self) {
|
||||||
|
if (!self.is_conj()) { return self; }
|
||||||
|
// conjugation is handled in `copy_()` that clone ultimately calls into
|
||||||
|
return self.clone(self.suggest_memory_format());
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor _conj(const Tensor& self) {
|
||||||
|
Tensor self_ = self.alias();
|
||||||
|
self_._set_conj(!self.is_conj());
|
||||||
|
namedinference::propagate_names(self_, self);
|
||||||
|
return self_;
|
||||||
|
}
|
||||||
|
|
||||||
Tensor conj(const Tensor& self) {
|
Tensor conj(const Tensor& self) {
|
||||||
if (!self.is_complex()) {
|
// This might look like an infinite recursion but it's not.
|
||||||
return self;
|
// This actually calls into `conj()` defined in the Tensor class.
|
||||||
}
|
return self.conj();
|
||||||
return at::_conj(self);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// special_exp2, alias for exp2
|
// special_exp2, alias for exp2
|
||||||
|
|
@ -689,7 +718,7 @@ DEFINE_DISPATCH(abs_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-va
|
||||||
DEFINE_DISPATCH(angle_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
DEFINE_DISPATCH(angle_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
DEFINE_DISPATCH(real_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
DEFINE_DISPATCH(real_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
DEFINE_DISPATCH(imag_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
DEFINE_DISPATCH(imag_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
DEFINE_DISPATCH(conj_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
DEFINE_DISPATCH(conj_physical_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
DEFINE_DISPATCH(acos_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
DEFINE_DISPATCH(acos_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
DEFINE_DISPATCH(acosh_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
DEFINE_DISPATCH(acosh_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
DEFINE_DISPATCH(asinh_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
DEFINE_DISPATCH(asinh_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ DECLARE_DISPATCH(unary_fn, abs_stub);
|
||||||
DECLARE_DISPATCH(unary_fn, angle_stub);
|
DECLARE_DISPATCH(unary_fn, angle_stub);
|
||||||
DECLARE_DISPATCH(unary_fn, real_stub);
|
DECLARE_DISPATCH(unary_fn, real_stub);
|
||||||
DECLARE_DISPATCH(unary_fn, imag_stub);
|
DECLARE_DISPATCH(unary_fn, imag_stub);
|
||||||
DECLARE_DISPATCH(unary_fn, conj_stub);
|
DECLARE_DISPATCH(unary_fn, conj_physical_stub);
|
||||||
DECLARE_DISPATCH(unary_fn, acos_stub);
|
DECLARE_DISPATCH(unary_fn, acos_stub);
|
||||||
DECLARE_DISPATCH(unary_fn, acosh_stub);
|
DECLARE_DISPATCH(unary_fn, acosh_stub);
|
||||||
DECLARE_DISPATCH(unary_fn, asinh_stub);
|
DECLARE_DISPATCH(unary_fn, asinh_stub);
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@
|
||||||
#include <ATen/native/TensorIterator.h>
|
#include <ATen/native/TensorIterator.h>
|
||||||
#include <ATen/native/cpu/Loops.h>
|
#include <ATen/native/cpu/Loops.h>
|
||||||
#include <c10/util/TypeCast.h>
|
#include <c10/util/TypeCast.h>
|
||||||
|
#include <ATen/native/cpu/zmath.h>
|
||||||
|
|
||||||
namespace at {
|
namespace at {
|
||||||
namespace native {
|
namespace native {
|
||||||
|
|
@ -13,6 +14,15 @@ namespace {
|
||||||
static void copy_kernel(TensorIterator& iter, bool non_blocking) {
|
static void copy_kernel(TensorIterator& iter, bool non_blocking) {
|
||||||
ScalarType dtype = iter.dtype(0);
|
ScalarType dtype = iter.dtype(0);
|
||||||
if (dtype == iter.dtype(1)) {
|
if (dtype == iter.dtype(1)) {
|
||||||
|
// TODO: as the majority of these operations can be done treating
|
||||||
|
// their datatypes as opaque bit patterns, we don't actually need
|
||||||
|
// separate instantiations per dtype; we only need a separate
|
||||||
|
// instantiation per dtype size. This would probably save us a
|
||||||
|
// little bit of code size here
|
||||||
|
// TODO: not sure if optimizer is able to compile two levels of
|
||||||
|
// conditionals into a single jump table. We should have a
|
||||||
|
// single jump table here; might be worth just writing out the
|
||||||
|
// dispatch statement by hand instead of using AT_DISPATCH
|
||||||
if (dtype == ScalarType::Half) {
|
if (dtype == ScalarType::Half) {
|
||||||
cpu_kernel(iter, [=](at::Half a) -> at::Half { return a; });
|
cpu_kernel(iter, [=](at::Half a) -> at::Half { return a; });
|
||||||
} else if (dtype == ScalarType::ComplexHalf) {
|
} else if (dtype == ScalarType::ComplexHalf) {
|
||||||
|
|
@ -25,12 +35,21 @@ static void copy_kernel(TensorIterator& iter, bool non_blocking) {
|
||||||
[=](Vectorized<scalar_t> a) -> Vectorized<scalar_t> { return a; });
|
[=](Vectorized<scalar_t> a) -> Vectorized<scalar_t> { return a; });
|
||||||
});
|
});
|
||||||
} else if (isComplexType(dtype)) {
|
} else if (isComplexType(dtype)) {
|
||||||
|
if (iter.tensor(0).is_conj() == iter.tensor(1).is_conj()) {
|
||||||
AT_DISPATCH_COMPLEX_TYPES(dtype, "copy_kernel", [&] {
|
AT_DISPATCH_COMPLEX_TYPES(dtype, "copy_kernel", [&] {
|
||||||
cpu_kernel_vec(
|
cpu_kernel_vec(
|
||||||
iter,
|
iter,
|
||||||
[=](scalar_t a) -> scalar_t { return a; },
|
[=](scalar_t a) -> scalar_t { return a; },
|
||||||
[=](Vectorized<scalar_t> a) -> Vectorized<scalar_t> { return a; });
|
[=](Vectorized<scalar_t> a) -> Vectorized<scalar_t> { return a; });
|
||||||
});
|
});
|
||||||
|
} else {
|
||||||
|
AT_DISPATCH_COMPLEX_TYPES(dtype, "conj_kernel", [&] {
|
||||||
|
cpu_kernel_vec(
|
||||||
|
iter,
|
||||||
|
[=](scalar_t a) -> scalar_t { return conj_impl(a); },
|
||||||
|
[=](Vectorized<scalar_t> a) -> Vectorized<scalar_t> { return a.conj(); });
|
||||||
|
});
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
AT_DISPATCH_ALL_TYPES_AND2(
|
AT_DISPATCH_ALL_TYPES_AND2(
|
||||||
ScalarType::Bool, ScalarType::BFloat16,dtype, "copy_kernel", [&] {
|
ScalarType::Bool, ScalarType::BFloat16,dtype, "copy_kernel", [&] {
|
||||||
|
|
@ -62,6 +81,9 @@ static void copy_kernel(TensorIterator& iter, bool non_blocking) {
|
||||||
return c10::static_cast_with_inter_type<dest_t, scalar_t>::apply(src); });
|
return c10::static_cast_with_inter_type<dest_t, scalar_t>::apply(src); });
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
if (iter.tensor(0).is_conj() != iter.tensor(1).is_conj()) {
|
||||||
|
iter.tensor(0).conj_physical_();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -193,6 +193,7 @@ static void imag_kernel(TensorIteratorBase& iter) {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NB: Ignores the negative bit on tensors
|
||||||
static void conj_kernel(TensorIteratorBase& iter) {
|
static void conj_kernel(TensorIteratorBase& iter) {
|
||||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
|
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
|
||||||
kBool, kBFloat16, kHalf, iter.common_dtype(), "conj_cpu", [&]() {
|
kBool, kBFloat16, kHalf, iter.common_dtype(), "conj_cpu", [&]() {
|
||||||
|
|
@ -716,7 +717,7 @@ REGISTER_DISPATCH(real_stub, &CPU_CAPABILITY::real_kernel);
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
REGISTER_DISPATCH(imag_stub, &CPU_CAPABILITY::imag_kernel);
|
REGISTER_DISPATCH(imag_stub, &CPU_CAPABILITY::imag_kernel);
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
REGISTER_DISPATCH(conj_stub, &CPU_CAPABILITY::conj_kernel);
|
REGISTER_DISPATCH(conj_physical_stub, &CPU_CAPABILITY::conj_kernel);
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
REGISTER_DISPATCH(exp2_stub, &CPU_CAPABILITY::exp2_kernel);
|
REGISTER_DISPATCH(exp2_stub, &CPU_CAPABILITY::exp2_kernel);
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,8 @@ void copy_device_to_device(TensorIterator& iter, bool non_blocking) {
|
||||||
// We can memcpy the memory if both tensors have the same type AND both
|
// We can memcpy the memory if both tensors have the same type AND both
|
||||||
// tensors are contiguous after dimension coalescing and reordering.
|
// tensors are contiguous after dimension coalescing and reordering.
|
||||||
bool same_type = iter.dtype(0) == iter.dtype(1);
|
bool same_type = iter.dtype(0) == iter.dtype(1);
|
||||||
bool memcpy_eligible = same_type && iter.is_contiguous();
|
bool same_conj = iter.tensor(0).is_conj() == iter.tensor(1).is_conj();
|
||||||
|
bool memcpy_eligible = same_type && same_conj && iter.is_contiguous();
|
||||||
|
|
||||||
Device dst_device = iter.device(0);
|
Device dst_device = iter.device(0);
|
||||||
Device src_device = iter.device(1);
|
Device src_device = iter.device(1);
|
||||||
|
|
@ -70,6 +71,12 @@ void copy_device_to_device(TensorIterator& iter, bool non_blocking) {
|
||||||
AT_DISPATCH_QINT_TYPES(dtype, "copy_", [&] {
|
AT_DISPATCH_QINT_TYPES(dtype, "copy_", [&] {
|
||||||
gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) { return x; });
|
gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) { return x; });
|
||||||
});
|
});
|
||||||
|
} else {
|
||||||
|
if (!same_conj && same_type) {
|
||||||
|
AT_DISPATCH_COMPLEX_TYPES(
|
||||||
|
dtype, "copy_conj_", [&] {
|
||||||
|
gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) { return std::conj(x); });
|
||||||
|
});
|
||||||
} else {
|
} else {
|
||||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
|
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
|
||||||
kHalf, kBool, kBFloat16, dtype, "copy_", [&] {
|
kHalf, kBool, kBFloat16, dtype, "copy_", [&] {
|
||||||
|
|
@ -77,6 +84,7 @@ void copy_device_to_device(TensorIterator& iter, bool non_blocking) {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (src_device != dst_device) {
|
if (src_device != dst_device) {
|
||||||
// dst waits on src barrier (dst already waits on dst). We cannot
|
// dst waits on src barrier (dst already waits on dst). We cannot
|
||||||
|
|
@ -152,6 +160,10 @@ static void copy_kernel_cuda(TensorIterator& iter, bool non_blocking) {
|
||||||
src_contig = iter.tensor(1).expand_as(dst).contiguous();
|
src_contig = iter.tensor(1).expand_as(dst).contiguous();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// propagate the correct conjugate bit
|
||||||
|
dst_contig._set_conj(dst.is_conj());
|
||||||
|
src_contig._set_conj(iter.tensor(1).is_conj());
|
||||||
|
|
||||||
// perform a same-dtype copy on contiguous tensors
|
// perform a same-dtype copy on contiguous tensors
|
||||||
TORCH_INTERNAL_ASSERT(dst_contig.sizes().equals(src_contig.sizes()));
|
TORCH_INTERNAL_ASSERT(dst_contig.sizes().equals(src_contig.sizes()));
|
||||||
TORCH_INTERNAL_ASSERT(dst_contig.scalar_type() == src_contig.scalar_type());
|
TORCH_INTERNAL_ASSERT(dst_contig.scalar_type() == src_contig.scalar_type());
|
||||||
|
|
@ -201,6 +213,10 @@ static void copy_kernel_cuda(TensorIterator& iter, bool non_blocking) {
|
||||||
AT_CUDA_CHECK(cudaStreamSynchronize(stream));
|
AT_CUDA_CHECK(cudaStreamSynchronize(stream));
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (iter.tensor(0).is_conj() != iter.tensor(1).is_conj()) {
|
||||||
|
iter.tensor(0).conj_physical_();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
REGISTER_DISPATCH(copy_stub, ©_kernel_cuda);
|
REGISTER_DISPATCH(copy_stub, ©_kernel_cuda);
|
||||||
|
|
|
||||||
|
|
@ -80,6 +80,7 @@ __host__ __device__ static inline c10::complex<T> conj_wrapper(c10::complex<T> v
|
||||||
return std::conj(v);
|
return std::conj(v);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NB: Ignores the negative bit on tensors
|
||||||
void conj_kernel_cuda(TensorIteratorBase& iter) {
|
void conj_kernel_cuda(TensorIteratorBase& iter) {
|
||||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
|
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
|
||||||
kBool, kBFloat16, kHalf, iter.common_dtype(), "conj_cuda", [&]() {
|
kBool, kBFloat16, kHalf, iter.common_dtype(), "conj_cuda", [&]() {
|
||||||
|
|
@ -92,6 +93,6 @@ void conj_kernel_cuda(TensorIteratorBase& iter) {
|
||||||
REGISTER_DISPATCH(angle_stub, &angle_kernel_cuda);
|
REGISTER_DISPATCH(angle_stub, &angle_kernel_cuda);
|
||||||
REGISTER_DISPATCH(real_stub, &real_kernel_cuda);
|
REGISTER_DISPATCH(real_stub, &real_kernel_cuda);
|
||||||
REGISTER_DISPATCH(imag_stub, &imag_kernel_cuda);
|
REGISTER_DISPATCH(imag_stub, &imag_kernel_cuda);
|
||||||
REGISTER_DISPATCH(conj_stub, &conj_kernel_cuda);
|
REGISTER_DISPATCH(conj_physical_stub, &conj_kernel_cuda);
|
||||||
|
|
||||||
}} // namespace at::native
|
}} // namespace at::native
|
||||||
|
|
|
||||||
|
|
@ -271,6 +271,11 @@
|
||||||
dispatch:
|
dispatch:
|
||||||
CPU, CUDA: view_as_real
|
CPU, CUDA: view_as_real
|
||||||
|
|
||||||
|
- func: _view_as_real_physical(Tensor(a) self) -> Tensor(a)
|
||||||
|
variants: function
|
||||||
|
dispatch:
|
||||||
|
CPU, CUDA: _view_as_real_physical
|
||||||
|
|
||||||
- func: view_as_complex(Tensor(a) self) -> Tensor(a)
|
- func: view_as_complex(Tensor(a) self) -> Tensor(a)
|
||||||
variants: function
|
variants: function
|
||||||
dispatch:
|
dispatch:
|
||||||
|
|
@ -298,21 +303,36 @@
|
||||||
device_check: NoCheck # TensorIterator
|
device_check: NoCheck # TensorIterator
|
||||||
variants: function
|
variants: function
|
||||||
|
|
||||||
- func: conj(Tensor(a) self) -> Tensor(a)
|
- func: _conj(Tensor(a) self) -> Tensor(a)
|
||||||
device_check: NoCheck # TensorIterator
|
|
||||||
variants: function, method
|
variants: function, method
|
||||||
|
|
||||||
- func: conj.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
|
|
||||||
device_check: NoCheck # TensorIterator
|
|
||||||
dispatch:
|
|
||||||
CPU, CUDA: conj_out
|
|
||||||
SparseCPU, SparseCUDA: conj_out_sparse
|
|
||||||
|
|
||||||
- func: _conj(Tensor self) -> Tensor
|
|
||||||
variants: function
|
|
||||||
dispatch:
|
dispatch:
|
||||||
CompositeExplicitAutograd: _conj
|
CompositeExplicitAutograd: _conj
|
||||||
|
|
||||||
|
- func: conj(Tensor(a) self) -> Tensor(a)
|
||||||
|
variants: function, method
|
||||||
|
manual_cpp_binding: True
|
||||||
|
|
||||||
|
- func: _conj_physical(Tensor self) -> Tensor
|
||||||
|
variants: function, method
|
||||||
|
dispatch:
|
||||||
|
CompositeExplicitAutograd: _conj_physical
|
||||||
|
|
||||||
|
- func: conj_physical(Tensor self) -> Tensor
|
||||||
|
variants: function, method
|
||||||
|
|
||||||
|
- func: conj_physical.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
|
||||||
|
dispatch:
|
||||||
|
CPU, CUDA: conj_physical_out
|
||||||
|
SparseCPU, SparseCUDA: conj_physical_out_sparse
|
||||||
|
|
||||||
|
- func: conj_physical_(Tensor(a!) self) -> Tensor(a!)
|
||||||
|
variants: function, method
|
||||||
|
dispatch:
|
||||||
|
CompositeExplicitAutograd: conj_physical_
|
||||||
|
|
||||||
|
- func: resolve_conj(Tensor(a) self) -> Tensor(a)
|
||||||
|
variants: function, method
|
||||||
|
|
||||||
- func: acos(Tensor self) -> Tensor
|
- func: acos(Tensor self) -> Tensor
|
||||||
device_check: NoCheck # TensorIterator
|
device_check: NoCheck # TensorIterator
|
||||||
variants: function, method
|
variants: function, method
|
||||||
|
|
@ -2249,6 +2269,11 @@
|
||||||
device_guard: False
|
device_guard: False
|
||||||
manual_cpp_binding: True
|
manual_cpp_binding: True
|
||||||
|
|
||||||
|
- func: is_conj(Tensor self) -> bool
|
||||||
|
variants: function, method
|
||||||
|
device_guard: False
|
||||||
|
manual_cpp_binding: True
|
||||||
|
|
||||||
- func: isreal(Tensor self) -> Tensor
|
- func: isreal(Tensor self) -> Tensor
|
||||||
variants: function, method
|
variants: function, method
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1645,15 +1645,15 @@ Tensor& bmm_out_sparse_cpu(const SparseTensor& self, const Tensor& mat2, Tensor&
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor conj_sparse(const Tensor& input) {
|
// Tensor conj_physical_sparse(const Tensor& input) {
|
||||||
if (!input.is_complex()) {
|
// if (!input.is_complex()) {
|
||||||
return input;
|
// return input;
|
||||||
}
|
// }
|
||||||
Tensor result = at::native::empty_like(input);
|
// Tensor result = at::native::empty_like(input);
|
||||||
return conj_out_sparse(input, result);
|
// return conj_physical_out_sparse(input, result);
|
||||||
}
|
// }
|
||||||
|
|
||||||
Tensor& conj_out_sparse(const Tensor& input, Tensor& result) {
|
Tensor& conj_physical_out_sparse(const Tensor& input, Tensor& result) {
|
||||||
TORCH_INTERNAL_ASSERT(input.is_sparse());
|
TORCH_INTERNAL_ASSERT(input.is_sparse());
|
||||||
if (input.numel() == 0) {
|
if (input.numel() == 0) {
|
||||||
return result;
|
return result;
|
||||||
|
|
@ -1665,7 +1665,7 @@ Tensor& conj_out_sparse(const Tensor& input, Tensor& result) {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
Tensor result_values = result._values();
|
Tensor result_values = result._values();
|
||||||
at::conj_out(result_values, input._values());
|
at::conj_physical_out(result_values, input._values());
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -216,4 +216,12 @@ inline bool is_inference(const Tensor& tensor) {
|
||||||
return tensor.is_inference();
|
return tensor.is_inference();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline bool is_conj(const Tensor& tensor) {
|
||||||
|
return tensor.is_conj();
|
||||||
|
}
|
||||||
|
|
||||||
|
inline Tensor conj(const Tensor& tensor) {
|
||||||
|
return tensor.conj();
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -136,6 +136,17 @@ class TORCH_API Tensor {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Tensor conj() const {
|
||||||
|
if (!this->is_complex()) {
|
||||||
|
return *this;
|
||||||
|
} else {
|
||||||
|
if (this->is_sparse()) {
|
||||||
|
return this->conj_physical();
|
||||||
|
}
|
||||||
|
return this->_conj();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Should be used if *this can reasonably be expected to be contiguous and
|
/// Should be used if *this can reasonably be expected to be contiguous and
|
||||||
/// performance is important.
|
/// performance is important.
|
||||||
/// Compared to contiguous, it saves a reference count
|
/// Compared to contiguous, it saves a reference count
|
||||||
|
|
@ -363,6 +374,18 @@ class TORCH_API Tensor {
|
||||||
return !at::impl::variable_excluded_from_dispatch();
|
return !at::impl::variable_excluded_from_dispatch();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline bool is_conj() const {
|
||||||
|
return impl_->is_conj();
|
||||||
|
}
|
||||||
|
|
||||||
|
// sets the conjugate bit of a tensor.
|
||||||
|
// NOTE: Conjugate bit is supposed to be a read-only field. Only change this, if you are extremely sure
|
||||||
|
// that's what you want. Changing this might lead to incorrect behavior since conjugation is
|
||||||
|
// a lazy operation and we rely on this bit to determine if a conjugation needs to be materialized.
|
||||||
|
inline void _set_conj(bool conjugate) const {
|
||||||
|
impl_->_set_conj(conjugate);
|
||||||
|
}
|
||||||
|
|
||||||
/// Returns a `Tensor`'s layout.
|
/// Returns a `Tensor`'s layout.
|
||||||
Layout layout() const noexcept {
|
Layout layout() const noexcept {
|
||||||
return impl_->layout();
|
return impl_->layout();
|
||||||
|
|
|
||||||
|
|
@ -65,6 +65,8 @@ const char* toString(DispatchKey t) {
|
||||||
case DispatchKey::PrivateUse3:
|
case DispatchKey::PrivateUse3:
|
||||||
return "PrivateUse3";
|
return "PrivateUse3";
|
||||||
|
|
||||||
|
case DispatchKey::Conjugate:
|
||||||
|
return "Conjugate";
|
||||||
case DispatchKey::Meta:
|
case DispatchKey::Meta:
|
||||||
return "Meta";
|
return "Meta";
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -142,6 +142,11 @@ enum class DispatchKey : uint8_t {
|
||||||
// constituent parts.
|
// constituent parts.
|
||||||
Named,
|
Named,
|
||||||
|
|
||||||
|
// The Conjugate dispatch key is set for any tensors that need to perform
|
||||||
|
// conjugation
|
||||||
|
// This is implemented at a dispatch level right before any backends run
|
||||||
|
Conjugate,
|
||||||
|
|
||||||
// See Note [Out-of-tree vmap+grad prototype]. The purpose of this key
|
// See Note [Out-of-tree vmap+grad prototype]. The purpose of this key
|
||||||
// is to insert code after the "autograd subsystem" runs, so this key should
|
// is to insert code after the "autograd subsystem" runs, so this key should
|
||||||
// be directly after ADInplaceOrView and all of the autograd keys.
|
// be directly after ADInplaceOrView and all of the autograd keys.
|
||||||
|
|
|
||||||
|
|
@ -938,6 +938,26 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||||
*/
|
*/
|
||||||
const at::Tensor& grad() const;
|
const at::Tensor& grad() const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Whether or not the imaginary part of the tensor should be negated
|
||||||
|
*/
|
||||||
|
inline bool is_conj() const {
|
||||||
|
return key_set_.has(DispatchKey::Conjugate);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Set whether or not to take the conjugate of the tensor (flip the imaginary
|
||||||
|
* bit).
|
||||||
|
*/
|
||||||
|
void _set_conj(bool value) {
|
||||||
|
if (value) {
|
||||||
|
key_set_ = key_set_.add(DispatchKey::Conjugate);
|
||||||
|
TORCH_INTERNAL_ASSERT(isComplexType(typeMetaToScalarType(dtype())));
|
||||||
|
} else {
|
||||||
|
key_set_ = key_set_.remove(DispatchKey::Conjugate);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Return the accumulated gradient of a tensor. This gradient is computed
|
* Return the accumulated gradient of a tensor. This gradient is computed
|
||||||
* using forward mode AD.
|
* using forward mode AD.
|
||||||
|
|
|
||||||
|
|
@ -64,7 +64,6 @@ For reference, here’s a full list of view ops in PyTorch:
|
||||||
- :attr:`~torch.Tensor.real`
|
- :attr:`~torch.Tensor.real`
|
||||||
- :attr:`~torch.Tensor.imag`
|
- :attr:`~torch.Tensor.imag`
|
||||||
- :meth:`~torch.Tensor.view_as_real`
|
- :meth:`~torch.Tensor.view_as_real`
|
||||||
- :meth:`~torch.Tensor.view_as_imag`
|
|
||||||
- :meth:`~torch.Tensor.unflatten`
|
- :meth:`~torch.Tensor.unflatten`
|
||||||
- :meth:`~torch.Tensor.unfold`
|
- :meth:`~torch.Tensor.unfold`
|
||||||
- :meth:`~torch.Tensor.unsqueeze`
|
- :meth:`~torch.Tensor.unsqueeze`
|
||||||
|
|
|
||||||
|
|
@ -276,6 +276,9 @@ Tensor class reference
|
||||||
Tensor.contiguous
|
Tensor.contiguous
|
||||||
Tensor.copy_
|
Tensor.copy_
|
||||||
Tensor.conj
|
Tensor.conj
|
||||||
|
Tensor.conj_physical
|
||||||
|
Tensor.conj_physical_
|
||||||
|
Tensor.resolve_conj
|
||||||
Tensor.copysign
|
Tensor.copysign
|
||||||
Tensor.copysign_
|
Tensor.copysign_
|
||||||
Tensor.cos
|
Tensor.cos
|
||||||
|
|
@ -410,6 +413,7 @@ Tensor class reference
|
||||||
Tensor.isnan
|
Tensor.isnan
|
||||||
Tensor.is_contiguous
|
Tensor.is_contiguous
|
||||||
Tensor.is_complex
|
Tensor.is_complex
|
||||||
|
Tensor.is_conj
|
||||||
Tensor.is_floating_point
|
Tensor.is_floating_point
|
||||||
Tensor.is_inference
|
Tensor.is_inference
|
||||||
Tensor.is_leaf
|
Tensor.is_leaf
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,7 @@ Tensors
|
||||||
is_tensor
|
is_tensor
|
||||||
is_storage
|
is_storage
|
||||||
is_complex
|
is_complex
|
||||||
|
is_conj
|
||||||
is_floating_point
|
is_floating_point
|
||||||
is_nonzero
|
is_nonzero
|
||||||
set_default_dtype
|
set_default_dtype
|
||||||
|
|
@ -86,6 +87,7 @@ Indexing, Slicing, Joining, Mutating Ops
|
||||||
:nosignatures:
|
:nosignatures:
|
||||||
|
|
||||||
cat
|
cat
|
||||||
|
conj
|
||||||
chunk
|
chunk
|
||||||
dsplit
|
dsplit
|
||||||
column_stack
|
column_stack
|
||||||
|
|
@ -293,7 +295,7 @@ Pointwise Ops
|
||||||
ceil
|
ceil
|
||||||
clamp
|
clamp
|
||||||
clip
|
clip
|
||||||
conj
|
conj_physical
|
||||||
copysign
|
copysign
|
||||||
cos
|
cos
|
||||||
cosh
|
cosh
|
||||||
|
|
@ -511,6 +513,7 @@ Other Operations
|
||||||
vander
|
vander
|
||||||
view_as_real
|
view_as_real
|
||||||
view_as_complex
|
view_as_complex
|
||||||
|
resolve_conj
|
||||||
|
|
||||||
|
|
||||||
BLAS and LAPACK Operations
|
BLAS and LAPACK Operations
|
||||||
|
|
|
||||||
|
|
@ -91,6 +91,9 @@ allow_list = [
|
||||||
("aten::linalg_vector_norm", datetime.date(2021, 5, 15)),
|
("aten::linalg_vector_norm", datetime.date(2021, 5, 15)),
|
||||||
("aten::repeat_interleave", datetime.date(2021, 6, 26)),
|
("aten::repeat_interleave", datetime.date(2021, 6, 26)),
|
||||||
("aten::one_hot", datetime.date(2021, 6, 15)),
|
("aten::one_hot", datetime.date(2021, 6, 15)),
|
||||||
|
("aten::conj", datetime.date(2021, 8, 1)),
|
||||||
|
("aten::_conj", datetime.date(2021, 8, 1)),
|
||||||
|
("aten::conj.out", datetime.date(2021, 8, 1)),
|
||||||
]
|
]
|
||||||
|
|
||||||
def allow_listed(schema, allow_list):
|
def allow_listed(schema, allow_list):
|
||||||
|
|
|
||||||
|
|
@ -4279,7 +4279,7 @@ class TestAutograd(TestCase):
|
||||||
gradcheck(bad_fn, (x, y), check_forward_ad=True, fast_mode=fast_mode)
|
gradcheck(bad_fn, (x, y), check_forward_ad=True, fast_mode=fast_mode)
|
||||||
|
|
||||||
def basic_mul(x):
|
def basic_mul(x):
|
||||||
return torch.view_as_real(x * 1j)
|
return torch.view_as_real(torch.resolve_conj(x * 1j))
|
||||||
gradcheck(basic_mul, x, check_forward_ad=True, fast_mode=fast_mode)
|
gradcheck(basic_mul, x, check_forward_ad=True, fast_mode=fast_mode)
|
||||||
|
|
||||||
# Test for one input and one output being complex
|
# Test for one input and one output being complex
|
||||||
|
|
@ -5459,7 +5459,7 @@ complex_list = ['t', 'view', 'reshape', 'reshape_as', 'view_as', 'roll', 'clone'
|
||||||
'permute', 'squeeze', 'unsqueeze', 'resize', 'resize_as', 'tril', 'triu',
|
'permute', 'squeeze', 'unsqueeze', 'resize', 'resize_as', 'tril', 'triu',
|
||||||
'chunk', 'split', 'split_with_sizes', 'zero_',
|
'chunk', 'split', 'split_with_sizes', 'zero_',
|
||||||
'__radd__', 'mul', '__rmul__', 'diagonal', 'fill_', 'sub', 'narrow',
|
'__radd__', 'mul', '__rmul__', 'diagonal', 'fill_', 'sub', 'narrow',
|
||||||
'swapaxes', 'swapdims', 'tensor_split'] + separate_complex_tests
|
'swapaxes', 'swapdims', 'tensor_split', 'select', 'clone'] + separate_complex_tests
|
||||||
|
|
||||||
# deny list for batched grad computation
|
# deny list for batched grad computation
|
||||||
EXCLUDE_BATCHED_GRAD_TESTS = set([
|
EXCLUDE_BATCHED_GRAD_TESTS = set([
|
||||||
|
|
@ -5591,7 +5591,9 @@ def add_test(
|
||||||
output_variable = getattr(self_variable, name)(*args_variable, **kwargs_variable)
|
output_variable = getattr(self_variable, name)(*args_variable, **kwargs_variable)
|
||||||
if not isinstance(output_variable, tuple):
|
if not isinstance(output_variable, tuple):
|
||||||
output_variable = (output_variable,)
|
output_variable = (output_variable,)
|
||||||
|
|
||||||
inplace_self_variable = deepcopy(self_variable)
|
inplace_self_variable = deepcopy(self_variable)
|
||||||
|
self.assertEqual(inplace_self_variable, self_variable)
|
||||||
inplace_self_variable_copy = tuple(i.clone() if isinstance(i, torch.Tensor) else i
|
inplace_self_variable_copy = tuple(i.clone() if isinstance(i, torch.Tensor) else i
|
||||||
for i in (inplace_self_variable,))
|
for i in (inplace_self_variable,))
|
||||||
inplace_args_variable = deepcopy(args_variable)
|
inplace_args_variable = deepcopy(args_variable)
|
||||||
|
|
|
||||||
|
|
@ -1946,6 +1946,7 @@ known_failures = [
|
||||||
|
|
||||||
# If your OpInfo test causes this test to fail, add it here
|
# If your OpInfo test causes this test to fail, add it here
|
||||||
skip_ops = [
|
skip_ops = [
|
||||||
|
'conj'
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_name(op):
|
def get_name(op):
|
||||||
|
|
|
||||||
109
test/test_ops.py
109
test/test_ops.py
|
|
@ -17,7 +17,7 @@ from torch.testing._internal.common_jit import JitCommonTestCase, check_against_
|
||||||
from torch.testing._internal.jit_metaprogramming_utils import create_script_fn, create_traced_fn, \
|
from torch.testing._internal.jit_metaprogramming_utils import create_script_fn, create_traced_fn, \
|
||||||
check_alias_annotation
|
check_alias_annotation
|
||||||
from torch.testing._internal.jit_utils import disable_autodiff_subgraph_inlining
|
from torch.testing._internal.jit_utils import disable_autodiff_subgraph_inlining
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
# Get names of all the operators which have entry in `method_tests` (legacy testing infra)
|
# Get names of all the operators which have entry in `method_tests` (legacy testing infra)
|
||||||
method_tested_operators = set(map(lambda test_details: test_details[0], method_tests()))
|
method_tested_operators = set(map(lambda test_details: test_details[0], method_tests()))
|
||||||
|
|
@ -110,7 +110,9 @@ class TestGradients(TestCase):
|
||||||
return variant.__wrapped__ is op.get_inplace()
|
return variant.__wrapped__ is op.get_inplace()
|
||||||
return variant is op.get_inplace()
|
return variant is op.get_inplace()
|
||||||
|
|
||||||
samples = op.sample_inputs(device, dtype, requires_grad=True)
|
include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex
|
||||||
|
samples = op.sample_inputs(device, dtype, requires_grad=True, include_conjugated_inputs=include_conjugated_inputs)
|
||||||
|
|
||||||
for sample in samples:
|
for sample in samples:
|
||||||
if sample.broadcasts_input and is_inplace(variant):
|
if sample.broadcasts_input and is_inplace(variant):
|
||||||
continue
|
continue
|
||||||
|
|
@ -280,7 +282,8 @@ class TestCommon(JitCommonTestCase):
|
||||||
_requires_grad = (op.supports_autograd and
|
_requires_grad = (op.supports_autograd and
|
||||||
(dtype.is_floating_point or op.supports_complex_autograd(torch.device(device).type)))
|
(dtype.is_floating_point or op.supports_complex_autograd(torch.device(device).type)))
|
||||||
|
|
||||||
samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad)
|
include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex
|
||||||
|
samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad, include_conjugated_inputs=include_conjugated_inputs)
|
||||||
|
|
||||||
def _test_consistency_helper(samples, variants):
|
def _test_consistency_helper(samples, variants):
|
||||||
for sample in samples:
|
for sample in samples:
|
||||||
|
|
@ -381,7 +384,8 @@ class TestCommon(JitCommonTestCase):
|
||||||
_requires_grad = op.supports_autograd and (dtype.is_floating_point or
|
_requires_grad = op.supports_autograd and (dtype.is_floating_point or
|
||||||
op.supports_complex_autograd(torch.device(device).type))
|
op.supports_complex_autograd(torch.device(device).type))
|
||||||
|
|
||||||
samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad)
|
include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex
|
||||||
|
samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad, include_conjugated_inputs=include_conjugated_inputs)
|
||||||
|
|
||||||
for sample in samples:
|
for sample in samples:
|
||||||
# Acquires variants to test
|
# Acquires variants to test
|
||||||
|
|
@ -734,6 +738,103 @@ class TestCommon(JitCommonTestCase):
|
||||||
with self.assertRaises(RuntimeError, msg=msg_fail):
|
with self.assertRaises(RuntimeError, msg=msg_fail):
|
||||||
op_out(out=out)
|
op_out(out=out)
|
||||||
|
|
||||||
|
# Tests that
|
||||||
|
# 1. The operator's output for physically conjugated tensors and conjugate view tensors
|
||||||
|
# produces the same value
|
||||||
|
# 2. The gradients are same in both cases mentioned in (1)
|
||||||
|
# 3. If the operator's inplace variant is supported, tests that the inplace operation
|
||||||
|
# produces the correct value when called on a conjugate view tensor and that the output
|
||||||
|
# has its conj bit set to true
|
||||||
|
# This test only runs for C -> R and C -> C functions
|
||||||
|
# TODO: add tests for `R->C` functions
|
||||||
|
# Note: This test runs for functions that take both tensors and tensorlists as input.
|
||||||
|
@ops(op_db, allowed_dtypes=(torch.cfloat,))
|
||||||
|
def test_conj_view(self, device, dtype, op):
|
||||||
|
if not op.test_conjugated_samples:
|
||||||
|
self.skipTest("Operation doesn't support conjugated inputs.")
|
||||||
|
_requires_grad = (op.supports_autograd and op.supports_complex_autograd(torch.device(device).type))
|
||||||
|
samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad)
|
||||||
|
inplace_variant = op.inplace_variant
|
||||||
|
|
||||||
|
# helper function to physically conjugate the tensor
|
||||||
|
def conjugate_physical(input):
|
||||||
|
if isinstance(input, torch.Tensor):
|
||||||
|
tensor_requires_grad = input.requires_grad
|
||||||
|
with torch.no_grad():
|
||||||
|
input = input.conj_physical()
|
||||||
|
return input.requires_grad_(tensor_requires_grad)
|
||||||
|
|
||||||
|
if isinstance(input, Sequence):
|
||||||
|
out = list(map(clone_input_helper, input))
|
||||||
|
out[0] = conjugate_physical(out[0])
|
||||||
|
return tuple(out)
|
||||||
|
|
||||||
|
# helper function to clone and conjugate the input if its a tensor
|
||||||
|
# else clone the sequence and conjugate the first element in the sequence
|
||||||
|
# If a requires_grad argument is provided the tensor being conjugated will
|
||||||
|
# have its requires_grad set to that value.
|
||||||
|
def clone_conj_input_helper(input, **kwargs):
|
||||||
|
if isinstance(input, torch.Tensor):
|
||||||
|
requires_grad = kwargs.get('requires_grad', input.requires_grad)
|
||||||
|
with torch.no_grad():
|
||||||
|
input = input.clone()
|
||||||
|
# Note: .conj() is not called under no_grad mode since it's not allowed to modify a
|
||||||
|
# view created in no_grad mode. Here it's ok to do so, so as a workaround we call conj
|
||||||
|
# before resetting the requires_grad field for input
|
||||||
|
input = input.conj()
|
||||||
|
assert input.is_leaf
|
||||||
|
return input.requires_grad_(requires_grad)
|
||||||
|
|
||||||
|
if isinstance(input, Sequence):
|
||||||
|
out = list(map(clone_input_helper, input))
|
||||||
|
out[0] = clone_conj_input_helper(out[0])
|
||||||
|
return tuple(out)
|
||||||
|
|
||||||
|
for sample in samples:
|
||||||
|
tensor = sample.input if isinstance(sample.input, torch.Tensor) else sample.input[0]
|
||||||
|
cloned1 = clone_conj_input_helper(sample.input)
|
||||||
|
sample.input = conjugate_physical(sample.input)
|
||||||
|
|
||||||
|
# Computes function forward value with a physically conjugated tensor and
|
||||||
|
# a conj view tensor and verifies that the output in both case are equal.
|
||||||
|
expected_forward = op(sample.input, *sample.args, **sample.kwargs)
|
||||||
|
forward_with_conjview = op(cloned1, *sample.args, **sample.kwargs)
|
||||||
|
self.assertEqual(expected_forward, forward_with_conjview)
|
||||||
|
|
||||||
|
# If the op has an inplace variant, and the input doesn't require broadcasting
|
||||||
|
# and has the same dtype as output, verify that the inplace operation on a conjugated
|
||||||
|
# input produces correct output, and the output tensor has the conj bit set to True
|
||||||
|
if inplace_variant is not None and not sample.broadcasts_input:
|
||||||
|
cloned2 = clone_conj_input_helper(tensor, requires_grad=False)
|
||||||
|
if (isinstance(expected_forward, torch.Tensor) and
|
||||||
|
expected_forward.dtype is tensor.dtype):
|
||||||
|
inplace_forward = inplace_variant(cloned2, *sample.args, **sample.kwargs)
|
||||||
|
self.assertTrue(inplace_forward.is_conj())
|
||||||
|
self.assertEqual(inplace_forward, expected_forward)
|
||||||
|
|
||||||
|
# TODO: backward consistency only supported for single tensor outputs
|
||||||
|
# TODO: backward consistency only checked on sample.input, not all
|
||||||
|
# tensor inputs
|
||||||
|
# TODO: update to handle checking grads of all tensor inputs as
|
||||||
|
# derived from each tensor output
|
||||||
|
if isinstance(expected_forward, torch.Tensor) and expected_forward.requires_grad:
|
||||||
|
tensor = sample.input if isinstance(sample.input, torch.Tensor) else sample.input[0]
|
||||||
|
expected_forward.sum().backward(retain_graph=True)
|
||||||
|
forward_with_conjview.sum().backward(retain_graph=True)
|
||||||
|
if tensor.grad is not None:
|
||||||
|
cloned1_tensor = cloned1 if isinstance(cloned1, torch.Tensor) else cloned1[0]
|
||||||
|
self.assertEqual(tensor.grad, cloned1_tensor.grad)
|
||||||
|
|
||||||
|
tensor.grad, cloned1_tensor.grad = None, None
|
||||||
|
|
||||||
|
# a repeat of the above test if output is not complex valued
|
||||||
|
if (expected_forward.is_complex()):
|
||||||
|
grad = torch.randn_like(expected_forward)
|
||||||
|
expected_forward.backward(grad.conj_physical())
|
||||||
|
forward_with_conjview.backward(grad.conj())
|
||||||
|
|
||||||
|
self.assertEqual(tensor.grad, cloned1_tensor.grad)
|
||||||
|
|
||||||
|
|
||||||
instantiate_device_type_tests(TestOpInfo, globals())
|
instantiate_device_type_tests(TestOpInfo, globals())
|
||||||
instantiate_device_type_tests(TestGradients, globals())
|
instantiate_device_type_tests(TestGradients, globals())
|
||||||
|
|
|
||||||
|
|
@ -346,6 +346,14 @@ class TestViewOps(TestCase):
|
||||||
self.assertEqual(a[5:].real, a.real[5:])
|
self.assertEqual(a[5:].real, a.real[5:])
|
||||||
self.assertEqual(a[5:].imag, a.imag[5:])
|
self.assertEqual(a[5:].imag, a.imag[5:])
|
||||||
|
|
||||||
|
@onlyOnCPUAndCUDA
|
||||||
|
@dtypes(*torch.testing.get_all_complex_dtypes())
|
||||||
|
def test_conj_view(self, device, dtype) -> None:
|
||||||
|
t = _make_tensor((4, 5,), dtype, device)
|
||||||
|
v = t.conj()
|
||||||
|
self.assertTrue(self.is_view_of(t, v))
|
||||||
|
self.assertEqual(v, torch.from_numpy(t.cpu().numpy().conj()).to(device=device))
|
||||||
|
|
||||||
@onlyOnCPUAndCUDA
|
@onlyOnCPUAndCUDA
|
||||||
@dtypes(*product(torch.testing.get_all_complex_dtypes(), torch.testing.get_all_dtypes()))
|
@dtypes(*product(torch.testing.get_all_complex_dtypes(), torch.testing.get_all_dtypes()))
|
||||||
@suppress_warnings
|
@suppress_warnings
|
||||||
|
|
|
||||||
|
|
@ -380,16 +380,23 @@
|
||||||
|
|
||||||
- name: complex(Tensor real, Tensor imag) -> Tensor
|
- name: complex(Tensor real, Tensor imag) -> Tensor
|
||||||
real: at::real(grad)
|
real: at::real(grad)
|
||||||
imag: at::imag(grad)
|
imag: at::imag(grad.resolve_conj())
|
||||||
result: at::complex(real_t, imag_t)
|
result: at::complex(real_t, imag_t)
|
||||||
|
|
||||||
- name: polar(Tensor abs, Tensor angle) -> Tensor
|
- name: polar(Tensor abs, Tensor angle) -> Tensor
|
||||||
abs, angle: polar_backward(grad, result)
|
abs, angle: polar_backward(grad, result)
|
||||||
|
|
||||||
- name: _conj(Tensor self) -> Tensor
|
- name: _conj(Tensor(a) self) -> Tensor(a)
|
||||||
self: grad.conj()
|
self: grad.conj()
|
||||||
result: self_t.conj()
|
result: self_t.conj()
|
||||||
|
|
||||||
|
- name: _conj_physical(Tensor self) -> Tensor
|
||||||
|
self: grad.conj_physical()
|
||||||
|
result: self_t.conj_physical()
|
||||||
|
|
||||||
|
- name: conj_physical_(Tensor(a!) self) -> Tensor(a!)
|
||||||
|
self: grad.conj_physical()
|
||||||
|
|
||||||
- name: copysign.Tensor(Tensor self, Tensor other) -> Tensor
|
- name: copysign.Tensor(Tensor self, Tensor other) -> Tensor
|
||||||
self: copysign_tensor_self_backward(grad, self, result)
|
self: copysign_tensor_self_backward(grad, self, result)
|
||||||
other: zeros_like(other)
|
other: zeros_like(other)
|
||||||
|
|
@ -1306,12 +1313,16 @@
|
||||||
- name: view.dtype(Tensor(a) self, ScalarType dtype) -> Tensor(a)
|
- name: view.dtype(Tensor(a) self, ScalarType dtype) -> Tensor(a)
|
||||||
output_differentiability: [False]
|
output_differentiability: [False]
|
||||||
|
|
||||||
|
# If the conj bit for self is set, this is effectively a fused conj() and view_as_real
|
||||||
|
- name: _view_as_real_physical(Tensor(a) self) -> Tensor(a)
|
||||||
|
self: "self.is_conj() ? at::view_as_complex(grad.contiguous()) : at::view_as_complex(grad.contiguous()).conj()"
|
||||||
|
|
||||||
- name: view_as_real(Tensor(a) self) -> Tensor(a)
|
- name: view_as_real(Tensor(a) self) -> Tensor(a)
|
||||||
self: at::view_as_complex(grad.contiguous()) # gx0 + 1j * gx1
|
self: at::view_as_complex(grad.contiguous()) # gx0 + 1j * gx1
|
||||||
result: at::view_as_real(self_t)
|
result: at::view_as_real(self_t)
|
||||||
|
|
||||||
- name: view_as_complex(Tensor(a) self) -> Tensor(a)
|
- name: view_as_complex(Tensor(a) self) -> Tensor(a)
|
||||||
self: at::view_as_real(grad.contiguous()) # [gx, gy]
|
self: at::view_as_real(grad.contiguous().resolve_conj()) # [gx, gy]
|
||||||
result: at::view_as_complex(self_t)
|
result: at::view_as_complex(self_t)
|
||||||
|
|
||||||
- name: _s_where(Tensor condition, Tensor self, Tensor other) -> Tensor
|
- name: _s_where(Tensor condition, Tensor self, Tensor other) -> Tensor
|
||||||
|
|
|
||||||
|
|
@ -28,7 +28,7 @@ from .gen_trace_type import (
|
||||||
#
|
#
|
||||||
# A map: function name => name of the argument that all outputs are view of
|
# A map: function name => name of the argument that all outputs are view of
|
||||||
|
|
||||||
VIEW_FUNCTIONS_WITH_METADATA_CHANGE = ['view_as_real', 'view_as_complex']
|
VIEW_FUNCTIONS_WITH_METADATA_CHANGE = ['view_as_complex', '_view_as_real_physical', 'view_as_real', '_conj']
|
||||||
|
|
||||||
VIEW_FUNCTIONS = {
|
VIEW_FUNCTIONS = {
|
||||||
'numpy_T': 'self',
|
'numpy_T': 'self',
|
||||||
|
|
|
||||||
|
|
@ -100,7 +100,8 @@ GRADIENT_IMPLEMENTED_FOR_COMPLEX = {
|
||||||
'replication_pad1d_backward', 'replication_pad2d_backward', 'replication_pad3d_backward',
|
'replication_pad1d_backward', 'replication_pad2d_backward', 'replication_pad3d_backward',
|
||||||
'diag', 'masked_scatter', 'masked_select', 'index_fill', 'trace', 'polar', 'cumsum', 'rsub',
|
'diag', 'masked_scatter', 'masked_select', 'index_fill', 'trace', 'polar', 'cumsum', 'rsub',
|
||||||
'eig', 'lerp', 'linalg_vector_norm', 'cumprod', 'prod', 'index_copy', 'lu', 'unfold', 'unfold_backward',
|
'eig', 'lerp', 'linalg_vector_norm', 'cumprod', 'prod', 'index_copy', 'lu', 'unfold', 'unfold_backward',
|
||||||
'index', 'masked_fill', 'cross', 'lu_unpack', 'renorm',
|
'index', 'masked_fill', 'cross', 'lu_unpack', 'renorm', '_view_as_real_physical', '_conj_physical',
|
||||||
|
'conj_physical_'
|
||||||
}
|
}
|
||||||
|
|
||||||
GRADIENT_IMPLEMENTED_FOR_SPARSE_COMPLEX = {
|
GRADIENT_IMPLEMENTED_FOR_SPARSE_COMPLEX = {
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ import yaml
|
||||||
from tools.codegen.api.autograd import (Derivative, DifferentiabilityInfo,
|
from tools.codegen.api.autograd import (Derivative, DifferentiabilityInfo,
|
||||||
SavedAttribute, ForwardDerivative)
|
SavedAttribute, ForwardDerivative)
|
||||||
from tools.codegen.api.types import (Binding, CppSignatureGroup, NamedCType, BaseCType, VectorCType,
|
from tools.codegen.api.types import (Binding, CppSignatureGroup, NamedCType, BaseCType, VectorCType,
|
||||||
intArrayRefT, tensorOptionsT, typeAndSizeT, intT,
|
intArrayRefT, tensorOptionsT, typeAndSizeT, intT, boolT,
|
||||||
tensorGeometryT, scalarTypeT, SpecialArgName,
|
tensorGeometryT, scalarTypeT, SpecialArgName,
|
||||||
OptionalCType, stringT)
|
OptionalCType, stringT)
|
||||||
from tools.codegen.api import cpp
|
from tools.codegen.api import cpp
|
||||||
|
|
@ -498,6 +498,11 @@ def saved_variables(
|
||||||
'nctype': lambda name: NamedCType(name, BaseCType(intArrayRefT)),
|
'nctype': lambda name: NamedCType(name, BaseCType(intArrayRefT)),
|
||||||
'expr': stride_expr,
|
'expr': stride_expr,
|
||||||
}),
|
}),
|
||||||
|
# replace self.is_conj() with self_conjugate
|
||||||
|
(r'{}.is_conj\(\)', {
|
||||||
|
'suffix': '_conjugate',
|
||||||
|
'nctype': lambda name: NamedCType(name, BaseCType(boolT)),
|
||||||
|
})
|
||||||
]
|
]
|
||||||
|
|
||||||
# find which arguments need to be saved
|
# find which arguments need to be saved
|
||||||
|
|
|
||||||
|
|
@ -86,8 +86,10 @@ class Tensor(torch._C._TensorBase):
|
||||||
self.requires_grad,
|
self.requires_grad,
|
||||||
self._backward_hooks)
|
self._backward_hooks)
|
||||||
else:
|
else:
|
||||||
new_tensor = self.new()
|
new_tensor = self.new_empty([])
|
||||||
new_tensor.set_(new_storage, self.storage_offset(), self.size(), self.stride())
|
new_tensor.set_(new_storage, self.storage_offset(), self.size(), self.stride())
|
||||||
|
if self.is_conj():
|
||||||
|
new_tensor = new_tensor.conj_physical()
|
||||||
new_tensor.requires_grad = self.requires_grad
|
new_tensor.requires_grad = self.requires_grad
|
||||||
if self.grad is not None:
|
if self.grad is not None:
|
||||||
new_tensor.grad = self.grad.__deepcopy__(memo)
|
new_tensor.grad = self.grad.__deepcopy__(memo)
|
||||||
|
|
|
||||||
|
|
@ -893,6 +893,27 @@ conj() -> Tensor
|
||||||
See :func:`torch.conj`
|
See :func:`torch.conj`
|
||||||
""")
|
""")
|
||||||
|
|
||||||
|
add_docstr_all('conj_physical',
|
||||||
|
r"""
|
||||||
|
conj_physical() -> Tensor
|
||||||
|
|
||||||
|
See :func:`torch.conj_physical`
|
||||||
|
""")
|
||||||
|
|
||||||
|
add_docstr_all('conj_physical_',
|
||||||
|
r"""
|
||||||
|
conj_physical_() -> Tensor
|
||||||
|
|
||||||
|
In-place version of :meth:`~Tensor.conj_physical`
|
||||||
|
""")
|
||||||
|
|
||||||
|
add_docstr_all('resolve_conj',
|
||||||
|
r"""
|
||||||
|
resolve_conj() -> Tensor
|
||||||
|
|
||||||
|
See :func:`torch.resolve_conj`
|
||||||
|
""")
|
||||||
|
|
||||||
add_docstr_all('copysign',
|
add_docstr_all('copysign',
|
||||||
r"""
|
r"""
|
||||||
copysign(other) -> Tensor
|
copysign(other) -> Tensor
|
||||||
|
|
@ -1977,6 +1998,13 @@ is_inference() -> bool
|
||||||
See :func:`torch.is_inference`
|
See :func:`torch.is_inference`
|
||||||
""")
|
""")
|
||||||
|
|
||||||
|
add_docstr_all('is_conj',
|
||||||
|
r"""
|
||||||
|
is_conj() -> bool
|
||||||
|
|
||||||
|
Returns True if the conjugate bit of :attr:`self` is set to true.
|
||||||
|
""")
|
||||||
|
|
||||||
add_docstr_all('is_signed',
|
add_docstr_all('is_signed',
|
||||||
r"""
|
r"""
|
||||||
is_signed() -> bool
|
is_signed() -> bool
|
||||||
|
|
|
||||||
|
|
@ -153,11 +153,12 @@ class _Formatter(object):
|
||||||
def _scalar_str(self, formatter1, formatter2=None):
|
def _scalar_str(self, formatter1, formatter2=None):
|
||||||
if formatter2 is not None:
|
if formatter2 is not None:
|
||||||
real_str = _scalar_str(self.real, formatter1)
|
real_str = _scalar_str(self.real, formatter1)
|
||||||
imag_str = _scalar_str(self.imag, formatter2) + "j"
|
imag_str = (_scalar_str(self.imag, formatter2) + "j").lstrip()
|
||||||
if self.imag < 0:
|
# handles negative numbers, +0.0, -0.0
|
||||||
return real_str + imag_str.lstrip()
|
if imag_str[0] == '+' or imag_str[0] == '-':
|
||||||
|
return real_str + imag_str
|
||||||
else:
|
else:
|
||||||
return real_str + "+" + imag_str.lstrip()
|
return real_str + "+" + imag_str
|
||||||
else:
|
else:
|
||||||
return formatter1.format(self.item())
|
return formatter1.format(self.item())
|
||||||
|
|
||||||
|
|
@ -174,11 +175,12 @@ def _vector_str(self, indent, summarize, formatter1, formatter2=None):
|
||||||
def _val_formatter(val, formatter1=formatter1, formatter2=formatter2):
|
def _val_formatter(val, formatter1=formatter1, formatter2=formatter2):
|
||||||
if formatter2 is not None:
|
if formatter2 is not None:
|
||||||
real_str = formatter1.format(val.real)
|
real_str = formatter1.format(val.real)
|
||||||
imag_str = formatter2.format(val.imag) + "j"
|
imag_str = (formatter2.format(val.imag) + "j").lstrip()
|
||||||
if val.imag < 0:
|
# handles negative numbers, +0.0, -0.0
|
||||||
return real_str + imag_str.lstrip()
|
if imag_str[0] == '+' or imag_str[0] == '-':
|
||||||
|
return real_str + imag_str
|
||||||
else:
|
else:
|
||||||
return real_str + "+" + imag_str.lstrip()
|
return real_str + "+" + imag_str
|
||||||
else:
|
else:
|
||||||
return formatter1.format(val)
|
return formatter1.format(val)
|
||||||
|
|
||||||
|
|
@ -235,6 +237,7 @@ def _tensor_str(self, indent):
|
||||||
self = self.float()
|
self = self.float()
|
||||||
|
|
||||||
if self.dtype.is_complex:
|
if self.dtype.is_complex:
|
||||||
|
self = self.resolve_conj()
|
||||||
real_formatter = _Formatter(get_summarized_data(self.real) if summarize else self.real)
|
real_formatter = _Formatter(get_summarized_data(self.real) if summarize else self.real)
|
||||||
imag_formatter = _Formatter(get_summarized_data(self.imag) if summarize else self.imag)
|
imag_formatter = _Formatter(get_summarized_data(self.imag) if summarize else self.imag)
|
||||||
return _tensor_str_with_formatter(self, indent, summarize, real_formatter, imag_formatter)
|
return _tensor_str_with_formatter(self, indent, summarize, real_formatter, imag_formatter)
|
||||||
|
|
|
||||||
|
|
@ -2181,15 +2181,18 @@ Example::
|
||||||
tensor([(0.0000+1.0000j), (-1.4142-1.4142j)], dtype=torch.complex128)
|
tensor([(0.0000+1.0000j), (-1.4142-1.4142j)], dtype=torch.complex128)
|
||||||
""")
|
""")
|
||||||
|
|
||||||
add_docstr(torch.conj,
|
add_docstr(torch.conj_physical,
|
||||||
r"""
|
r"""
|
||||||
conj(input, *, out=None) -> Tensor
|
conj_physical(input, *, out=None) -> Tensor
|
||||||
|
|
||||||
Computes the element-wise conjugate of the given :attr:`input` tensor. If :attr:`input` has a non-complex dtype,
|
Computes the element-wise conjugate of the given :attr:`input` tensor.
|
||||||
this function just returns :attr:`input`.
|
If :attr:`input` has a non-complex dtype, this function just returns :attr:`input`.
|
||||||
|
|
||||||
.. warning:: In the future, :func:`torch.conj` may return a non-writeable view for an :attr:`input` of
|
.. note::
|
||||||
non-complex dtype. It's recommended that programs not modify the tensor returned by :func:`torch.conj`
|
This performs the conjugate operation regardless of the fact conjugate bit is set or not.
|
||||||
|
|
||||||
|
.. warning:: In the future, :func:`torch.conj_physical` may return a non-writeable view for an :attr:`input` of
|
||||||
|
non-complex dtype. It's recommended that programs not modify the tensor returned by :func:`torch.conj_physical`
|
||||||
when :attr:`input` is of non-complex dtype to be compatible with this change.
|
when :attr:`input` is of non-complex dtype to be compatible with this change.
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
|
|
@ -2203,10 +2206,63 @@ Keyword args:
|
||||||
|
|
||||||
Example::
|
Example::
|
||||||
|
|
||||||
>>> torch.conj(torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j]))
|
>>> torch.conj_physical(torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j]))
|
||||||
tensor([-1 - 1j, -2 - 2j, 3 + 3j])
|
tensor([-1 - 1j, -2 - 2j, 3 + 3j])
|
||||||
""".format(**common_args))
|
""".format(**common_args))
|
||||||
|
|
||||||
|
add_docstr(torch.conj,
|
||||||
|
r"""
|
||||||
|
conj(input) -> Tensor
|
||||||
|
|
||||||
|
Returns a view of :attr:`input` with a flipped conjugate bit. If :attr:`input` has a non-complex dtype,
|
||||||
|
this function just returns :attr:`input`.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
:func:`torch.conj` performs a lazy conjugation, but the actual conjugated tensor can be materialized
|
||||||
|
at any time using :func:`torch.resolve_conj`.
|
||||||
|
|
||||||
|
.. warning:: In the future, :func:`torch.conj` may return a non-writeable view for an :attr:`input` of
|
||||||
|
non-complex dtype. It's recommended that programs not modify the tensor returned by :func:`torch.conj_physical`
|
||||||
|
when :attr:`input` is of non-complex dtype to be compatible with this change.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
{input}
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
>>> x = torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j])
|
||||||
|
>>> x.is_conj()
|
||||||
|
False
|
||||||
|
>>> y = torch.conj(x)
|
||||||
|
>>> y.is_conj()
|
||||||
|
True
|
||||||
|
|
||||||
|
""")
|
||||||
|
|
||||||
|
add_docstr(torch.resolve_conj,
|
||||||
|
r"""
|
||||||
|
resolve_conj(input) -> Tensor
|
||||||
|
|
||||||
|
Returns a new tensor with materialized conjugation if :attr:`input`'s conjugate bit is set to `True`,
|
||||||
|
else returns :attr:`input`. The output tensor will always have its conjugate bit set to `False`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
{input}
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
>>> x = torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j])
|
||||||
|
>>> y = x.conj()
|
||||||
|
>>> y.is_conj()
|
||||||
|
True
|
||||||
|
>>> z = y.resolve_conj()
|
||||||
|
>>> z
|
||||||
|
tensor([-1 - 1j, -2 - 2j, 3 + 3j])
|
||||||
|
>>> z.is_conj()
|
||||||
|
False
|
||||||
|
|
||||||
|
""")
|
||||||
|
|
||||||
add_docstr(torch.copysign,
|
add_docstr(torch.copysign,
|
||||||
r"""
|
r"""
|
||||||
copysign(input, other, *, out=None) -> Tensor
|
copysign(input, other, *, out=None) -> Tensor
|
||||||
|
|
@ -4201,6 +4257,15 @@ Args:
|
||||||
{input}
|
{input}
|
||||||
""".format(**common_args))
|
""".format(**common_args))
|
||||||
|
|
||||||
|
add_docstr(torch.is_conj, r"""
|
||||||
|
is_conj(input) -> (bool)
|
||||||
|
|
||||||
|
Returns True if the :attr:`input` is a conjugated tensor, i.e. its conjugate bit is set to `True`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
{input}
|
||||||
|
""".format(**common_args))
|
||||||
|
|
||||||
add_docstr(torch.is_nonzero, r"""
|
add_docstr(torch.is_nonzero, r"""
|
||||||
is_nonzero(input) -> (bool)
|
is_nonzero(input) -> (bool)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -545,7 +545,7 @@ def _get_analytical_vJu_backward_mode(inputs, outputs, nondet_tol, check_grad_dt
|
||||||
# the error checking logic from slow mode
|
# the error checking logic from slow mode
|
||||||
vJ = vJ.T.squeeze(0)
|
vJ = vJ.T.squeeze(0)
|
||||||
if vJ.is_complex(): # C -> R
|
if vJ.is_complex(): # C -> R
|
||||||
tv = torch.view_as_real(vJ)
|
tv = torch.view_as_real(vJ.resolve_conj())
|
||||||
tr = tv.select(-1, 0)
|
tr = tv.select(-1, 0)
|
||||||
ti = tv.select(-1, 1)
|
ti = tv.select(-1, 1)
|
||||||
jacobian_scalars.append(tr.dot(u[0]) + 1j * ti.dot(u[1]))
|
jacobian_scalars.append(tr.dot(u[0]) + 1j * ti.dot(u[1]))
|
||||||
|
|
@ -894,7 +894,11 @@ def _real_and_imag_output(fn):
|
||||||
outs = _as_tuple(fn(*inputs))
|
outs = _as_tuple(fn(*inputs))
|
||||||
return tuple(fn_to_apply(o) if o.is_complex() else o for o in outs)
|
return tuple(fn_to_apply(o) if o.is_complex() else o for o in outs)
|
||||||
return wrapped_fn
|
return wrapped_fn
|
||||||
return apply_to_c_outs(fn, torch.real), apply_to_c_outs(fn, torch.imag)
|
|
||||||
|
# TODO(@anjali411): remove this workaround once neg bit is added.
|
||||||
|
def torch_imag(x):
|
||||||
|
return x.resolve_conj().imag
|
||||||
|
return apply_to_c_outs(fn, torch.real), apply_to_c_outs(fn, torch_imag)
|
||||||
|
|
||||||
def _real_and_imag_input(fn, complex_inp_indices):
|
def _real_and_imag_input(fn, complex_inp_indices):
|
||||||
# returns new functions that take real inputs instead of complex inputs and compute fn(x + 0 * 1j)
|
# returns new functions that take real inputs instead of complex inputs and compute fn(x + 0 * 1j)
|
||||||
|
|
@ -935,7 +939,7 @@ def _gradcheck_real_imag(gradcheck_fn, func, func_out, tupled_inputs, outputs, e
|
||||||
if complex_inp_indices:
|
if complex_inp_indices:
|
||||||
real_fn, imag_fn = _real_and_imag_input(func, complex_inp_indices)
|
real_fn, imag_fn = _real_and_imag_input(func, complex_inp_indices)
|
||||||
|
|
||||||
imag_inputs = [inp.imag if is_tensor_like(inp) and inp.is_complex() else inp for inp in tupled_inputs]
|
imag_inputs = [inp.resolve_conj().imag if is_tensor_like(inp) and inp.is_complex() else inp for inp in tupled_inputs]
|
||||||
imag_func_out = imag_fn(*imag_inputs)
|
imag_func_out = imag_fn(*imag_inputs)
|
||||||
diff_imag_func_out = _differentiable_outputs(imag_func_out)
|
diff_imag_func_out = _differentiable_outputs(imag_func_out)
|
||||||
gradcheck_fn(imag_fn, imag_func_out, imag_inputs, diff_imag_func_out, eps,
|
gradcheck_fn(imag_fn, imag_func_out, imag_inputs, diff_imag_func_out, eps,
|
||||||
|
|
|
||||||
|
|
@ -248,7 +248,7 @@ Tensor pow_backward_self(Tensor grad, const Tensor & self, const Tensor & expone
|
||||||
Tensor pow_backward_exponent(Tensor grad, const Tensor& self, const Tensor& exponent, Tensor result) {
|
Tensor pow_backward_exponent(Tensor grad, const Tensor& self, const Tensor& exponent, Tensor result) {
|
||||||
Tensor cond;
|
Tensor cond;
|
||||||
if (exponent.is_complex()) {
|
if (exponent.is_complex()) {
|
||||||
auto is_real_exp = at::logical_and(at::imag(exponent) == 0, at::real(exponent) >= 0);
|
auto is_real_exp = at::logical_and(at::imag(exponent.resolve_conj()) == 0, at::real(exponent) >= 0);
|
||||||
cond = at::logical_and(self == 0, is_real_exp);
|
cond = at::logical_and(self == 0, is_real_exp);
|
||||||
} else {
|
} else {
|
||||||
cond = at::logical_and(self == 0, exponent >= 0);
|
cond = at::logical_and(self == 0, exponent >= 0);
|
||||||
|
|
@ -264,7 +264,7 @@ Tensor pow_backward_exponent(Tensor grad, const Scalar & base, const Tensor& exp
|
||||||
if (base.equal(0.0)) {
|
if (base.equal(0.0)) {
|
||||||
auto cond = [](auto exp) {
|
auto cond = [](auto exp) {
|
||||||
if (exp.is_complex()) {
|
if (exp.is_complex()) {
|
||||||
return at::logical_and(at::imag(exp) == 0, at::real(exp) >= 0);
|
return at::logical_and(at::imag(exp.resolve_conj()) == 0, at::real(exp) >= 0);
|
||||||
} else {
|
} else {
|
||||||
return exp >=0;
|
return exp >=0;
|
||||||
}
|
}
|
||||||
|
|
@ -2154,7 +2154,7 @@ Tensor svd_backward(const std::vector<torch::autograd::Variable> &grads, const T
|
||||||
if (self.is_complex() && gu.defined()) {
|
if (self.is_complex() && gu.defined()) {
|
||||||
Tensor L = at::matmul(uh, gu).diagonal(0, -2, -1);
|
Tensor L = at::matmul(uh, gu).diagonal(0, -2, -1);
|
||||||
at::real(L).zero_();
|
at::real(L).zero_();
|
||||||
at::imag(L).mul_(sigma_inv);
|
at::imag(L.resolve_conj()).mul_(sigma_inv);
|
||||||
Tensor imag_term = at::matmul(u * L.unsqueeze(-2), vh);
|
Tensor imag_term = at::matmul(u * L.unsqueeze(-2), vh);
|
||||||
return u_term + sigma_term + v_term + imag_term;
|
return u_term + sigma_term + v_term + imag_term;
|
||||||
}
|
}
|
||||||
|
|
@ -2195,7 +2195,7 @@ Tensor eig_backward(const std::vector<torch::autograd::Variable> &grads, const T
|
||||||
}
|
}
|
||||||
// path for torch.linalg.eig with always a complex tensor of eigenvalues
|
// path for torch.linalg.eig with always a complex tensor of eigenvalues
|
||||||
else {
|
else {
|
||||||
is_imag_eigvals_zero = (at::imag(D) == 0.0).min().item<bool>();
|
is_imag_eigvals_zero = (at::imag(D.resolve_conj()) == 0.0).min().item<bool>();
|
||||||
// insert an additional dimension to be compatible with torch.eig.
|
// insert an additional dimension to be compatible with torch.eig.
|
||||||
// Recall that it produces 2D tensors.
|
// Recall that it produces 2D tensors.
|
||||||
// We extract only the real parts as there is no support for
|
// We extract only the real parts as there is no support for
|
||||||
|
|
|
||||||
|
|
@ -83,23 +83,6 @@ void nnc_aten_sgn(
|
||||||
} catch (...) {
|
} catch (...) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
void nnc_aten_conj(
|
|
||||||
int64_t bufs_num,
|
|
||||||
void** buf_data,
|
|
||||||
int64_t* buf_ranks,
|
|
||||||
int64_t* buf_dims,
|
|
||||||
int8_t* buf_dtypes,
|
|
||||||
int64_t args_num,
|
|
||||||
int64_t* extra_args) {
|
|
||||||
std::vector<at::Tensor> tensors =
|
|
||||||
constructTensors(bufs_num, buf_data, buf_ranks, buf_dims, buf_dtypes);
|
|
||||||
at::Tensor& r = tensors[0];
|
|
||||||
const at::Tensor& self = tensors[1];
|
|
||||||
try {
|
|
||||||
at::conj_out(r, self);
|
|
||||||
} catch (...) {
|
|
||||||
}
|
|
||||||
}
|
|
||||||
void nnc_aten_acos(
|
void nnc_aten_acos(
|
||||||
int64_t bufs_num,
|
int64_t bufs_num,
|
||||||
void** buf_data,
|
void** buf_data,
|
||||||
|
|
@ -2758,9 +2741,6 @@ const static RegisterNNCExternalFunction nnc_angle(
|
||||||
"nnc_aten_angle",
|
"nnc_aten_angle",
|
||||||
nnc_aten_angle);
|
nnc_aten_angle);
|
||||||
const static RegisterNNCExternalFunction nnc_sgn("nnc_aten_sgn", nnc_aten_sgn);
|
const static RegisterNNCExternalFunction nnc_sgn("nnc_aten_sgn", nnc_aten_sgn);
|
||||||
const static RegisterNNCExternalFunction nnc_conj(
|
|
||||||
"nnc_aten_conj",
|
|
||||||
nnc_aten_conj);
|
|
||||||
const static RegisterNNCExternalFunction nnc_acos(
|
const static RegisterNNCExternalFunction nnc_acos(
|
||||||
"nnc_aten_acos",
|
"nnc_aten_acos",
|
||||||
nnc_aten_acos);
|
nnc_aten_acos);
|
||||||
|
|
|
||||||
|
|
@ -228,6 +228,8 @@ def get_ignored_functions() -> Set[Callable]:
|
||||||
Tensor.to_sparse_csr,
|
Tensor.to_sparse_csr,
|
||||||
Tensor._reduce_ex_internal,
|
Tensor._reduce_ex_internal,
|
||||||
Tensor._fix_weakref,
|
Tensor._fix_weakref,
|
||||||
|
Tensor._conj,
|
||||||
|
Tensor._conj_physical,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -349,6 +351,8 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
||||||
torch.polar: lambda abs, ang: -1,
|
torch.polar: lambda abs, ang: -1,
|
||||||
torch.linalg.cond: lambda input, ord=None: -1,
|
torch.linalg.cond: lambda input, ord=None: -1,
|
||||||
torch.conj: lambda input, out=None: -1,
|
torch.conj: lambda input, out=None: -1,
|
||||||
|
torch.conj_physical: lambda input, out=None: -1,
|
||||||
|
torch.resolve_conj: lambda input, out=None: -1,
|
||||||
torch.constant_pad_nd: lambda input, pad, value=0: -1,
|
torch.constant_pad_nd: lambda input, pad, value=0: -1,
|
||||||
torch.conv1d: lambda input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1: -1,
|
torch.conv1d: lambda input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1: -1,
|
||||||
torch.conv2d: lambda input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1: -1,
|
torch.conv2d: lambda input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1: -1,
|
||||||
|
|
@ -500,6 +504,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
||||||
torch.linalg.inv: lambda input, out=None: -1,
|
torch.linalg.inv: lambda input, out=None: -1,
|
||||||
torch.linalg.inv_ex: lambda input, check_errors=False, out=None: -1,
|
torch.linalg.inv_ex: lambda input, check_errors=False, out=None: -1,
|
||||||
torch.is_complex: lambda input: -1,
|
torch.is_complex: lambda input: -1,
|
||||||
|
torch.is_conj: lambda input: -1,
|
||||||
torch.is_distributed: lambda input: -1,
|
torch.is_distributed: lambda input: -1,
|
||||||
torch.is_inference: lambda input: -1,
|
torch.is_inference: lambda input: -1,
|
||||||
torch.is_floating_point: lambda input: -1,
|
torch.is_floating_point: lambda input: -1,
|
||||||
|
|
@ -797,6 +802,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
||||||
torch.real: lambda input, out=None: -1,
|
torch.real: lambda input, out=None: -1,
|
||||||
torch.vdot: lambda input, other, out=None: -1,
|
torch.vdot: lambda input, other, out=None: -1,
|
||||||
torch.view_as_real: lambda input: -1,
|
torch.view_as_real: lambda input: -1,
|
||||||
|
torch._view_as_real_physical: lambda input: -1,
|
||||||
torch.view_as_complex: lambda input: -1,
|
torch.view_as_complex: lambda input: -1,
|
||||||
torch.reciprocal: lambda input, out=None: -1,
|
torch.reciprocal: lambda input, out=None: -1,
|
||||||
torch.relu: lambda input, inplace=False: -1,
|
torch.relu: lambda input, inplace=False: -1,
|
||||||
|
|
|
||||||
|
|
@ -134,6 +134,8 @@ def _compare_tensors_internal(a: torch.Tensor, b: torch.Tensor, *, rtol, atol, e
|
||||||
# Compares complex tensors' real and imaginary parts separately.
|
# Compares complex tensors' real and imaginary parts separately.
|
||||||
# (see NOTE Test Framework Tensor "Equality")
|
# (see NOTE Test Framework Tensor "Equality")
|
||||||
if a.is_complex():
|
if a.is_complex():
|
||||||
|
a = a.resolve_conj()
|
||||||
|
b = b.resolve_conj()
|
||||||
if equal_nan == "relaxed":
|
if equal_nan == "relaxed":
|
||||||
a = a.clone()
|
a = a.clone()
|
||||||
b = b.clone()
|
b = b.clone()
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,6 @@ import random
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from torch._six import inf
|
from torch._six import inf
|
||||||
from torch.autograd import Variable
|
|
||||||
import collections.abc
|
import collections.abc
|
||||||
|
|
||||||
from typing import List, Sequence, Tuple, Dict, Any, Union
|
from typing import List, Sequence, Tuple, Dict, Any, Union
|
||||||
|
|
@ -231,6 +230,7 @@ class OpInfo(object):
|
||||||
# function around gradcheck (testing._internal.common_utils.gradcheck)
|
# function around gradcheck (testing._internal.common_utils.gradcheck)
|
||||||
inplace_variant=_NOTHING, # explicitly pass the inplace variant of the operator if required
|
inplace_variant=_NOTHING, # explicitly pass the inplace variant of the operator if required
|
||||||
method_variant=_NOTHING, # explicitly pass the method variant of the operator if required
|
method_variant=_NOTHING, # explicitly pass the method variant of the operator if required
|
||||||
|
test_conjugated_samples=True,
|
||||||
):
|
):
|
||||||
|
|
||||||
# Validates the dtypes are generated from the dispatch-related functions
|
# Validates the dtypes are generated from the dispatch-related functions
|
||||||
|
|
@ -300,6 +300,8 @@ class OpInfo(object):
|
||||||
if aliases is not None:
|
if aliases is not None:
|
||||||
self.aliases = tuple(AliasInfo(a) for a in aliases) # type: ignore[assignment]
|
self.aliases = tuple(AliasInfo(a) for a in aliases) # type: ignore[assignment]
|
||||||
|
|
||||||
|
self.test_conjugated_samples = test_conjugated_samples
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
"""Calls the function variant of the operator."""
|
"""Calls the function variant of the operator."""
|
||||||
return self.op(*args, **kwargs)
|
return self.op(*args, **kwargs)
|
||||||
|
|
@ -326,6 +328,37 @@ class OpInfo(object):
|
||||||
"""
|
"""
|
||||||
return self.operator_variant
|
return self.operator_variant
|
||||||
|
|
||||||
|
def conjugate_sample_inputs(self, device, dtype, requires_grad=False, **kwargs):
|
||||||
|
"""Returns an iterable of SampleInputs but with the tensor input or first
|
||||||
|
tensor in a sequence input conjugated.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# TODO: Remove the try/except once all operators have sample_inputs_func with
|
||||||
|
# **kwargs in their signature.
|
||||||
|
try:
|
||||||
|
samples = self.sample_inputs_func(self, device, dtype, requires_grad, **kwargs)
|
||||||
|
except TypeError:
|
||||||
|
samples = self.sample_inputs_func(self, device, dtype, requires_grad)
|
||||||
|
|
||||||
|
conj_samples = list(samples)
|
||||||
|
|
||||||
|
def conjugate(tensor):
|
||||||
|
_requires_grad = tensor.requires_grad
|
||||||
|
with torch.no_grad():
|
||||||
|
tensor = tensor.conj()
|
||||||
|
return tensor.requires_grad_(_requires_grad)
|
||||||
|
|
||||||
|
for i in range(len(samples)):
|
||||||
|
sample = conj_samples[i]
|
||||||
|
# Note: it is assumed that the input here is either a tensor or tensorlist
|
||||||
|
if isinstance(sample.input, torch.Tensor):
|
||||||
|
sample.input = conjugate(sample.input)
|
||||||
|
else:
|
||||||
|
with torch.no_grad():
|
||||||
|
sample.input[0] = conjugate(sample.input[0])
|
||||||
|
|
||||||
|
return tuple(conj_samples)
|
||||||
|
|
||||||
def sample_inputs(self, device, dtype, requires_grad=False, **kwargs):
|
def sample_inputs(self, device, dtype, requires_grad=False, **kwargs):
|
||||||
"""Returns an iterable of SampleInputs.
|
"""Returns an iterable of SampleInputs.
|
||||||
|
|
||||||
|
|
@ -339,6 +372,13 @@ class OpInfo(object):
|
||||||
samples = self.sample_inputs_func(self, device, dtype, requires_grad, **kwargs)
|
samples = self.sample_inputs_func(self, device, dtype, requires_grad, **kwargs)
|
||||||
except TypeError:
|
except TypeError:
|
||||||
samples = self.sample_inputs_func(self, device, dtype, requires_grad)
|
samples = self.sample_inputs_func(self, device, dtype, requires_grad)
|
||||||
|
|
||||||
|
if 'include_conjugated_inputs' in kwargs and kwargs.get('include_conjugated_inputs'):
|
||||||
|
conj_samples = self.conjugate_sample_inputs(device, dtype, requires_grad, **kwargs)
|
||||||
|
samples_list = list(samples)
|
||||||
|
samples_list.extend(conj_samples)
|
||||||
|
samples = tuple(samples_list)
|
||||||
|
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
# Returns True if the test should be skipped and False otherwise
|
# Returns True if the test should be skipped and False otherwise
|
||||||
|
|
@ -1001,21 +1041,21 @@ def sample_inputs_addbmm(op_info, device, dtype, requires_grad, **kwargs):
|
||||||
return tuple(sample_inputs)
|
return tuple(sample_inputs)
|
||||||
|
|
||||||
def sample_inputs_addcmul_addcdiv(op_info, device, dtype, requires_grad, **kwargs):
|
def sample_inputs_addcmul_addcdiv(op_info, device, dtype, requires_grad, **kwargs):
|
||||||
test_cases = [((S, S), (S, S), (S, S)),
|
test_cases = [(((S, S), (S, S), (S, S)), False),
|
||||||
((S, S), (S, 1), (1, S)),
|
(((S, S), (S, 1), (1, S)), False),
|
||||||
((1,), (S, S, 1), (1, S)),
|
(((1,), (S, S, 1), (1, S)), True),
|
||||||
((), (), ()),
|
(((), (), ()), False),
|
||||||
((S, S), (), ()),
|
(((S, S), (), ()), True),
|
||||||
((), (S, S, 1), (1, S)),
|
(((), (S, S, 1), (1, S)), True)
|
||||||
]
|
]
|
||||||
|
|
||||||
sample_inputs = []
|
sample_inputs = []
|
||||||
for input_args in test_cases:
|
for input_args, broadcasts_input in test_cases:
|
||||||
args = tuple(make_tensor(arg, device, dtype, requires_grad=requires_grad) if isinstance(arg, tuple) else arg
|
args = tuple(make_tensor(arg, device, dtype, requires_grad=requires_grad) if isinstance(arg, tuple) else arg
|
||||||
for arg in input_args)
|
for arg in input_args)
|
||||||
sample_inputs.append(SampleInput(args[0], args=args[1:]))
|
sample_inputs.append(SampleInput(args[0], args=args[1:], broadcasts_input=broadcasts_input))
|
||||||
|
|
||||||
sample_inputs.append(SampleInput(args[0], args=args[1:], kwargs=dict(value=3.14)))
|
sample_inputs.append(SampleInput(args[0], args=args[1:], kwargs=dict(value=3.14), broadcasts_input=broadcasts_input))
|
||||||
|
|
||||||
return tuple(sample_inputs)
|
return tuple(sample_inputs)
|
||||||
|
|
||||||
|
|
@ -1057,7 +1097,8 @@ def sample_inputs_addr(op_info, device, dtype, requires_grad, **kwargs):
|
||||||
make_tensor((), device, dtype, low=None, high=None, requires_grad=requires_grad),
|
make_tensor((), device, dtype, low=None, high=None, requires_grad=requires_grad),
|
||||||
args=(
|
args=(
|
||||||
make_tensor((S, ), device, dtype, low=None, high=None, requires_grad=requires_grad),
|
make_tensor((S, ), device, dtype, low=None, high=None, requires_grad=requires_grad),
|
||||||
make_tensor((M, ), device, dtype, low=None, high=None, requires_grad=requires_grad)))
|
make_tensor((M, ), device, dtype, low=None, high=None, requires_grad=requires_grad)),
|
||||||
|
broadcasts_input=True)
|
||||||
|
|
||||||
if dtype.is_complex:
|
if dtype.is_complex:
|
||||||
alpha, beta = 0.1 + 0.3j, 0.4 + 0.6j
|
alpha, beta = 0.1 + 0.3j, 0.4 + 0.6j
|
||||||
|
|
@ -1078,7 +1119,8 @@ def sample_inputs_addr(op_info, device, dtype, requires_grad, **kwargs):
|
||||||
args=(
|
args=(
|
||||||
make_tensor((S, ), device, dtype, low=None, high=None, requires_grad=requires_grad),
|
make_tensor((S, ), device, dtype, low=None, high=None, requires_grad=requires_grad),
|
||||||
make_tensor((M, ), device, dtype, low=None, high=None, requires_grad=requires_grad)),
|
make_tensor((M, ), device, dtype, low=None, high=None, requires_grad=requires_grad)),
|
||||||
kwargs=dict(beta=beta, alpha=alpha))
|
kwargs=dict(beta=beta, alpha=alpha),
|
||||||
|
broadcasts_input=True)
|
||||||
|
|
||||||
return (input1, input2, input3, input4)
|
return (input1, input2, input3, input4)
|
||||||
|
|
||||||
|
|
@ -2196,36 +2238,7 @@ def np_unary_ufunc_integer_promotion_wrapper(fn):
|
||||||
|
|
||||||
return wrapped_fn
|
return wrapped_fn
|
||||||
|
|
||||||
|
def sample_inputs_spectral_ops(self, device, dtype, requires_grad=False, **kwargs):
|
||||||
# Metadata class for Fast Fourier Transforms in torch.fft.
|
|
||||||
class SpectralFuncInfo(OpInfo):
|
|
||||||
"""Operator information for torch.fft transforms. """
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
name, # the string name of the function
|
|
||||||
*,
|
|
||||||
ref=None, # Reference implementation (probably in np.fft namespace)
|
|
||||||
dtypes=floating_and_complex_types(),
|
|
||||||
ndimensional: bool, # Whether dim argument can be a tuple
|
|
||||||
decorators=None,
|
|
||||||
**kwargs):
|
|
||||||
decorators = list(decorators) if decorators is not None else []
|
|
||||||
decorators += [
|
|
||||||
skipCPUIfNoMkl,
|
|
||||||
skipCUDAIfRocm,
|
|
||||||
# gradgrad is quite slow
|
|
||||||
DecorateInfo(slowTest, 'TestGradients', 'test_fn_gradgrad'),
|
|
||||||
]
|
|
||||||
|
|
||||||
super().__init__(name=name,
|
|
||||||
dtypes=dtypes,
|
|
||||||
decorators=decorators,
|
|
||||||
**kwargs)
|
|
||||||
self.ref = ref if ref is not None else _getattr_qual(np, name)
|
|
||||||
self.ndimensional = ndimensional
|
|
||||||
|
|
||||||
|
|
||||||
def sample_inputs(self, device, dtype, requires_grad=False, **kwargs):
|
|
||||||
nd_tensor = make_tensor((S, S + 1, S + 2), device, dtype, low=None, high=None,
|
nd_tensor = make_tensor((S, S + 1, S + 2), device, dtype, low=None, high=None,
|
||||||
requires_grad=requires_grad)
|
requires_grad=requires_grad)
|
||||||
tensor = make_tensor((31,), device, dtype, low=None, high=None,
|
tensor = make_tensor((31,), device, dtype, low=None, high=None,
|
||||||
|
|
@ -2252,6 +2265,35 @@ class SpectralFuncInfo(OpInfo):
|
||||||
for dim in [-1, -2, -3]),
|
for dim in [-1, -2, -3]),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Metadata class for Fast Fourier Transforms in torch.fft.
|
||||||
|
class SpectralFuncInfo(OpInfo):
|
||||||
|
"""Operator information for torch.fft transforms. """
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
name, # the string name of the function
|
||||||
|
*,
|
||||||
|
ref=None, # Reference implementation (probably in np.fft namespace)
|
||||||
|
dtypes=floating_and_complex_types(),
|
||||||
|
ndimensional: bool, # Whether dim argument can be a tuple
|
||||||
|
sample_inputs_func=sample_inputs_spectral_ops,
|
||||||
|
decorators=None,
|
||||||
|
**kwargs):
|
||||||
|
decorators = list(decorators) if decorators is not None else []
|
||||||
|
decorators += [
|
||||||
|
skipCPUIfNoMkl,
|
||||||
|
skipCUDAIfRocm,
|
||||||
|
# gradgrad is quite slow
|
||||||
|
DecorateInfo(slowTest, 'TestGradients', 'test_fn_gradgrad'),
|
||||||
|
]
|
||||||
|
|
||||||
|
super().__init__(name=name,
|
||||||
|
dtypes=dtypes,
|
||||||
|
decorators=decorators,
|
||||||
|
sample_inputs_func=sample_inputs_func,
|
||||||
|
**kwargs)
|
||||||
|
self.ref = ref if ref is not None else _getattr_qual(np, name)
|
||||||
|
self.ndimensional = ndimensional
|
||||||
|
|
||||||
|
|
||||||
class ShapeFuncInfo(OpInfo):
|
class ShapeFuncInfo(OpInfo):
|
||||||
"""Early version of a specialized OpInfo for Shape manipulating operations like tile and roll"""
|
"""Early version of a specialized OpInfo for Shape manipulating operations like tile and roll"""
|
||||||
|
|
@ -2737,13 +2779,13 @@ def sample_inputs_pow(op_info, device, dtype, requires_grad, **kwargs):
|
||||||
test_cases = (
|
test_cases = (
|
||||||
((2, 2), 0, 5, 1e-3, requires_grad, (2, 2), 0, 1, 0.1, requires_grad, False),
|
((2, 2), 0, 5, 1e-3, requires_grad, (2, 2), 0, 1, 0.1, requires_grad, False),
|
||||||
((2, 2), 0, 5, 1e-3, requires_grad, (1,), 0, 1, 0.1, requires_grad, False),
|
((2, 2), 0, 5, 1e-3, requires_grad, (1,), 0, 1, 0.1, requires_grad, False),
|
||||||
((), 1e-3, 1e-3 + 1, 0, True, (), 0.1, 1.1, 0, False, False),
|
((), 1e-3, 1e-3 + 1, 0, requires_grad, (), 0.1, 1.1, 0, False, False),
|
||||||
((2, 2), 0, 5, 1e-3, requires_grad, (), 0.1, 1.1, 1, False, False),
|
((2, 2), 0, 5, 1e-3, requires_grad, (), 0.1, 1.1, 1, False, False),
|
||||||
)
|
)
|
||||||
tests_require_resizing = (
|
tests_require_resizing = (
|
||||||
((1,), 0, 5, 1e-3, requires_grad, (2, 2), 0, 1, 0.1, requires_grad, True),
|
((1,), 0, 5, 1e-3, requires_grad, (2, 2), 0, 1, 0.1, requires_grad, requires_grad),
|
||||||
((2, 1, 2), 0, 5, 1e-3, requires_grad, (1, 2, 1), 0, 1, 0.1, requires_grad, True),
|
((2, 1, 2), 0, 5, 1e-3, requires_grad, (1, 2, 1), 0, 1, 0.1, requires_grad, requires_grad),
|
||||||
((), 1e-3, 1e-3 + 1, 0, True, (1, S, 1), 0, 1, 0.1, requires_grad, True),
|
((), 1e-3, 1e-3 + 1, 0, requires_grad, (1, S, 1), 0, 1, 0.1, requires_grad, requires_grad),
|
||||||
)
|
)
|
||||||
cases = test_cases + tests_require_resizing
|
cases = test_cases + tests_require_resizing
|
||||||
samples = list(SampleInput(make_tensor(shape_b, low=low_b, high=high_b,
|
samples = list(SampleInput(make_tensor(shape_b, low=low_b, high=high_b,
|
||||||
|
|
@ -2757,7 +2799,7 @@ def sample_inputs_pow(op_info, device, dtype, requires_grad, **kwargs):
|
||||||
high_e, additive_e, e_grad, broadcasts_input in cases)
|
high_e, additive_e, e_grad, broadcasts_input in cases)
|
||||||
tensor_scalar_inputs = (
|
tensor_scalar_inputs = (
|
||||||
((2, 2), 0, 5, 1e-3, requires_grad, (3.14,)),
|
((2, 2), 0, 5, 1e-3, requires_grad, (3.14,)),
|
||||||
((), 1e-3, 1e-3 + 1, 0, True, (3.14,))
|
((), 1e-3, 1e-3 + 1, 0, requires_grad, (3.14,))
|
||||||
)
|
)
|
||||||
more_samples = list(SampleInput(make_tensor(shape, dtype=dtype, device=device,
|
more_samples = list(SampleInput(make_tensor(shape, dtype=dtype, device=device,
|
||||||
high=high, low=low,
|
high=high, low=low,
|
||||||
|
|
@ -2768,8 +2810,8 @@ def sample_inputs_pow(op_info, device, dtype, requires_grad, **kwargs):
|
||||||
elif dtype in [torch.complex64, torch.complex128]:
|
elif dtype in [torch.complex64, torch.complex128]:
|
||||||
args_tuple = (
|
args_tuple = (
|
||||||
((2, 2), 0, 5, requires_grad, (3.14,)),
|
((2, 2), 0, 5, requires_grad, (3.14,)),
|
||||||
((), 0, 1, True, (3.14,)),
|
((), 0, 1, requires_grad, (3.14,)),
|
||||||
((), 0, 1, True, (3.14j,))
|
((), 0, 1, requires_grad, (3.14j,))
|
||||||
)
|
)
|
||||||
samples = list(SampleInput(make_tensor(shape, dtype=dtype, device=device,
|
samples = list(SampleInput(make_tensor(shape, dtype=dtype, device=device,
|
||||||
high=high, low=low,
|
high=high, low=low,
|
||||||
|
|
@ -4298,7 +4340,7 @@ op_db: List[OpInfo] = [
|
||||||
SkipInfo('TestGradients', 'test_method_grad',
|
SkipInfo('TestGradients', 'test_method_grad',
|
||||||
device_type='cuda', dtypes=[torch.cdouble], active_if=IS_WINDOWS),
|
device_type='cuda', dtypes=[torch.cdouble], active_if=IS_WINDOWS),
|
||||||
SkipInfo('TestGradients', 'test_forward_mode_AD',
|
SkipInfo('TestGradients', 'test_forward_mode_AD',
|
||||||
dtypes=[torch.cdouble], active_if=IS_WINDOWS),
|
dtypes=[torch.cdouble]),
|
||||||
)),
|
)),
|
||||||
OpInfo('add',
|
OpInfo('add',
|
||||||
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
|
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
|
||||||
|
|
@ -4670,28 +4712,30 @@ op_db: List[OpInfo] = [
|
||||||
supports_forward_ad=True,
|
supports_forward_ad=True,
|
||||||
),
|
),
|
||||||
UnaryUfuncInfo('conj',
|
UnaryUfuncInfo('conj',
|
||||||
|
ref=np.conj,
|
||||||
|
dtypes=all_types_and_complex_and(torch.bool,
|
||||||
|
torch.bfloat16, torch.half),
|
||||||
|
supports_forward_ad=True,
|
||||||
|
supports_out=False),
|
||||||
|
UnaryUfuncInfo('conj_physical',
|
||||||
ref=np.conj,
|
ref=np.conj,
|
||||||
dtypes=all_types_and_complex_and(torch.bool,
|
dtypes=all_types_and_complex_and(torch.bool,
|
||||||
torch.bfloat16, torch.half),
|
torch.bfloat16, torch.half),
|
||||||
supports_forward_ad=True,
|
supports_forward_ad=True,
|
||||||
skips=(
|
skips=(
|
||||||
# File "test_unary_ufuncs.py", line 289, in test_reference_numerics
|
SkipInfo('TestCommon', 'test_variant_consistency_jit', dtypes=(torch.float32, )),
|
||||||
# if not torch.can_cast(numpy_to_torch_dtype_dict[expected.dtype.type], dtype):
|
|
||||||
# KeyError: <class 'numpy.intc'>
|
|
||||||
# Following error in Windows CI
|
|
||||||
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics_normal',
|
|
||||||
dtypes=[torch.int],
|
|
||||||
active_if=IS_WINDOWS),
|
|
||||||
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics_hard',
|
|
||||||
dtypes=[torch.int],
|
|
||||||
active_if=IS_WINDOWS),
|
|
||||||
# TODO fix the formula for complex forward AD
|
|
||||||
SkipInfo('TestGradients', 'test_forward_mode_AD'),
|
|
||||||
)),
|
)),
|
||||||
|
OpInfo('resolve_conj',
|
||||||
|
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
|
||||||
|
sample_inputs_func=sample_inputs_view_as_real,
|
||||||
|
supports_forward_ad=True,
|
||||||
|
supports_out=False,
|
||||||
|
),
|
||||||
OpInfo('view_as_real',
|
OpInfo('view_as_real',
|
||||||
dtypes=complex_types(),
|
dtypes=complex_types(),
|
||||||
supports_forward_ad=True,
|
supports_forward_ad=True,
|
||||||
sample_inputs_func=sample_inputs_view_as_real,
|
sample_inputs_func=sample_inputs_view_as_real,
|
||||||
|
test_conjugated_samples=False,
|
||||||
),
|
),
|
||||||
OpInfo('view_as_complex',
|
OpInfo('view_as_complex',
|
||||||
dtypes=floating_types_and(torch.half),
|
dtypes=floating_types_and(torch.half),
|
||||||
|
|
@ -5120,7 +5164,9 @@ op_db: List[OpInfo] = [
|
||||||
ref=np.imag,
|
ref=np.imag,
|
||||||
dtypes=complex_types(),
|
dtypes=complex_types(),
|
||||||
supports_out=False,
|
supports_out=False,
|
||||||
supports_autograd=False,
|
supports_forward_ad=True,
|
||||||
|
# TODO(@anjali411): Test this once neg bit is added.
|
||||||
|
test_conjugated_samples=False,
|
||||||
skips=(
|
skips=(
|
||||||
# Skip since real and imag don't have out variants.
|
# Skip since real and imag don't have out variants.
|
||||||
SkipInfo('TestUnaryUfuncs', 'test_out_arg_all_dtypes'),
|
SkipInfo('TestUnaryUfuncs', 'test_out_arg_all_dtypes'),
|
||||||
|
|
@ -5251,8 +5297,6 @@ op_db: List[OpInfo] = [
|
||||||
supports_autograd=False,
|
supports_autograd=False,
|
||||||
decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack],
|
decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack],
|
||||||
skips=(
|
skips=(
|
||||||
# skip because `linalg_lstsq` is not differentiable
|
|
||||||
SkipInfo('TestGradients', 'test_fn_grad'),
|
|
||||||
SkipInfo('TestCommon', 'test_variant_consistency_jit'),
|
SkipInfo('TestCommon', 'test_variant_consistency_jit'),
|
||||||
)),
|
)),
|
||||||
OpInfo('linalg.matrix_power',
|
OpInfo('linalg.matrix_power',
|
||||||
|
|
@ -5471,7 +5515,8 @@ op_db: List[OpInfo] = [
|
||||||
# "RuntimeError: CUDA error: CUBLAS_STATUS_NOT_SUPPORTED when
|
# "RuntimeError: CUDA error: CUBLAS_STATUS_NOT_SUPPORTED when
|
||||||
# calling cublasGemmStridedBatchedExFix."
|
# calling cublasGemmStridedBatchedExFix."
|
||||||
SkipInfo('TestOpInfo', 'test_supported_backward',
|
SkipInfo('TestOpInfo', 'test_supported_backward',
|
||||||
device_type='cuda', dtypes=(torch.bfloat16,)),)),
|
device_type='cuda', dtypes=(torch.bfloat16,)),
|
||||||
|
SkipInfo('TestCommon', 'test_conj_view', device_type='cpu'),),),
|
||||||
OpInfo('max',
|
OpInfo('max',
|
||||||
op=torch.max,
|
op=torch.max,
|
||||||
variant_test_name='binary',
|
variant_test_name='binary',
|
||||||
|
|
@ -5699,7 +5744,9 @@ op_db: List[OpInfo] = [
|
||||||
assert_autodiffed=True),
|
assert_autodiffed=True),
|
||||||
OpInfo('float_power',
|
OpInfo('float_power',
|
||||||
dtypes=all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool),
|
dtypes=all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool),
|
||||||
sample_inputs_func=sample_inputs_pow),
|
sample_inputs_func=sample_inputs_pow,
|
||||||
|
skips=(
|
||||||
|
SkipInfo('TestCommon', 'test_conj_view', device_type='cuda'),),),
|
||||||
OpInfo('prod',
|
OpInfo('prod',
|
||||||
dtypes=all_types_and_complex_and(torch.bool),
|
dtypes=all_types_and_complex_and(torch.bool),
|
||||||
dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
|
dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
|
||||||
|
|
@ -5739,7 +5786,6 @@ op_db: List[OpInfo] = [
|
||||||
ref=np.real,
|
ref=np.real,
|
||||||
dtypes=complex_types(),
|
dtypes=complex_types(),
|
||||||
supports_out=False,
|
supports_out=False,
|
||||||
supports_autograd=False,
|
|
||||||
skips=(
|
skips=(
|
||||||
# Skip since real and imag don't have out variants.
|
# Skip since real and imag don't have out variants.
|
||||||
SkipInfo('TestUnaryUfuncs', 'test_out_arg_all_dtypes'),
|
SkipInfo('TestUnaryUfuncs', 'test_out_arg_all_dtypes'),
|
||||||
|
|
@ -7068,23 +7114,26 @@ def create_input(call_args, requires_grad=True, non_contiguous=False, call_kwarg
|
||||||
def maybe_non_contig(tensor):
|
def maybe_non_contig(tensor):
|
||||||
return tensor if not non_contiguous else make_non_contiguous(tensor)
|
return tensor if not non_contiguous else make_non_contiguous(tensor)
|
||||||
|
|
||||||
|
def conjugate(tensor):
|
||||||
|
return tensor.conj()
|
||||||
|
|
||||||
if isinstance(arg, torch.Size) or isinstance(arg, dont_convert):
|
if isinstance(arg, torch.Size) or isinstance(arg, dont_convert):
|
||||||
return arg
|
return arg
|
||||||
elif isinstance(arg, tuple) and len(arg) == 0:
|
elif isinstance(arg, tuple) and len(arg) == 0:
|
||||||
var = torch.randn((), dtype=dtype, device=device)
|
var = conjugate(torch.randn((), dtype=dtype, device=device))
|
||||||
var.requires_grad = requires_grad
|
var.requires_grad = requires_grad
|
||||||
return var
|
return var
|
||||||
elif isinstance(arg, tuple) and not isinstance(arg[0], torch.Tensor):
|
elif isinstance(arg, tuple) and not isinstance(arg[0], torch.Tensor):
|
||||||
return Variable(maybe_non_contig(torch.randn(*arg, dtype=dtype, device=device)), requires_grad=requires_grad)
|
return conjugate(maybe_non_contig(torch.randn(*arg, dtype=dtype, device=device))).requires_grad_(requires_grad)
|
||||||
# double check casting
|
# double check casting
|
||||||
elif isinstance(arg, non_differentiable):
|
elif isinstance(arg, non_differentiable):
|
||||||
if isinstance(arg.tensor, torch.Tensor):
|
if isinstance(arg.tensor, torch.Tensor):
|
||||||
if arg.tensor.dtype == torch.float:
|
if arg.tensor.dtype == torch.float:
|
||||||
return maybe_non_contig(arg.tensor.to(dtype=torch.double, device=device))
|
return maybe_non_contig(arg.tensor.to(dtype=torch.double, device=device))
|
||||||
if arg.tensor.dtype == torch.cfloat:
|
if arg.tensor.dtype == torch.cfloat:
|
||||||
return maybe_non_contig(arg.tensor.to(dtype=torch.cdouble, device=device))
|
return conjugate(maybe_non_contig(arg.tensor.to(dtype=torch.cdouble, device=device)))
|
||||||
return maybe_non_contig(arg.tensor.to(device=device))
|
return conjugate(maybe_non_contig(arg.tensor.to(device=device)))
|
||||||
return maybe_non_contig(arg.tensor.to(device=device))
|
return conjugate(maybe_non_contig(arg.tensor.to(device=device)))
|
||||||
elif isinstance(arg, torch.Tensor):
|
elif isinstance(arg, torch.Tensor):
|
||||||
if arg.dtype == torch.float:
|
if arg.dtype == torch.float:
|
||||||
arg = arg.double()
|
arg = arg.double()
|
||||||
|
|
@ -7094,7 +7143,7 @@ def create_input(call_args, requires_grad=True, non_contiguous=False, call_kwarg
|
||||||
raise RuntimeError("User provided tensor is real for a test that runs with complex dtype, ",
|
raise RuntimeError("User provided tensor is real for a test that runs with complex dtype, ",
|
||||||
"which is not supported for now")
|
"which is not supported for now")
|
||||||
# NOTE: We do clone() after detach() here because we need to be able to change size/storage of v afterwards
|
# NOTE: We do clone() after detach() here because we need to be able to change size/storage of v afterwards
|
||||||
v = maybe_non_contig(arg).detach().to(device=device).clone()
|
v = conjugate(maybe_non_contig(arg)).detach().to(device=device).clone()
|
||||||
v.requires_grad = requires_grad and (v.is_floating_point() or v.is_complex())
|
v.requires_grad = requires_grad and (v.is_floating_point() or v.is_complex())
|
||||||
return v
|
return v
|
||||||
elif callable(arg):
|
elif callable(arg):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user