Commit Graph

53 Commits

Author SHA1 Message Date
vasiliy
d19791e4cd add autocast keys to pybind11 DispatchKey object (#90821)
Summary:

This is useful for debugging what autocast is doing when
it's running on top of torchdynamo, without this the Python dispatch
key for autocast prints as `???`.

Test Plan:

```
import torch
dir(torch._C.DispatchKey)
// the autocast keys show up now
```

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90821
Approved by: https://github.com/ezyang
2022-12-15 00:15:07 +00:00
Edward Z. Yang
5266953443 Add crossref debug mode for functionalization, catches stride errors (#89498)
The idea is to add a custom handler to Functionalize key in Python
dispatcher that runs the functionalized version along side a non
functionalized version, and checks that their outputs agree in the
end.  (Technically, for metadata mutation we should also check the
inputs, but for now we're relying on those functions returning self.)
I turned this on for test_functionalize.py (new TestCrossRefFunctionalize)
and found a bunch of failures that look legit.

This probably doesn't interact that nicely if you're also tracing at
the same time, probably need more special logic for that (directly,
just disabling tracing for when we create the nested fake tensor mode,
but IDK if there's a more principled way to organize this.)

There are some misc fixups which I can split if people really want.

- xfail_inherited_tests moved to test common_utils
- Bindings for _dispatch_tls_set_dispatch_key_included,
  _dispatch_tls_is_dispatch_key_included and _functionalization_reapply_views_tls
- Type stubs for _enable_functionalization, _disable_functionalization
- all_known_overloads utility to let you iterate over all OpOverloads
  in all namespaces.  Iterator support on all torch._ops objects to let
  you iterate over their members.
- suspend_functionalization lets you temporarily disable functionalization mode
  in a context
- check_metadata_matches for easily comparing outputs of functions and see
  if they match (TODO: there are a few copies of this logic, consolidate!)
- _fmt for easily printing the metadata of a tensor without its data
- _uncache_dispatch for removing a particular dispatch key from the cache,
  so that we force it to regenerate
- check_significant_strides new kwarg only_cuda to let you also do stride
  test even when inputs are not CUDA
- Functionalize in torch._C.DispatchKey

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89498
Approved by: https://github.com/malfet
2022-11-23 04:18:25 +00:00
Edward Z. Yang
57ed94804e Bind DispatchKey.Functionalonalize in pybind11 (#89452)
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89452
Approved by: https://github.com/albanD, https://github.com/bdhirsh
2022-11-22 00:32:30 +00:00
zhxchen17
c3938bb97a [functorch] introduce an experimental map() op. (#88767)
Summary:
We want to introduce an experimental control flow op: map() to export some models as FX graphs correctly.

Some calrification on basic requirements we have in mind:
1. This op can nest cond() and other control flow primitives internally.
2. We don't necessarily need loop carried dependencies for the models we've seen.
3. This map() op can handle dynamically shaped tensor as input and return dynamically shaped output based on input shapes.
4. We should be able to pass through additional arguments to the loop body as extra arguments.

In this diff we introduce a new control flow op `map()` which has the following semantics:
```
def map(f: Callable, xs: Tensor, *args):
    # one possible implementation:
    return torch.stack([f(x, *args) for x in xs])
```

Test Plan:
pytest functorch/test_control_flow.py
CI

Differential Revision: D41165796

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88767
Approved by: https://github.com/zou3519
2022-11-19 00:19:50 +00:00
Richard Zou
3bc327993f PyDispatcher integration with functorch (#88785)
This PR teaches PyDispatcher and PyOperator about functorch transforms.
It is important that PyDispatcher/PyOperator dispatch with functorch
transforms, because this is our plan for higher-order operators
(operators that accept functions as arguments). Examples of these
include:
- functorch transforms over the existing cond operator (control flow)
- autograd.Function support for functorch (which I am working towards),
- AOTDispatcher (should be a higher order operator)

Concretely, the problem with teaching PyDispatcher/PyOperator about
functorch is that the stack-based dispatching logic (DynamicLayerStack)
is hidden inside the fallbacks for two dispatch keys
(DynamicLayer{Front, Back}). PyDispatcher doesn't know about C++ boxed
fallbacks, our plan on record for that is that we need to reimplement
all of them in Python (but can call helper functions in C++ to make our
lives easier).

Instead of exposing all of what DynamicLayer{Front, Back} do to python,
this PR takes the approach of re-implementing part of the stack-based
dispatching in Python. The motivation is that this is more sane and
follows what the "ideal" implementation of functorch would have been:
- each transform should be a "mode"
- there should be no TLS dispatch key set hackery. functorch needs to do
this hackery today to re-use VariableType implementations.

This PR:
- exposes the DynamicLayerStack to Python
- The DynamicLayerStack is a stack of Interpreters.
These get exposed to Python as well.
- Interpreters can run operations (Interpreter.process) or lower them to
the next interpreter in the stack (Interpreter.lower)
- To use a PyOperator with functorch transforms, a developer needs to
register a rule for each transform (vmap, grad, jvp, ...).
- The PyOperator API is NOT user-facing. Things like autograd.Function
support for functorch will end up going through the autograd.Function
API.

Question for reviewers:
- Does this design make sense?
- I'm trying to split up the "functorch support for autograd.Function"
work into logical pieces. Would it be better if I didn't? (the full
thing is a bit long - 1000-2000 LOC).

Test Plan:
- new tests that construct PyOperator and compose them with functorch
transforms
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88785
Approved by: https://github.com/samdow, https://github.com/soulitzer
2022-11-16 00:46:59 +00:00
Edward Z. Yang
f884e817d4 Make Python op registration work with torchdeploy/multipy (#87162)
See strategy at PythonOpRegistrationTrampoline.cpp for the
big picture.

Along the way, I made OperatorHandle support == and hashing,
and slightly changed the low level python_dispatch impl API
to disallow empty strings for dispatch key, which had the knock
on effect of requiring us to explicitly make sure we pass in
CompositeImplicitAutograd if we would have passed in "" (I didn't apply
this to the rest of the file because I'm lazy.)

Test strategy is we delete the logic for preventing Python op
registrations in torch from being skipped in a torchdeploy context
and show CI still works.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87162
Approved by: https://github.com/anjali411, https://github.com/bdhirsh
2022-11-03 12:56:44 +00:00
Sherlock Huang
ab901b4817 Python binding for dispatcher getAllOpNames (#87422)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87422
Approved by: https://github.com/bdhirsh
2022-10-21 06:55:10 +00:00
Richard Zou
cd32a86bf2 Stop monkeypatching Tensor.backward() on import functorch (#85152)
Monkeypatching is bad, we should never be doing it. This PR removes
functorch's monkeypatching on Tensor.backward() by adding it directly to
the implementation of Tensor.backward().

As an alternative, we could have done an `import functorch` and used
`functorch._C.are_transforms_active` directly in
`torch/autograd/__init__.py`. The problem with that is that it runs into a
bunch of circular imports.

NB: https://github.com/pytorch/pytorch/issues/72179 is still on my mind.
I didn't choose to do it right now because:
- This PR doesn't make the situation worse than it already is (no
monkeypatching is better than having the monkeypatch)
- We don't have a design for #72179 yet.

Test Plan:
- tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85152
Approved by: https://github.com/soulitzer
2022-09-19 17:06:15 +00:00
Michael Voznesensky
8ca1839d32 Python Dispatcher integration with C++ dispatcher (#85050)
#84826 but without ghstack
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85050
Approved by: https://github.com/malfet
2022-09-15 00:43:36 +00:00
PyTorch MergeBot
706b990306 Revert "Python Dispatcher integration with C++ dispatcher (#84826)"
This reverts commit 35f6a69191.

Reverted https://github.com/pytorch/pytorch/pull/84826 on behalf of https://github.com/malfet due to Broke dynamo, see 35f6a69191
2022-09-14 14:07:58 +00:00
Michael Voznesensky
35f6a69191 Python Dispatcher integration with C++ dispatcher (#84826)
Signed-off-by: Edward Z. Yang <ezyangfb.com>

From @ezyang's original PR:

There are a number of situations where we have non-backend kernels (e.g., CompositeImplicitAutograd, batching rules) which we would like to port to Python, but we have no way to integrate these ports with the overall system while using preexisting C++ registrations otherwise. This PR changes that by introducing a Python dispatcher (which can have its own kernels directly in Python), which can be interpose over ordinary C++ dispatch. The ingredients:

We introduce a new PythonDispatcher dispatch key, that has the same tenor as FuncTorchDynamicLayerFrontMode: it works by getting triggered before every other dispatch key in the dispatch key, and shunting to a Python implementation
The Python dispatcher is a per-interpreter global object that is enabled/disabled via the guard EnablePythonDispatcher/DisablePythonDispatcher. We don't make it compositional as I have no idea what a compositional version of this feature would look like. Because it is global, we don't need to memory manage it and so I use a simpler SafePyHandle (newly added) to control access to this pointer from non-Python C++. Like __torch_dispatch__, we use PyInterpreter to get to the Python interpreter to handle the dispatch.
I need to reimplement dispatch table computation logic in Python. To do this, I expose a lot more helper functions for doing computations on alias dispatch keys and similar. I also improve the pybind11 handling for DispatchKey so that you can either accept the pybind11 bound enum or a string; this simplifies our binding code. See https://github.com/pybind/pybind11/issues/483#issuecomment-1237418106 for how this works; the technique is generally useful.

I need to be able to call backend fallbacks. I do this by permitting you to call at a dispatch key which doesn't have a kernel for the operator; if the kernel doesn't exist, we check the backend fallback table instead.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84826
Approved by: https://github.com/ezyang
2022-09-14 06:57:19 +00:00
Michael Voznesensky
ced2ca8f86 Torch cond operator, python dispatch, pyoperator (#83154)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83154
Approved by: https://github.com/ezyang
2022-08-25 20:11:53 +00:00
Brian Hirsh
1a51efd8bb dispatch API for checking computed table, use it in prim decomps (#82358)
Fixes https://github.com/pytorch/pytorch/issues/82331

Expose a `torch._C._dispatch_has_computed_kernel_for_dispatch_key` to check if an operator has a kernel registered to the given dispatch key in the **computed table**.

Use it in the prim registration logic, making it more accurate and robust (so that it e.g. picks up `CompositeExplicitAutograd` kernels.

It looks like before this change we'd register 134 prim ops to the meta key, and after we only register 62. So that's 72 ops that now use an existing C++ decomp to get meta working, instead of going directly through the prim decomp.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/82358
Approved by: https://github.com/ezyang
2022-08-10 23:42:02 +00:00
Edward Z. Yang
df69660832 Revert "Revert "Add a lint rule for torch/csrc/util/pybind.h include (#82552)"" (#82599)
This reverts commit 532b8a9e00.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82599
Approved by: https://github.com/albanD
2022-08-02 19:37:02 +00:00
PyTorch MergeBot
532b8a9e00 Revert "Add a lint rule for torch/csrc/util/pybind.h include (#82552)"
This reverts commit 9465c0e0b5.

Reverted https://github.com/pytorch/pytorch/pull/82552 on behalf of https://github.com/zengk95 due to This seems to be breaking windows binary wheels
2022-08-01 20:25:35 +00:00
Edward Z. Yang
9465c0e0b5 Add a lint rule for torch/csrc/util/pybind.h include (#82552)
We define specializations for pybind11 defined templates
(in particular, PYBIND11_DECLARE_HOLDER_TYPE) and consequently
it is important that these specializations *always* be #include'd
when making use of pybind11 templates whose behavior depends on
these specializations, otherwise we can cause an ODR violation.

The easiest way to ensure that all the specializations are always
loaded is to designate a header (in this case, torch/csrc/util/pybind.h)
that ensures the specializations are defined, and then add a lint
to ensure this header is included whenever pybind11 headers are
included.

The existing grep linter didn't have enough knobs to do this
conveniently, so I added some features.  I'm open to suggestions
for how to structure the features better.  The main changes:

- Added an --allowlist-pattern flag, which turns off the grep lint
  if some other line exists.  This is used to stop the grep
  lint from complaining about pybind11 includes if the util
  include already exists.

- Added --match-first-only flag, which lets grep only match against
  the first matching line.  This is because, even if there are multiple
  includes that are problematic, I only need to fix one of them.
  We don't /really/ need this, but when I was running lintrunner -a
  to fixup the preexisting codebase it was annoying without this,
  as the lintrunner overall driver fails if there are multiple edits
  on the same file.

I excluded any files that didn't otherwise have a dependency on
torch/ATen, this was mostly caffe2 and the valgrind wrapper compat
bindings.

Note the grep replacement is kind of crappy, but clang-tidy lint
cleaned it up in most cases.

See also https://github.com/pybind/pybind11/issues/4099

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82552
Approved by: https://github.com/albanD
2022-08-01 17:16:58 +00:00
Michael Suo
30fb2c4aba [lint] autoformat test/cpp and torch/csrc
Let's have some fun.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78828

Approved by: https://github.com/ezyang
2022-06-11 21:11:16 +00:00
Edward Z. Yang
eb856daf0f Do not treat all dense tensors as isTensorSubclassLike
Fixes https://github.com/pytorch/pytorch/issues/79079

Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79098

Approved by: https://github.com/soulitzer, https://github.com/albanD
2022-06-09 03:00:57 +00:00
Horace He
bbbfbbeddc Added "dump ops" API to return ops instead of print (#78995)
Useful to use for grabbing info instead of the hacky "redirect C++ output" stuff I currently do.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78995
Approved by: https://github.com/ezyang
2022-06-07 05:19:07 +00:00
Edward Z. Yang
80f2c175be Follow up on CR for "Replace TensorMeta with FakeTensor"
See https://github.com/pytorch/pytorch/pull/78836

Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78895

Approved by: https://github.com/albanD
2022-06-06 22:20:40 +00:00
Edward Z. Yang
587efdb5fa Replace TensorMeta with FakeTensor
Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78836

Approved by: https://github.com/albanD, https://github.com/mruberry
2022-06-05 11:51:27 +00:00
anjali411
5984bc8233 Allow specifying alias analysis while registering new ops
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77690

Approved by: https://github.com/ezyang
2022-05-19 21:11:40 +00:00
Edward Z. Yang
4941e72e40 Revert "Revert "Implement sym_sizes to create proper IR for sym ints representing tensor sizes (#76836)""
This reverts commit c35bd8d423.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77719

Approved by: https://github.com/Chillee, https://github.com/malfet
2022-05-18 18:40:57 +00:00
PyTorch MergeBot
48581d74ad Revert "Add dispatch mode testing for meta tensors and other stuff"
This reverts commit c1cdb1216b.

Reverted https://github.com/pytorch/pytorch/pull/77477 on behalf of https://github.com/malfet
2022-05-18 02:56:48 +00:00
Edward Z. Yang
c1cdb1216b Add dispatch mode testing for meta tensors and other stuff
We don't have any coverage for meta tensor correctness for backwards
because torch function mode can only allow us to interpose on
Python torch API calls, but backwards invocations happen from C++.
To make this possible, I add torch_dispatch_meta test which runs the
tests with __torch_dispatch__

While doing this, I needed to generate fresh expected failure / skip
lists for the new test suite, and I discovered that my original
scaffolding for this purpose was woefully insufficient.  So I rewrote
how the test framework worked, and at the same time rewrote the
__torch_function__ code to also use the new logic.  Here's whats
new:

- Expected failure / skip is now done on a per function call basis,
  rather than the entire test.  This means that separate OpInfo
  samples for a function don't affect each other.

- There are now only two lists: expect failure list (where the test
  consistently fails on all runs) and skip list (where the test
  sometimes passes and fails.

- We explicitly notate the dtype that failed.  I considered detecting
  when something failed on all dtypes, but this was complicated and
  listing everything out seemed to be nice and simple.  To keep the
  dtypes short, I introduce a shorthand notation for dtypes.

- Conversion to meta tensors is factored into its own class
  MetaConverter

- To regenerate the expected failure / skip lists, just run with
  PYTORCH_COLLECT_EXPECT and filter on a specific test type
  (test_meta or test_dispatch_meta) for whichever you want to update.

Other misc fixes:

- Fix max_pool1d to work with BFloat16 in all circumstances, by making
  it dispatch and then fixing a minor compile error (constexpr doesn't
  work with BFloat16)

- Add resolve_name for turning random torch API functions into string
  names

- Add push classmethod to the Mode classes, so that you can more easily
  push a mode onto the mode stack

- Add some more skips for missing LAPACK

- Added an API to let you query if there's already a registration for
  a function, added a test to check that we register_meta for all
  decompositions (except detach, that decomp is wrong lol), and then
  update all the necessary sites to make the test pass.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77477

Approved by: https://github.com/zou3519
2022-05-18 00:18:34 +00:00
anjali411
17653a53d5 Forward fix failing TestDispatch tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77442

Approved by: https://github.com/janeyx99
2022-05-13 20:08:10 +00:00
Anjali Chourdia
2c1de3aa47 Back out Dispatcher change that makes Messenger Desktop crash on M1 devices (#77414)
Summary:
This change causes Messenger Dekstop to crash on M1 devices when the user enables background during the call. The change apparently causes the compiler to emit AVX instructions that are not supported by Rosetta.

This is a surgical backout that only backs out the changes in C++ side,
and not Python bindings which I believe are not shipped with Workplace Chat.

Test Plan:
Run the application and make sure that it doesn't crash when the background is enabled
https://pxl.cl/23VSH

Reviewed By: ezyang

Differential Revision: D36358832

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77414
Approved by: https://github.com/bigfootjon
2022-05-13 17:33:53 +00:00
Edward Z. Yang
0a14a4c280 Register prims as operators.
This makes prims look as if they were defined in native_functions.yaml
but they're still all written in Python.  You now need to give a full
schema string for your prims.  The returned prim object is now
torch.ops.prim overload (prims are not allowed to be overloaded,
so we return the overload, not the overload packet, for speed.)

Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77117

Approved by: https://github.com/mruberry, https://github.com/albanD
2022-05-11 16:38:14 +00:00
anjali411
3d28ab0709 Minor follow up fixes for python registration
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76892

Approved by: https://github.com/albanD
2022-05-05 13:46:48 +00:00
anjali411
07f766df54 Allow creating new libraries and defining new operators from Python
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76250

Approved by: https://github.com/ezyang
2022-05-05 03:33:08 +00:00
anjali411
55f55a4cf6 Allow users to override kernels for existing C++ ops through Python
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75905

Approved by: https://github.com/ezyang
2022-05-05 03:31:39 +00:00
Brian Hirsh
bcc6e3ab5e add python API to print all operators that have kernels registered to a particular DispatchKey (#63575)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63575

Test Plan: Imported from OSS

Reviewed By: ezyang, Chillee

Differential Revision: D30426919

Pulled By: bdhirsh

fbshipit-source-id: b0e487e48dfe02f7b9d678403f0a2b5bfe146f4e
2021-09-22 09:15:55 -07:00
Alex Suhan
b176feec1e Add device and key for lazy tensors (#61621)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/61621

Test Plan: CI

Reviewed By: mruberry

Differential Revision: D29912934

Pulled By: asuhan

fbshipit-source-id: 493c32063a3e756d93cbf1d876563a35eaafb537
2021-07-26 23:00:22 -07:00
Jiewen Tan
d5be67a338 Expose findDanglingImpls to Python (#60827)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/60827

This diff exposed Dispatcher.findDanglingImpls to Python as _C._dispatch_find_dangling_impls.
ghstack-source-id: 132799970

Test Plan: buck test mode/dev //caffe2/test:others -- test_find_dangling_impls

Reviewed By: ezyang

Differential Revision: D29416330

fbshipit-source-id: d2f26054b6e247be1bb9e818eaa7cb9e68a4a913
2021-06-30 12:31:19 -07:00
Edward Yang
13b1ca9466 Rename DefaultBackend to CompositeExplicitAutograd (#54470)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/54470

```
git grep -l 'DefaultBackend' | xargs sed -i 's/DefaultBackend/CompositeExplicitAutograd/g'
```

Plus a quick fixup in native/README.md

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Test Plan: Imported from OSS

Reviewed By: bdhirsh

Differential Revision: D27253240

Pulled By: ezyang

fbshipit-source-id: 964df951ea8b52fa72937f3cc66aeaf49a702e6f
2021-03-26 10:53:30 -07:00
Edward Yang
145bc5cd51 Rename Math to CompositeImplicitAutograd (#54466)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/54466

I had to very carefully audit all the use sites since there are a lot
of other uses of the string Math; I did most of the conversion by
grepping for all occurrences of Math and then doing a search
replace.

I also updated documentation for clarity.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Test Plan: Imported from OSS

Reviewed By: ngimel

Differential Revision: D27253239

Pulled By: ezyang

fbshipit-source-id: afb485d07ff39575742a4f0e1e205179b60bc953
2021-03-24 13:49:24 -07:00
Scott Wolchok
1935880860 [PyTorch] Remove unnecessary dispatcher.h include in torch/library.h (#51162)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/51162

It's unused.
ghstack-source-id: 120427120

Test Plan: CI

Reviewed By: bhosmer

Differential Revision: D25859010

fbshipit-source-id: 7bb21312843debaedaa6a969727c171b2bb0e6b2
2021-01-26 22:19:32 -08:00
Ailing Zhang
0ddcc0ce35 Add alias dispatch key DefaultBackend. (#45718)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45718

Test Plan: Imported from OSS

Reviewed By: bhosmer

Differential Revision: D24165892

Pulled By: ailzhang

fbshipit-source-id: ed28bf62b7c6320d966fd10b7a44b14efffe2f62
2020-10-09 12:02:44 -07:00
Ailing Zhang
10f287539f Align casing in test_dispatch with dispatch keys. (#44933)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/44933

Test Plan: Imported from OSS

Reviewed By: ezyang

Differential Revision: D23778247

Pulled By: ailzhang

fbshipit-source-id: bc3725eae670b03543015afe763cb3bb16baf8f6
2020-09-22 10:50:08 -07:00
Ailing Zhang
92f8f75c59 Add alias dispatch key Math. (#44354)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/44354

Test Plan: Imported from OSS

Reviewed By: ezyang

Differential Revision: D23591481

Pulled By: ailzhang

fbshipit-source-id: 6e93c4ec99a07f3fc920ba2d09dc222e6ced5adf
2020-09-21 11:10:39 -07:00
Ailing Zhang
24efd29d19 Check commutativity for computed dispatch table and add a test to check entries. (#44088)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/44088

Test Plan: Imported from OSS

Reviewed By: ezyang

Differential Revision: D23492793

Pulled By: ailzhang

fbshipit-source-id: 37502f2a8a4d755219b400fcbb029e49d6cdb6e9
2020-09-09 12:48:34 -07:00
Ailing Zhang
1b2da9ed82 Expose alias key info in dumpState and update test_dispatch. (#44081)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/44081

Test Plan: Imported from OSS

Reviewed By: ezyang

Differential Revision: D23492794

Pulled By: ailzhang

fbshipit-source-id: 27a2978591900463bda2e92e0201c9fd719f9792
2020-09-06 18:43:05 -07:00
Ailing Zhang
224232032c Move Autograd to an alias dispatch key (#43070)
Summary:
This PR moves `DispatchKey::Autograd` to an alias dispatch key mapping to `AutogradCPU, AutogradCUDA, AutogradXLA, AutogradOther, AutogradPrivate*` keys.

A few things are handled in this PR:
- Update alias dispatch key mapping and precompute dispatchTable logic
- Move `Autograd` key from `always_included` set to TensorImpl constructor.
- Update `dummyTensor` constructor to take `requires_grad` as optional argument so that it's closer to the real application in op_registration_test.
- Use `BackendSelect` key for both backend select before and after autograd layer. (1 liner in backend_select codegen)

A few planned followups ordered by priority:
- [cleanup] Update `test_dispatch.py` to include testing `Autograd`.
- [cleanup] Add Math alias key and move catchAll to Math. (to remove 2.2 in `computeDispatchTableEntryWithDebug`)
- [new feature] Add support for Math in native_functions.yaml
- [cleanup] Add iterator like functionality to DispatchKeySet
- [cleanup/large] Only add Autograd backend keys when tensor requires grad. (cc: ljk53 ?)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/43070

Reviewed By: ezyang

Differential Revision: D23281535

Pulled By: ailzhang

fbshipit-source-id: 9ad00b17142e9b83304f63cf599f785500f28f71
2020-09-01 09:05:29 -07:00
Edward Yang
a4cabd1a3c Generalize Python dispatcher testing API; disallow overwriting fallback (#40469)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/40469

- The old testing interface C._dispatch_import was based off the old
  c10::import variation, which meant the API lined up in a strange
  way with the actual torch/library.h.  This diff reduces the
  differences by letting you program the Library constructor directly.

- Using this newfound flexibility, we add a test for backend fallbacks
  from Python; specifically testing that we disallow registering a
  backend fallback twice.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Test Plan: Imported from OSS

Differential Revision: D22236351

Pulled By: ezyang

fbshipit-source-id: f8365e3033e9410c7e6eaf9f78aa32e1f7d55833
2020-06-26 09:01:28 -07:00
Edward Yang
a894fff265 Back out "Revert D21089648: Put TORCH_LIBRARY in torch/library.h; add custom class API"
Summary: Original commit changeset: 636e8a11afc6

Test Plan: export to OSS

Reviewed By: malfet

Differential Revision: D21170502

fbshipit-source-id: e8f35f103c4924aedbcaaf868475008d24bdeeab
2020-04-22 09:18:23 -07:00
James Reed
2ccdc39dce Revert D21089648: Put TORCH_LIBRARY in torch/library.h; add custom class API
Test Plan: revert-hammer

Differential Revision:
D21089648

Original commit changeset: 8d54329c1252

fbshipit-source-id: 636e8a11afc628a4cdae9d44824985c10c70555e
2020-04-21 12:21:45 -07:00
Edward Yang
01100cb477 Put TORCH_LIBRARY in torch/library.h; add custom class API (#36742)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/36742

Now, you can define a custom class inside a TORCH_LIBRARY block.
It looks very similar to what you did before.  Instead of

```
static auto m = torch::class_<Class>("Namespace", "Class").def("foo", foo);
```

you write

```
TORCH_LIBRARY(Namespace, m) {
  m.class_<Class>("Class")
    .def("foo", foo);
}
```

All the old usages still work, but at some point we should start
updating the tutorials when we're ready to go 100% live with the
new pybind11 style API.

custom class API previously lived in torch/ folder and in torch
namespace, so for consistency, the new TORCH_LIBRARY also got
moved to torch/library.h The definition of Library::class_ is in the
bottom of that header because I need all of the class_ constructors
available, but there is a circular dependency between the two headers.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Differential Revision: D21089648

Test Plan: Imported from OSS

Pulled By: ezyang

fbshipit-source-id: 8d54329c125242605336c22fa1642aae6940b507
2020-04-21 10:05:21 -07:00
Edward Yang
e29348f828 Switch to pybind11 style registration function API. (#36258)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/36258

Previous we had a && chaining style API.  There are some downsides to
this API:

- It's easy to forget the 'static' qualifier in front, leading to
  subtle ODR bugs.
- It is not compatible with torchbind class_ definitions, as these
  need multiple levels of chaining.  So in practice people end
  up having to define multiple static initializers, one per class.
- It's not like pybind11.
- There's no way to conveniently get the file and line number of
  the registration, as there is no macro point in the API.
- The old API doesn't really encourage people to put all of their
  definitions for a library in one place, and to give a custom
  namespace for it.  Similarly, the old API wasn't very DRY, because
  you had to keep repeating the namespace/dispatch key you
  were writing implementations for.

The new API is modeled exactly off of the PYBIND11_MODULE macro:
you write:

```
TORCH_LIBRARY(aten, m) {
  m.def("aten::add(Tensor self, Tensor other) -> Tensor");
  ...
}
```

in a non-chaining fashion, and under the hood the macro expands to
define a function, and define a static initializer that allocates
c10::Library (previously called c10::Module, but we renamed it
to avoid confusion with the existing NN module concept), passes
it to your function, and then retains it for the rest of the lifetime
of the program.  Specification of the namespace is mandatory,
and in later commit I plan to make it a hard error to TORCH_LIBRARY
the same library name twice.

If you are specifying an implementation for an existing operator
(e.g., you're the XLA backend, or even if you're just putting
registrations for implementations at the implementation site),
you should use TORCH_LIBRARY_IMPL, which instead takes a backend
argument (instead of namespace) and can be used to specify an
implementation for a backend.  Unlike TORCH_LIBRARY, you can do
as many of these as you want for a backend.

This needs updates to the mobile code analyzer.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Test Plan: Imported from OSS

Differential Revision: D20929257

Pulled By: ezyang

fbshipit-source-id: ba04d78492e8c93ae7190165fb936f6872896ada
2020-04-16 10:44:21 -07:00
Edward Yang
dd64e738c5 Expunge TensorId from all DispatchKey names. (#36240)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/36240

It's annoying, historical, and unnecessary (enum class is already
namespaced).  I did this codemod with:

```
git grep -l 'CPUTensorId' | xargs sed -i 's/CPUTensorId/CPU/g'
git grep -l 'CUDATensorId' | xargs sed -i 's/CUDATensorId/CUDA/g'
git grep -l 'VariableTensorId' | xargs sed -i 's/VariableTensorId/Autograd/g'
git grep -l 'HIPTensorId' | xargs sed -i 's/HIPTensorId/HIP/g'
git grep -l 'MSNPUTensorId' | xargs sed -i 's/MSNPUTensorId/MSNPU/g'
git grep -l 'XLATensorId' | xargs sed -i 's/XLATensorId/XLA/g'
git grep -l 'PrivateUse1_TensorId' | xargs sed -i 's/PrivateUse1_TensorId/PrivateUse1/g'
git grep -l 'PrivateUse2_TensorId' | xargs sed -i 's/PrivateUse2_TensorId/PrivateUse2/g'
git grep -l 'PrivateUse3_TensorId' | xargs sed -i 's/PrivateUse3_TensorId/PrivateUse3/g'
git grep -l 'AutocastTensorId' | xargs sed -i 's/AutocastTensorId/Autocast/g'
git grep -l '_PreAutogradTensorId' | xargs sed -i 's/_PreAutogradTensorId/_PreAutograd/g'
git grep -l 'TESTING_ONLY_GenericWrapperTensorId' | xargs sed -i 's/TESTING_ONLY_GenericWrapperTensorId/TESTING_ONLY_GenericWrapper/g'
git grep -l 'TESTING_ONLY_GenericModeTensorId' | xargs sed -i 's/TESTING_ONLY_GenericModeTensorId/TESTING_ONLY_GenericMode/g'
```

Then I did a git grep for remaining TensorId occurrences, and manually
killed those (mostly in codegen, and some docs that needed updating).

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Test Plan: Imported from OSS

Differential Revision: D20929255

Pulled By: ezyang

fbshipit-source-id: dc371b6aa6e6ea7c0a5660137c14debde806a09d
2020-04-13 23:33:44 -07:00
Edward Yang
4f4ed5c108 Disable c10::import(ns) (#35398)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/35398

This disables namespaced c10::import which is broken with custom
mobile op builds.  This is to help prevent people from accidentally
breaking the custom mobile build in a mysterious way; if they use
the longform version it will work.  Fixing the analyzer is tracked
in https://github.com/pytorch/pytorch/issues/35397

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Test Plan: Imported from OSS

Differential Revision: D20680519

Pulled By: ezyang

fbshipit-source-id: a18ac8df7e72bf399807870beedb828131273e48
2020-03-30 21:12:49 -07:00