pytorch/torch/_subclasses
Richard Zou e6f9bc500b CustomOp simple abstract implementation registration (#99439)
This PR:
- adds an abstract registration API for CustomOp (CustomOp.impl_abstract)
that is used for both FakeTensor and meta tensors
- deletes CustomOp.impl_meta

The user story behind this API is that it is the one-stop shop for
registering implementations for data-less Tensors, i.e. FakeTensor and
Meta tensor.

The abstract implementation provided by the user:
- gets registered as the FakeTensor implementation AND the meta formula
- can be written like a regular meta formula. If the user decides that
they need something more special (i.e. data-dependent output shape),
then they are able to query a current context object (FakeTensorImplCtx)
that has methods to construct new unbacked symints.

Caveats:
- we really need to make FakeTensor/FakeTensorMode public. Otherwise,
there isn't a way for the user to interactively test that their abstract
implementation is correct without running through large pieces of the
PT2 stack (make_fx or torch.compile).
- We do not memoize the symints produced by
ctx.create_unbacked_symint(). It is possible to do this in the
future, but it is difficult to do soundly and I am not convinced of
the utility outside of the nonzero() usecase mentioned in #95399

Public API:
- More docs will come when we actually expose this API to users by
putting it in a public namespace, unless you folks want it now.
- The APIs mentioned in `__all__` are the ones that are intended to be
public.

Test Plan:
- Updated existing custom_op_db operators
- Added new numpy_nonzero and numpy_nms operations that test operations
that have data-dependendent output shape.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99439
Approved by: https://github.com/ezyang
2023-04-28 13:45:39 +00:00
..
__init__.py Add Fake Cross Ref Mode, migrate sparse to it (#85382) 2022-09-21 17:15:47 +00:00
fake_tensor.py CustomOp simple abstract implementation registration (#99439) 2023-04-28 13:45:39 +00:00
fake_utils.py [MPS] Add group_norm[fwd+backward] and mean_var (take 2) (#91190) 2022-12-22 08:54:37 +00:00
meta_utils.py Don't detach to create parameters in MetaConverter (#99618) 2023-04-24 19:01:26 +00:00
schema_check_mode.py Get SchemaCheckMode to error on ops that return inputs directly. Expose as a dynamo backend, eager_debug (#99744) 2023-04-27 20:12:42 +00:00