Commit Graph

75 Commits

Author SHA1 Message Date
George Qi
a90f006fe5 add strides to slow path
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78610

Approved by: https://github.com/ezyang
2022-06-10 16:59:14 +00:00
Edward Z. Yang
eb856daf0f Do not treat all dense tensors as isTensorSubclassLike
Fixes https://github.com/pytorch/pytorch/issues/79079

Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79098

Approved by: https://github.com/soulitzer, https://github.com/albanD
2022-06-09 03:00:57 +00:00
Zachary DeVito
ab6c7b4b3f fix __torch_function__ bug in getindex that causes an error not set exception
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78781

Approved by: https://github.com/ezyang
2022-06-06 17:02:57 +00:00
samdow
184e0065b3 add better error message for class method
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78821

Approved by: https://github.com/ezyang
2022-06-06 13:31:32 +00:00
Michael Suo
22b10873f3 Allow torchdispatch to customize dim()
This follows the template in
https://github.com/pytorch/pytorch/pull/77396

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78691

Approved by: https://github.com/ezyang
2022-06-02 20:54:13 +00:00
anjali411
79ddc32b6a Add a check to ensure input func to Library.impl is callable
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77990

Approved by: https://github.com/albanD
2022-06-02 16:55:39 +00:00
Michael Suo
876c359347 Generalize sizes and strides policy on _make_wrapper_subclass
Previously, there was a `dispatch_strides` boolean arg. Change this to
a string argument that directly maps onto `SizesStridesPolicy`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78646

Approved by: https://github.com/ezyang
2022-06-02 02:06:38 +00:00
samdow
aa06d05297 enable with semantics
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78214

Approved by: https://github.com/ezyang, https://github.com/zou3519
2022-06-01 21:14:45 +00:00
Elias Ellison
678213ead2 Fake Tensor Part 1
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77969

Approved by: https://github.com/ezyang
2022-05-31 16:20:35 +00:00
soulitzer
f3af51069d Modernize LoggingTensorMode
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77667

Approved by: https://github.com/malfet
2022-05-24 22:41:49 +00:00
Elias Ellison
2d93e1fada Add slow path for device
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77684

Approved by: https://github.com/ezyang
2022-05-24 21:56:01 +00:00
George Qi
294fff16ec add slow path for is_contiguous (#77906)
Test Plan: CI

Reviewed By: malfet, b0noI

Differential Revision: D36493890

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77906
Approved by: https://github.com/malfet
2022-05-19 22:52:45 +00:00
anjali411
5984bc8233 Allow specifying alias analysis while registering new ops
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77690

Approved by: https://github.com/ezyang
2022-05-19 21:11:40 +00:00
PyTorch MergeBot
00a187c373 Revert "add slow path for is_contiguous"
This reverts commit f6beda89c6.

Reverted https://github.com/pytorch/pytorch/pull/77396 on behalf of https://github.com/malfet
2022-05-19 17:07:54 +00:00
Edward Z. Yang
4941e72e40 Revert "Revert "Implement sym_sizes to create proper IR for sym ints representing tensor sizes (#76836)""
This reverts commit c35bd8d423.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77719

Approved by: https://github.com/Chillee, https://github.com/malfet
2022-05-18 18:40:57 +00:00
PyTorch MergeBot
48581d74ad Revert "Add dispatch mode testing for meta tensors and other stuff"
This reverts commit c1cdb1216b.

Reverted https://github.com/pytorch/pytorch/pull/77477 on behalf of https://github.com/malfet
2022-05-18 02:56:48 +00:00
George Qi
f6beda89c6 add slow path for is_contiguous
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77396

Approved by: https://github.com/ezyang, https://github.com/cpuhrsch
2022-05-18 02:25:27 +00:00
Edward Z. Yang
c1cdb1216b Add dispatch mode testing for meta tensors and other stuff
We don't have any coverage for meta tensor correctness for backwards
because torch function mode can only allow us to interpose on
Python torch API calls, but backwards invocations happen from C++.
To make this possible, I add torch_dispatch_meta test which runs the
tests with __torch_dispatch__

While doing this, I needed to generate fresh expected failure / skip
lists for the new test suite, and I discovered that my original
scaffolding for this purpose was woefully insufficient.  So I rewrote
how the test framework worked, and at the same time rewrote the
__torch_function__ code to also use the new logic.  Here's whats
new:

- Expected failure / skip is now done on a per function call basis,
  rather than the entire test.  This means that separate OpInfo
  samples for a function don't affect each other.

- There are now only two lists: expect failure list (where the test
  consistently fails on all runs) and skip list (where the test
  sometimes passes and fails.

- We explicitly notate the dtype that failed.  I considered detecting
  when something failed on all dtypes, but this was complicated and
  listing everything out seemed to be nice and simple.  To keep the
  dtypes short, I introduce a shorthand notation for dtypes.

- Conversion to meta tensors is factored into its own class
  MetaConverter

- To regenerate the expected failure / skip lists, just run with
  PYTORCH_COLLECT_EXPECT and filter on a specific test type
  (test_meta or test_dispatch_meta) for whichever you want to update.

Other misc fixes:

- Fix max_pool1d to work with BFloat16 in all circumstances, by making
  it dispatch and then fixing a minor compile error (constexpr doesn't
  work with BFloat16)

- Add resolve_name for turning random torch API functions into string
  names

- Add push classmethod to the Mode classes, so that you can more easily
  push a mode onto the mode stack

- Add some more skips for missing LAPACK

- Added an API to let you query if there's already a registration for
  a function, added a test to check that we register_meta for all
  decompositions (except detach, that decomp is wrong lol), and then
  update all the necessary sites to make the test pass.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77477

Approved by: https://github.com/zou3519
2022-05-18 00:18:34 +00:00
Edward Z. Yang
b5bc954a71 Fix optional dtype/layout/memory_format pycall; fix memory format
Double-header bug fix:

- As reported by jansel, dtypes are still showing up as integers
  when the schema is an optional dtype.  This is simple enough to
  fix and I added a test for it.  But while I was at it...

- I noticed that the THPMemoryFormat_new idiom with "unused" name
  doesn't actually work, the repr of the returned memory format
  object is wrong and this shows up when we try to log the args/kwargs.
  So I fixed memory format to do it properly along with everything
  else.

Fixes https://github.com/pytorch/pytorch/issues/77135

Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77543

Approved by: https://github.com/albanD, https://github.com/jansel
2022-05-16 16:46:08 +00:00
Sherlockk Huang
61dcde88a6 Jiterator with Python Registration (#77121)
You can now do a lot of crazy things about redefining the behavior of an operator, and still be fast in cuda !!!

Example 1: swapping where's branches
```
code_string = "template <typename T> T inverted_where(bool cond, T a, T b){ return !cond ? a : b; }"
jitted_fn = torch.cuda.jiterator._create_jit_fn(code_string)
my_lib = torch.library.Library("aten", "IMPL")
my_lib.impl('aten::where.self', jitted_fn, "CUDA")

# torch.where is now overridden
```
Example 2: approximate gelu with relu
```
code_string = "template <typename T> T fast_gelu(T a){ return a > 0 ? a : 0;}"
jitted_fn = torch.cuda.jiterator._create_jit_fn(code_string)
my_lib = torch.library.Library("aten", "IMPL")
my_lib.impl('aten::gelu', jitted_fn, "CUDA")

# torch.nn.GELU and torch.nn.function.gelu are now overridden
```
Example 3: clipping output for numerical unstable kernels
```
code_string = "template <typename T> T clipped_exp(T a){ return a > T(10.0) ? T(22026.4657948) : exp(a); }"
jitted_fn = torch.cuda.jiterator._create_jit_fn(code_string)
my_lib = torch.library.Library("aten", "IMPL")
my_lib.impl('aten::exp', jitted_fn, "CUDA")

# torch.exp(x) and x.exp() are now overridden
```
Example 4: Simulate buggy hardware behaviors
```
code_string = "template <typename T> T buggy_add(T a, T b){ return a + b + T(1); }"
jitted_fn = torch.cuda.jiterator._create_jit_fn(code_string)
my_lib = torch.library.Library("aten", "IMPL")
my_lib.impl('aten::add.Tensor', jitted_fn, "CUDA")

torch.add(x, y), "x + y" and x.add(y) are now overridden
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77121
Approved by: https://github.com/anjali411
2022-05-10 20:54:23 +00:00
anjali411
767af8e335 Add meta tensor support for some operations using python registration
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76916

Approved by: https://github.com/ezyang
2022-05-10 17:55:06 +00:00
Edward Z. Yang
f2eed9400d Register PrimTorch refs as decompositions.
For the most part, PrimTorch refs have the same signature as their
ATen equivalents.  I modify most PrimTorch refs to register themselves
as decompositions, using the prim name they wrap to find the aten name
(except for a few cases where the prim/aten names mismatch).  There are
some exclusions, falling into one of two categories:

- The torch equivalent was already implemented as a CompositeImplicitAutograd
  decomposition in C++

- The ref doesn't support enough features (e.g., the real deal has more
  kwargs / overloads than are currently implemented)

PrimTorch refs are written as a single function that supports all
overloads, and this style is convenient for cases where we have a bundle
of overloads for what morally is a single overload with a Union type
on an argument (which we ought to have supported in
native_functions.yaml but blah); to support registering a single decomp
for all the overloads, we modify register_decomposition to register
to ALL overloads if you pass it an overload packet.  This is technically
BC breaking but no tests started failing because of it.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76835

Approved by: https://github.com/Chillee, https://github.com/mruberry
2022-05-06 20:11:45 +00:00
anjali411
07f766df54 Allow creating new libraries and defining new operators from Python
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76250

Approved by: https://github.com/ezyang
2022-05-05 03:33:08 +00:00
anjali411
55f55a4cf6 Allow users to override kernels for existing C++ ops through Python
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75905

Approved by: https://github.com/ezyang
2022-05-05 03:31:39 +00:00
samdow
6779366f27 add nested mode to python mode
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75965

Approved by: https://github.com/albanD, https://github.com/ezyang, https://github.com/zou3519
2022-05-04 13:01:06 +00:00
samdow
598e7e5f19 [Reland] Change 'python mode' to 'torch dispatch mode'
Changes Python Mode name to Torch Dispatch Mode because there is now a Torch Function Mode, so Torch Dispatch Mode and Torch Function Mode are consistent with each other
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76562
Approved by: https://github.com/zou3519, https://github.com/albanD
2022-05-02 20:06:43 +00:00
PyTorch MergeBot
395a620a4f Revert "Change 'python mode' to 'torch dispatch mode'"
This reverts commit 7203a73986.

Reverted https://github.com/pytorch/pytorch/pull/76562 on behalf of https://github.com/janeyx99
2022-05-02 14:42:11 +00:00
samdow
7203a73986 Change 'python mode' to 'torch dispatch mode'
Changes Python Mode name to Torch Dispatch Mode because there is now a Torch Function Mode, so Torch Dispatch Mode and Torch Function Mode are consistent with each other
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76562
Approved by: https://github.com/zou3519
2022-05-02 13:33:58 +00:00
albanD
cd0591dff3 Change default TLS behavior in dispatch to favor is-a style
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75827

Approved by: https://github.com/ezyang
2022-04-20 17:32:29 +00:00
Edward Z. Yang
2772870860 Preserve Python dispatch keys upon copy_tensor_metadata_except_version_counter
Whether or not this is a reasonable operation to do in the presence of
subclasses is a good question in and of itself, but this fixes an
obvious invariant violation, which is that if a Tensor reports that
it is a tensor subclass, it had better have the Python dispatch key.
Previously, the dispatch key would have gotten unconditionally cleared;
now we preserve what ever the original bit was.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/75644

Approved by: https://github.com/albanD
2022-04-15 13:26:23 +00:00
Jane Xu
a1e284d9c8 Remove high priority as an owner for tests (#74555)
Summary:
Following triage review discussion, it would be best for these tests to not be triaged high priority by automation, but by the triagers in the oncall.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/74555

Reviewed By: albanD

Differential Revision: D35099202

Pulled By: janeyx99

fbshipit-source-id: 657a0317141de3a598476a6f601ec26cc26231b1
(cherry picked from commit 057519cb2494d0f9a0b169f359ac87ba9e89f088)
2022-03-24 14:29:52 +00:00
Sherlock Huang
f4a0da8695 Supports super().__torch_dispatch__ with arguments list
Summary:
For THPModule_disable_torch_(dispatch|function),  converts list arguments to tuple before invoking PyObject_Call.

Fixes  #73933

Test Plan:

Reviewers:

Subscribers:

Tasks: 114830027

Tags:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/74509

Fix PyObject leak issue

Pull Request resolved: https://github.com/pytorch/pytorch/pull/74567

Approved by: https://github.com/ezyang
2022-03-23 23:33:44 +00:00
Duncan Hill
0988dc481a [Codemod][Codemod deprecated unittest asserts] fbcode//caffe2/test (#71708)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71708

In Python 3.2, a number of asserts were deprecated.

In Python 3.11, these asserts are deleted completely. The files in this change still use the deprecated asserts.

Switch over to the supported syntax for 3.2 onwards.

Test Plan: Tested on the internal test suite runner.

Reviewed By: ajtulloch

Differential Revision: D33503694

fbshipit-source-id: a150f296033260acf8365d77b837ce0679f57361
(cherry picked from commit abf60ed97409265222915d8265aaabedd625fd93)
2022-03-15 19:28:52 +00:00
Edward Yang
0239284313 Relax dtype restrictions on torch.Tensor (#73850)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73850

Previously, torch.Tensor was treated as if it were torch.FloatTensor
(where Float is whatever the default dtype was).  This is not good
behavior for tensor subclasses, which inherit from torch.Tensor and
will want to super() call into it and will only notice later that
only float works as a dtype.  So in this PR I relax the behavior
for this case to make the torch.Tensor constructor more useful for
subclasses.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D34707396

Pulled By: ezyang

fbshipit-source-id: a995d601007b6fcd0317d89f66ca7e08c4d6053e
(cherry picked from commit e8d0d7b3e8b17681b931cbe4f5729de2e80cf3de)
2022-03-09 15:45:24 +00:00
anjali411
086645ad77 Update __torch_dispatch__ to return op overload instead of the opoverload packet function (#72673)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/72673

Test Plan: Imported from OSS

Reviewed By: mruberry

Differential Revision: D34627164

Pulled By: anjali411

fbshipit-source-id: 3cb6406a392d530bf9da36b4d8e0a62b30e6497e
(cherry picked from commit 65b85a0a67df4d0f16ac8964e2b685d478a610fb)
2022-03-07 22:38:42 +00:00
Edward Z. Yang
35cfa74f97 Add a default implementation of __torch_dispatch__
I was working on an explanation of how to call into the "super"
implementation of some given ATen operation inside of __torch_dispatch__
(https://github.com/albanD/subclass_zoo/blob/main/trivial_tensors.py)
and I kept thinking to myself "Why doesn't just calling super() on
__torch_dispatch__ work"?  Well, after this patch, it does!  The idea
is if you don't actually unwrap the input tensors, you can call
super().__torch_dispatch__ to get at the original behavior.

Internally, this is implemented by disabling PythonKey and then
redispatching.  This implementation of disabled_torch_dispatch is
not /quite/ right, and some reasons why are commented in the code.
There is then some extra work I have to do to make sure we recognize
disabled_torch_dispatch as the "default" implementation (so we don't
start slapping PythonKey on all tensors, including base Tensors),
which is modeled the same way as how disabled_torch_function is done.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/73684

Approved by: albanD
2022-03-03 20:19:33 +00:00
Joel Benjamin Schlosser
30653d164d Fix serialization and deepcopying for wrapper subclasses
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73078
2022-02-24 18:21:25 +00:00
Alban Desmaison
1d6b156c3a Reland fix dispatch (#73231)
Summary:
Reland of https://github.com/pytorch/pytorch/issues/73045

Tweak class visibility to avoid windows linking issues.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/73231

Reviewed By: bdhirsh

Differential Revision: D34402767

Pulled By: albanD

fbshipit-source-id: 50aaadf5389ca516fa6a5034d42eee56abe3c7f7
(cherry picked from commit 0fe53bdfb7)
2022-02-23 15:28:15 +00:00
Nikita Shulga
9a96604800 Revert D34318185: [pytorch][PR] Ensure that call before redispatch work well for PythonTLSSnapshot
Test Plan: revert-hammer

Differential Revision:
D34318185 (04c9e52ecc)

Original commit changeset: abc30fe69176

Original Phabricator Diff: D34318185 (04c9e52ecc)

fbshipit-source-id: ba40c2e1eceb1c4b71ac6edefc64d01e174d9524
(cherry picked from commit f47961904d)
2022-02-22 18:31:13 +00:00
Alban Desmaison
04c9e52ecc Ensure that call before redispatch work well for PythonTLSSnapshot (#73045)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/73045

Reviewed By: zou3519

Differential Revision: D34318185

Pulled By: albanD

fbshipit-source-id: abc30fe69176ba474e28bb045406a410e17cfd79
(cherry picked from commit 4d9a305d3a)
2022-02-22 15:30:07 +00:00
Alban Desmaison
a7cac05ca6 Add new tls snapshot feature (#72832)
Summary:
Reland of https://github.com/pytorch/pytorch/pull/72623 that was reverted for the tls cleanup was removed.

From close inspection on the counting of the number of available keys, I think there is one more since the guard is actually one after the last usable key. With this update assert, the last updated key will still be <=63 which will fit just fine.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/72832

Reviewed By: H-Huang

Differential Revision: D34228571

Pulled By: albanD

fbshipit-source-id: ce5e10a841ea87386727346cfc8d9327252574c4
(cherry picked from commit 59d3b86353)
2022-02-15 19:02:05 +00:00
Brian Hirsh
f1a9650e4f Revert D34214953: Add new tls snapshot feature
Test Plan: revert-hammer

Differential Revision:
D34214953 (6199b5231f)

Original commit changeset: 7aa5d5e3540a

Original Phabricator Diff: D34214953 (6199b5231f)

fbshipit-source-id: 5d271e9a5ab021b8202402630dbf917b43c55421
(cherry picked from commit a12c630198)
2022-02-14 23:14:19 +00:00
Alban Desmaison
6199b5231f Add new tls snapshot feature (#72623)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/72623

Test Plan: Imported from OSS

Reviewed By: samdow

Differential Revision: D34214953

Pulled By: albanD

fbshipit-source-id: 7aa5d5e3540a45a0ae70c5af3a4495c755908aa9
(cherry picked from commit dc0a1ab54a)
2022-02-14 20:46:54 +00:00
Alban Desmaison
584f13967b Add wrapped Tensor autograd test (#72622)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72622

This contain a version of the test for next PR that doesn't work. To see the change in behavior more easily.

Test Plan: Imported from OSS

Reviewed By: samdow

Differential Revision: D34214954

Pulled By: albanD

fbshipit-source-id: 4d72f2d20e12c57ca7b63852ffe0c8aa61aa593b
(cherry picked from commit b5d792d103)
2022-02-14 20:13:30 +00:00
Alban Desmaison
3c33f0bdcd Clean up LoggingTensor semantic (#72620)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72620

Clarify how LoggingTensor works with autograd.
The updated comment should cover the semantic changes.

Test Plan: Imported from OSS

Reviewed By: samdow

Differential Revision: D34214956

Pulled By: albanD

fbshipit-source-id: 730d0a68f4228d2a84758e6807d869a34cbc1b31
(cherry picked from commit 66110bf16b)
2022-02-14 20:13:30 +00:00
Richard Zou
5735f2f875 Make detach redispatch like a regular PyTorch operator (#71707)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71707

Why?
- detach should behave like jax.stop_gradient in functorch. Because it
does not detach all the way through, functorch (as well as a Tensor
Subclass wrapping a Tensor subclass) won't see it after the first
layer/subclass handles it.

How?
- This PR changes detach to dispatch all the way through to the backend.
- This PR also modifies native::detach to call shallow_copy_and_detach
instead of native::alias. This is because today, the semantics of detach
and alias are differently -- they differ only by
allow_tensor_metadata_change. In the future, we may choose to deprecate
this flag.
- NB: Before and after this PR, detach() shows up twice in
torch_dispatch: https://github.com/pytorch/pytorch/issues/71725. This is
not a regression so I didn't want to fix it in this PR because it is
weird to fix.

Test Plan: - added new tests; run existing tests

Reviewed By: albanD

Differential Revision: D33752860

Pulled By: zou3519

fbshipit-source-id: 40cc2dc8232e75a02586a4ba5b0ef5f16cb76617
(cherry picked from commit f88aae426e)
2022-01-28 16:13:36 +00:00
Can Balioglu
80b19c4c8c Enable Python bindings for UntypedStorage (#68945)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/68945

This PR enables the Python conversion functions for `Storage` (specifically `UntypedStorage`) and also cleans up some remnants of the deprecated typed storages from `DynamicTypes.cpp`.
ghstack-source-id: 147245110

Test Plan: Run the existing unit and integration tests.

Reviewed By: albanD

Differential Revision: D32676505

fbshipit-source-id: 3a3f6db4fb0da5c78dd406c96ab70bdc37015521
(cherry picked from commit d6427b94cf)
2022-01-20 02:11:34 +00:00
Alban Desmaison
8b20dde932 add python dispatch test back to CI and fix typo in test (#69565)
Summary:
The error message was changed following a PR comment. And since the test doesn't run on CI, I forgot to update the test to catch the new error message.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/69565

Reviewed By: mrshenli

Differential Revision: D32932982

Pulled By: albanD

fbshipit-source-id: a1da72b0ca735e72b481bc944039233094f1c422
2021-12-08 08:44:49 -08:00
Alban Desmaison
28c519961f Follow the undefined Tensor <-> None rule better in torch dispatch (#67793)
Summary:
As per title. This in particular allows to more easily override backward function for which the underlying backend returns `None`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/67793

Reviewed By: zou3519

Differential Revision: D32242962

Pulled By: albanD

fbshipit-source-id: 6e114def90ee9499161e1303d301ba7fd003ff89
2021-12-02 07:46:56 -08:00
Richard Zou
3d504ae1b4 [RELAND] Fix Dispatching not considering List[Optional[Tensor]] for dispatch (#68073)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/68073

Relanding the original PR. Its body was as follows:

Followup to https://github.com/pytorch/pytorch/pull/60787

It turns out that the original PR was wrong for unboxed kernels. We
recently ran into this in
https://github.com/facebookresearch/functorch/issues/124

For unboxed kernels, the correct type for a Tensor?[] argument is
actually `List<optional<Tensor>>`, not `ArrayRef<optional<Tensor>>`
ghstack-source-id: 144204580

Test Plan:
- assert that https://github.com/facebookresearch/functorch/issues/124
actually works

Reviewed By: gchanan

Differential Revision: D32313601

Pulled By: zou3519

fbshipit-source-id: 8028d5f34eecabc53d603bd54d6b6748b5db461a
2021-11-29 08:31:55 -08:00