Commit Graph

219 Commits

Author SHA1 Message Date
Will Feng
d9dc6b56ec Support using SymInt shapes for torch.baddbmm no-broadcast case (#153112)
A typical `bmm` kernel in Helion needs to pass in symint shapes to `torch.baddbmm`. Currently `self.expand((dim1, dim2, dim3))` in baddbmm runs unconditionally and it doesn't work with symint shapes (it raises the following error):
```
Traceback (most recent call last):
  File "/home/willfeng/local/helion_yf225/helion/_compiler/type_propagation.py", line 699, in propagate_call
    CheckForIndexCalls.retry_call(self.value, proxy_args, proxy_kwargs),
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/willfeng/local/helion_yf225/helion/_compiler/tile_index_proxy.py", line 104, in retry_call
    return fn(*proxy_args, **proxy_kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/willfeng/local/pytorch/torch/utils/_stats.py", line 27, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/willfeng/local/pytorch/torch/_subclasses/fake_tensor.py", line 1338, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/willfeng/local/pytorch/torch/_subclasses/fake_tensor.py", line 1986, in dispatch
    return self._cached_dispatch_impl(func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/willfeng/local/pytorch/torch/_subclasses/fake_tensor.py", line 1450, in _cached_dispatch_impl
    output = self._dispatch_impl(func, types, args, kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/willfeng/local/pytorch/torch/_subclasses/fake_tensor.py", line 2645, in _dispatch_impl
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "/home/willfeng/local/pytorch/torch/_ops.py", line 806, in __call__
    return self._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/willfeng/local/pytorch/torch/_prims_common/wrappers.py", line 309, in _fn
    result = fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
  File "/home/willfeng/local/pytorch/torch/_meta_registrations.py", line 2172, in meta_baddbmm
    self = self.expand((dim1, dim2, dim3))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: /home/willfeng/local/pytorch/build/aten/src/ATen/RegisterCompositeExplicitAutograd_0.cpp:5025: SymIntArrayRef expected to contain only concrete integers
```
This PR changes it so that we don't run `expand()` when not necessary, which makes the Helion use case (i.e. no broadcasting) work.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153112
Approved by: https://github.com/jansel
2025-05-08 21:34:24 +00:00
Laith Sakka
38a9a8b7f7 Fix: Consider input defined unbacked during inductor codegen for runtime asserts (#152231)
So when we use mark_unbacked the graph will have an unbacked inputs symInt. Right now,
deferred runtime assertions that uses those  is never generated.

This PR changes that, such that in the forward graph we consider those and generate the corresponding
runtime assertions of them. We still ignore them for backward which is not ideal

The way we generate runtime assertion is by emitting them when all the defined unbacked symbols used
in them are seen.

We previously skipped placeholder, because for backward we have a wacky approach were we
ignore input defined unbacked symbols and assumes assertions that uses them are already emitted
in forward and we try to emit all other runtime assertions again. see [Note [Backwards runtime asserts]

Doing that we ends up only emitting the runtime assertions that depends on things defined solely in backward, but we could miss checks that spans inputs defined in both backward and forward, i.e one symbol defined in forward passed as input to backward., and another that is defined in backward.) .This is not ideal an ideal approach could be something like this https://github.com/pytorch/pytorch/pull/151919 but it require more work .

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152231
Approved by: https://github.com/aorenste
2025-05-02 07:01:48 +00:00
Laith Sakka
6ea2e6a2d2 Do not do proper const fold during tensorify_python_scalars (#151494)
Chatting with Bob the goal of this is to const fold the floats that where tensorified by calling
guard_scalar(val) on them and then replacing their usages by their values.
Hence we do not need to do this for nodes with no float symbols.

We do not want todo proper const folding because we need to preserve statements that deferred
runtime asserts depend on. (see the added test)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151494
Approved by: https://github.com/bobrenjc93
2025-04-21 22:39:50 +00:00
Tugsbayasgalan Manlaibaatar
adf5f38eae Don't specialize min/max (#151347)
address https://github.com/pytorch/pytorch/issues/149635
Differential Revision: [D73041489](https://our.internmc.facebook.com/intern/diff/D73041489/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151347
Approved by: https://github.com/bobrenjc93
2025-04-19 00:11:15 +00:00
Laith Sakka
b434322075 Fix has_free_symbols (#151492)
used to fail for
        self.assertFalse(has_free_symbols(sympy.S.true))

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151492
Approved by: https://github.com/bobrenjc93
ghstack dependencies: #151170, #151171
2025-04-18 01:19:01 +00:00
Laith Sakka
0a489f924d Fix: missing () in generated runtime assert c++ code (#151171)
Address one of the issues in https://github.com/pytorch/pytorch/issues/151127
generated code used to be
not a==5 or b==5

should be
not (a==5 or b==5)

address one of the issues in the comments of Address one of the issues in https://github.com/pytorch/pytorch/issues/151127

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151171
Approved by: https://github.com/aorenste, https://github.com/eellison
ghstack dependencies: #151170
2025-04-16 08:10:17 +00:00
Laith Sakka
55595e0c85 Fix Issues in deferring runtime assertions. (#151170)
This PR fix two bugs:
1)  Update self.bound_unbacked_symbols before emitting runtime asserts :
set self.bound_unbacked_symbols before emitting runtime asserts to include runtime asserts depending on the current node

2) In the pass that remove unused graph inputs, we should not remove symbols that are used by runtime assertions.

Address some of the issues in https://github.com/pytorch/pytorch/issues/151127

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151170
Approved by: https://github.com/bobrenjc93, https://github.com/eellison
2025-04-16 08:10:17 +00:00
Laith Sakka
cd80778ac8 Fix issue in optimized_add issue: make_optimized should be called on non args only (#150955)
PR https://github.com/pytorch/pytorch/pull/149665 did a change to the optimized_add that is causing an issue internally.
In general make_optimized should be only be called with valid new_args,  new_args can become None
when elements already exists also, we should break out of the loop in that case.

Note that I also only maintained the optimized summation when both lhs and rhs lengths are <=2.
This is ok because the optimization is based on the inductive property of adding one symbol at a time.
the [2]+[2] here is serving as base case ( i feel we can also remove it ) .

Note that keeping it for all sizes while correct, I am not sure if tis as efficient (we will do N log(n) insertions).
there is no current justification for that.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150955
Approved by: https://github.com/Mingming-Ding, https://github.com/atalman, https://github.com/bobrenjc93
2025-04-10 03:00:21 +00:00
Laith Sakka
087e8587cd support backed_size_oblivious in guard_or_false/guard_or_true (#150231)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150231
Approved by: https://github.com/pianpwk
2025-04-09 21:47:20 +00:00
Pian Pawakapan
c6d79c163c [dynamic shapes] allow duck typing for 0/1 (#150222)
Fixes #150184

e.g. for config.backed_size_oblivious=True and compile

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150222
Approved by: https://github.com/laithsakka
2025-04-04 03:24:46 +00:00
Laith Sakka
f9f6c080d8 support guard or false/true in user code and add tests (#150178)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150178
Approved by: https://github.com/pianpwk
2025-04-04 01:19:14 +00:00
bobrenjc93
f649ee73ce Use source hashing to generate consistent symbolic ids (#149665)
This PR was inspired by internal models that were cache missing due to PGO. At a high level the problem looks as follows

Run 1, Invocation 1: We do static compile, save some example values in PGO/automatic dynamic

Run 1, Invocation 2: We detect varying inputs, do dynamic compile, get a dynamic graph and save to PGO. Crucially what we save to PGO is actually a superset of what is actually dynamic. If we notice an input was varying, we mark it as dynamic in PGO even if later on that value gets specialized. When a value gets specialized, we actually remove the symbol from the graph. This results in an interesting conundrum where although we are producing the same isomorphic graph, PGO makes the second run cache miss. Let's see how....

Run 2, Invocation 1: We fetch the PGO, over-mark things as dynamic, get a fx graph, look it up in the cache and... whoops! cache miss! This is because of the aforementioned behavior where the PGO profile will cause us to over-allocate symbols. In practice this means we end up saving a graph in cache with symbols x:s1, y:s3 and on second attempt we cache miss with x:s1, y:s6 where symbols s3,s4,s5 were all optimistically marked dynamic by PGO and subsequently specialized.

We solve this problem by hashing the source names. This ensures somewhat stable assignment. To prevent catastrophic symbol collisions, we use linear probing to ensure no collisions.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149665
Approved by: https://github.com/Mingming-Ding, https://github.com/laithsakka
2025-03-28 05:36:32 +00:00
PyTorch MergeBot
af7719a2fa Revert "Use source hashing to generate consistent symbolic ids (#149665)"
This reverts commit 1f92348dc6.

Reverted https://github.com/pytorch/pytorch/pull/149665 on behalf of https://github.com/malfet due to Broke trunk, see 6eb3c2e282/1 ([comment](https://github.com/pytorch/pytorch/pull/149665#issuecomment-2758578187))
2025-03-27 16:02:27 +00:00
bobrenjc93
1f92348dc6 Use source hashing to generate consistent symbolic ids (#149665)
This PR was inspired by internal models that were cache missing due to PGO. At a high level the problem looks as follows

Run 1, Invocation 1: We do static compile, save some example values in PGO/automatic dynamic

Run 1, Invocation 2: We detect varying inputs, do dynamic compile, get a dynamic graph and save to PGO. Crucially what we save to PGO is actually a superset of what is actually dynamic. If we notice an input was varying, we mark it as dynamic in PGO even if later on that value gets specialized. When a value gets specialized, we actually remove the symbol from the graph. This results in an interesting conundrum where although we are producing the same isomorphic graph, PGO makes the second run cache miss. Let's see how....

Run 2, Invocation 1: We fetch the PGO, over-mark things as dynamic, get a fx graph, look it up in the cache and... whoops! cache miss! This is because of the aforementioned behavior where the PGO profile will cause us to over-allocate symbols. In practice this means we end up saving a graph in cache with symbols x:s1, y:s3 and on second attempt we cache miss with x:s1, y:s6 where symbols s3,s4,s5 were all optimistically marked dynamic by PGO and subsequently specialized.

We solve this problem by hashing the source names. This ensures somewhat stable assignment. To prevent catastrophic symbol collisions, we use linear probing to ensure no collisions.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149665
Approved by: https://github.com/Mingming-Ding, https://github.com/laithsakka
2025-03-27 03:39:27 +00:00
Pian Pawakapan
a6459afb0e [dynamic shapes] add backed_size_oblivious option (#148696)
Adds option `torch.fx.experimental._config.backed_size_oblivious = True` to allocate `[0, inf]` instead of `[2, inf]` ranges for size backed symbols, and opting into size-oblivious semantics for them.

Helps in a number of cases like
- Keeps `[0, inf]` bounds for unbacked symbols, when we make a unbacked -> backed replacement
- More sound handling for 0/1 inputs at runtime when we lower from export
- Avoids ends-of-bounds, sys.maxsize constraint violations for exporting with named Dims (https://github.com/pytorch/pytorch/issues/146315, https://github.com/pytorch/pytorch/issues/146046)

May look towards turning this on globally for export.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148696
Approved by: https://github.com/bobrenjc93
2025-03-11 21:52:34 +00:00
Brian Hirsh
621dadd4ca partitioner: when materializing unbacked tensor intermediates, apply hint to symbol, not expr (#144097)
Fixes https://github.com/pytorch/pytorch/issues/144095

open to suggestions: the `hint_int(..., fallback=...)` API feels like a bit of a footgun, because:

(1) we use the same guess for every unbacked symint (both symbols, and compound expressions)
(2) the user may have established some relationship between some unbacked symints that we are not taking into account.

I'm not sure how real of an issue (2) is - is it common to e.g. generate two unbacked symints, and then add a runtime assert that they are unequal?

Instead I did something simpler that's just enough to fix the linked issue: if we have a sympy expression containing an unbacked symbol (e.g. `u0 + 1`), then the partitioner will now fill in the symbol with our guess instead of the expression (plugging in `u0=4096` gets us 4097). This was important for an internal custom op, that had some logic like this:
```
def custom_op(x: [u0], y: [u0 + 1]):
    assert x.shape[0] = y.shape[0] - 1
    ...
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144097
Approved by: https://github.com/laithsakka
2025-03-11 02:11:57 +00:00
Laith Sakka
454fbd5bbe realize stride symbols in estimate_runtime (#146752)
Unfortuanlty could not create a local repo, or unit test.
fix https://github.com/pytorch/pytorch/issues/146686

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146752
Approved by: https://github.com/bobrenjc93, https://github.com/bdhirsh
2025-02-19 06:02:49 +00:00
bobrenjc93
6f07847efe Bail on checking internal overlap when dealing with unbacked symints (#145385)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145385
Approved by: https://github.com/ezyang
2025-01-23 22:31:31 +00:00
Tom Ritchford
d8c8ba2440 Fix unused Python variables in test/[e-z]* (#136964)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136964
Approved by: https://github.com/justinchuby, https://github.com/albanD
2024-12-18 23:02:30 +00:00
Laith Sakka
f684dbd002 Try to simplify FloorDiv axioms implications when needed during evaluations. (#141267)
Summary:
This very much the same solution proposed by bobrenjc93 except that it restrict it to expressions and axioms that have FloorDiv, since those are the only ones that could have became CleanDiv. and the only one that can changes as shape env changes.

This also does not break torchrec benchmarks, it might be worth it to know why the generalization of this does break the torchrec benchmarks, but we could just be hitting another bug or NYI situation.

ovearhead?
None on
```
buck2 run fbcode//mode/opt fbcode//torchrec/distributed/tests:pt2_compile_benchmark -- --num-features=1000
```

Differential Revision: D66307433

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141267
Approved by: https://github.com/ezyang
2024-11-28 15:35:35 +00:00
Isuru Fernando
44186a0a4e Move Sympy printers to torch/utils/_sympy/printers.py (#140597)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140597
Approved by: https://github.com/ezyang, https://github.com/anijain2305
2024-11-26 18:11:00 +00:00
William Wen
ee7eaad5c3 [dynamo] add SymNode bitwise and/or (#138777)
Fixes [T203472723](https://www.internalfb.com/intern/tasks/?t=203472723)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138777
Approved by: https://github.com/ezyang
2024-11-22 23:36:16 +00:00
PyTorch MergeBot
f23621ec56 Revert "Move Sympy printers to torch/utils/_sympy/printers.py (#140597)"
This reverts commit c25b201583.

Reverted https://github.com/pytorch/pytorch/pull/140597 on behalf of https://github.com/huydhn due to Trunk is sad again after this lands, this looks like a landrace this time, so please do a rebase ([comment](https://github.com/pytorch/pytorch/pull/140597#issuecomment-2494052978))
2024-11-22 15:43:39 +00:00
Isuru Fernando
c25b201583 Move Sympy printers to torch/utils/_sympy/printers.py (#140597)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140597
Approved by: https://github.com/ezyang, https://github.com/anijain2305
2024-11-22 02:04:36 +00:00
Laith Sakka
e39955e82f Avoid some max constructor optimizations when known not needed. (#139741)
Summary:
around 10% with 1K nodes
more than that with 2K features. 414.5735 -> 333 (20%)

This target optimizing patterns like this
```
 sym_max: "Sym(Max(u31 + u32, u33 + u34))" = torch.sym_max(sym_sum_6, sym_sum_7);  sym_sum_6 = sym_sum_7 = None
        sym_max_1: "Sym(Max(u31 + u32, u33 + u34, u35 + u36))" = torch.sym_max(sym_max, sym_sum_8);  sym_max = sym_sum_8 = None
        sym_max_2: "Sym(Max(u31 + u32, u33 + u34, u35 + u36, u37 + u38))" = torch.sym_max(sym_max_1, sym_sum_9);  sym_max_1 = sym_sum_9 = None
        sym_max_3: "Sym(Max(u31 + u32, u33 + u34, u35 + u36, u37 + u38, u39 + u40))" = torch.sym_max(sym_max_2, sym_sum_10);  sym_max_2 = sym_sum_10 = None
        sym_max_4: "Sym(Max(u31 + u32, u33 + u34, u35 + u36, u37 + u38, u39 + u40, u41 + u42))" = torch.sym_max(sym_max_3, sym_sum_11);  sym_max_3 = sym_sum_11 = None
        sym_max_5: "Sym(Max(u31 + u32, u33 + u34, u35 + u36, u37 + u38, u39 + u40, u41 + u42, u43 + u44))" = torch.sym_max(sym_max_4, sym_sum_12);  sym_max_4 = sym_sum_12 = None
        sym_max_6: "Sym(Max(u31 + u32, u33 + u34, u35 + u36, u37 + u38, u39 + u40, u41 + u42, u43 + u44, u45 + u46))" = torch.sym_max(sym_max_5, sym_sum_13);  sym_max_5 = sym_sum_13 = None
        sym_max_7: "Sym(Max(u31 + u32, u33 + u34, u35 + u36, u37 + u38, u39 + u40, u41 + u42, u43 + u44, u45 + u46, u47 + u48))" = torch.sym_max(sym_max_6, sym_sum_14);  sym_max_6 = sym_sum_14 = None
        sym_max_8: "Sym(Max(u31 + u32, u33 + u34, u35 + u36, u37 + u38, u39 + u40, u41 + u42, u43 + u44, u45 + u46, u47 + u48, u49 + u50))" = torch.sym_max(sym_max_7, sym_sum_15);  sym_max_7 = sym_sum_15 = sym_max_8 = None
```

<img width="496" alt="Screenshot 2024-11-05 at 11 00 35 AM" src="https://github.com/user-attachments/assets/455c06a3-e1bf-43cb-b880-9470ae6fb07f">
<img width="511" alt="Screenshot 2024-11-05 at 11 00 57 AM" src="https://github.com/user-attachments/assets/ff0d4236-9b5c-4a9a-8520-47b005bb3cb0">

Differential Revision: D65354971

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139741
Approved by: https://github.com/ezyang
2024-11-21 16:50:52 +00:00
PyTorch MergeBot
701e06b643 Revert "Move Sympy printers to torch/utils/_sympy/printers.py (#140597)"
This reverts commit aefcdb3c9f.

Reverted https://github.com/pytorch/pytorch/pull/140597 on behalf of https://github.com/huydhn due to Sorry for reverting your change but I think it fails inductor/test_padding in trunk. This is a target determination miss and that failed test was not run in your PR ([comment](https://github.com/pytorch/pytorch/pull/140597#issuecomment-2489641453))
2024-11-20 22:13:57 +00:00
Isuru Fernando
aefcdb3c9f Move Sympy printers to torch/utils/_sympy/printers.py (#140597)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140597
Approved by: https://github.com/ezyang, https://github.com/anijain2305
2024-11-20 20:26:49 +00:00
Laith Sakka
8d708090c0 Optimize increment summations [Latest Nov 15] (#140822)
Summary:
**wins**
on torchrec benchmark, for 2K nodes it save 40seconds
with the recent sympy changes (https://www.internalfb.com/diff/D65883538) we save around 13 second ( with the max opt on).
```
buck2 run fbcode//mode/opt fbcode//torchrec/distributed/tests:pt2_compile_benchmark -- --num-features=200
```
This diff optimizes construction expressions of the form
a+b+c...  (all unique symbols).
which are very common in torchrec models.

**How**
Expressions of the form a+b+c are not optimized by add, the only needed optimization is sorting them.
If we have  a+b+c and we are adding (d) to it, we can do a binary search to know
the position of (d) and avoid optimizing the new expression by passing the new order.

**Extensions**:
1. support constant terms.
2. support 10a+10b+.. (this will give even more wins will extend the support in second PR)

Differential Revision: D66008482

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140822
Approved by: https://github.com/ezyang
2024-11-20 16:48:20 +00:00
PyTorch MergeBot
c1fe6be202 Revert "[dynamo] add SymNode bitwise and/or (#138777)"
This reverts commit c98ef0279e.

Reverted https://github.com/pytorch/pytorch/pull/138777 on behalf of https://github.com/ezyang due to triggering AssertionError: Guard check failed: 14/2: name 'BitwiseFn_bitwise_or' is not defined ([comment](https://github.com/pytorch/pytorch/pull/138777#issuecomment-2477477776))
2024-11-14 21:52:40 +00:00
William Wen
c98ef0279e [dynamo] add SymNode bitwise and/or (#138777)
Fixes [T203472723](https://www.internalfb.com/intern/tasks/?t=203472723)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138777
Approved by: https://github.com/ezyang
2024-11-13 18:31:06 +00:00
Edward Z. Yang
91ded0576d Add sym_log2 (#137980)
Internal xref: https://fb.workplace.com/groups/1075192433118967/permalink/1515595595745313/

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137980
Approved by: https://github.com/bobrenjc93
2024-10-28 17:03:14 +00:00
PyTorch MergeBot
2487a834a4 Revert "Add sym_log2 (#137980)"
This reverts commit 5d450d7fac.

Reverted https://github.com/pytorch/pytorch/pull/137980 on behalf of https://github.com/jeanschmidt due to lint broke from this onwards on main ([comment](https://github.com/pytorch/pytorch/pull/137980#issuecomment-2441570186))
2024-10-28 13:21:08 +00:00
Edward Z. Yang
5d450d7fac Add sym_log2 (#137980)
Internal xref: https://fb.workplace.com/groups/1075192433118967/permalink/1515595595745313/

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137980
Approved by: https://github.com/bobrenjc93
2024-10-28 03:09:11 +00:00
Laith Sakka
ed313a5ca2 Introduce torch.sym_add, variadic add (#138660)
Tested internally here: https://www.internalfb.com/diff/D64057744
This is a reland after previous internal failures.
main change is
```
 if min is None and max is None:
        torch._check_is_size(size)
        return
```

Partially addresses https://github.com/pytorch/pytorch/issues/128150

When you have big sums of values, we end up computing long chains of
binary addition in our FX graph representation.  Not only is this ugly,
it also is quadratic, as the sympy.Add constructor is O(N) in number
of arguments.  Instead, ensure that we maintain the summation as a
single FX node so we can do the entire addition all in one go.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138660
Approved by: https://github.com/ezyang, https://github.com/bobrenjc93
2024-10-23 17:42:41 +00:00
Edward Z. Yang
d9f4a7d3f9 Simplify find_localzeros (#133325)
Instead of doing an N^2 connected thing, only do simplifications for binary max/min, and for very simple situations.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Differential Revision: [D64135230](https://our.internmc.facebook.com/intern/diff/D64135230)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133325
Approved by: https://github.com/albanD
2024-10-10 00:52:50 +00:00
PyTorch MergeBot
16a2c2cfd4 Revert "Introduce torch.sym_sum (#136429)"
This reverts commit 90bed32b98.

Reverted https://github.com/pytorch/pytorch/pull/136429 on behalf of https://github.com/ezyang due to fails internal stuff ([comment](https://github.com/pytorch/pytorch/pull/136429#issuecomment-2403335147))
2024-10-09 20:08:01 +00:00
Edward Z. Yang
90bed32b98 Introduce torch.sym_sum (#136429)
Partially addresses https://github.com/pytorch/pytorch/issues/128150

When you have big sums of values, we end up computing long chains of
binary addition in our FX graph representation.  Not only is this ugly,
it also is quadratic, as the sympy.Add constructor is O(N) in number
of arguments.  Instead, ensure that we maintain the summation as a
single FX node so we can do the entire addition all in one go.

update_hint_regression benchmark, before and after:

```
update_hint_regression,compile_time_instruction_count,2648328980
update_hint_regression,compile_time_instruction_count,2563748678
```

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136429
Approved by: https://github.com/isuruf
2024-10-08 18:12:57 +00:00
imShZh
4ab232d0c4 Fix symbolic number's type and tensor's dtype mismatch bug in Tensor ctor (#135433)
Fixes #135432

In the current implementation, if we try to store a symbolic number in Tensor's constructor, it assumes that the tensor's dtype and the symbolic number's type are matched, which is not the case.

In other words, if we try to store a `SymInt`, current implementation assumes tensor's dtype is `torch.int32`, `torch.int64` or something. And if we try to store a `SymFloat`, it assumes tensor's dtype is `torch.float32` or `torch.float64`. However, the tensor's dtype could also be `torch.float32` or something else when we try to store `SymInt`, which would be wrong.

This PR stores symbolic numbers by tensor's scalar type by wrapping `SymInt` and `SymFoat`'s guarded number into a PyObject.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135433
Approved by: https://github.com/ezyang
2024-09-09 19:32:18 +00:00
Edward Z. Yang
6c5669903f Fix Invalid NaN comparison due to infinity-zero multiply on latest sympy (#135044)
Fixes https://github.com/pytorch/pytorch/issues/133735

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135044
Approved by: https://github.com/zou3519
2024-09-04 14:13:09 +00:00
Avik Chaudhuri
b454c51060 remove dynamic_dim (#134211)
Summary: As promised in https://github.com/pytorch/pytorch/pull/134045.

Test Plan: existing

Differential Revision: D61646937

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134211
Approved by: https://github.com/angelayi
2024-08-23 04:13:03 +00:00
Avik Chaudhuri
695d7db2d6 remove dead code for suggesting legacy dynamic shapes fixes (#133700)
Summary: `dynamic_dim` based dynamic shapes are long gone, so pretty-printing suggested fixes for them is dead code.

Test Plan: existing tests

Differential Revision: D61398303

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133700
Approved by: https://github.com/zhxchen17
2024-08-17 01:59:34 +00:00
Xuehai Pan
4226ed1585 [BE] Format uncategorized Python files with ruff format (#132576)
Remove patterns `**`, `test/**`, and `torch/**` in `tools/linter/adapters/pyfmt_linter.py` and run `lintrunner`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132576
Approved by: https://github.com/ezyang, https://github.com/Skylion007
ghstack dependencies: #132574
2024-08-04 17:13:31 +00:00
Animesh Jain
ddde9dd25c [dynamo][automatic_dynamic] Trigger dynamism on stride changes (#130232)
Fixes https://github.com/pytorch/pytorch/issues/129798

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130232
Approved by: https://github.com/ezyang
2024-07-21 03:45:54 +00:00
Xuehai Pan
ba48cf6535 [BE][Easy][6/19] enforce style for empty lines in import segments in test/ (#129757)
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter.

You can review these PRs via:

```bash
git diff --ignore-all-space --ignore-blank-lines HEAD~1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129757
Approved by: https://github.com/ezyang
2024-07-17 06:42:37 +00:00
Edward Z. Yang
408c921d96 Make hashing a SymInt raise an error again (#130548)
See https://github.com/pytorch/pytorch/issues/130547

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130548
Approved by: https://github.com/Skylion007, https://github.com/albanD, https://github.com/lezcano
2024-07-16 18:30:30 +00:00
PyTorch MergeBot
2b1df24877 Revert "Make hashing a SymInt raise an error again (#130548)"
This reverts commit 3100455b8e.

Reverted https://github.com/pytorch/pytorch/pull/130548 on behalf of https://github.com/clee2000 due to broke inductor/test_triton_kernels.py https://github.com/pytorch/pytorch/actions/runs/9908970127/job/27377960411 3100455b8e. Not run on PR due to bad TD ([comment](https://github.com/pytorch/pytorch/pull/130548#issuecomment-2225912018))
2024-07-12 16:20:12 +00:00
Edward Z. Yang
3100455b8e Make hashing a SymInt raise an error again (#130548)
See https://github.com/pytorch/pytorch/issues/130547

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130548
Approved by: https://github.com/Skylion007, https://github.com/albanD
2024-07-12 13:49:56 +00:00
Edward Z. Yang
10c831567b Make sympify'ing SymInt/etc produce their sympy expression (#130166)
There is one huge problem this fixes: today, sympify(symint)
produces a float(!!) because Sympy attempts to see if you can
coerce the symint to float in sympify and of course this works on
SymInt.

However, this also has another nontrivial effect: anywhere in Inductor
where sympy expressions are passed around, it is also valid to pass
around a SymInt now.  I'm ambivalent about this: it's currently a
mistake to be passing around a SymInt when a sympy expression is
expected.  But maybe this is fine?

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130166
Approved by: https://github.com/yf225
2024-07-06 03:56:45 +00:00
Edward Z. Yang
35600bcaad Print float with full precision, don't truncate (#130027)
Fixes https://github.com/pytorch/pytorch/issues/119338

Exercised in https://github.com/pytorch/pytorch/pull/118448

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130027
Approved by: https://github.com/lezcano, https://github.com/Skylion007
2024-07-03 17:20:19 +00:00
Edward Z. Yang
d7680a564b Bug fixes for disabling 0/1 specialization on plain int (#129961)
These bug fixes will be exercised in
https://github.com/pytorch/pytorch/pull/128327 but I separate them from
the actual policy change (which is more risky)

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129961
Approved by: https://github.com/lezcano
2024-07-02 23:19:48 +00:00