mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/54987 Based off of ezyang (https://github.com/pytorch/pytorch/pull/44799) and bdhirsh (https://github.com/pytorch/pytorch/pull/43702) 's prototype: Here's a summary of the changes in this PR: This PR adds a new dispatch key called Conjugate. This enables us to make conjugate operation a view and leverage the specialized library functions that fast path with the hermitian operation (conj + transpose). 1. Conjugate operation will now return a view with conj bit (1) for complex tensors and returns self for non-complex tensors as before. This also means `torch.view_as_real` will no longer be a view on conjugated complex tensors and is hence disabled. To fill the gap, we have added `torch.view_as_real_physical` which would return the real tensor agnostic of the conjugate bit on the input complex tensor. The information about conjugation on the old tensor can be obtained by calling `.is_conj()` on the new tensor. 2. NEW API: a) `.conj()` -- now returning a view. b) `.conj_physical()` -- does the physical conjugate operation. If the conj bit for input was set, you'd get `self.clone()`, else you'll get a new tensor with conjugated value in its memory. c) `.conj_physical_()`, and `out=` variant d) `.resolve_conj()` -- materializes the conjugation. returns self if the conj bit is unset, else returns a new tensor with conjugated values and conj bit set to 0. e) `.resolve_conj_()` in-place version of (d) f) `view_as_real_physical` -- as described in (1), it's functionally same as `view_as_real`, just that it doesn't error out on conjugated tensors. g) `view_as_real` -- existing function, but now errors out on conjugated tensors. 3. Conjugate Fallback a) Vast majority of PyTorch functions would currently use this fallback when they are called on a conjugated tensor. b) This fallback is well equipped to handle the following cases: - functional operation e.g., `torch.sin(input)` - Mutable inputs and in-place operations e.g., `tensor.add_(2)` - out-of-place operation e.g., `torch.sin(input, out=out)` - Tensorlist input args - NOTE: Meta tensors don't work with conjugate fallback. 4. Autograd a) `resolve_conj()` is an identity function w.r.t. autograd b) Everything else works as expected. 5. Testing: a) All method_tests run with conjugate view tensors. b) OpInfo tests that run with conjugate views - test_variant_consistency_eager/jit - gradcheck, gradgradcheck - test_conj_views (that only run for `torch.cfloat` dtype) NOTE: functions like `empty_like`, `zero_like`, `randn_like`, `clone` don't propagate the conjugate bit. Follow up work: 1. conjugate view RFC 2. Add neg bit to re-enable view operation on conjugated tensors 3. Update linalg functions to call into specialized functions that fast path with the hermitian operation. Test Plan: Imported from OSS Reviewed By: VitalyFedyunin Differential Revision: D28227315 Pulled By: anjali411 fbshipit-source-id: acab9402b9d6a970c6d512809b627a290c8def5f |
||
|---|---|---|
| .. | ||
| _static | ||
| _templates | ||
| community | ||
| elastic | ||
| notes | ||
| rpc | ||
| scripts | ||
| __config__.rst | ||
| amp.rst | ||
| autograd.rst | ||
| backends.rst | ||
| benchmark_utils.rst | ||
| bottleneck.rst | ||
| checkpoint.rst | ||
| complex_numbers.rst | ||
| conf.py | ||
| cpp_extension.rst | ||
| cpp_index.rst | ||
| cuda.rst | ||
| cudnn_persistent_rnn.rst | ||
| cudnn_rnn_determinism.rst | ||
| data.rst | ||
| ddp_comm_hooks.rst | ||
| distributed.elastic.rst | ||
| distributed.optim.rst | ||
| distributed.rst | ||
| distributions.rst | ||
| dlpack.rst | ||
| docutils.conf | ||
| fft.rst | ||
| futures.rst | ||
| fx.rst | ||
| hub.rst | ||
| index.rst | ||
| jit_builtin_functions.rst | ||
| jit_language_reference_v2.rst | ||
| jit_language_reference.rst | ||
| jit_python_reference.rst | ||
| jit_unsupported.rst | ||
| jit.rst | ||
| linalg.rst | ||
| math-quantizer-equation.png | ||
| mobile_optimizer.rst | ||
| model_zoo.rst | ||
| multiprocessing.rst | ||
| name_inference.rst | ||
| named_tensor.rst | ||
| nn.functional.rst | ||
| nn.init.rst | ||
| nn.rst | ||
| onnx.rst | ||
| optim.rst | ||
| package.rst | ||
| pipeline.rst | ||
| profiler.rst | ||
| quantization-support.rst | ||
| quantization.rst | ||
| random.rst | ||
| rpc.rst | ||
| sparse.rst | ||
| special.rst | ||
| storage.rst | ||
| tensor_attributes.rst | ||
| tensor_view.rst | ||
| tensorboard.rst | ||
| tensors.rst | ||
| testing.rst | ||
| torch.nn.intrinsic.qat.rst | ||
| torch.nn.intrinsic.quantized.rst | ||
| torch.nn.intrinsic.rst | ||
| torch.nn.qat.rst | ||
| torch.nn.quantized.dynamic.rst | ||
| torch.nn.quantized.rst | ||
| torch.overrides.rst | ||
| torch.quantization.rst | ||
| torch.rst | ||
| type_info.rst | ||