Summary:
I found a confusing bug in the PassManager that only happens
when you instantiate one multiple times: it will use old passes and
constraints!
This occurs because the class-level declarations initialize it to an empty list,
but the problem is that class initializers only run once, and are creating class
variables. This means the same empty list was being reused every time, except
after the first time it isn't empty.
The empty list has to be created in `__init__` newly each time or else it'll be shared.
Note that this is the same type of bug as using an empty list as a default parameter, where
it'll reuse the same list pointer and not make it empty each time.
The better way to do this is with either:
* An immutable default parameter like an empty tuple, that you create a new list from: `self.passes = list(passes)`
* Use None and then create the empty list inside `__init__`
I chose the latter as it's less likely to cause a behavior change due to the changed default.
Note that for immutable values like `False` and `1` this doesn't apply as you can't mutate that
value for everyone.
Test Plan:
Added a test to ensure that the pass state is not saved.
Without my change, this test would fail as it would run all of the `2 * x` passes first,
then all of the `3 * x` passes.
Differential Revision: D41327056
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89108
Approved by: https://github.com/angelayi
Summary:
* Added an error message for when the result is not a PassResult
* Modified the error handling to capture exceptions that happen in the check() function
* consolidated inplace_wrapper and pass_result_wrapper
Test Plan: CI
Differential Revision: D40950135
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88520
Approved by: https://github.com/SherlockNoMad
Summary: Before the change, wrapped_fn should only take mutating passes, but we don't actually have any way to detect whether a pass is mutating before running it. To make this an abstraction without involving any precondition depending on PassManager run, we could just relax the precondition to take any kind of passes, and conditionally return the original pass based on the pass result.
Test Plan: eyes
Reviewed By: qihqi, angelayi
Differential Revision: D39086343
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84232
Approved by: https://github.com/angelayi
Example:
```
======================================================================
ERROR: test_pass_manager_error (fx.test_pass_infra.TestPassManager)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/Users/angelayi/Projects/pytorch/torch/fx/passes/infra/pass_manager.py", line 285, in __call__
res = fn(module)
File "/Users/angelayi/Projects/pytorch/test/fx/test_pass_infra.py", line 164, in pass_fail
raise RuntimeError("bad")
RuntimeError: bad
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/Users/angelayi/Projects/pytorch/test/fx/test_pass_infra.py", line 170, in test_pass_manager_error
pm(traced_m)
File "/Users/angelayi/Projects/pytorch/torch/fx/passes/infra/pass_manager.py", line 289, in __call__
raise RuntimeError(msg) from e
RuntimeError: An error occured when running the 'pass_fail' pass after the following passes: ['replace_add_with_mul_pass', 'replace_mul_with_div_pass']
```
Fixes #ISSUE_NUMBER
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83933
Approved by: https://github.com/SherlockNoMad
PassManager is a class used to run multiple passes on a given graph module.
Class Attributes
* `passes: List[Callable]`: A list of callable passes
* `constraints: List[Callable]`: A list of constraints
* `run_checks_after_each_pass`: Flag for running checks each pass
Class Methods:
* `__call__(graph_module: DispatchGraphModule)`:
* Runs the passes based on the list of passes until the graph stops changes, or until `steps` number of times.
* Each time a pass is run, it will check that the graph module still maintains the required invariants by calling `check()` and will lint the graph to check that it’s well formed if the flag `run_checks_after_each_pass` is set.
* `check(graph_module: DispatchGraphModule)`: Runs various checks on the given graph module to make sure that it contains the needed data for passes
* `add_check(check: Callable)`: Adds the `check` function to the given pass manager instance
* `add_constraint(constraint: Callable)`: Adds a constraint to the current list of constraints
We can create a PassManager and run it by doing:
```
PassManager(passes=[pass1, pass2])(graph_module)
```
Differential Revision: [D37523159](https://our.internmc.facebook.com/intern/diff/D37523159)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80531
Approved by: https://github.com/SherlockNoMad