Commit Graph

115 Commits

Author SHA1 Message Date
weiyusheng
c3949b20a1 Opt model save and load (#126374)
## save&load support for OptimizedModule

[Issue Description](https://github.com/pytorch/pytorch/pull/101651)

English is not my native language; please excuse typing errors.

This pr is based on commit b9588101c4d3411b107fdc860acfa8a72c642f91\
I'll do something with the merge conflicts later

### test result for test/dynamo

Conclusion:\
It performs the same as before as far as I can see.

ENV(CPU only):\
platform linux -- Python 3.10.14, pytest-7.3.2, pluggy-1.5.0\
configfile: pytest.ini\
plugins: anyio-3.7.1, cpp-2.3.0, flakefinder-1.1.0, xdist-3.3.1, xdoctest-1.1.0, metadata-3.1.1, html-4.1.1, hypothesis-5.35.1, rerunfailures-14.0

#### before this pr:

[before](https://github.com/pytorch/pytorch/files/15329370/before.md)

#### after this pr:

[after](https://github.com/pytorch/pytorch/files/15329376/after.md)

### some changes

1. add test_save_and_load to test/dynamo/test_modules.py with & without "backend='inductor'"
2. add \_\_reduce\_\_ function to OptimizedModule and derived classes of _TorchDynamoContext for pickling & unpickling
3. change the wrappers into wrapper classes ( including convert_frame_assert, convert_frame, catch_errors_wrapper in torch/_dynamo/convert_frame.py & wrap_backend_debug in torch/_dynamo/repro/after_dynamo.py )
4. change self.output.compiler_fn into innermost_fn(self.output.compiler_fn) in torch/_dynamo/symbolic_convert.py to get the origin compiler_fn and to avoid the "compiler_fn is not eager" condition

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126374
Approved by: https://github.com/msaroufim, https://github.com/jansel
2024-06-05 13:01:16 +00:00
Yanbo Liang
c1b90a4e8a [Dynamo] Treat integers stored on nn.Modules as dynamic (#126466)
Fixes #115711

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126466
Approved by: https://github.com/jansel
2024-05-21 03:31:20 +00:00
PyTorch MergeBot
71b6459edc Revert "[Dynamo] Treat integers stored on nn.Modules as dynamic (#126466)"
This reverts commit 6bb9d6080d.

Reverted https://github.com/pytorch/pytorch/pull/126466 on behalf of https://github.com/huydhn due to Sorry for reverting your change but the ONNX test failure looks legit, not flaky, as it starts failing in trunk 6bb9d6080d ([comment](https://github.com/pytorch/pytorch/pull/126466#issuecomment-2119078245))
2024-05-19 02:52:11 +00:00
Yanbo Liang
6bb9d6080d [Dynamo] Treat integers stored on nn.Modules as dynamic (#126466)
Fixes #115711

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126466
Approved by: https://github.com/jansel
2024-05-18 05:02:16 +00:00
Animesh Jain
bd63300bae [dynamo][inline-inbuilt-nn-modules] Add and update test_modules.py for nlining work (#126327)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126327
Approved by: https://github.com/williamwen42
ghstack dependencies: #126303, #126316, #126314
2024-05-16 01:35:09 +00:00
Animesh Jain
90461d4986 [dynamo] Detect monkeypatching on nn module forward method (#126203)
An alternative was https://github.com/pytorch/pytorch/pull/124975. Though it was safer because it was adding guards for every inlined function, it was causing guard overhead for a few models of > 20%.  The overhead of this PR is minimal for the common unpatched case.

Fixes an internal issue - [fb.workplace.com/groups/1075192433118967/permalink/1411067766198097](https://fb.workplace.com/groups/1075192433118967/permalink/1411067766198097/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126203
Approved by: https://github.com/ezyang
2024-05-15 20:41:13 +00:00
Animesh Jain
5ba777f46e [guards][cpp-guards] Optimize NN module getattr guards (#124522)
Improves the guard overhead of MobileBert model with nn module guards from 92000 units to 20000 units.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124522
Approved by: https://github.com/jansel
ghstack dependencies: #125439, #125421
2024-05-04 22:08:56 +00:00
Animesh Jain
a13a0a2479 [dynamo][easy] Simple fixes to prepare for nn module guards (#125316)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125316
Approved by: https://github.com/williamwen42
ghstack dependencies: #125275
2024-05-02 12:08:11 +00:00
Animesh Jain
1a0b247762 [dynamo] Bug fix for LOAD_GLOBAL and STORE_GLOBAL (#125002)
Earlier globals of inlined functions from other files were not handled correctly. We were not tracking mutations on them. They were colliding with the same global name in the parent function etc. This PR overrides the LOAD/STORE_GLOBAL for inline tx and tracks mutation on them separately.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125002
Approved by: https://github.com/jansel
ghstack dependencies: #125097, #125107
2024-04-28 15:24:17 +00:00
Will Feng
7a78534468 [Compile FSDP2][1/n] Support using user-defined object instance method as hook (#123399)
FSDP2 has this pattern of using user-defined object instance method as hook, and it will throw this error under compile:
`torch._dynamo.exc.Unsupported: call_function UserDefinedObjectVariable(_pre_forward) [FSDPManagedNNModuleVariable(), TupleVariable(), ConstDictVariable()] {}`

This PR adds support for it by always allowing to trace into a UserDefinedObjectVariable that's an instance method (i.e. `MethodType`).

Supersedes https://github.com/pytorch/pytorch/pull/123320.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123399
Approved by: https://github.com/jansel
2024-04-09 17:29:08 +00:00
Peter Bell
6939279a17 [dynamo] Forward OptimizedModule.__setattr__ to the wrapped module (#122098)
Fixes #114844

In the linked issue we have
```
compiled_module = torch.compile(module)
compiled_module.x = ...
compiled_module(...)  # Mutates self.x
```
Where since the module mutates `self.x` you would expect `compiled_module.x`
to be updated but actually `compiled_module.x = ...` sets an attribute "x"
on the `OptimizedModule` object while the forward method of the module mutates
`module.x`.

This gives the expected behavior by forwarding `compiled_module.__setattr__`
down to `module.__setattr__`. There is already a corresponding `__getattr__`
so now `compiled_module.x` becomes an alias for `module.x`.

Co-authored-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122098
Approved by: https://github.com/ezyang, https://github.com/lezcano
2024-04-01 14:30:44 +00:00
PyTorch MergeBot
f631586084 Revert "[dynamo] Forward OptimizedModule.__setattr__ to the wrapped module (#122098)"
This reverts commit b6982bf2b2.

Reverted https://github.com/pytorch/pytorch/pull/122098 on behalf of https://github.com/atalman due to Failing internally ([comment](https://github.com/pytorch/pytorch/pull/122098#issuecomment-2021233604))
2024-03-26 18:54:17 +00:00
Peter Bell
b6982bf2b2 [dynamo] Forward OptimizedModule.__setattr__ to the wrapped module (#122098)
Fixes #114844

In the linked issue we have
```
compiled_module = torch.compile(module)
compiled_module.x = ...
compiled_module(...)  # Mutates self.x
```
Where since the module mutates `self.x` you would expect `compiled_module.x`
to be updated but actually `compiled_module.x = ...` sets an attribute "x"
on the `OptimizedModule` object while the forward method of the module mutates
`module.x`.

This gives the expected behavior by forwarding `compiled_module.__setattr__`
down to `module.__setattr__`. There is already a corresponding `__getattr__`
so now `compiled_module.x` becomes an alias for `module.x`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122098
Approved by: https://github.com/ezyang, https://github.com/lezcano
2024-03-26 00:52:12 +00:00
PyTorch MergeBot
e5e0685f61 Revert "[dynamo] Forward OptimizedModule.__setattr__ to the wrapped module (#122098)"
This reverts commit 88ebdbc97c.

Reverted https://github.com/pytorch/pytorch/pull/122098 on behalf of https://github.com/huydhn due to Sorry for reverting your change but the distributed failure looks legit as it is also failing in trunk 88ebdbc97c ([comment](https://github.com/pytorch/pytorch/pull/122098#issuecomment-2008483316))
2024-03-20 01:12:24 +00:00
Peter Bell
88ebdbc97c [dynamo] Forward OptimizedModule.__setattr__ to the wrapped module (#122098)
Fixes #114844

In the linked issue we have
```
compiled_module = torch.compile(module)
compiled_module.x = ...
compiled_module(...)  # Mutates self.x
```
Where since the module mutates `self.x` you would expect `compiled_module.x`
to be updated but actually `compiled_module.x = ...` sets an attribute "x"
on the `OptimizedModule` object while the forward method of the module mutates
`module.x`.

This gives the expected behavior by forwarding `compiled_module.__setattr__`
down to `module.__setattr__`. There is already a corresponding `__getattr__`
so now `compiled_module.x` becomes an alias for `module.x`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122098
Approved by: https://github.com/ezyang, https://github.com/lezcano
2024-03-19 16:51:43 +00:00
Animesh Jain
0b11b0edd6 [dynamo][refactor] Use existing helper functions for CLOSURE_MATCH (#120145)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120145
Approved by: https://github.com/jansel, https://github.com/Fidget-Spinner
ghstack dependencies: #120132, #120140
2024-02-18 00:31:36 +00:00
Yanbo Liang
2a63dd8889 [Dynamo] Support lazy module with namedtuple/dict input (#119972)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119972
Approved by: https://github.com/jansel
2024-02-15 23:18:18 +00:00
Yue Dong
915f9db03c [Dynamo] Support kwargs for lazy module (#119445)
Summary:
Seems like `kwargs` is already support in `_infer_argument`, so we don't need the extra assertion `len(kwargs) == 0`.

This optimization ensures compatibility with torch.compile() for LazyModules with kwargs inputs, preventing graph breaks.

Test Plan: Unit tetst and CI

Differential Revision: D53558778

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119445
Approved by: https://github.com/yanboliang
2024-02-09 00:46:41 +00:00
lezcano
eb2bdfae88 Make variables in dict LazyTrackers (not lazily guarded yet) and avoid using DICT_KEYS guard (#117625)
Make variables in dict lazy and remove DICT_KEYS guard.

We build the keys of a dict depth-first and we rely on the guards of
each element in the dict to create the correct guards. This allows us to
remove the rather buggy DICT_KEYS guard and make the guard lazy.
The guards are not completely lazy yet, as we instantiate them in
`_HashableTracker._eq_impl` but it should be possible to make them
truly lazy.

Also, adding new types to the supported types within keys should be less
error prone.

This is marginally less efficient when we graph break, but in turn we
should graph break much less. It also  makes the dicts code easier to maintain
(removes `is_hashable_python_var`).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117625
Approved by: https://github.com/jansel, https://github.com/peterbell10, https://github.com/anijain2305
ghstack dependencies: #117982, #118098, #117983
2024-02-02 14:38:08 +00:00
Aaron Gokaslan
bd10fea79a [BE]: Enable F821 and fix bugs (#116579)
Fixes #112371

I tried to fix as many of the bugs as I could, a few I could not figure out what the proper fix for them was though and so I left them with noqas.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116579
Approved by: https://github.com/ezyang
2024-01-01 08:40:46 +00:00
Yingxin Kang
199b04fdbd Back out "Implement pass-through state_dict and load_state_dict for dynamo OptimizedModule (#113423)" (#116243)
Summary:
Original commit changeset: 2a9588cfd51b

Original Phabricator Diff: D52062368

Test Plan: In investigating S386328 and S382826, we found checkpoint loading succeed after backout D52062368: S386328_backout_1220_193648

Differential Revision: D52356011

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116243
Approved by: https://github.com/voznesenskym
2023-12-21 17:57:05 +00:00
Adrian Wälchli
38f890341d Implement pass-through state_dict and load_state_dict for dynamo OptimizedModule (#113423)
Fixes #113422
Fixes #94575

This is now possible:
```py
model = Model()
compiled_model = torch.compile(model)

model.load_state_dict(compiled_model.state_dict())  # previously key mismatch!
```

This also makes it much easier to checkpoint and load models that were wrapped like so:
```py
FSDP(torch.compile(model))
# or
DDP(torch.compile(model))
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113423
Approved by: https://github.com/msaroufim
2023-12-10 22:09:19 +00:00
Jason Ansel
4ee80fd7f4 [dynamo] Support UNPACK_SEQUENCE nn.ModuleList (#114959)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114959
Approved by: https://github.com/oulgen, https://github.com/yanboliang
2023-12-01 21:42:23 +00:00
PyTorch MergeBot
92e3f45f0e Revert "[dynamo] Refactor test cross importing (#113242)"
This reverts commit 4309d38f5d.

Reverted https://github.com/pytorch/pytorch/pull/113242 on behalf of https://github.com/huydhn due to Sorry for reverting your stack, but it is failing to list test internally with buck2 ([comment](https://github.com/pytorch/pytorch/pull/113242#issuecomment-1811674395))
2023-11-15 01:53:07 +00:00
Ken Jin
70064ac416 [Dynamo] Match closures by code ID (#109427)
Closes https://github.com/pytorch/pytorch/issues/107866

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109427
Approved by: https://github.com/ezyang, https://github.com/jansel
2023-11-12 08:20:14 +00:00
Jason Ansel
4309d38f5d [dynamo] Refactor test cross importing (#113242)
Having tests import tests is a bit annoying because fbcode/oss have different paths.  This moves that stuff into a helper function.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113242
Approved by: https://github.com/yanboliang
2023-11-11 03:17:35 +00:00
PyTorch MergeBot
59592389fc Revert "[dynamo] Refactor test cross importing (#113242)"
This reverts commit 8858edad65.

Reverted https://github.com/pytorch/pytorch/pull/113242 on behalf of https://github.com/PaliC due to this diff appears to be causing inductor failures internally ([comment](https://github.com/pytorch/pytorch/pull/113242#issuecomment-1805132719))
2023-11-10 05:43:08 +00:00
Jason Ansel
8858edad65 [dynamo] Refactor test cross importing (#113242)
Having tests import tests is a bit annoying because fbcode/oss have different paths.  This moves that stuff into a helper function.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113242
Approved by: https://github.com/yanboliang
2023-11-09 01:36:27 +00:00
Aaron Gokaslan
9c1fb2cbb3 [BE]: Enable ruff PIE794 and fix bugs it found in test suite (#112989)
Enables some tests that were incorrectly not being run and enables PIE794 globally. This rule checks if a classvar is defined twice as flags it as it is likely a bug. In fact, we found several cases where it was a bug. It does have a couple of false positives which I flagged upstream and replaced with noqas: https://github.com/astral-sh/ruff/issues/8497

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112989
Approved by: https://github.com/malfet
2023-11-05 22:11:53 +00:00
Kazuaki Ishizaki
9089242048 Fix typo under test directory (#112346)
This PR fixes typo in comments and messages under `test` directory. This PR also fixes related typo in messages under `torch` directory.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112346
Approved by: https://github.com/kit1980, https://github.com/ezyang
2023-11-03 07:53:33 +00:00
Jon Chuang
6d78f34a06 fix regression which creates a new fake tensor (#111864)
Fixes regression identified here: ccd6b373b5 (r1369334484)

Now that `get_fake_value` will identify aliases, we should not try to wrap the fake value again.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/111864
Approved by: https://github.com/eellison
2023-10-24 05:11:48 +00:00
Jon Chuang
47eed65481 [dynamo] Add is_ support for Tensors, force get_fake_value to reuse previously computed example_value if available (#111565)
Use FakeTensor id match as equivalent to object identity match

cc

Pull Request resolved: https://github.com/pytorch/pytorch/pull/111565
Approved by: https://github.com/ezyang
2023-10-21 13:56:30 +00:00
Michael Voznesensky
1e7947b3e0 Revert "Reland 3rd try [finishing colesbury's PR 100642] Guard on nn.Module dicts and type (#109323)" + Forward fixes + test (#110964)
This reverts commit f786fbdebd.

Forward fixes

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110964
Approved by: https://github.com/ezyang, https://github.com/anijain2305
2023-10-11 05:16:47 +00:00
Animesh Jain
ce8b4f56d8 [dynamo] Dont put nn module guards on torch inbuilt nn modules (#110230)
This is one way to fix https://github.com/pytorch/pytorch/issues/110048

Looking for feedback.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110230
Approved by: https://github.com/ezyang
2023-09-29 00:43:16 +00:00
PyTorch MergeBot
559d1f94a0 Revert "[Dynamo][Test] reland testcase with state (#109713)"
This reverts commit 5c897eacff.

Reverted https://github.com/pytorch/pytorch/pull/109713 on behalf of https://github.com/PaliC due to creates a out of memory error for macos tests ([comment](https://github.com/pytorch/pytorch/pull/109713#issuecomment-1728314478))
2023-09-20 19:34:07 +00:00
Kaichao You
5c897eacff [Dynamo][Test] reland testcase with state (#109713)
Reland the PR https://github.com/pytorch/pytorch/pull/108750 reverted by https://github.com/pytorch/pytorch/issues/108838 , since https://github.com/pytorch/pytorch/pull/108969 has been merged.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109713
Approved by: https://github.com/anijain2305
2023-09-20 18:19:18 +00:00
Animesh Jain
f786fbdebd Reland 3rd try [finishing colesbury's PR 100642] Guard on nn.Module dicts and type (#109323)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109323
Approved by: https://github.com/huydhn, https://github.com/voznesenskym
2023-09-15 08:44:14 +00:00
PyTorch MergeBot
92de1d2d02 Revert "[Dynamo][Test]Add a testcase for module with training state (#108750)"
This reverts commit f90444cf0b.

Reverted https://github.com/pytorch/pytorch/pull/108750 on behalf of https://github.com/huydhn due to Sorry for reverting you change, but it starts failing this test https://github.com/pytorch/pytorch/issues/108838 without https://github.com/pytorch/pytorch/pull/108883 and the latter has been reverted ([comment](https://github.com/pytorch/pytorch/pull/108750#issuecomment-1712708800))
2023-09-10 04:45:00 +00:00
PyTorch MergeBot
56c2386157 Revert "reland [finishing colesbury's PR 100642] Guard on nn.Module dicts and type (#108883)"
This reverts commit d4230e5574.

Reverted https://github.com/pytorch/pytorch/pull/108883 on behalf of https://github.com/huydhn due to Per the discussion thread on D49122208, reverting this change ([comment](https://github.com/pytorch/pytorch/pull/108883#issuecomment-1712707853))
2023-09-10 04:40:02 +00:00
Michael Voznesensky
e4350d6d4e Functools partial support in dynamo (#108846)
The strategy for supporting functools partials is relatively straightforward.

There are 2 cases we need to support:

**1) Functools partials as input**
In this case, we are first seeing the functools partial and it is guaranteed to have a source. As such, the args, keywords, and func of the functools partial are passed through VariableBuilder. As this is the first time we are seeing these objects (as it is an input), we re-enter VariableBuilder with a source referencing the args, keywords, and func as attributes of the input to produce:

- func: A callable VariableTracker (UDF, TorchVariable, etc) depending on the value of `func`
- args: List[VariableTracker] - note, not ListVariableTracker!
- keywords: Dict[str, VariableTracker]

A major benefit of this structure is that it very elegantly matches the args to `call_function`.

We then compose a FunctoolsPartialVariable from the VariableTrackers made above.

**2) Functools partials created within compile**
In this case, we already have all the args as known VTs, and thus just compose a FunctoolsPartialVariable as we do for case (1).

For both (1) and (2) - we propagate all guards from the func, args, and keyword VTs to the FunctoolsPartialVariable

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108846
Approved by: https://github.com/ezyang, https://github.com/jansel
2023-09-09 17:25:02 +00:00
Animesh Jain
d4230e5574 reland [finishing colesbury's PR 100642] Guard on nn.Module dicts and type (#108883)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108883
Approved by: https://github.com/voznesenskym, https://github.com/huydhn
2023-09-09 03:12:31 +00:00
PyTorch MergeBot
72f24d0001 Revert "[dynamo][finishing colesbury's PR 100642] Guard on nn.Module dicts and type (#108528)"
This reverts commit 34bb74c4cf.

Reverted https://github.com/pytorch/pytorch/pull/108528 on behalf of https://github.com/huydhn due to Sorry for reverting your change, but it has some nasty merge conflicts after the revert of D48910794. I need to revert this so the conflict could be resolved. Please help rebase this tomorrow and reland the change ([comment](https://github.com/pytorch/pytorch/pull/108528#issuecomment-1711034781))
2023-09-08 03:49:41 +00:00
youkaichao
f90444cf0b [Dynamo][Test]Add a testcase for module with training state (#108750)
Add the problem mentioned in https://github.com/pytorch/pytorch/issues/105653 into tests. This issue has been addressed by https://github.com/pytorch/pytorch/pull/108528 .

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108750
Approved by: https://github.com/anijain2305
2023-09-08 02:39:42 +00:00
Zhengxu Chen
c75aec90d3 [dynamo] Record nn_module_stack also for unspecialized nn modules. (#108281)
Summary: Currently node metadata "nn_module_stack" is only being used by export. For some export model, we still want to retain nn_module_stack for unspecialized module for various purposes. This diff add a path to also record nn_module_stack when unspecialized module has a source available.

Test Plan: test_export_nn_module_stack_patched_module

Differential Revision: D48841193

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108281
Approved by: https://github.com/yanboliang, https://github.com/tugsbayasgalan
2023-09-07 15:38:46 +00:00
Animesh Jain
34bb74c4cf [dynamo][finishing colesbury's PR 100642] Guard on nn.Module dicts and type (#108528)
**This PR is a 99% copy paste of Sam Gross** (@colesbury) work at https://github.com/pytorch/pytorch/pull/100642. Copied from there

--------
The NN_MODULE guard now subsumes guards on Module attributes. The check_fn will fail if the module attributes are changed (such as Module.training), parameters, submodules, and buffers are added or removed, and if fields are changed on the type itself.

This gives up specificity in the guard check -- if any field is changed the check_fn fails -- for faster overall checks.

-----

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108528
Approved by: https://github.com/ezyang
2023-09-07 01:45:47 +00:00
Jason Ansel
6d61d74545 [dynamo] Fix setattr nn.Module with new attribute (#108098)
This is one (but not all) issues in DALLE2_pytorch

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108098
Approved by: https://github.com/eellison
ghstack dependencies: #108096, #108087
2023-08-29 02:58:48 +00:00
Animesh Jain
9d2ffc5dfa [reland][Dynamo] cache_size policy #107496 (#108069)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108069
Approved by: https://github.com/yanboliang
2023-08-28 22:06:54 +00:00
PyTorch MergeBot
b4c6c4da88 Revert "[Dynamo] cache_size policy (#107496)"
This reverts commit 4175a6e944.

Reverted https://github.com/pytorch/pytorch/pull/107496 on behalf of https://github.com/ZainRizvi due to Breaking internal builds ([comment](https://github.com/pytorch/pytorch/pull/107496#issuecomment-1693590121))
2023-08-25 16:07:14 +00:00
Animesh Jain
4175a6e944 [Dynamo] cache_size policy (#107496)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107496
Approved by: https://github.com/ezyang
ghstack dependencies: #107645
2023-08-24 21:50:00 +00:00
Wanchao Liang
9c2b4a35a3 [dtensor] group all dynamo tests together (#107487)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107487
Approved by: https://github.com/fduwjj
ghstack dependencies: #107472, #107473
2023-08-21 23:56:00 +00:00