mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add __torch_function__ support for generated tensor methods/property of PrivateUse1 (#121723)
support following case:
```python
import torch
...
class CustomFooTensor(torch.Tensor):
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
...
a = CustomFooTensor([3])
print(a.is_foo)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121723
Approved by: https://github.com/albanD
This commit is contained in:
parent
19850d770d
commit
f8f7cfbeee
|
|
@ -1433,6 +1433,11 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
|||
torch.linalg.lstsq: lambda self, b, cond=None, driver=None: -1,
|
||||
}
|
||||
|
||||
privateuse1_backend_name = torch.utils.backend_registration._privateuse1_backend_name
|
||||
if hasattr(Tensor, privateuse1_backend_name):
|
||||
ret[getattr(Tensor, privateuse1_backend_name)] = lambda self, device=None, non_blocking=False, **kwargs: -1
|
||||
ret[getattr(Tensor, f'is_{privateuse1_backend_name}').__get__] = lambda self: -1 # noqa: B009
|
||||
|
||||
ret2 = {}
|
||||
ignored = get_ignored_functions()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,8 @@
|
|||
import torch
|
||||
from torch.overrides import (
|
||||
handle_torch_function,
|
||||
has_torch_function_unary,
|
||||
)
|
||||
from torch._C import _rename_privateuse1_backend, _get_privateuse1_backend_name
|
||||
from typing import List, Optional, Union
|
||||
|
||||
|
|
@ -126,9 +130,13 @@ def _normalization_device(custom_backend_name: str, device: Optional[Union[int,
|
|||
def _generate_tensor_methods_for_privateuse1_backend(custom_backend_name: str) -> None:
|
||||
@property # type: ignore[misc]
|
||||
def wrap_tensor_backend(self: torch.Tensor) -> bool:
|
||||
if has_torch_function_unary(self):
|
||||
# TODO mypy doesn't support @property, see: https://github.com/python/mypy/issues/6185
|
||||
return handle_torch_function(wrap_tensor_backend.__get__, (self,), self) # type: ignore[attr-defined]
|
||||
return self.device.type == custom_backend_name
|
||||
|
||||
_check_register_once(torch.Tensor, f'is_{custom_backend_name}')
|
||||
wrap_tensor_backend.fget.__name__ = f'is_{custom_backend_name}' # type: ignore[attr-defined]
|
||||
setattr(torch.Tensor, f'is_{custom_backend_name}', wrap_tensor_backend)
|
||||
|
||||
def wrap_tensor_to(self: torch.Tensor, device: Optional[Union[int, torch.device]] = None, non_blocking=False,
|
||||
|
|
@ -147,10 +155,13 @@ def _generate_tensor_methods_for_privateuse1_backend(custom_backend_name: str) -
|
|||
the argument has no effect.
|
||||
**kwargs (dict): For compatibility, may contain the key ``memory_format`` argument.
|
||||
"""
|
||||
if has_torch_function_unary(self):
|
||||
return handle_torch_function(wrap_tensor_to, (self,), self, device=device, non_blocking=False, **kwargs)
|
||||
device_idx = _normalization_device(custom_backend_name, device)
|
||||
return self.to(device=torch.device(f'{custom_backend_name}:{device_idx}'), non_blocking=non_blocking, **kwargs)
|
||||
|
||||
_check_register_once(torch.Tensor, custom_backend_name)
|
||||
wrap_tensor_to.__name__ = custom_backend_name
|
||||
setattr(torch.Tensor, custom_backend_name, wrap_tensor_to)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user