Add __torch_function__ support for generated tensor methods/property of PrivateUse1 (#121723)

support following case:
```python
import torch
...
class CustomFooTensor(torch.Tensor):
  @classmethod
  def __torch_function__(cls, func, types, args=(), kwargs=None):
    ...
a = CustomFooTensor([3])
print(a.is_foo)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121723
Approved by: https://github.com/albanD
This commit is contained in:
cdzhan 2024-04-19 22:34:25 +00:00 committed by PyTorch MergeBot
parent 19850d770d
commit f8f7cfbeee
2 changed files with 16 additions and 0 deletions

View File

@ -1433,6 +1433,11 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.linalg.lstsq: lambda self, b, cond=None, driver=None: -1,
}
privateuse1_backend_name = torch.utils.backend_registration._privateuse1_backend_name
if hasattr(Tensor, privateuse1_backend_name):
ret[getattr(Tensor, privateuse1_backend_name)] = lambda self, device=None, non_blocking=False, **kwargs: -1
ret[getattr(Tensor, f'is_{privateuse1_backend_name}').__get__] = lambda self: -1 # noqa: B009
ret2 = {}
ignored = get_ignored_functions()

View File

@ -1,4 +1,8 @@
import torch
from torch.overrides import (
handle_torch_function,
has_torch_function_unary,
)
from torch._C import _rename_privateuse1_backend, _get_privateuse1_backend_name
from typing import List, Optional, Union
@ -126,9 +130,13 @@ def _normalization_device(custom_backend_name: str, device: Optional[Union[int,
def _generate_tensor_methods_for_privateuse1_backend(custom_backend_name: str) -> None:
@property # type: ignore[misc]
def wrap_tensor_backend(self: torch.Tensor) -> bool:
if has_torch_function_unary(self):
# TODO mypy doesn't support @property, see: https://github.com/python/mypy/issues/6185
return handle_torch_function(wrap_tensor_backend.__get__, (self,), self) # type: ignore[attr-defined]
return self.device.type == custom_backend_name
_check_register_once(torch.Tensor, f'is_{custom_backend_name}')
wrap_tensor_backend.fget.__name__ = f'is_{custom_backend_name}' # type: ignore[attr-defined]
setattr(torch.Tensor, f'is_{custom_backend_name}', wrap_tensor_backend)
def wrap_tensor_to(self: torch.Tensor, device: Optional[Union[int, torch.device]] = None, non_blocking=False,
@ -147,10 +155,13 @@ def _generate_tensor_methods_for_privateuse1_backend(custom_backend_name: str) -
the argument has no effect.
**kwargs (dict): For compatibility, may contain the key ``memory_format`` argument.
"""
if has_torch_function_unary(self):
return handle_torch_function(wrap_tensor_to, (self,), self, device=device, non_blocking=False, **kwargs)
device_idx = _normalization_device(custom_backend_name, device)
return self.to(device=torch.device(f'{custom_backend_name}:{device_idx}'), non_blocking=non_blocking, **kwargs)
_check_register_once(torch.Tensor, custom_backend_name)
wrap_tensor_to.__name__ = custom_backend_name
setattr(torch.Tensor, custom_backend_name, wrap_tensor_to)