Commit Graph

143 Commits

Author SHA1 Message Date
Edward Z. Yang
d690a596dc Fast path binary ops in fake tensor (#94047)
Fast path execution of a few binary ops in fake tensor, to speed up trace time. When testing `python benchmarks/dynamo/timm_models.py --accuracy --timing --backend aot_eager --dynamic-shapes --float32 --only hrnet_w18`, I get the following trace speedup.

Before:

```
cuda eval  hrnet_w18                           PASS
TIMING: entire_frame_compile:53.97591 backend_compile:33.60832
STATS: call_* op count: 1369 | FakeTensor.__torch_dispatch__:4995 | FakeTensorMode.__torch_dispatch__:89985 | ProxyTorchDispatchMode.__torch_dispatch__:3010
```

After:

```
cuda eval  hrnet_w18                           PASS
TIMING: entire_frame_compile:40.18931 backend_compile:25.28828
STATS: call_* op count: 1369 | FakeTensor.__torch_dispatch__:4995 | FakeTensorMode.__torch_dispatch__:69478 | attempt fast:4399 | fast is_contiguous:4399 | ProxyTorchDispatchMode.__torch_dispatch__:3010
```

My experiment notebook can be found at https://docs.google.com/document/d/1_dTIQUwjIVnEWmiFAavJQYVF8uzXqD9Dk6b9gGQLF_U/edit#

This is not the "most" optimized version of the code; compared with Horace/Voz roofline experiment:

```
diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py
index e3bf545f3b8..395942c6ffe 100644
--- a/torch/_subclasses/fake_tensor.py
+++ b/torch/_subclasses/fake_tensor.py
@@ -774,6 +774,10 @@ class FakeTensorMode(TorchDispatchMode):
     def __torch_dispatch__(self, func, types, args=(), kwargs=None):
         kwargs = kwargs if kwargs else {}

+        with no_dispatch():
+            if func in {aten.mul.Tensor, aten.add.Tensor, aten.sub.Tensor, aten.relu.default}:
+                return FakeTensor(self, torch.empty(args[0].shape, device='meta'), device='cuda')
+
         if func == torch.ops.prim.device.default:
             assert len(args) == 1 and isinstance(args[0], FakeTensor)
             if args[0].fake_mode.in_kernel_invocation:
```

I am still leaving about 5s of trace time improvement on the table (3s of which is attributable to not yet handling relu.)

The implementation here is based off of https://github.com/pytorch/pytorch/pull/93118/ but I modeled the short circuit logic off of TensorIterator's implementation, for ease of code review and correctness verification. However, there are some important divergences:

* Traditional fast setup in TensorIterator only short circuits if the shapes of all input elements are equal. On hrnet_w18, only 5% of fastpath'ed binary operators actually satisfy this. So instead, I compute the broadcasted shape, but then I only allow the fast path if (1) at least one input tensor has a shape that is exactly the output size, and (2) all the tensors are contiguous (or if all the tensors are channels last).
* I had to manually adjust the logic to handle wrapped numbers (which ordinarily are handled by wrapping into tensors). I think I got this right.

Some evidence that this heuristic is correct is here in: https://gist.github.com/ezyang/b22fa7b72b7349137211d8dc7041f758 I exhaustively test all dim=3 tensors with sizes [1, 2] and show that we get the same significant strides between PrimTorch and the new algorithm. In fact, there ARE differences between this algorithm and PrimTorch, but in fact this algorithm agrees with TensorIterator where PrimTorch is wrong (sample case: size=(1, 1, 2), stride=(1, 1, 1), stride=(1, 1, 1))

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94047
Approved by: https://github.com/eellison
2023-02-07 18:34:24 +00:00
Yanbo Liang
605b661805 FakeTensor should constant propagate through ops that allow numbers as scalars (#94145)
Fixes #92655

Thanks @eellison for the code change suggestion.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94145
Approved by: https://github.com/eellison
2023-02-07 06:20:35 +00:00
Edward Z. Yang
2481fc0df4 Add count to FakeTensorMode.__torch_dispatch__ (#93936)
Most calls to fake tensor never hit `FakeTensor.__torch_dispatch__`

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/93936
Approved by: https://github.com/bdhirsh, https://github.com/albanD
2023-02-03 14:21:11 +00:00
Edward Z. Yang
12f22655b1 Short circuit device property access on FakeTensor (#93946)
Before:

```
(/home/ezyang/local/a/pytorch-env) [ezyang@devgpu020.ftw1 ~/local/a/pytorch (ab0e3db0)]$ python benchmarks/dynamo/timm_models.py --accuracy --timing --backend aot_eager --dynamic-shapes --float32 --only hrnet_w18
cuda eval  hrnet_w18                           PASS
TIMING: entire_frame_compile:54.19504 backend_compile:33.86702
STATS: call_* op count: 1369 | FakeTensor.__torch_dispatch__:72549 | FakeTensorMode.__torch_dispatch__:115542 | ProxyTorchDispatchMode.__torch_dispatch__:3103
```

After

```
(/home/ezyang/local/a/pytorch-env) [ezyang@devgpu020.ftw1 ~/local/a/pytorch (ab0e3db0)]$ python benchmarks/dynamo/timm_models.py --accuracy --timing --backend aot_eager --dynamic-shapes --float32 --only hrnet_w18
cuda eval  hrnet_w18                           PASS
TIMING: entire_frame_compile:53.97591 backend_compile:33.60832
STATS: call_* op count: 1369 | FakeTensor.__torch_dispatch__:4995 | FakeTensorMode.__torch_dispatch__:89985 | ProxyTorchDispatchMode.__torch_dispatch__:3010
```

It doesn't really help end-to-end wall time all that much, but it does cut the number of calls to FakeTensor.__torch_dispatch__ by an order of magnitude, which hopefully has other positive effects.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/93946
Approved by: https://github.com/eellison, https://github.com/albanD
2023-02-03 14:20:30 +00:00
Elias Ellison
e4f11e01bd [Fake Tensor] Allow fake meta by default, delete unused ctor args (#93993)
Two small changes that I'm bundling together because one of them needs to touch fbcode and I'm not sure how to do stacked diffs + internal changes + land before release cut.

Remove allow_meta from ctor, and allow by default: we should be able to trace through meta with fake tensors, so in some senses it's a bit weird to expose to user to disallow this. However, it's still useful debug wise to error from time to time, so I've added an option to the config that will get back previous behavior.

Remove `throw_on_data_dependent_ops=True`: this was intended as a temporary behavior as we were smoothing things turning on the erroring. There are no uses anywhere of `throw_on_data_dependent_ops=False` I could find.

These are technically backward-incompatble, but fake tensor is new since the last release / in a private namespace, and I don't want to release it with baggage that would be hard to remove later.

Fix for https://github.com/pytorch/pytorch/issues/92877.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/93993
Approved by: https://github.com/bdhirsh, https://github.com/ezyang
2023-02-03 09:23:38 +00:00
Jason Ansel
8c09a005c5 [inductor] Pattern matching engine (copy) (#93291)
This is an exact duplicate of https://github.com/pytorch/pytorch/pull/90739

The fbcode workflow for landing that diff seems buggy.  The github-export-checks task is failing with credentials errors.  Plan to try to land it using GH1.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/93291
Approved by: https://github.com/desertfire
2023-01-31 04:51:00 +00:00
Michael Voznesensky
d322f82b05 Add @count util to torch, use it to track benchmark stats (#93013)
<img width="1333" alt="image" src="https://user-images.githubusercontent.com/4755252/214687911-f766f072-c162-4298-9aed-c889f1375336.png">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/93013
Approved by: https://github.com/ezyang
2023-01-26 03:09:12 +00:00
Edward Z. Yang
1237cf6b6c Allow direct Tensor constructor to return preexisting PyObject (#92754)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92754
Approved by: https://github.com/albanD, https://github.com/voznesenskym
2023-01-23 20:20:43 +00:00
PyTorch MergeBot
db466ae057 Revert "[Modes] Add assert that the mode isn't already on the stack (#90770)"
This reverts commit 702838637d.

Reverted https://github.com/pytorch/pytorch/pull/90770 on behalf of https://github.com/DanilBaibak due to Break internal build
2023-01-12 16:44:29 +00:00
samdow
702838637d [Modes] Add assert that the mode isn't already on the stack (#90770)
Redo of #89726 on a clean PR, thanks @voznesenskym for the first draft!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90770
Approved by: https://github.com/ezyang
2023-01-11 15:19:43 +00:00
Samantha Andow
a7749ae177 [reland] rename DisableTorchFunction to DisableTorchFunctionSubclass (#88218) (#89221)
Summary: First half of #87990. This doesn't change any of the behavior and is just a rename

#88218 got reverted for internal breakages. This is the reland of started from internal

Differential Revision:
D41268423

LaMa Project: L1098534

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89221
Approved by: https://github.com/meliy-meyada, https://github.com/zou3519
2023-01-04 18:32:49 +00:00
Edward Z. Yang
bcf15cd93b Store source, not sname, in Symbol (#91057)
I'm going to need this in the follow up PR. Instead of storing only Source.name() in Symbol, I now store a full on Source. Lots of replumbing reoccurs. In particular:

- Move Source to torch._guards to break cycles
- I have to add TensorPropertySource and NegateSource to handle x.size()[0] and -x codegen that I was doing with string manipulation previously
- I tighten up invariants so that I never pass source=None; instead I pass ConstantSource (these are constant sources right) and test for that rather than source being missing. I think this is more parsimonious
- Some mypy wobbles from new imports

I didn't move LocalSource and friends to torch._guards, but I ended up needing to access them in a few places. The main annoyance with moving these is that then I also need to move the bytecode codegen stuff, and that's not so easy to move without bringing in the kitchen sink.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91057
Approved by: https://github.com/albanD, https://github.com/voznesenskym, https://github.com/zou3519
2022-12-30 05:56:56 +00:00
Nikita Shulga
fd3a7264ae [MPS] Add group_norm[fwd+backward] and mean_var (take 2) (#91190)
Use Prims to implement group_norm, group_norm_backward and mean_var

Use `torch._ops.ops` instead of `torch.ops` in numerous subpackages in
order to be able to make them importable from `torch/backend/mps/__init__.py` as this alias is defined in
15af4b1cee/torch/__init__.py (L1095)
is executed last during init process.

Add `__all__` to `torch/backends/mps/__init__.py` as well as alias all imports as private

Add `TestNNMPS.test_group_norm_backward` that validates no NaNs are generated during the backward pass

Fixes https://github.com/pytorch/pytorch/issues/88331
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91190
Approved by: https://github.com/albanD
2022-12-22 08:54:37 +00:00
PyTorch MergeBot
b68fd7e319 Revert "Store source, not sname, in Symbol (#91057)"
This reverts commit 88c581be87.

Reverted https://github.com/pytorch/pytorch/pull/91057 on behalf of https://github.com/atalman due to causing internal build failures
2022-12-21 22:33:15 +00:00
PyTorch MergeBot
645eda0a00 Revert "[MPS] Add group_norm[fwd+backward] and mean_var (#91190)"
This reverts commit 371716eb36.

Reverted https://github.com/pytorch/pytorch/pull/91190 on behalf of https://github.com/kit1980 due to Broke test_correct_module_names because of underscore _ops
2022-12-21 19:37:43 +00:00
Nikita Shulga
371716eb36 [MPS] Add group_norm[fwd+backward] and mean_var (#91190)
Use Prims to implement group_norm, group_norm_backward and mean_var

Use `torch._ops.ops` instead of `torch.ops` in numerous subpackages in
order to be able to make them importable from `torch/backend/mps/__init__.py` as this alias is defined in
15af4b1cee/torch/__init__.py (L1095)
is executed last during init process.

Depends on https://github.com/pytorch/pytorch/pull/91203

Fixes https://github.com/pytorch/pytorch/issues/88331
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91190
Approved by: https://github.com/albanD
2022-12-21 17:33:27 +00:00
Edward Z. Yang
88c581be87 Store source, not sname, in Symbol (#91057)
I'm going to need this in the follow up PR. Instead of storing only Source.name() in Symbol, I now store a full on Source. Lots of replumbing reoccurs. In particular:

- Move Source to torch._guards to break cycles
- I have to add TensorPropertySource and NegateSource to handle x.size()[0] and -x codegen that I was doing with string manipulation previously
- I tighten up invariants so that I never pass source=None; instead I pass ConstantSource (these are constant sources right) and test for that rather than source being missing. I think this is more parsimonious
- Some mypy wobbles from new imports

I didn't move LocalSource and friends to torch._guards, but I ended up needing to access them in a few places. The main annoyance with moving these is that then I also need to move the bytecode codegen stuff, and that's not so easy to move without bringing in the kitchen sink.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91057
Approved by: https://github.com/albanD, https://github.com/voznesenskym
2022-12-21 04:51:51 +00:00
Edward Z. Yang
0b22f5ae9f Deeply rework WeakIdKeyDictionary (#90825)
In the prior patch, I just YOLOed a mutable mapping implementation.
Many edge cases were not handled correctly.  In this PR, I just
copy paste the WeakKeyDictionary from CPython and the hacked it up
to use WeakIdRef instead of weakref.ref.  You can see each line
I changed with the comment CHANGED; there aren't many.

Being exactly API compatible with WeakKeyDictionary means I can also
rob all of the tests from CPython, which I also did for
test/test_weak.py

How to review?  You could either try taking the delta from CPython
(recommended), or review everything from scratch (not recommended).
Can post diff representing delta on request.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90825
Approved by: https://github.com/albanD
2022-12-15 08:43:08 +00:00
Tugsbayasgalan (Tugsuu) Manlaibaatar
1aab755320 Fakify params and weights under private config (#90417)
Previously, we planned to lift the parameters and weights while exporting and implement our own transformer to "unlift" the lifted weights and params back to the graph as attributes. But this is bit challenging because:

- We need to maintain correct ordering for weights and parameters that are passed as inputs so that we know how to map them back.
- Some weights are unused in the graph, so our transformer needs to be aware of which weights and parameters are not used in the graph. And we need to distinguish which are real user input and which are parameters.
- There can be more edge cases we haven't seen in other models yet.

I am aware that @Chillee  and @bdhirsh mentioned that functionalization won't work with fake-tensor attributes but this is fine for the short term as we don't expect users to be modifying weights and params in inference mode. In fact, we explicitly disable attribute mutation in torchdynamo export mode right now.

Given above condition, it might be ok to just fakify params when we need. I use a flag to guard against this change.

Differential Revision: [D41891201](https://our.internmc.facebook.com/intern/diff/D41891201)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90417
Approved by: https://github.com/eellison
2022-12-14 09:33:18 +00:00
Edward Z. Yang
f7365eca90 Add unbacked symints support; item works now (#90624)
The big idea is to add `create_unbacked_symfloat` and `create_unbacked_symint` to ShapeEnv, allowing you to allocate symbolic floats/ints corresponding to data you don't know about at compile time. Then, instead of immediately erroring out when you try to call local_scalar_dense on a FakeTensor, we instead create a fresh symint/symfloat and return that.

There a bunch of odds and ends that need to be handled:

* A number of `numel` calls converted to `sym_numel`
* When we finally return from item(), we need to ensure we actually produce a SymInt/SymFloat when appropriate. The previous binding code assumed that you would have to get a normal Python item. I add a pybind11 binding for Scalar (to PyObject only) and refactor the code to use that. There is some trickiness where you are NOT allowed to go through c10::SymInt if there isn't actually any SymInt involved. See comment.
* One of our unit tests tripped an implicit data dependent access which occurs when you pass a Tensor as an argument to a sizes parameter. This is also converted to support symbolic shapes
* We now support tracking bare SymInt/SymFloat returns in proxy tensor mode (this was already in symbolic-shapes branch)
* Whenever we allocate an unbacked symint, we record the stack trace it was allocated at. These get printed when you attempt data dependent access on the symint (e.g., you try to guard on it)
* Subtlety: unbacked symints are not necessarily > 1. I added a test for this.

These unbacked symints are not very useful right now as you will almost always immediately raise an error later when you try to guard on them. The next logical step is adding an assertion refinement system that lets ShapeEnv learn facts about unbacked symints so it can do a better job eliding guards that are unnecessary.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90624
Approved by: https://github.com/Skylion007, https://github.com/voznesenskym
2022-12-12 13:33:07 +00:00
Edward Z. Yang
b68dead20c Keep track of source name on all allocated SymInts (#90295)
Wow, I had to sweat so much to get this PR out lol.

This PR enforces the invariant that whenever we allocate SymInts as part of fakeification, the SymInt is associated with a Source, and in fact we store the string source name on SymbolWithSourceName. We use 'sname' as the shorthand for source name, as 'name' is already used by sympy to name symbols.

In order to store source names, we have to plumb source names from Dynamo to PyTorch. This made doing this PR a bit bone crushing, because there are many points in the Dynamo codebase where we are improperly converting intermediate tensors into fake tensors, where there is no source (and there cannot be, because it's a frickin' intermediate tensor). I've fixed all of the really awful cases in earlier PRs in the stack. This PR is just plumbing in source names from places where we do have it.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90295
Approved by: https://github.com/voznesenskym
2022-12-10 13:17:34 +00:00
Edward Z. Yang
3d4b92b171 Ensure that we fakeify tensor subclasses when they are initially tracked (#90009)
The old code didn't actually fakeify traceable tensor subclasses at the
time they are added as a GraphArg to the module; now we do, by ignoring
the subclass during fakeification and relying on Dynamo to simulate
the subclass on top.  See comments for more details.

BTW, this codepath is super broken, see filed issues linked on the
inside.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90009
Approved by: https://github.com/wconstab, https://github.com/voznesenskym
2022-12-06 22:36:32 +00:00
Michael Voznesensky
3b9a386d48 Add TORCH_FAKE_TENSOR_DEBUG use it to enable storage of traces on fake tensors at init time (#90215)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90215
Approved by: https://github.com/ezyang
2022-12-06 22:28:52 +00:00
Elias Ellison
1a33b7cbfa Make fake tensors preserve dense strides in type conversion (#89803)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89803
Approved by: https://github.com/ngimel
2022-11-30 01:28:51 +00:00
Edward Z. Yang
860bae49e4 Suppress guards on as_strided call only. (#89569)
See comment in meta_utils.py for the whole story.

This doesn't have a substantive impact yet, but will in the next
PR on the stack.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89569
Approved by: https://github.com/albanD
2022-11-24 14:01:12 +00:00
Edward Z. Yang
ea50549ce6 Suppress guards when creating fake tensors (#89349)
When we create fake tensors, we may call operators that introduce
guards, to accurately reconstruct views.  But these guards are spurious:
if a user is able to present a tensor that "looks the same", they have
implicitly fulfilled the contract that the view is creatable.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89349
Approved by: https://github.com/voznesenskym
2022-11-21 23:14:20 +00:00
Sherlock Huang
caf3d5319f Symintify numel(), infer_size, prims.elementwise_meta (#88956)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88956
Approved by: https://github.com/ezyang
2022-11-20 00:42:03 +00:00
PyTorch MergeBot
8ad39536d7 Revert "Symintify numel(), infer_size, prims.elementwise_meta (#88956)"
This reverts commit ce2f8700ba.

Reverted https://github.com/pytorch/pytorch/pull/88956 on behalf of https://github.com/ezyang due to somehow breaks torch.numel
2022-11-19 21:47:55 +00:00
Edward Z. Yang
5582001bd5 Reland 2 "Towards unifying symbolic and non symbolic fake tensor (#89038) (#89143)" (#89346)
This reverts commit 8e4c9828f4.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89346
Approved by: https://github.com/wconstab
2022-11-19 21:14:31 +00:00
PyTorch MergeBot
8e4c9828f4 Revert "Reland "Towards unifying symbolic and non symbolic fake tensor (#89038)" (#89143)"
This reverts commit e686b8c3ba.

Reverted https://github.com/pytorch/pytorch/pull/89143 on behalf of https://github.com/ZainRizvi due to This seems to be causing the test_make_fx_symbolic_exhaustive_rad2deg_cpu_float32 and test_make_fx_symbolic_exhaustive_inplace_rad2deg_cpu_float32 test to fail across multiple jobs
2022-11-17 17:02:36 +00:00
Edward Z. Yang
e686b8c3ba Reland "Towards unifying symbolic and non symbolic fake tensor (#89038)" (#89143)
This reverts commit cf6003f046.

Differential Revision: [D41363992](https://our.internmc.facebook.com/intern/diff/D41363992)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89143
Approved by: https://github.com/albanD
2022-11-17 13:55:06 +00:00
PyTorch MergeBot
cf6003f046 Revert "Towards unifying symbolic and non symbolic fake tensor (#89038)"
This reverts commit 37d54239c7.

Reverted https://github.com/pytorch/pytorch/pull/89038 on behalf of https://github.com/ezyang due to executorch segfaults
2022-11-16 16:52:47 +00:00
Edward Z. Yang
37d54239c7 Towards unifying symbolic and non symbolic fake tensor (#89038)
Fake tensor behaves pretty differently depending on if you have
symbolic shapes or not.  This leads to bugs; for example, we
weren't getting correct convolution_backward strides because we
bypassed the correct stride logic in fake tensor on symbolic
shapes.

This PR attempts to unify the two codepaths.  I don't manage to
unify everything, but I get most of it.  The algorithm is delicate
and I'm still hosing down test failures.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89038
Approved by: https://github.com/anjali411
2022-11-16 14:02:43 +00:00
Sherlock Huang
ce2f8700ba Symintify numel(), infer_size, prims.elementwise_meta (#88956)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88956
Approved by: https://github.com/ezyang
2022-11-16 03:36:00 +00:00
anjali411
b815f1fc50 Symintify view_as_complex and view_as_real (#89052)
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom):
* __->__ #89052
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89052
Approved by: https://github.com/ezyang
2022-11-15 16:28:36 +00:00
Michael Voznesensky
06ce1338bc [dynamo] Port all pytorch/dynamo and test/dynamo pieces over from symbolic-shapes branch (#88768)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88768
Approved by: https://github.com/jansel, https://github.com/ezyang
2022-11-13 04:50:21 +00:00
PyTorch MergeBot
ba4d5aae06 Revert "rename DisableTorchFunction to DisableTorchFunctionSubclass (#88218)"
This reverts commit 7f28be10e5.

Reverted https://github.com/pytorch/pytorch/pull/88218 on behalf of https://github.com/izaitsevfb due to BC-breaking change, D41211901
2022-11-11 19:13:05 +00:00
samdow
7f28be10e5 rename DisableTorchFunction to DisableTorchFunctionSubclass (#88218)
First half of #87990. This doesn't change any of the behavior and is just a rename

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88218
Approved by: https://github.com/ezyang, https://github.com/zou3519
2022-11-10 14:51:13 +00:00
Edward Z. Yang
1b5373fc83 Mark as_strided_ as supporting SymInt in C++ (#88674)
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88674
Approved by: https://github.com/anjali411
2022-11-08 18:45:05 +00:00
Kurt Mohler
ee28b865ee Deprecate TypedStorage, its derived classes, and all of their public methods (#85303)
Part of #85302

Pull Request resolved: https://github.com/pytorch/pytorch/pull/85303
Approved by: https://github.com/ezyang
2022-11-08 18:11:01 +00:00
Edward Z. Yang
825f4e602b Add support for symbolic shapes to sparse tensor (#88573)
Along the way, I undid making sparse/dense dim symint (they're
dimensions, so they should be static.)

Also symintify set_indices_and_values_unsafe

There is a little bit of a nontrivial infra change here: previously, we didn't populate the strides field on sparse tensors. It is now populated with "empty" strides, and this meant that sparse tensors were falsely reporting they were non-overlapping dense/contiguous. I added in a hack to work around this case.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88573
Approved by: https://github.com/anjali411
2022-11-08 03:13:42 +00:00
Edward Z. Yang
23a3eb37cf SymIntify _copy functionalization kernels (and _copy_out too) (#88572)
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88572
Approved by: https://github.com/anjali411, https://github.com/bdhirsh
2022-11-07 21:40:10 +00:00
Edward Z. Yang
1e5d33b6df Reenable assert sanity testing with ADInplaceOrView reenable (#88102)
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88102
Approved by: https://github.com/albanD
2022-11-01 14:29:00 +00:00
Edward Z. Yang
f2b247f0d8 Remove stale comment (#88135)
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88135
Approved by: https://github.com/albanD
2022-10-31 22:29:07 +00:00
Edward Z. Yang
ff94494644 Revert "Revert "Unify meta tensor and fake tensor converter conversion (#87943)"" (#88045)
This reverts commit bc64999b83.

Check torch/_subclasses/meta_utils.py for "This is very tricky" for the bugfix explanation.

cc @mlazos @soumith @voznesenskym @yanboliang @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @chunyuan-w @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88045
Approved by: https://github.com/kit1980, https://github.com/Chillee
2022-10-31 17:50:14 +00:00
PyTorch MergeBot
bc64999b83 Revert "Unify meta tensor and fake tensor converter conversion (#87943)"
This reverts commit baa715e790.

Reverted https://github.com/pytorch/pytorch/pull/87943 on behalf of https://github.com/kit1980 due to Broke several inductor tests
2022-10-29 18:39:28 +00:00
Edward Z. Yang
baa715e790 Unify meta tensor and fake tensor converter conversion (#87943)
Meta tensor does a lot of work to make sure tensors "look" similar
to the original parts; e.g., if the original was a non-leaf, meta
converter ensures the meta tensor is a non-leaf too.  Fake tensor
destroyed some of these properties when it wraps it in a FakeTensor.

This patch pushes the FakeTensor constructor into the meta converter
itself, so that we first create a fake tensor, and then we do various
convertibility bits to it to make it look right.

The two tricky bits:

- We need to have no_dispatch enabled when we allocate the initial meta
  tensor, or fake tensor gets mad at us for making a meta fake tensor.
  This necessitates the double-callback structure of the callback
  arguments: the meta construction happens *inside* the function so
  it is covered by no_dispatch

- I can't store tensors for the storages anymore, as that will result
  in a leak.  But we have untyped storage now, so I just store untyped
  storages instead.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

cc @jansel @mlazos @soumith @voznesenskym @yanboliang @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @chunyuan-w @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87943
Approved by: https://github.com/eellison, https://github.com/albanD
2022-10-29 15:01:07 +00:00
Edward Z. Yang
e72962a34d Force people to call from_meta_and_device directly (#87903)
It was pretty hard to tell at call site if I was doing device meta
convert or not.  This gets rid of the "dual" API and forces people
to call the method manually for the device case.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87903
Approved by: https://github.com/eellison, https://github.com/albanD
2022-10-28 21:05:13 +00:00
Elias Ellison
fc21b9db23 Use Eager Code To Determine Conv Layout (#87305)
The logic for determine conv backend and therefore output striding is very complex. It depends on build settings, input striding/contiguity, sizes, etc. Eventually we should port that logic to the meta impl for dynamic shapes but that will require a lot more work and keeping the implementations in sync. See https://github.com/pytorch/torchdynamo/issues/1701

This is a prerequisite to removing the inductor conv stride propagation and more general fake tensor for inductor propagation. In that PR, the meta impls for cpu conv give incorrect striding which led to test failures (https://github.com/pytorch/pytorch/pull/87083).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/87305
Approved by: https://github.com/ezyang
2022-10-28 16:37:04 +00:00
Sherlock Huang
e8a97a3721 FakeTensorMode and Prims.add/sub/mul/div support scalar only inputs (#87759)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87759
Approved by: https://github.com/ngimel, https://github.com/mruberry, https://github.com/eellison
2022-10-28 04:34:25 +00:00