import functools from typing import Any, Callable, Dict, Sequence import torch import torch._prims import torch._refs import torch._refs.nn import torch._refs.nn.functional import torch._refs.special import torch.overrides from torch._prims_common import torch_function_passthrough @functools.lru_cache(None) def torch_to_refs_map(): """ Mapping of torch API functions to torch._refs functions. E.g. torch_to_refs_map()[torch.add] == torch._refs.add """ modules = [ (torch, torch._refs), (torch.nn, torch._refs.nn), (torch.nn.functional, torch._refs.nn.functional), (torch.special, torch._refs.special), (torch.fft, torch._refs.fft), (torch.linalg, torch._refs.linalg), ] r: Dict[Any, Any] = { torch.Tensor.__invert__: torch._refs.bitwise_not, torch.Tensor.__xor__: torch._refs.bitwise_xor, torch.Tensor.__and__: torch._refs.bitwise_and, torch.Tensor.__or__: torch._refs.bitwise_or, torch.Tensor.__eq__: torch._refs.eq, torch.Tensor.new_empty: torch._refs.new_empty, torch.Tensor.new_full: torch._refs.new_full, torch.Tensor.new_zeros: torch._refs.new_zeros, torch.Tensor.new_ones: torch._refs.new_ones, torch.Tensor.fill_: torch._refs.fill_, torch.Tensor.zero_: torch._refs.zero_, # TODO: Should these methods be mapped some other way? torch.Tensor.copy_: torch._prims.copy_to, torch.Tensor.resize: torch._prims.resize, } for mod_torch, mod_refs in modules: for s in mod_refs.__all__: # type: ignore[attr-defined] r[mod_torch.__dict__.get(s)] = mod_refs.__dict__.get(s) # Support remapping torch.Tensor.foo to _refs.foo for s in dir(torch.Tensor): if s in torch._refs.__all__: r[getattr(torch.Tensor, s)] = torch._refs.__dict__.get(s) return r @functools.lru_cache(None) def all_prims(): """ Set of all prim functions, e.g., torch._prims.add in all_prims() """ return {torch._prims.__dict__.get(s) for s in torch._prims.__all__} class TorchRefsMode(torch.overrides.TorchFunctionMode): """ Switches the interpretation of torch.* functions and Tensor methods to use PrimTorch refs in torch._refs. (Direct calls to _refs are unaffected.) >>> with TorchRefsMode.push(): ... torch.add(x, y) # calls torch._refs.add(x, y) By default, this context manager will fall back on the torch.* if the ref does not exist; set strict=True to error if this occurs. """ def __init__(self, strict=False): self.strict = strict def __torch_function__( self, orig_func: Callable, types: Sequence, args: Sequence[Any] = (), kwargs: Dict = None, ): if kwargs is None: kwargs = {} # For primitive operations, run them as is without interception if orig_func in torch_function_passthrough or orig_func in all_prims(): return orig_func(*args, **kwargs) mapping = torch_to_refs_map() func = mapping.get(orig_func, None) if func is not None: # torch calls inside func should be interpreted as refs calls with torch.overrides.enable_torch_function_mode(self, replace=self.inner): return func(*args, **kwargs) if self.strict: raise RuntimeError( f"no _refs support for {torch.overrides.resolve_name(orig_func)}" ) return orig_func(*args, **kwargs)