mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Changelog: - torch.library.impl_abstract optionally accepts a torch.library.Library object. If passed in, then the lifetime of the registration is tied to the Library object. - we've also changed torch.library.impl_abstract to work on all operators, including overloads. - we refactored the `torch._custom_ops.*` and `torch._custom_op.*` impl_abstract APIs and put them under torch._library. This is the final resting place for them. I will follow-up with deleting all the `torch._custom_ops.*` stuff later. - There is a new "SimpleOperatorRegistry" where we actually collect the abstract_impl. We will expand this to also hold the other torch._custom_ops.* APIs when we move those to torch.library NB: Previously we had designed `impl_abstract` assuming a very high-level Python-only custom op API. We've revisited that since; now, impl_abstract works for all custom ops, no matter python or C++, no matter the schema. The new refactored design reflects this better. Test Plan: - existing and new tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/109912 Approved by: https://github.com/ezyang
44 lines
1.4 KiB
Python
44 lines
1.4 KiB
Python
from .abstract_impl import AbstractImplHolder
|
|
|
|
__all__ = ["SimpleLibraryRegistry", "SimpleOperatorEntry", "singleton"]
|
|
|
|
|
|
class SimpleLibraryRegistry:
|
|
"""Registry for the "simple" torch.library APIs
|
|
|
|
The "simple" torch.library APIs are a higher-level API on top of the
|
|
raw PyTorch DispatchKey registration APIs that includes:
|
|
- abstract impl
|
|
|
|
Registrations for these APIs do not go into the PyTorch dispatcher's
|
|
table because they may not directly involve a DispatchKey. For example,
|
|
the abstract impl is a Python function that gets invoked by FakeTensor.
|
|
Instead, we manage them here.
|
|
|
|
SimpleLibraryRegistry is a mapping from a fully qualified operator name
|
|
(including the overload) to SimpleOperatorEntry.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self._data = {}
|
|
|
|
def find(self, qualname: str) -> "SimpleOperatorEntry":
|
|
if qualname not in self._data:
|
|
self._data[qualname] = SimpleOperatorEntry(qualname)
|
|
return self._data[qualname]
|
|
|
|
|
|
singleton: SimpleLibraryRegistry = SimpleLibraryRegistry()
|
|
|
|
|
|
class SimpleOperatorEntry:
|
|
"""This is 1:1 to an operator overload.
|
|
|
|
The fields of SimpleOperatorEntry are Holders where kernels can be
|
|
registered to.
|
|
"""
|
|
|
|
def __init__(self, qualname: str):
|
|
self.qualname: str = qualname
|
|
self.abstract_impl: AbstractImplHolder = AbstractImplHolder(qualname)
|