Commit Graph

214 Commits

Author SHA1 Message Date
PyTorch MergeBot
c88c0e6c65 Revert "[Dynamo] Handle torch function subclass/mode dispatch on generic tensor methods (#137119)"
This reverts commit d255b34c0a.

Reverted https://github.com/pytorch/pytorch/pull/137119 on behalf of https://github.com/malfet due to Need to revert to be able to revert https://github.com/pytorch/pytorch/pull/136910 ([comment](https://github.com/pytorch/pytorch/pull/137119#issuecomment-2400401262))
2024-10-08 17:09:26 +00:00
Michael Lazos
d255b34c0a [Dynamo] Handle torch function subclass/mode dispatch on generic tensor methods (#137119)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137119
Approved by: https://github.com/williamwen42
ghstack dependencies: #137114, #137115, #137116, #137117, #137120, #137227
2024-10-07 18:55:26 +00:00
Michael Lazos
14eabd6915 [Dynamo] Handle extracted unbound tensor methods (#137227)
fixes2

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137227
Approved by: https://github.com/williamwen42
ghstack dependencies: #137114, #137115, #137116, #137117, #137120
2024-10-07 18:55:26 +00:00
Michael Lazos
941be418d8 [Dynamo] Ensure torch function modes are dispatched on builtin ops (#137117)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137117
Approved by: https://github.com/yanboliang, https://github.com/williamwen42
ghstack dependencies: #137114, #137115, #137116
2024-10-07 18:55:26 +00:00
Salman Mohammadi
48c18ff850 [dynamo] Added support for tensor's is_inference method (#136450)
Fixes #135439

This PR adds support for the `is_inference` method on torch tensors which successfully compiles the following example fn without graph breaks:
```python
def fn_simple(x):
    if x.is_inference():
        return x.sum()
    else:
        return x.min()
```

I've also tried to add guards on the tensor to guard against  `is_inference`. I wasn't 100% sure where these should go so please don't hesitate to correct me.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136450
Approved by: https://github.com/ezyang
2024-09-27 09:15:07 +00:00
PyTorch MergeBot
9223c16208 Revert "Fix constant propagation in builtins and UserClasses (#131354)"
This reverts commit dd4a51b39a.

Reverted https://github.com/pytorch/pytorch/pull/131354 on behalf of https://github.com/atalman due to Breaks torchrec tests ([comment](https://github.com/pytorch/pytorch/pull/131354#issuecomment-2375417145))
2024-09-25 23:01:03 +00:00
Tom Ritchford
dd4a51b39a Fix constant propagation in builtins and UserClasses (#131354)
* Fixes https://github.com/pytorch/pytorch/issues/118675
* Replaces https://github.com/pytorch/pytorch/pull/118994

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131354
Approved by: https://github.com/jansel, https://github.com/anijain2305
2024-09-25 13:03:40 +00:00
Will Feng
a815611db9 [Traceable FSDP2][Partitioner] Must save AC output if output has a backward hook (#135727)
If node is AC region output and has a backward hook on it, we intentionally choose to save it.
This is to work around circular dependencies in Traceable FSDP2+AC.
Example:
```
out = fully_shard(utils.checkpoint(module))(x)
norm_out = layer_norm(out)
```
and there is a circular dependency:
1. In backward, grad_input of layer_norm aka. `out_grad` is actually dependent on `out`.
2. `out` depends on `out`'s backward hook created by FSDP2 (which does all-gather for `module` weights) in order to be recomputed.
3. `out`'s FSDP2 backward hook, as is the case for all eager backward hooks, depends on `out_grad`  -> circular dependency with (1)!

Solution: check whether `out` has a backward hook, and if so, intentionally save `out` in forward graph outputs. With this, we can break the above circular dependency.

----

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135727
Approved by: https://github.com/Chillee
2024-09-14 08:45:58 +00:00
Tom Ritchford
2c99f17a32 Implement VariableTracker.python_type() (#134215)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134215
Approved by: https://github.com/amjames, https://github.com/jansel
2024-09-05 16:35:47 +00:00
Michael Lazos
4e71418566 [dynamo] rewrite addcmul_ to remove graph break (#134168)
Context: Adding support for the beta parameters to be tensors

Details: Similarly to the previous two PRs addcmul_ is used with the tensor betas as the value argument. When this occurs, an item() call is invoked in the aten op. To avoid this graph break, addcmul_ is decomposed into its constrituent ops to avoid this.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134168
Approved by: https://github.com/anijain2305
ghstack dependencies: #134166, #134167
2024-08-31 10:24:39 +00:00
Oguz Ulgen
6e79932543 Add basic mypy annotations to dynamo (#132415)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132415
Approved by: https://github.com/XuehaiPan, https://github.com/jamesjwu
2024-08-04 18:43:36 +00:00
PyTorch MergeBot
3558a8cf4a Revert "Add basic mypy annotations to dynamo (#132415)"
This reverts commit 71e22e0959.

Reverted https://github.com/pytorch/pytorch/pull/132415 on behalf of https://github.com/ZainRizvi due to Sorry, this PR has entered a weird state in the diff train. Trying to revert it to skip it, and then we can try relanding it ([comment](https://github.com/pytorch/pytorch/pull/132415#issuecomment-2267631785))
2024-08-04 18:39:29 +00:00
Chen Haifeng
50ed6ce277 Support built-in id function for TensorVariable on parameters (#130100)
Fixes #130087

This patch tries to provide a built-in id function implementation for TensorVariable when the id function is called on tensors like module parameters. The id function call on intermediate tensors is not supported.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130100
Approved by: https://github.com/anijain2305
2024-08-02 01:19:25 +00:00
Oguz Ulgen
71e22e0959 Add basic mypy annotations to dynamo (#132415)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132415
Approved by: https://github.com/XuehaiPan, https://github.com/jamesjwu
2024-08-01 20:14:25 +00:00
Yiming Zhou
ee09d066d3 [dynamo] Add line number to _warn_capture_scalar_outputs() (#132333)
Fixes #127667.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132333
Approved by: https://github.com/anijain2305
2024-08-01 16:11:21 +00:00
Xuehai Pan
e74ba1b34a [BE][Easy][15/19] enforce style for empty lines in import segments in torch/_d*/ (#129767)
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter.

You can review these PRs via:

```bash
git diff --ignore-all-space --ignore-blank-lines HEAD~1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129767
Approved by: https://github.com/anijain2305
2024-07-31 21:18:11 +00:00
Brian Hirsh
8bb9aa93a7 dynamo: mutations on .data should be invisible to autograd (#131403)
Fixes https://github.com/pytorch/pytorch/issues/121353

our handle for `.data` in dynamo today basically just converts `y = x.data` into `y = x.detach()`. The semantics of these two ops are not quite the same, because:

(1) any future mutations on `x.data` will be fully ignored by autograd
(2) any mutations on `x.detach()` will bump x's version counter

the linked model does a .data mutation that is hidden from autograd in eager, but ends up erroring during AOTDispatcher tracing.

I updated dynamo's handling so that:

(1) when dynamo sees a call to `getattr(tensor, "data")` and calls `.detach()` we set a flag on the returned `TensorVariable` indicating it came from `.data`

(2) on any tensor method that we call with an input `TensorVariable` with this flag turned on, we proxy autograd's `preserve_version_counter` logic into the graph, to properly reset the VC after the op is run.

One thing to note is that I don't actually do this on every op that we pass the tensor to: I only do it for tensor methods that appear to be mutations (by checking for a trailing underscore). My thought was that:

(1) I didn't want to do this for **every** op that you pass `y` into, since that will e.g. triple the number of nodes in the graph, and could cause compile time regressions if you use .data

(2) this situation is pretty rare in general, and I'm hoping that "tensor method mutations" cover most reasonable mutation cases. If we manage to miss a case, you will get a loud error during tracing anyway, so there is not a safety issue.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131403
Approved by: https://github.com/anijain2305, https://github.com/zou3519
2024-07-26 14:22:20 +00:00
Oguz Ulgen
7a42470bcb Annotate all InstructionTranslator (#131509)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131509
Approved by: https://github.com/zou3519
2024-07-24 23:45:53 +00:00
PyTorch MergeBot
5db5865614 Revert "Annotate all InstructionTranslator (#131509)"
This reverts commit eafbd20f23.

Reverted https://github.com/pytorch/pytorch/pull/131509 on behalf of https://github.com/clee2000 due to sorry need to revert this to revert something else, I think you only need to rebase and remerge ([comment](https://github.com/pytorch/pytorch/pull/131509#issuecomment-2249000843))
2024-07-24 22:29:49 +00:00
Oguz Ulgen
b56939dae1 Annotate more InstructionTranslator (#131680)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131680
Approved by: https://github.com/zou3519
ghstack dependencies: #131676
2024-07-24 22:14:29 +00:00
Oguz Ulgen
eafbd20f23 Annotate all InstructionTranslator (#131509)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131509
Approved by: https://github.com/zou3519
2024-07-24 05:31:01 +00:00
Michael Lazos
1b72cf0b09 Add hasattr for tensor variable (#131008)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131008
Approved by: https://github.com/anijain2305
ghstack dependencies: #131007
2024-07-19 12:43:27 +00:00
Michael Lazos
22388ffe03 Graph break on tostring for numpy remapping (#131007)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131007
Approved by: https://github.com/williamwen42
2024-07-18 17:23:41 +00:00
rzou
6ce0bd7d3b [HOP] Use user directed names for variables where possible (#130271)
Afaict the previous check was too strict. Removing it passes all the
mutation tests (mutation checks happen via the TensorVariable's mutable_local).

Test Plan:
- tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130271
Approved by: https://github.com/Chillee, https://github.com/ydwu4
2024-07-10 13:59:20 +00:00
PyTorch MergeBot
3be4922a9d Revert "[HOP] Use user directed names for variables where possible (#130271)"
This reverts commit adb65682af.

Reverted https://github.com/pytorch/pytorch/pull/130271 on behalf of https://github.com/clee2000 due to broke inductor/test_flex_attention https://github.com/pytorch/pytorch/actions/runs/9863205414/job/27236960046 adb65682af Test not run on PR due to bad TD ([comment](https://github.com/pytorch/pytorch/pull/130271#issuecomment-2218832643))
2024-07-09 22:24:39 +00:00
rzou
adb65682af [HOP] Use user directed names for variables where possible (#130271)
Afaict the previous check was too strict. Removing it passes all the
mutation tests (mutation checks happen via the TensorVariable's mutable_local).

Test Plan:
- tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130271
Approved by: https://github.com/Chillee, https://github.com/ydwu4
ghstack dependencies: #130255, #130268
2024-07-09 19:42:52 +00:00
Yanbo Liang
551f3b92b2 [Dynamo] Add assertion for tensor unpack shape mismatch (#130077)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130077
Approved by: https://github.com/Chillee
2024-07-04 09:25:08 +00:00
William Wen
79aabaf626 [3.13, dynamo] codegen PUSH_NULL when callable is codegen'd (#129172)
Significant bytecode generation API change!

The new suggested convention to generating bytecode to call a function is now to wrap instructions that push a callable to the stack with `add_push_null`, then that callable is called with `create_call_function` with `push_null=False` (see diff for examples).

In Python 3.13, NULL is now expected to be pushed after the callable. In <=3.12, the NULL was pushed before the callable.  This change abstracts away the exact placement of the NULL, but the developer must be aware that a NULL may be needed when codegen'ing a callable.

This abstraction also reduces the need for the `push_null=True` option in `create_call_function`, which removes the need to rotate a NULL to the right place on the stack with a sequence of `SWAP` instructions.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129172
Approved by: https://github.com/jansel
2024-06-22 17:25:23 +00:00
PyTorch MergeBot
48a54146e7 Revert "[dynamo] Support ndarray.dtype attribute access (#124490)"
This reverts commit 4adee71155.

Reverted https://github.com/pytorch/pytorch/pull/124490 on behalf of https://github.com/atalman due to Breaks internal builds ([comment](https://github.com/pytorch/pytorch/pull/124490#issuecomment-2152664749))
2024-06-06 14:21:29 +00:00
Andrew M. James
4adee71155 [dynamo] Support ndarray.dtype attribute access (#124490)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124490
Approved by: https://github.com/lezcano
ghstack dependencies: #125717
2024-06-05 17:20:01 +00:00
Animesh Jain
84f8cd22ac [dynamo][TensorVariable] Support "if param.grad_fn" usecase (#126960)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126960
Approved by: https://github.com/jansel
ghstack dependencies: #126922
2024-05-25 01:09:26 +00:00
Aart Bik
ff82e2e7cf [traced-graph][sparse] propagate sparsity metadata into traced graph (#117907)
Propagate sparsity metadata from sparse tensors of torch.sparse into the traced graph representation (with would be useful for a JIT backend that supports a "sparse compiler"). This is a first careful attempt, since the actual "meta" feature seem still incomplete for coo and completely lacking for csr/csc/bsr/bsc.

For background see forum postings (with examples):
  https://discuss.pytorch.org/t/connecting-pytorch-sparse-tensors-with-mlir/195145
  https://dev-discuss.pytorch.org/t/connecting-pytorch-sparse-tensors-with-mlir/1803

And feature request:
  https://github.com/pytorch/pytorch/issues/117188

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117907
Approved by: https://github.com/pearu, https://github.com/ezyang
2024-05-23 22:46:46 +00:00
Matthew Hoffman
81277baa0c Remove removed ruff rule TRY200 (#126256)
My TOML linter is complaining that "TRY200" is not acceptable for the `tool.ruff.lint` schema.

From the ruff docs: https://docs.astral.sh/ruff/rules/reraise-no-cause/

> This rule has been removed and its documentation is only available for historical reasons.
>
> This rule is identical to [B904](https://docs.astral.sh/ruff/rules/raise-without-from-inside-except/) which should be used instead.

and we are currently explicitly ignoring B904.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126256
Approved by: https://github.com/Skylion007
2024-05-17 16:31:05 +00:00
Edward Z. Yang
9c9d0c2fab Add VariableTracker.debug_repr (#126299)
Now you can print arbitrary values at compile time with
comptime.print()

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126299
Approved by: https://github.com/jansel
ghstack dependencies: #126292
2024-05-15 23:55:29 +00:00
Edward Z. Yang
2ba102f689 Implement native support for float inputs in Dynamo and ShapeEnv (#125325)
The big idea is that floats are treated as Tensors on input/output to the FX graph, but on the inside, we immediately call item() on the synthetic Tensor and record regular float operations on it. Canonicalization to Tensor operations will happen in a standalone FX pass. This behavior is controlled by `specialize_float` config variable when set to False.

The generated graph looks like this for the test `test_unspec_float_output`:

```
 def forward(self, L_x_: "f32[3]", L_y_: "f32[]"):
     l_x_ = L_x_
     l_y_ = L_y_

     # File: /data/users/ezyang/a/pytorch/test/dynamo/test_unspec.py:511 in f, code: return x + 1, y * 2
     add: "f32[3]" = l_x_ + 1;  l_x_ = None
     item: "Sym(zf0)" = l_y_.item();  l_y_ = None
     mul: "Sym(2*zf0)" = item * 2;  item = None
     scalar_tensor: "f32[]" = torch.scalar_tensor(mul);  mul = None
     return (add, scalar_tensor)
```

The ingredients:

* **torch/_dynamo/variables/builder.py** When `specialize_float` is False, we wrap float literals with `wrap_symfloat`. This is an unholy mashup of `wrap_symint` and `wrap_unspecialized_primitive`. The overall strategy is that we first generate a tensor argument (because that's what we want to show up into the FX graph), but then immediately call item() on the tensor argument to get a SymNodeVariable, which we will do the rest of the tracing with.  Importantly, this SymNodeVariable is backed with the source of the original float: this means we can guard on the resulting value (something we could NOT do with UnspecializedPythonVariable). This has to be done manually, because if you literally call item() on the tensor, you will end up with an unbacked float. There is a bit of copy paste from wrap_symint and wrap_unspecialized_primitive which we can try to factor out, but this really is its own thing and you should review every line of code in the function.
* **torch/fx/experimental/symbolic_shapes.py** We now can generate guards on float inputs, and these guards are handled inside of ShapeEnv. So we need to be able to allocate (backed!) float symbols, and produce guards for them. Fairly straightforward generalization.
* **torch/_dynamo/codegen.py** I also need to maintain the invariant that there are no float outputs to the FX graph. I chose to do this at codegen time. When we detect a SymNodeVariable on the return stack for a float, we on the fly convert it (via `as_tensor`) to a TensorVariable, which is the true output. We then special case the output bytecode to call item() on it again. The tensor conversion is memoized on SymNodeVariable since we typically run the code generation process twice.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125325
Approved by: https://github.com/lezcano, https://github.com/jansel
2024-05-14 04:10:01 +00:00
Edward Z. Yang
e93b57a570 Add propagate_real_tensors mode for unbacked (#125115)
A common complaint when working with data-dependent code in PyTorch is that it's hard to tell how far you are from the finish line: every time a GuardOnDataDependentSymNode error is hit, you have to somehow fix or workaround it to see the next one.

This PR adds a new mode `torch._functorch.config.fake_tensor_propagate_real_tensors` which modifies fake tensors to also propagate real tensors. This means that when we try to guard on a data-dependent SymNode, we can actually produce a real result. We also produce a warning which you should consult to figure out what the crux points are.

I ran this on vision_maskrcnn. In the baseline (without this mode), the model has 27 graph breaks, resulting in 40 graphs. With this mode on, the model has only 11 graph breaks, resulting in 15 graphs (the remaining graph breaks are due to missing functionality for item() on float tensor and some other Dynamo missing features.) You get a list of things that would have errored like this:

```
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u1) < 2) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u1), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u1), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Ne(Max(1, u1), 1)) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u0) < 2) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u0), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u0), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Ne(Max(1, u0), 1)) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u1) < 2) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u1), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u1), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Ne(Max(1, u1), 1)) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u0) < 2) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u0), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u0), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Ne(Max(1, u0), 1)) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u1) < 2) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u1), 1)) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Ne(Max(1, u1), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u0) < 2) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u0), 1)) -> False
```

Potential later follow ups:

* Improve the warning messages (in particular, should provide user frames)
* GC real tensors when they are no longer needed by tracing. Right now, this will use A LOT of memory, equal to as if your GC was broken and every intermediate tensor was kept live

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125115
Approved by: https://github.com/IvanKobzarev
2024-05-02 15:28:26 +00:00
Avik Chaudhuri
746da8755c switch tests from constrain_as* to torch._check* (#125253)
To fix data-dependent errors we want to recommend that people use `torch._check*` APIs. The `constrain_as*` APIs should be fully subsumed by them, and in the future we should kill them entirely.

Differential Revision: D56774333

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125253
Approved by: https://github.com/ezyang
2024-05-01 21:01:27 +00:00
YangQun1
91d565da0c [dynamo] Add support for tensor's is_complex method (#124927)
This PR is to add support for tensor's is_complex method in dynamo. Take the following code as an example:
```python
   def test_tensor_is_complex(x):
        if x.is_complex():
            return x + 1
        else:
            return x - 1
```
Before this fix, the is_complex() call will cause a graph break "torch.* op returned non-Tensor bool call_method is_complex". After this fix, the graph break can be avoided.

Fixes #122692

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124927
Approved by: https://github.com/ezyang
2024-04-26 18:28:14 +00:00
Edward Z. Yang
bebdbb63ce Introduce set_example_value and use it throughout Dynamo (#124176)
I'm going to setup some extra behavior when we set example value, so
I need a convenient place to interpose.  I cannot easily do it on
meta itself because its a generic dict with no interposition point.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124176
Approved by: https://github.com/oulgen
ghstack dependencies: #124105, #124059
2024-04-17 22:57:11 +00:00
Xuehai Pan
93e249969b [BE] enable ruff rule RSE and remove useless parentheses in raise statements (#124261)
Remove useless parentheses in `raise` statements if the exception type is raised with no argument.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124261
Approved by: https://github.com/albanD
2024-04-17 19:29:34 +00:00
Jason Ansel
f3fd280238 [dynamo] Relax strict_mode for autograd.Function forward inputs (#123910)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123910
Approved by: https://github.com/oulgen
2024-04-13 19:41:59 +00:00
Jason Ansel
70b8c58f84 [dynamo] Emit warning to turn on capture_scalar_outputs (#123896)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123896
Approved by: https://github.com/anijain2305
ghstack dependencies: #123700, #123705, #123786, #123790, #123803, #123804
2024-04-12 19:03:13 +00:00
Brian Hirsh
09be5800c8 dynamo: support placement kwargs for DTensor.to_local() (#119947)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119947
Approved by: https://github.com/wanchaol, https://github.com/yoyoyocmu
ghstack dependencies: #118803
2024-03-22 14:42:27 +00:00
Jason Ansel
477d154ffd [dynamo] Add missing _nonvar_fields annotations (#122219)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122219
Approved by: https://github.com/anijain2305
ghstack dependencies: #122218
2024-03-20 07:53:18 +00:00
Jason Ansel
153a01833b [dynamo] Optimize SourcelessBuilder (#122063)
Improves `benchmarks/dynamo/microbenchmarks/dynamo_microbenchmarks.py`
from 2.7s to 2.5s.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122063
Approved by: https://github.com/anijain2305
ghstack dependencies: #122039, #122043, #122055, #122058, #122060
2024-03-19 04:23:30 +00:00
Jason Ansel
8082adcf65 [dynamo] Only rename a proxy once (#122060)
Improves `benchmarks/dynamo/microbenchmarks/dynamo_microbenchmarks.py`
from 3.9s to 2.7s.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122060
Approved by: https://github.com/oulgen
ghstack dependencies: #122039, #122043, #122055, #122058
2024-03-19 04:23:27 +00:00
Jason Ansel
2bec55c5f9 [dynamo] Remove VariableTracker.parents_tracker (#122058)
This is leftover from mutable variable tracker days and no longer needed.

Improves benchmarks/dynamo/microbenchmarks/dynamo_microbenchmarks.py
from 4.2s to 3.9s.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122058
Approved by: https://github.com/oulgen, https://github.com/anijain2305
ghstack dependencies: #122039, #122043, #122055
2024-03-19 04:23:24 +00:00
Jason Ansel
769ff86b91 [dynamo] Optimize COMPARE_OP (#122039)
Improves `benchmarks/dynamo/microbenchmarks/dynamo_microbenchmarks.py`
from 5.6 to 5.1s.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122039
Approved by: https://github.com/Skylion007, https://github.com/anijain2305
2024-03-19 04:23:14 +00:00
Jason Ansel
5d52b163d1 [dynamo] Optimize load/store/const op handling (#122038)
Improves `benchmarks/dynamo/microbenchmarks/dynamo_microbenchmarks.py`
from 6.7s to 5.6.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122038
Approved by: https://github.com/Skylion007
ghstack dependencies: #122032, #122033, #122034, #122035
2024-03-18 18:08:06 +00:00
Jason Ansel
6ca0323615 [dynamo] Optimize VariableTracker.__post_init__ (#122034)
Improves `benchmarks/dynamo/microbenchmarks/dynamo_microbenchmarks.py`
from 8.6s to 7.3s.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122034
Approved by: https://github.com/Skylion007
ghstack dependencies: #122032, #122033
2024-03-18 18:08:06 +00:00