Commit Graph

524 Commits

Author SHA1 Message Date
andrewor14
7b4f70eda5 Batch Norm Consolidation (#116092)
**Summary:**

This commit simplifies the existing decomposition hierarchy
of batch norm ops by adding a single, backend agnostic op:
`batch_norm_with_update`. The existing hierarchy looks like:

```
aten.batch_norm ->
aten._batch_norm_impl_index ->
[
  aten.native_batch_norm ->
  aten._native_batch_norm_legit (export only) ->
  _batch_norm_legit_cpu/cuda (kernels, export only) ->
  _batch_norm_cpu/cuda (kernels)
] OR
[ aten.cudnn_batch_norm ] OR
[ aten.miopen_batch_norm ]
```

Aside from complexity, an important problem with the
above decomposition hierarchy is cuda numerics in
export flows. We observed significantly worse convergence
when training a mobilenetv2-like model when using the
`_batch_norm_cuda` kernel instead of the `cudnn_batch_norm`
kernel. This means users who export their models on CPU
first then move the models to cuda later may silently
see worse accuracies even when cudnn is installed,
because they are using the worse kernel. This issue is
summarized in https://github.com/pytorch/pytorch/issues/111384.

Instead, the new hierarchy proposed by consolidating
existing batch norm ops will look like:

```
aten.batch_norm ->
aten.batch_norm_with_update ->
[ _batch_norm_cpu (kernel) ] OR
[ _batch_norm_cuda (kernel) ] OR
[ cudnn_batch_norm (kernel) ] OR
[ miopen_batch_norm (kernel) ]
```

The new op `batch_norm_with_update` hides backend
implementation details and automatically picks the right
kernel based on what is installed. This commit also adds
the following variants to this op:

```
batch_norm_with_update_functional
batch_norm_with_update.out
batch_norm_no_update
batch_norm_no_update.out
batch_norm_backward
```

Note that this commit only adds this op and its variants,
but does not actually change the decomps to produce these
ops in the graph. This will be done after the 2 week FC
window, and the ops used in the old stack is planned to
be removed after the 6 month BC window.

Test Plan: `OpInfo` tests for `batch_norm_with_update`.

Reviewers: albanD, bdhirsh

Subscribers: albanD, bdhirsh, supriyar

Tasks: https://github.com/pytorch/pytorch/issues/111384

Co-authored-by: Tugsbayasgalan Manlaibaatar <tmanlaibaatar@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116092
Approved by: https://github.com/bdhirsh, https://github.com/albanD
2024-03-08 15:07:15 +00:00
PyTorch MergeBot
b529c19bdf Revert "Batch Norm Consolidation (#116092)"
This reverts commit 5680f565d5.

Reverted https://github.com/pytorch/pytorch/pull/116092 on behalf of https://github.com/jeffdaily due to broke ROCm, PR signal was clean but trunk was not, the merge should have been blocked but wasn't ([comment](https://github.com/pytorch/pytorch/pull/116092#issuecomment-1981373237))
2024-03-06 17:10:01 +00:00
Tugsbayasgalan Manlaibaatar
5680f565d5 Batch Norm Consolidation (#116092)
**Summary:**

This commit simplifies the existing decomposition hierarchy
of batch norm ops by adding a single, backend agnostic op:
`batch_norm_with_update`. The existing hierarchy looks like:

```
aten.batch_norm ->
aten._batch_norm_impl_index ->
[
  aten.native_batch_norm ->
  aten._native_batch_norm_legit (export only) ->
  _batch_norm_legit_cpu/cuda (kernels, export only) ->
  _batch_norm_cpu/cuda (kernels)
] OR
[ aten.cudnn_batch_norm ] OR
[ aten.miopen_batch_norm ]
```

Aside from complexity, an important problem with the
above decomposition hierarchy is cuda numerics in
export flows. We observed significantly worse convergence
when training a mobilenetv2-like model when using the
`_batch_norm_cuda` kernel instead of the `cudnn_batch_norm`
kernel. This means users who export their models on CPU
first then move the models to cuda later may silently
see worse accuracies even when cudnn is installed,
because they are using the worse kernel. This issue is
summarized in https://github.com/pytorch/pytorch/issues/111384.

Instead, the new hierarchy proposed by consolidating
existing batch norm ops will look like:

```
aten.batch_norm ->
aten.batch_norm_with_update ->
[ _batch_norm_cpu (kernel) ] OR
[ _batch_norm_cuda (kernel) ] OR
[ cudnn_batch_norm (kernel) ] OR
[ miopen_batch_norm (kernel) ]
```

The new op `batch_norm_with_update` hides backend
implementation details and automatically picks the right
kernel based on what is installed. This commit also adds
the following variants to this op:

```
batch_norm_with_update_functional
batch_norm_with_update.out
batch_norm_no_update
batch_norm_no_update.out
batch_norm_backward
```

Note that this commit only adds this op and its variants,
but does not actually change the decomps to produce these
ops in the graph. This will be done after the 2 week FC
window, and the ops used in the old stack is planned to
be removed after the 6 month BC window.

Test Plan: `OpInfo` tests for `batch_norm_with_update`.

Reviewers: albanD, bdhirsh

Subscribers: albanD, bdhirsh, supriyar

Tasks: https://github.com/pytorch/pytorch/issues/111384

Co-authored-by: Tugsbayasgalan Manlaibaatar <tmanlaibaatar@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116092
Approved by: https://github.com/bdhirsh, https://github.com/albanD
2024-03-06 04:50:46 +00:00
Peter Bell
eae9751e82 Fix linalg_eigvals invalid use of composite dispatch key (#121142)
`linalg_eigvals_out` calls into a dispatch stub, so only supports CPU and CUDA
strided tensors but incorrectly claimed to be a composite op. `linalg_eigvals`
also shouldn't defer to the out variant inside a `CompositeImplicitAutograd` op
as not all types support out variants. Instead, I add a new helper
`_linalg_eigvals` which does the same thing in a non-composite operator.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121142
Approved by: https://github.com/lezcano
2024-03-05 21:13:27 +00:00
Xu Zhao
7a64eb65e4 Fix Dynamo tests failing with "Failed running call_function <built-in function linalg_norm" (#120993)
When iterating the ord value through an array, we are sharing the same torchdynamo context. This makes dynamo treat the `ord` variable as dynamic shape, causing problems.

In the `vector_norm` decomposition, casting the int type ord to float will fix this problem.

Fixes https://github.com/pytorch/pytorch/issues/119795
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120993
Approved by: https://github.com/lezcano
2024-03-01 20:27:45 +00:00
Edward Z. Yang
f94933ed42 Refine value ranges on inequalities (#120800)
This is basically done the obvious way. For better or worse, I jammed this into what used to be `_maybe_guard_eq` but now is `_maybe_guard_rel`. I was careful to test all the off by one conditions, and each permutation. Let me know if you think I missed anything. Importantly, this now works for unbacked SymInts.

While testing, I noticed we are silently duck sizing all symbolic variables in `test_dynamic_shapes.py`. This may or may not be covering up bugs.

Along the way, I had to fix a bug in export constraints, where we weren't checking that the final var_to_range was consistent with what the user requested at top level.

After I implemented all this, I realized that applying this to non-unbacked SymInts was duplicative with @ysiraichi's previous work on https://github.com/pytorch/pytorch/pull/97963 . The upside is I now understand what Yukio was trying to do in the original PR, and I think my new logic is simpler and less error prone. In Yukio's earlier diff, Yukio tried very hard to avoid changing what guards we actually issue (since this would cause tests to wobble). Thus, when he refined a range, he also saved the guard that actually caused the range to refine. In this PR, I don't bother saving these guards; instead I just tighten var_to_range directly and rely on generating guards on this to be correct. The key insight is that if I assert `x < y`, it's always safe to emit (potentially) more restrictive range guards, because this won't invalidate our guards, it will just make them a little too strong (but actually, I think we are precise along the way.) If these guards make it unnecessary to test `x < y`, because now the ranges for x and y are disjoint, this is fine, we've subsumed the x < y guard and can just not bother testing it. If I've gotten it right, TV will agree with me.

In fact, I had a bug in this PR which TV didn't catch, which is that when we have a recorded var_to_guards for a symbol, we unconditionally never generate the range guard for it, even if the var_to_guards is potentially inconsistent with var_to_range (because var_to_range was updated separately). With var_to_guards removed, I don't have to worry abou this inconsistency.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120800
Approved by: https://github.com/Skylion007, https://github.com/avikchaudhuri, https://github.com/ysiraichi
2024-02-29 19:41:51 +00:00
Isuru Fernando
435063aa89 Decomposition for upsample_linear{1d, 3d} (#114774)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114774
Approved by: https://github.com/lezcano, https://github.com/vfdev-5, https://github.com/peterbell10
2024-02-27 11:57:45 +00:00
Isuru Fernando
b7df3bba62 add decomposition for frexp (#119217)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119217
Approved by: https://github.com/peterbell10
ghstack dependencies: #119284, #120027
2024-02-23 21:52:42 +00:00
Isuru Fernando
f7e79299c7 register torch.return_types in torch.fx._pytree (#120027)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120027
Approved by: https://github.com/lezcano, https://github.com/zou3519, https://github.com/XuehaiPan
ghstack dependencies: #119284
2024-02-23 21:52:42 +00:00
Aaron Meurer
5ce305270b Add a decomposition for isin() (#115390)
Co-authored-by: Peter Bell <peterbell10@live.co.uk>
Co-authored-by: Mario Lezcano Casado <3291265+lezcano@users.noreply.github.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115390
Approved by: https://github.com/peterbell10
2024-02-14 03:03:42 +00:00
Sergii Dymchenko
bd9db6a9c7 Update to TorchFix 0.4.0 (#119424)
`torch.library.Library` updated to `torch.library._scoped_library` in files with many tests where it seems obvious to do, otherwise `noqa: TOR901` added - see https://github.com/pytorch/pytorch/pull/118318 for more context.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119424
Approved by: https://github.com/zou3519
2024-02-12 23:30:12 +00:00
Pearu Peterson
2c91e13afc Add lowerings to special functions (#119187)
As in the title.

In addition, the PR introduces infrastructure for lowerings of pointwise functions that have both cpp and triton implementations available.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119187
Approved by: https://github.com/peterbell10
2024-02-11 16:35:40 +00:00
Edward Z. Yang
2349e473f1 Forward fix for same_shape oblivious guard (#119383)
Fixes internal test

```
buck2 test '@fbcode//mode/opt' fbcode//accelerators/workloads/models/slimdsnn:slimdsnn_test -- --exact 'accelerators/workloads/models/slimdsnn:slimdsnn_test - test_generate (accelerators.workloads.models.slimdsnn.test_slimdsnn.SlimDSNN)'
```

And I added an OSS test that approximates the internal situation.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Differential Revision: [D53544208](https://our.internmc.facebook.com/intern/diff/D53544208)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119383
Approved by: https://github.com/atalman, https://github.com/albanD
2024-02-08 02:11:46 +00:00
Elias Ellison
d0ca849fdf Refactor Symint Deduping to separate pass (#118938)
Previously Symint Deduping was done during proxy tracing which made it more difficult to reason about. This refactors the deduping to a separate pass.

We only dedupe symints which are resolvable from input symint nodes so as to avoid inducing a dependency on the backward in the forward.

potential fix for : https://github.com/pytorch/pytorch/issues/118224

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118938
Approved by: https://github.com/ezyang
2024-02-06 23:07:31 +00:00
Isuru Fernando
3e79ef6db8 Complete decomposition for aten.round (#118635)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118635
Approved by: https://github.com/peterbell10
2024-02-01 17:14:44 +00:00
Isuru Fernando
81d12846dc Add decomp for pixel_shuffle/unshuffle (#118239)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118239
Approved by: https://github.com/peterbell10
2024-01-31 18:34:21 +00:00
Edward Z. Yang
220cf46c2a Always accept 0-d scalar tensors as int, even if __index__ fails (#117451)
Fixes https://github.com/pytorch/pytorch/issues/117288

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117451
Approved by: https://github.com/yanboliang
2024-01-14 14:39:46 +00:00
Elias Ellison
e3d4f4d14b [ProxyTensor] dedupe symbolic shapes in tracing (#116158)
Dedupes symbolic shapes in proxy tensor tracing. Reusing the existing sym shape avoids inserting spurious sym_size calls, which can interfere with pattern matching and graph passes.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116158
Approved by: https://github.com/ezyang
2024-01-11 07:15:11 +00:00
Tugsbayasgalan Manlaibaatar
dfc898ede4 Don't decompose functional ops in predispatch functionalization (#116383)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116383
Approved by: https://github.com/bdhirsh
ghstack dependencies: #115188, #115210
2023-12-28 11:54:04 +00:00
Tugsbayasgalan Manlaibaatar
76b1d44d57 pre_dispatch aot_export (#115188)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115188
Approved by: https://github.com/bdhirsh
2023-12-25 04:51:21 +00:00
Tugsbayasgalan (Tugsuu) Manlaibaatar
60f4114769 Support nn_module_stack in non_strict mode (#116309)
Summary: Title

Test Plan: CI

Differential Revision: D52382672

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116309
Approved by: https://github.com/zhxchen17
2023-12-23 03:34:58 +00:00
PyTorch MergeBot
ec6c4fed3f Revert "Support nn_module_stack in torch.export(strict=False) (#115454)"
This reverts commit 6730b5bcb4.

Reverted https://github.com/pytorch/pytorch/pull/115454 on behalf of https://github.com/jeanschmidt due to Breaking internal tests recycle_bin_citadel and executorch, check internal diff to see more details ([comment](https://github.com/pytorch/pytorch/pull/115454#issuecomment-1866315233))
2023-12-21 14:05:43 +00:00
PyTorch MergeBot
0567f71ac6 Revert " pre_dispatch aot_export (#115188)"
This reverts commit a267d67350.

Reverted https://github.com/pytorch/pytorch/pull/115188 on behalf of https://github.com/jeanschmidt due to sadly, it is required to revert this commit in order to revert https://github.com/pytorch/pytorch/pull/115454 ([comment](https://github.com/pytorch/pytorch/pull/115188#issuecomment-1866310014))
2023-12-21 14:03:18 +00:00
Tugsbayasgalan Manlaibaatar
a267d67350 pre_dispatch aot_export (#115188)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115188
Approved by: https://github.com/bdhirsh
2023-12-20 21:36:25 +00:00
Tugsbayasgalan Manlaibaatar
6730b5bcb4 Support nn_module_stack in torch.export(strict=False) (#115454)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115454
Approved by: https://github.com/suo, https://github.com/bdhirsh
2023-12-20 01:43:39 +00:00
Tugsbayasgalan Manlaibaatar
d85314c95c Support Predispatch functionalization (#113728)
In this PR, we are implementing Functionalization on pre-dispatch graph. Today, every dispatch key except for Dispatchkey.Python has a dedicated mode stack in python. PreDispatch tracing relies on this behaviour by pushing ProxyTorchDispatchMode to Dispatchkey.PreDispatch mode stack and handle the dispatching logic in python. To make pre-dispatch functionalization work, we now need to push FunctionalTensorMode on DispatchKey.PreDispatch mode stack and make sure it runs before ProxyTorchDispatchMode. (this is very similar to how post-dispatch tracing work). Here are some design decisions we made for this flow to work:

1. FunctionalTensorMode internally calls C++ functionalize key. Since C++ functionalization goes after PreDispatch, if we are not careful, we will keep re-entering into PreDispatch key. We solve this by directly dispatching to C++ Functionalize key.

2. We delete mode_stack_per_key logic because the only realistic time it is exercised is for PreDispatch and it is in general not safe to have a plain list because FunctionalTensorMode and ProxyTorchDispatchMode ordering matter and it is hard to enforce it on plain list. Instead, now we have a private class that tracks PreDispatch mode stack.

3.  We will still run CompositeImplicitAutograd decomps in this PR, and disable this logic later as a followup.

Some missing bits after this PR:
1. Preserving autograd ops in a functional form. Right now they still show up in the graph but in a "non-functional" way.
2. Turn off CompositeImplicitAutograd decomps
3. Functionalizing HOO

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113728
Approved by: https://github.com/bdhirsh
2023-12-19 20:28:35 +00:00
Joel Schlosser
22704426c3 Expand dynamic dims support for traceable subclasses (#114311)
Continuation of #112185, following the design in this [doc](https://docs.google.com/document/d/1ipSxcTzEMMOAPvxP-YJlD5JBZZmIGgh8Q34ixtOUCRo).

Summary:
* Introduce `SubclassSymbolicPolicy` containing separate dynamic dim / constraint policies for the outer and inner tensors
    * Expand the automatic dynamic algorithm to recurse into inner tensors and produce one of these for a subclass instance
    * Maintain legacy behavior for subclasses by recursively calling `mark_dynamic()` on inner tensors *of the same dim as outer* when `mark_dynamic(outer, ...)` is called
    * Addresses this: 6a86cf00ad/torch/_dynamo/variables/builder.py (L1750)
* Add `outer_size` and `outer_stride` arguments to `__tensor_unflatten__()` so that you can find out what symbols were allocated for the outer size / stride (you are expected to return a tensor that compares equal to the outer symbols)
    * Signatures now:
    ```python
    # attrs is a list of inner tensor attributes on x; inner_tensor = getattr(x, attr)
    # ctx is anything useful for rebuilding the class we want to guard on
    attrs, ctx = x.__tensor_flatten__()
    ...
    # inner_tensors is a dict of {attr -> tensor}
    # ctx is taken unmodified from flattening and (eventually) guarded on
    # outer_size is the expected size of the output; possibly symbolic
    # outer_stride is the expected strides of the output; possibly symbolic
    y = MySubclass.__tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride)

    # at the __tensor_unflatten__() call-site in PT2, we assert y.shape == outer_size and y.stride() == outer_stride
    # the assert simplifies symbols when there are relationships between outer and inner symbols
    ```
    * Size info needed for `NestedTensor` at least, stride info needed for `DTensor` at least
    * Punting on `outer_storage_offset` because storage_offset handling is horribly broken in PT2 right now
* ~~Add new `__tensor_mark_dynamic__()` to allow overriding the behavior of mark_dynamic on a per-subclass basis~~ (booted to future work)
* ~~Add guards for tensor subclasses by calling `__tensor_flatten__()` in the guard to test equality on `ctx`~~
    * Now handled in #114469
* Next PR: add TENSOR_MATCH guards on inner tensors

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114311
Approved by: https://github.com/ezyang, https://github.com/drisspg, https://github.com/voznesenskym, https://github.com/bdhirsh
2023-12-05 21:09:25 +00:00
Edward Z. Yang
8c4812be80 Replace expect_int with guard_int (#113921)
The idea is that instead of erroring, we will just specialize at these sites.

Fixes https://github.com/pytorch/pytorch/issues/113142

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113921
Approved by: https://github.com/zou3519
2023-11-20 21:27:48 +00:00
rzou
af51c948ac Add mechanism for make_fx to not error on data-dependent-ops (#114129)
I'm looking for a make_fx(tracing_mode=real) that doesn't error out on
data-dependent operations. This PR adds a flag to do that. We use this
to help implement offline generation, but this is useful by itself:
sometimes we want to trace a function with real tensors and don't care
if we bake values in (because we just want to see what happened).

Test Plan:
- new test
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114129
Approved by: https://github.com/ezyang
ghstack dependencies: #114128
2023-11-20 20:55:55 +00:00
Edward Z. Yang
fdaddec2c3 make_fx can now SymIntify int inputs (#113452)
This PR also contains a basket of fixes that were turned up by now testing more arguments with SymInt. I fixed as many of the easy ones as I could easily get earlier in this stack and a bunch here, but there are some more annoying ones I xfailed.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113452
Approved by: https://github.com/Chillee
ghstack dependencies: #113877, #113911
2023-11-18 06:39:09 +00:00
Edward Z. Yang
4979f9c0d7 [EASY] Support SymInt tracing on broadcast_shapes (#113877)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113877
Approved by: https://github.com/Skylion007
2023-11-17 04:43:57 +00:00
Edward Z. Yang
dfa9e7b511 Allow inferring divisibility on unbacked SymInts and do replacement trick (#113165)
We want something like torch.empty(i0, 12).view(4, -1, 12) to work.  Right now, it chokes on guards on data dependent accesses. It turns out we are very close to having it work based on experiments in https://github.com/pytorch/pytorch/issues/112347 if we do the replacement trick, setting i0 = i1 * 4 to explicitly encode in the divisibility; this is good enough for Sympy to be able to handle the rest.

There are two parts to this PR.

* First, we must discover that there is this divisibility constraint. The place where this happens on view is in `infer_size`; however, we are unable to discover the modulus test with `expect_true` because the condition is currently written with a Python boolean operator that forces guarding too early: `numel == newsize or (dim is not None and newsize > 0 and numel % newsize == 0)`. We rewrite this into an equivalent version which tests on dim being None or not first, before performing individual tests. The main nontrivial reasoning here is that I must show that my set of tests in the `dim is None` branch are sufficient when `numel == newsize`. However, if `numel == newsize`, then the modulus must pass. Thus this is equivalent.
* Given the modifications to `infer_size`, this suffices to produce a runtime assert `Eq(Mod(192*i0, 2304), 0)`. Now we must simply turn this into the replacement automatically. I wasn't really sure how to use Sympy to do this for me, so I just manually pattern matched for this particular expression form, and if it exists do the replacements.

Note that this is kind of only useful for export, because inductor chokes on views involving unbacked SymInts. That will be follow up.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113165
Approved by: https://github.com/lezcano, https://github.com/aakhundov
2023-11-10 21:28:02 +00:00
Edward Z. Yang
f49b8e9313 Register SymInt-aware meta function for mm out, symintify resize (#113202)
Fixes https://github.com/pytorch/pytorch/issues/112489

Fixes https://github.com/pytorch/pytorch/issues/112494

New OpInfo tests for out variants added, since these were not exercised previously.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113202
Approved by: https://github.com/albanD
2023-11-10 14:27:05 +00:00
Peter Bell
65ecb36621 Move ShapeEnv config out of dynamo (#112933)
Previously there was a circular dependency between fx and dynamo that happened
to work out since ShapeEnv didn't access the config at module init time.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112933
Approved by: https://github.com/ezyang
2023-11-07 01:10:25 +00:00
PaliC
542fa4a2e7 Revert "Revert "Use OpOverload instead of OpOverloadPacket for size/s… (#113058)
Revert "Revert "Use OpOverload instead of OpOverloadPacket for size/stride/etc slots (#112119)""

This reverts commit a1d1b73a7c.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113058
Approved by: https://github.com/izaitsevfb
2023-11-06 19:38:49 +00:00
PyTorch MergeBot
a1d1b73a7c Revert "Use OpOverload instead of OpOverloadPacket for size/stride/etc slots (#112119)"
This reverts commit 2337d8d062.

Reverted https://github.com/pytorch/pytorch/pull/112119 on behalf of https://github.com/PaliC due to still breaking trt tests :( refer to diff ([comment](https://github.com/pytorch/pytorch/pull/112119#issuecomment-1795496395))
2023-11-06 17:01:50 +00:00
Edward Z. Yang
dfb26d5999 Reland "Symintify repeat_interleave (#109133)" (#112726)
This reverts commit 08dbfecdbd.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112726
Approved by: https://github.com/albanD
2023-11-04 05:15:55 +00:00
Edward Z. Yang
2337d8d062 Use OpOverload instead of OpOverloadPacket for size/stride/etc slots (#112119)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112119
Approved by: https://github.com/yanboliang
2023-11-03 13:54:41 +00:00
PyTorch MergeBot
25e17f3522 Revert "Use OpOverload instead of OpOverloadPacket for size/stride/etc slots (#112119)"
This reverts commit dd24e92949.

Reverted https://github.com/pytorch/pytorch/pull/112119 on behalf of https://github.com/ZainRizvi due to Breaking internal tests. See D50912326 ([comment](https://github.com/pytorch/pytorch/pull/112119#issuecomment-1791072363))
2023-11-02 16:32:25 +00:00
Edward Z. Yang
a1ab22b81d Reland "Trigger specialization when you call size()/stride() from C++ (#111935)" (#112605)
This reverts commit 22221c6d60.

Differential Revision: [D50886564](https://our.internmc.facebook.com/intern/diff/D50886564)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112605
Approved by: https://github.com/voznesenskym
2023-11-02 13:27:31 +00:00
Edward Z. Yang
258874888b Refine replacements with equality tests on runtime asserts (#112156)
Just poppin' off some TODOs.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112156
Approved by: https://github.com/albanD, https://github.com/aakhundov
ghstack dependencies: #112155
2023-11-01 23:02:17 +00:00
Edward Z. Yang
793c62b79c Allow binary pointwise operations to cause refinement on unbacked SymInts (#112155)
To do this, there is a little detour to remove hint caching for unbacked
SymInts; now, we just always attempt to update the hint (using
maybe_evaluate_static; this is much better than the replace we were
doing before) if we don't think we know it.

With this change, we now can generally infer that i0 == 1 is false for
a size-like unbacked SymInt.  So if we write the size match /
broadcasting test very carefully (see comment), we will eventually
end up expect_true(sizeA == sizeB), which is good enough to cause
refinement.  Phew!

I think I still want to setup a replacement if you do i0 == s0, but I'm
going to do that in a follow up.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112155
Approved by: https://github.com/aakhundov, https://github.com/voznesenskym
2023-11-01 23:02:17 +00:00
Edward Z. Yang
dd24e92949 Use OpOverload instead of OpOverloadPacket for size/stride/etc slots (#112119)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112119
Approved by: https://github.com/yanboliang
2023-11-01 18:26:01 +00:00
Elias Ellison
08dbfecdbd Revert "Symintify repeat_interleave (#109133)" (#112245)
This reverts commit 41e5d410cf.

Differential Revision: [D50804696](https://our.internmc.facebook.com/intern/diff/D50804696)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112245
Approved by: https://github.com/eellison
2023-10-31 03:50:26 +00:00
PyTorch MergeBot
22221c6d60 Revert "Trigger specialization when you call size()/stride() from C++ (#111935)"
This reverts commit 5846705e36.

Reverted https://github.com/pytorch/pytorch/pull/111935 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/111935#issuecomment-1782107024))
2023-10-27 00:23:03 +00:00
lezcano
47ccf04885 Split SymNode into its own file (#112037)
This PR:

- Moves TrueDiv, LShift, RShift, IsNonOverlappingAndDenseIndicator to `_sympy.functions.py`
- Moves SymNode to `fx.experimental.sym_node`.
  - This file does not have any SymPy dependencies at import time
  - It installs the magic methods in Sym{Bool,Int,Float}.
  - N.b. With this split, we may be able to move Sym{Bool,Int,Float} to this file, and remove quite a few of the hacks around these classes
- Imports `sym_node` in `torch/__init__.py` rather than the whole `symbolic_shapes.py`.
  This breaks the import-time dependency between torch and SymPy

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112037
Approved by: https://github.com/peterbell10
ghstack dependencies: #112035, #112036
2023-10-26 23:32:27 +00:00
Edward Z. Yang
5846705e36 Trigger specialization when you call size()/stride() from C++ (#111935)
This should be the last of the "it used to work with static shapes but
it doesn't work with dynamic shapes" hard errors.  Now we will just
specialize if you hit it from C++.

The strategy here is a bit clever.  We shunt the size() call to Python
binding if an error would have occurred.  Importantly, we already have
logic to make sure the newly allocated ints stay live for the duration
of the ArrayRef access.

storage_offset is intentionally omitted because there are some problems
with it.  I will fix them next.

This should let us get rid of the aotautograd_static test configuration.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111935
Approved by: https://github.com/zou3519
2023-10-25 16:17:55 +00:00
soulitzer
786c51d626 Symintify torch.diff (#111530)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111530
Approved by: https://github.com/bdhirsh, https://github.com/ezyang
ghstack dependencies: #111529
2023-10-19 20:38:57 +00:00
Edward Z. Yang
40c44c2307 Force specialization on INT_LIST (#111216)
Follow up on https://github.com/pytorch/pytorch/pull/95479

Fixes https://github.com/pytorch/pytorch/issues/111198

Fixes https://github.com/pytorch/pytorch/issues/111197

Fixes https://github.com/pytorch/pytorch/issues/111188

Fixes https://github.com/pytorch/pytorch/issues/111201

Fixes https://github.com/pytorch/pytorch/issues/111202

I can also do this for some other types, will do this stacked on top.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/111216
Approved by: https://github.com/voznesenskym
2023-10-19 12:55:18 +00:00
Tugsbayasgalan Manlaibaatar
5614023f5e Move export.constrain_as_* to torch._constrain_as_* (#110757)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110757
Approved by: https://github.com/avikchaudhuri
ghstack dependencies: #109859
2023-10-12 05:37:44 +00:00
PyTorch MergeBot
6ce3a38050 Revert "Move export.constrain_as_* to torch._constrain_as_* (#110757)"
This reverts commit 5aee22e0e0.

Reverted https://github.com/pytorch/pytorch/pull/110757 on behalf of https://github.com/kit1980 due to Depends on https://github.com/pytorch/pytorch/pull/109859 that needs to be reverted ([comment](https://github.com/pytorch/pytorch/pull/110757#issuecomment-1758908371))
2023-10-12 04:53:29 +00:00
Tugsbayasgalan Manlaibaatar
5aee22e0e0 Move export.constrain_as_* to torch._constrain_as_* (#110757)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110757
Approved by: https://github.com/avikchaudhuri
ghstack dependencies: #109859
2023-10-11 02:37:55 +00:00
Edward Z. Yang
24bf9aeb6b Fix arange with dynamic end argument. (#110979)
Fixes https://github.com/pytorch/pytorch/issues/93468

There's a few extra tests that are sort of unrelated, but I ended up writing them while working on the fix and decided to keep them. The big idea here is to split the `_check` so that `expect_true` works; I could have probably also improved the symbolic reasoning but I'm lazy. One small logging fix too.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110979
Approved by: https://github.com/Skylion007
2023-10-11 00:32:34 +00:00
Edward Z. Yang
f7c9ef88f5 Add masked_select abstract impl (#110103)
Fixes https://github.com/pytorch/pytorch/issues/109871

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110103
Approved by: https://github.com/bdhirsh
2023-09-27 04:07:58 +00:00
Edward Z. Yang
b07bebd4bd Add default arguments to sym_constrain_range_for_size (#109858)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109858
Approved by: https://github.com/williamwen42
2023-09-26 00:35:33 +00:00
Edward Z. Yang
09622d8d49 Allow inferring size-nature from sizes passed to empty constructor (#109720)
This removes the need for many constrain_as_size calls as we now
infer them from error checking for sizes.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109720
Approved by: https://github.com/aakhundov
2023-09-21 17:57:40 +00:00
Edward Z. Yang
2c1554a032 Make SymFloat behave symmetrically with float in torch.tensor (#109513)
Previously, SymFloat would force double precision.  That's wrong;
instead, we must respect default dtype.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109513
Approved by: https://github.com/voznesenskym
2023-09-19 01:52:41 +00:00
Jez Ng
7f3885137f Add meta function for _segment_reduce (#109359)
This fixes numerous tests which were xfailing. For instance, the
`_segment_reduce.lengths` OpInfo test, which was previously relying on
the fallback kernel to determine the shape of the meta tensor. The
fallback kernel would fail with

    segment_reduce(): Expected all rows of lengths along axis to sum to data.size(lengths.dim()-1) when !unsafe.

as it was trying to read the values of a meta tensor.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109359
Approved by: https://github.com/ezyang
2023-09-16 13:31:03 +00:00
Guilherme Leobas
49e3d76684 Add SymInt support to torch.take_along_dim (#108879)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108879
Approved by: https://github.com/Skylion007, https://github.com/lezcano, https://github.com/Chillee
2023-09-13 23:13:09 +00:00
Edward Z. Yang
55f956f1d2 optests improvements based on torchvision usage on nms (#108929)
- Update cross-ref FakeMode test to use ShapeEnv.  Dynamic ops can now
  return an unbacked SymInt.  We always accept this as equal to whatever
  the real value was.
- Relax test so it works on all classes, not just unittest.TestCase
- Properly wrap the original method, so things like
  pytree.mark.parametrize are carried over
- Support dynamic shapes by default for make_fx `tracing_mode="fake"` without symbolifying everything else

Fixes https://github.com/pytorch/pytorch/issues/108927

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108929
Approved by: https://github.com/zou3519
2023-09-13 13:26:15 +00:00
Peter Bell
464f9c3725 [meta] Add meta implementation for aten.masked_scatter (#108802)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108802
Approved by: https://github.com/lezcano
2023-09-12 16:16:05 +00:00
Edward Z. Yang
2b9ad3d5c4 Fix setitem with SymInt (#108873)
Fixes https://github.com/pytorch/pytorch/issues/101939

Several fixes bundled together:

1. When we valueToTensor, we only handled non-symbolic inputs and not symbolic inputs. We support symbolic Scalar, so also handle symbolic values.
2. In the symbolic case, we MUST NOT lift_fresh, as you're not going to inline a constant into the graph, it's going to be from a `scalar_tensor` call (so no need to clone it to avoid mutations)
3. In indexing scalarToTensor, must not do the static, directly read out the scalar contents logic with the scalar is symbolic

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108873
Approved by: https://github.com/jansel
2023-09-10 06:44:22 +00:00
Edward Z. Yang
9b83402666 Add support for symbolic repeat_interleave (#108763)
Fixes https://github.com/pytorch/pytorch/issues/108195

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108763
Approved by: https://github.com/Chillee
2023-09-08 16:48:32 +00:00
David Watson
598babf017 Added normal op decomposition for specializations of the normal op (#106792)
This fixes running normal with the meta key.

```
import torch

t = torch.tensor(4.0, device='meta')
torch.normal(0.5, t)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106792
Approved by: https://github.com/lezcano
2023-08-25 16:18:28 +00:00
Edward Z. Yang
5673c0874c Use expect_true to make split with unbacked sizes work. (#106788)
This pattern shows up in torchrec KeyedJaggedTensor.  Most
of the change in this PR is mechanical: whenever we failed
an unbacked symint test due to just error checking, replace the
conditional with something that calls expect_true (e.g.,
torch._check or TORCH_SYM_CHECK).

Some of the changes are a bit more nuanced, I've commented on the PR
accordingly.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106788
Approved by: https://github.com/lezcano
ghstack dependencies: #106720
2023-08-15 20:31:30 +00:00
Tugsbayasgalan Manlaibaatar
20c5add133 [export] Refactor constrain_as_value and constrain_as_size (#106591)
Some notable changes:
1. `constrain_as_size` allows min value to be less than 2 as it will unconditionally assume min >= 2 for compiler purposes. Instead, we add additional check to make sure max value is always greater than 2.
2. Previously, we used to runtime assert on the unbacked symint's val range which would be always between [2, max]. I modified this logic to assert on [0, max] unless user explicitly specifies the min range.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106591
Approved by: https://github.com/gmagogsfm, https://github.com/ezyang
2023-08-15 05:41:43 +00:00
Nikita Karetnikov
e7a3fb13e7 [pt2] add Python metas for special ops (#106683)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106683
Approved by: https://github.com/ezyang
2023-08-13 14:12:21 +00:00
Sam Larsen
e165938853 Implement decomposition for aten.rrelu_with_noise (#106812)
Test Plan:
* Primarily, added new test in test/test_decomp.py
* Updated existing tests, e.g., to NOT expect failure

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106812
Approved by: https://github.com/eellison
2023-08-11 19:18:29 +00:00
PyTorch MergeBot
745d29b0cc Revert "[export] Refactor constrain_as_value and constrain_as_size (#106591)"
This reverts commit 18989890bf.

Reverted https://github.com/pytorch/pytorch/pull/106591 on behalf of https://github.com/izaitsevfb due to Breaks inductor test on trunk ([comment](https://github.com/pytorch/pytorch/pull/106591#issuecomment-1675069091))
2023-08-11 16:37:47 +00:00
Tugsbayasgalan Manlaibaatar
18989890bf [export] Refactor constrain_as_value and constrain_as_size (#106591)
Some notable changes:
1. `constrain_as_size` allows min value to be less than 2 as it will unconditionally assume min >= 2 for compiler purposes. Instead, we add additional check to make sure max value is always greater than 2.
2. Previously, we used to runtime assert on the unbacked symint's val range which would be always between [2, max]. I modified this logic to assert on [0, max] unless user explicitly specifies the min range.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106591
Approved by: https://github.com/gmagogsfm, https://github.com/ezyang
2023-08-11 05:29:22 +00:00
David Berard
393e9eed90 [inductor] modify index_reduce to pass opinfo tests (#106429)
1. add a python meta registration, to fix an issue with the forward pass. The problem was that previously, the C++ meta registration calls [numel()](7b14a14e27/aten/src/ATen/native/TensorAdvancedIndexing.cpp (L329)) which fails (LMK if it's better to fix the C++ implementation to not do this check)
2. Modify the backward to fix an issue in the backward. The backward is not a custom op - it's a custom manual backward implementation. In particular, there's some situations that don't support double backward; the check for whether double backward is allowed requires a .item() call. To fix the meta/fake tensor case, this PR will avoid setting the double backward error only if `GradMode::is_enabled()` - which shouldn't be turned on in PT2.
3. Update skips.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106429
Approved by: https://github.com/zou3519
2023-08-10 18:14:00 +00:00
Nikita Karetnikov
467a2e63f0 [pt2] add Python meta for triangular_solve (#106682)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106682
Approved by: https://github.com/ezyang
2023-08-09 18:50:54 +00:00
Peter Bell
ab6efb1649 [pt2] Add reference implementations of torch.{stft,istft} (#106400)
This allows symbolic shapes to be traced through `torch.stft` and `torch.istft`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106400
Approved by: https://github.com/lezcano
ghstack dependencies: #106319
2023-08-07 20:59:30 +00:00
Peter Bell
d4d090e2da [FakeTensor] Workaround FFT ops with incorrect meta strides (#106319)
Currently there are FFT operators which raise `UnsupportedOperatorException`
because their meta implementations sometimes give incorrect strides. This works
around the problem for static shapes by falling back to eager. Though we still
don't support calls with dynamic shapes.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106319
Approved by: https://github.com/ezyang
2023-08-07 20:59:30 +00:00
Nikita Karetnikov
7215007f01 [pt2] add Python meta for polygamma (#106681)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106681
Approved by: https://github.com/ezyang
2023-08-07 00:59:14 +00:00
Nikita Karetnikov
19621a73c0 [pt2] add metas for grid_sampler_3d ops (#106261)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106261
Approved by: https://github.com/ezyang
2023-08-05 14:48:11 +00:00
Edward Z. Yang
91afefb55b Fix some fake mode confusion between inner/outer fake mode in export (#106515)
Fixes https://github.com/pytorch/pytorch/issues/106412

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106515
Approved by: https://github.com/voznesenskym, https://github.com/BowenBao, https://github.com/thiagocrepaldi
2023-08-04 15:42:23 +00:00
Nikita Karetnikov
1f734e03df [pt2] add metas for mode ops (#106273)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106273
Approved by: https://github.com/ezyang
ghstack dependencies: #106272
2023-08-03 13:11:10 +00:00
Nikita Karetnikov
70469e6f04 [pt2] add metas for median ops (#106272)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106272
Approved by: https://github.com/ezyang
2023-08-03 13:11:10 +00:00
Nikita Karetnikov
f23d755e1f [pt2] add meta for ormqr (#106278)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106278
Approved by: https://github.com/ezyang
2023-08-01 06:47:48 +00:00
Nikita Karetnikov
0ee3b84021 [pt2] add meta for cholesky_inverse (#106120)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106120
Approved by: https://github.com/ezyang
2023-07-29 17:16:20 +00:00
Nikita Karetnikov
80755884be [pt2] add meta for cholesky (#106115)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106115
Approved by: https://github.com/Skylion007, https://github.com/ezyang
2023-07-29 17:16:20 +00:00
lezcano
36ae359655 Update matmul decomp to match eager (#105850)
The decomposition was not updated after https://github.com/pytorch/pytorch/pull/95261

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105850
Approved by: https://github.com/Chillee
2023-07-26 09:24:51 +00:00
Nikita Karetnikov
a4cffaae67 [pt2] add metas for _cholesky_solve_helper and cholesky_solve (#105867)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105867
Approved by: https://github.com/ezyang
2023-07-25 20:21:47 +00:00
Yukio Siraichi
0cd51b3df0 Reland: Value range refinement using multi-variate expressions (#105491)
Trying to re-land: #97964.

Test strategy:

```
buck2 test '@fbcode//mode/dev-nosan' fbcode//pye/model_inventory/inside_out_tracking_model:inside_out_tracking_model_test -- --exact 'pye/model_inventory/inside_out_tracking_model:inside_out_tracking_model_test - test_executorch_e2e_output_consistency_aten (pye.model_inventory.inside_out_tracking_model.InsideOutTrackingModelTest.InsideOutTrackingModelTest)'
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105491
Approved by: https://github.com/ezyang
2023-07-20 02:38:39 +00:00
Edward Z. Yang
1152e86da1 Transmute refined SymInt into int (#104828)
Previously, x.size(0) could return a SymInt, even when the internal
sympy expression was actually already constant (e.g., due to an
introduced guard.)  We now allow to query the Python object with
maybe_as_int which allows us to transmute these objects back to
int when possible.

It is still possible to end up with a constant SymInt even after this
change, e.g., if you get out a SymInt and while holding onto it
specialize it, but casual users are more likely to get ints when they
want to.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104828
Approved by: https://github.com/Skylion007
2023-07-15 18:46:10 +00:00
PyTorch MergeBot
1c69f363c4 Revert "Transmute refined SymInt into int (#104828)"
This reverts commit 0f322a300e.

Reverted https://github.com/pytorch/pytorch/pull/104828 on behalf of https://github.com/ezyang due to executorch failure ([comment](https://github.com/pytorch/pytorch/pull/104828#issuecomment-1635997559))
2023-07-14 15:08:11 +00:00
Yukio Siraichi
8e01f75b1b New Mod class for SymPy expressions. (#104968)
This PR introduces a new `Mod` class to be used with SymPy expressions. The main reason
being due to SymPy simplification errors (#97792).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104968
Approved by: https://github.com/ezyang
2023-07-14 13:34:52 +00:00
Edward Z. Yang
0f322a300e Transmute refined SymInt into int (#104828)
Previously, x.size(0) could return a SymInt, even when the internal
sympy expression was actually already constant (e.g., due to an
introduced guard.)  We now allow to query the Python object with
maybe_as_int which allows us to transmute these objects back to
int when possible.

It is still possible to end up with a constant SymInt even after this
change, e.g., if you get out a SymInt and while holding onto it
specialize it, but casual users are more likely to get ints when they
want to.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104828
Approved by: https://github.com/Skylion007
2023-07-13 07:02:52 +00:00
PyTorch MergeBot
06a5df8d31 Revert "Transmute refined SymInt into int (#104828)"
This reverts commit 4694f54356.

Reverted https://github.com/pytorch/pytorch/pull/104828 on behalf of https://github.com/ezyang due to broke inductor ([comment](https://github.com/pytorch/pytorch/pull/104828#issuecomment-1633049980))
2023-07-12 18:57:58 +00:00
Edward Z. Yang
4694f54356 Transmute refined SymInt into int (#104828)
Previously, x.size(0) could return a SymInt, even when the internal
sympy expression was actually already constant (e.g., due to an
introduced guard.)  We now allow to query the Python object with
maybe_as_int which allows us to transmute these objects back to
int when possible.

It is still possible to end up with a constant SymInt even after this
change, e.g., if you get out a SymInt and while holding onto it
specialize it, but casual users are more likely to get ints when they
want to.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104828
Approved by: https://github.com/Skylion007
2023-07-12 16:40:21 +00:00
Yukio Siraichi
40b8d10d5e Re-land: Turn translation validation on for tests and accuracy runs by default. (#104467)
Re-landing: #103611

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104467
Approved by: https://github.com/malfet
2023-07-05 19:01:50 +00:00
Nikita Karetnikov
c00dd43e43 [pt2] add metas for multilabel_margin_loss ops (#104388)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104388
Approved by: https://github.com/ezyang
2023-07-05 13:42:22 +00:00
Nikita Karetnikov
a3aa4da154 [pt2] add metas for multi_margin_loss ops (#104236)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104236
Approved by: https://github.com/ezyang
2023-07-05 13:40:05 +00:00
Nikita Karetnikov
ad58aba932 [pt2] add metas for adaptive_max_pool ops (#104167)
Fixes #103892.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104167
Approved by: https://github.com/ezyang
2023-07-05 07:02:07 +00:00
Nikita Karetnikov
b1c31b1d26 [pt2] metas and SymInt support for max_pool ops (#103951)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103951
Approved by: https://github.com/Chillee, https://github.com/kulinseth
2023-07-01 01:33:35 +00:00
Nikita Karetnikov
c4a6f86062 [pt2] add metas for max_unpool2d and max_unpool3d (#103821)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103821
Approved by: https://github.com/Skylion007, https://github.com/Chillee
2023-07-01 01:33:35 +00:00
PyTorch MergeBot
4de1ee6ba4 Revert "Value range refinement using multi-variate expressions. (#97964)"
This reverts commit 2642412207.

Reverted https://github.com/pytorch/pytorch/pull/97964 on behalf of https://github.com/huydhn due to Sorry for reverting your PR, but it is breaking an internal test ([comment](https://github.com/pytorch/pytorch/pull/97964#issuecomment-1615194524))
2023-06-30 21:08:05 +00:00
PyTorch MergeBot
a2a8b4d415 Revert "Turn translation validation on for tests and accuracy runs by default. (#103611)"
This reverts commit e311bed2a8.

Reverted https://github.com/pytorch/pytorch/pull/103611 on behalf of https://github.com/malfet due to Broke inductor tests ([comment](https://github.com/pytorch/pytorch/pull/103611#issuecomment-1614850276))
2023-06-30 15:54:18 +00:00
Yukio Siraichi
2642412207 Value range refinement using multi-variate expressions. (#97964)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97964
Approved by: https://github.com/ezyang
2023-06-30 01:32:22 +00:00