Commit Graph

57 Commits

Author SHA1 Message Date
Edward Z. Yang
0a14a4c280 Register prims as operators.
This makes prims look as if they were defined in native_functions.yaml
but they're still all written in Python.  You now need to give a full
schema string for your prims.  The returned prim object is now
torch.ops.prim overload (prims are not allowed to be overloaded,
so we return the overload, not the overload packet, for speed.)

Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77117

Approved by: https://github.com/mruberry, https://github.com/albanD
2022-05-11 16:38:14 +00:00
Edward Z. Yang
f2eed9400d Register PrimTorch refs as decompositions.
For the most part, PrimTorch refs have the same signature as their
ATen equivalents.  I modify most PrimTorch refs to register themselves
as decompositions, using the prim name they wrap to find the aten name
(except for a few cases where the prim/aten names mismatch).  There are
some exclusions, falling into one of two categories:

- The torch equivalent was already implemented as a CompositeImplicitAutograd
  decomposition in C++

- The ref doesn't support enough features (e.g., the real deal has more
  kwargs / overloads than are currently implemented)

PrimTorch refs are written as a single function that supports all
overloads, and this style is convenient for cases where we have a bundle
of overloads for what morally is a single overload with a Union type
on an argument (which we ought to have supported in
native_functions.yaml but blah); to support registering a single decomp
for all the overloads, we modify register_decomposition to register
to ALL overloads if you pass it an overload packet.  This is technically
BC breaking but no tests started failing because of it.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76835

Approved by: https://github.com/Chillee, https://github.com/mruberry
2022-05-06 20:11:45 +00:00
Horace He
6917034afb Added logit/reciprocal decomps, fixed var for complex, moved type promotion logic to standardize on primtorch's
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76633
Approved by: https://github.com/ezyang
2022-05-04 21:29:52 +00:00
Horace He
ed18181d83 Added gelu decomposition
^
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76763
Approved by: https://github.com/ezyang
2022-05-03 23:23:18 +00:00
samdow
598e7e5f19 [Reland] Change 'python mode' to 'torch dispatch mode'
Changes Python Mode name to Torch Dispatch Mode because there is now a Torch Function Mode, so Torch Dispatch Mode and Torch Function Mode are consistent with each other
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76562
Approved by: https://github.com/zou3519, https://github.com/albanD
2022-05-02 20:06:43 +00:00
Horace He
fb24614011 Port functorch decomps over and fix some tests
Still some stuff to fix up, will finish later.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76621
Approved by: https://github.com/ezyang
2022-05-01 08:48:48 +00:00
Edward Z. Yang
a3f10ec281 Move functorch decompositions to PyTorch
Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76311

Approved by: https://github.com/Chillee
2022-04-30 16:47:53 +00:00