diff --git a/torch/_library/simple_registry.py b/torch/_library/simple_registry.py index bf25cde9cb5..1f11914e8e9 100644 --- a/torch/_library/simple_registry.py +++ b/torch/_library/simple_registry.py @@ -1,5 +1,4 @@ -# mypy: allow-untyped-defs -from typing import Callable, Optional +from typing import Any, Callable, Optional from .fake_impl import FakeImplHolder from .utils import RegistrationHandle @@ -24,8 +23,8 @@ class SimpleLibraryRegistry: (including the overload) to SimpleOperatorEntry. """ - def __init__(self): - self._data = {} + def __init__(self) -> None: + self._data: dict[str, SimpleOperatorEntry] = {} def find(self, qualname: str) -> "SimpleOperatorEntry": res = self._data.get(qualname, None) @@ -44,7 +43,7 @@ class SimpleOperatorEntry: registered to. """ - def __init__(self, qualname: str): + def __init__(self, qualname: str) -> None: self.qualname: str = qualname self.fake_impl: FakeImplHolder = FakeImplHolder(qualname) self.torch_dispatch_rules: GenericTorchDispatchRuleHolder = ( @@ -53,17 +52,17 @@ class SimpleOperatorEntry: # For compatibility reasons. We can delete this soon. @property - def abstract_impl(self): + def abstract_impl(self) -> FakeImplHolder: return self.fake_impl class GenericTorchDispatchRuleHolder: - def __init__(self, qualname): - self._data = {} - self.qualname = qualname + def __init__(self, qualname: str) -> None: + self._data: dict[type, Callable[..., Any]] = {} + self.qualname: str = qualname def register( - self, torch_dispatch_class: type, func: Callable + self, torch_dispatch_class: type, func: Callable[..., Any] ) -> RegistrationHandle: if self.find(torch_dispatch_class): raise RuntimeError( @@ -71,16 +70,18 @@ class GenericTorchDispatchRuleHolder: ) self._data[torch_dispatch_class] = func - def deregister(): + def deregister() -> None: del self._data[torch_dispatch_class] return RegistrationHandle(deregister) - def find(self, torch_dispatch_class): + def find(self, torch_dispatch_class: type) -> Optional[Callable[..., Any]]: return self._data.get(torch_dispatch_class, None) -def find_torch_dispatch_rule(op, torch_dispatch_class: type) -> Optional[Callable]: +def find_torch_dispatch_rule( + op: Any, torch_dispatch_class: type +) -> Optional[Callable[..., Any]]: return singleton.find(op.__qualname__).torch_dispatch_rules.find( torch_dispatch_class )