mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
torch.empty_permuted is a generalized version of torch.empty(memory_format=...), where you can pass an arbitrary physical layout as a tuple of dims to allow you to setup dense, non-overlapping tensors with non-standard memory format. Check the docblock for a full description of semantics. The initial motivation for this PR is with guard-less unbacked SymInts. Traditionally, the way we allocate dense tensors with arbitrary layout is with `empty_strided`. However, `empty_strided` does not know that the given strides are actually contiguous, and must test this manually to find out if it is the case. With `empty_permuted`, this is known statically to be the case and helps us skip some 0/1 guards. However, I also think torch.empty_permuted is a useful API in its own right. It is technically possible to simulate this with an empty and a permute; however, there are some downsides: * The manual incant is tricky to work out. To allocate an NHWC tensor, the invocation is `torch.empty(N, H, W, C).permute(0, 3, 1, 2)`; the permute call has to take NHWC to NCHW, and is the *inverse* of the permutation people are typically thinking of when they talk about NHWC (0, 2, 3, 1). Instead, torch.empty_permuted lets you say `torch.empty_permuted((N, C, H, W), (0, 2, 3, 1))`, letting you provide the intuitive permutation. It can be literally be read off as NHWC if you assign N=0, C=1, H=2, W=3. * An empty(requires_grad=True).permute() is no longer a leaf tensor. You can force it to be a leaf with a detach(), but it is more straightforward and less error prone to allow directly allocating a tensor with the correct permutation. It is also technically possible to simulate this with empty_strided. However, this requires the user to manually compute the contiguous output strides and is bad from a reduction of guards perspective. For what it's worth, this is one of the more common uses of as_strided in the wild, and it would be nice to get rid of it. A nice enhancement of this feature would be to accept `physical_layout` anywhere `memory_format` is accepted. However, this would be a pretty involved change, so I'm doing the easy thing instead. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/95069 Approved by: https://github.com/malfet, https://github.com/ngimel, https://github.com/albanD, https://github.com/dagitses
77 lines
2.2 KiB
Python
77 lines
2.2 KiB
Python
import torch
|
|
from torch.overrides import TorchFunctionMode
|
|
from torch.utils._contextlib import context_decorator
|
|
import functools
|
|
|
|
@functools.lru_cache(1)
|
|
def _device_constructors():
|
|
return {
|
|
# standard ones
|
|
torch.empty,
|
|
torch.empty_permuted,
|
|
torch.empty_strided,
|
|
torch.empty_quantized,
|
|
torch.ones,
|
|
torch.arange,
|
|
torch.bartlett_window,
|
|
torch.blackman_window,
|
|
torch.eye,
|
|
torch.fft.fftfreq,
|
|
torch.fft.rfftfreq,
|
|
torch.full,
|
|
torch.fill,
|
|
torch.hamming_window,
|
|
torch.hann_window,
|
|
torch.kaiser_window,
|
|
torch.linspace,
|
|
torch.logspace,
|
|
torch.nested.nested_tensor,
|
|
# This function doesn't actually take a device argument
|
|
# torch.normal,
|
|
torch.ones,
|
|
torch.rand,
|
|
torch.randn,
|
|
torch.randint,
|
|
torch.randperm,
|
|
torch.range,
|
|
torch.sparse_coo_tensor,
|
|
torch.sparse_compressed_tensor,
|
|
torch.sparse_csr_tensor,
|
|
torch.sparse_csc_tensor,
|
|
torch.sparse_bsr_tensor,
|
|
torch.sparse_bsc_tensor,
|
|
torch.tril_indices,
|
|
torch.triu_indices,
|
|
torch.vander,
|
|
torch.zeros,
|
|
torch.asarray,
|
|
# weird ones
|
|
torch.tensor,
|
|
torch.as_tensor,
|
|
torch.scalar_tensor,
|
|
}
|
|
|
|
# NB: This is directly called from C++ in torch/csrc/Device.cpp
|
|
class DeviceContext(TorchFunctionMode):
|
|
def __init__(self, device):
|
|
self.device = torch.device(device)
|
|
|
|
def __torch_function__(self, func, types, args=(), kwargs=None):
|
|
kwargs = kwargs or {}
|
|
if func in _device_constructors() and kwargs.get('device') is None:
|
|
kwargs['device'] = self.device
|
|
return func(*args, **kwargs)
|
|
|
|
# NB: This is directly called from C++ in torch/csrc/Device.cpp
|
|
def device_decorator(device, func):
|
|
return context_decorator(lambda: device, func)
|
|
|
|
def set_device(device):
|
|
"""
|
|
Decorator which sets the default device inside of the wrapped
|
|
function. If you would like to use this as a context manager,
|
|
use device as a context manager directly, e.g.,
|
|
``with torch.device(device)``.
|
|
"""
|
|
return lambda func: device_decorator(torch.device(device), func)
|