For tracing cond/while in eager, we trace the HOP with the eager backend with metadata torchfunction mode enabled. HOPs disallow the mutation that occurs in this torch function mode, so it is not able to be traced. As a result, we use a custom backend which enters this mode for tracing these HOPs. Thanks to @ydwu4 for the help with implementing this
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134732
Approved by: https://github.com/ydwu4
All of the previous benchmarks are similar, ListOfLinears should be representative enough.
I copied the previous benchmarks from unit tests without an intention, was just trying to create a large
number of benchmarks to better observe noise.
This PR keeps only one, we can add more as we see value and regressions in the future.
Also this diff adds a GPU version.
```
collecting compile time instruction count for basic_modules_ListOfLinears_eager
compile time instruction count for iteration 0 is 6479525851
compile time instruction count for iteration 1 is 1024432680
compile time instruction count for iteration 2 is 1019417317
compile time instruction count for iteration 3 is 1013603566
compile time instruction count for iteration 4 is 1008853980
compile time instruction count for iteration 5 is 1009541481
compile time instruction count for iteration 6 is 1005025533
compile time instruction count for iteration 7 is 1004116323
compile time instruction count for iteration 8 is 1000828633
compile time instruction count for iteration 9 is 999788323
collecting compile time instruction count for basic_modules_ListOfLinears_inductor
compile time instruction count for iteration 0 is 40837529730
compile time instruction count for iteration 1 is 18411921909
compile time instruction count for iteration 2 is 18383665161
compile time instruction count for iteration 3 is 18348983522
compile time instruction count for iteration 4 is 18349276590
compile time instruction count for iteration 5 is 18353046274
compile time instruction count for iteration 6 is 18346818581
compile time instruction count for iteration 7 is 18340057998
compile time instruction count for iteration 8 is 18331267320
compile time instruction count for iteration 9 is 18328381338
collecting compile time instruction count for basic_modules_ListOfLinears_inductor_gpu
compile time instruction count for iteration 0 is 15408870979
compile time instruction count for iteration 1 is 10949520859
compile time instruction count for iteration 2 is 11058786167
compile time instruction count for iteration 3 is 11003606719
compile time instruction count for iteration 4 is 10896406770
compile time instruction count for iteration 5 is 10982875189
compile time instruction count for iteration 6 is 10931848275
compile time instruction count for iteration 7 is 10956345008
compile time instruction count for iteration 8 is 11045384499
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135730
Approved by: https://github.com/ezyang, https://github.com/anijain2305
Summary:
Move towards consolidating strobelight profiler implementations between OSS and fbcode. This change is a first step towards that.
- Created a new function to abstract out compile time profiling enablement. This function allows profiler to switch between different function profilers (e.g. Thrift based or CLI based)
- Both OSS and Fbcode now use one compile time profiler in torch/_strobelight
Test Plan:
Tested OSS with following commands:
```
python torch/_strobelight/examples/compile_time_profile_example.py
python torch/_strobelight/examples/cli_function_profiler_example.py
TORCH_COMPILE_STROBELIGHT=TRUE TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 python benchmarks/dynamo/huggingface.py --ci --accuracy --timing --explain --inductor --device cuda --training --amp --only XLNetLMHeadModel
```
See test commands for fbcode in comments.
Differential Revision: D62444551
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135953
Approved by: https://github.com/laithsakka
If node is AC region output and has a backward hook on it, we intentionally choose to save it.
This is to work around circular dependencies in Traceable FSDP2+AC.
Example:
```
out = fully_shard(utils.checkpoint(module))(x)
norm_out = layer_norm(out)
```
and there is a circular dependency:
1. In backward, grad_input of layer_norm aka. `out_grad` is actually dependent on `out`.
2. `out` depends on `out`'s backward hook created by FSDP2 (which does all-gather for `module` weights) in order to be recomputed.
3. `out`'s FSDP2 backward hook, as is the case for all eager backward hooks, depends on `out_grad` -> circular dependency with (1)!
Solution: check whether `out` has a backward hook, and if so, intentionally save `out` in forward graph outputs. With this, we can break the above circular dependency.
----
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135727
Approved by: https://github.com/Chillee
During enablement of Traceable FSDP2 on internal models, sometimes the user only applies torch.compile to some of the FSDP2 instances but not all of them. Such mixed usage pattern is not supported by compiled autograd. Here we try to catch and throw error at such usage pattern, so that the user can fix the usage.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135824
Approved by: https://github.com/awgu
When we measure compile time instruction count, probably we do want in most cases to measure gc instructions
disabling it here by default.
if it is needed we can add an option to allow it, or someone can use the regular total instruction count instead of compile time instruction count.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135768
Approved by: https://github.com/ezyang, https://github.com/anijain2305
https://github.com/pytorch/pytorch/pull/133012 caused a regression on ROCm causing pointwise scan tests to fail
```
ERROR: test_pointwise_associative_scan_tuple_reverse_True_combine_mode_pointwise_cuda
ERROR: test_pointwise_associative_scan_tuple_reverse_False_combine_mode_pointwise_cuda
ERROR: test_pointwise_associative_scan_complex_pytree_reverse_True_combine_mode_pointwise_cuda
ERROR: test_pointwise_associative_scan_complex_pytree_reverse_False_combine_mode_pointwise_cuda
ERROR: test_pointwise_associative_scan_binary_operator_reverse_True_combine_mode_pointwise_cuda
ERROR: test_pointwise_associative_scan_binary_operator_reverse_False_combine_mode_pointwise_cuda
```
Skipping temporarily while triage is underway.
Full log: https://ossci-raw-job-status.s3.amazonaws.com/log/30067645445
```
File "/opt/conda/envs/py_3.8/lib/python3.8/site-packages/torch/_inductor/graph.py", line 1020, in call_function
out = lowerings[target](*args, **kwargs) # type: ignore[index]
File "/opt/conda/envs/py_3.8/lib/python3.8/site-packages/torch/_inductor/lowering.py", line 363, in wrapped
out = decomp_fn(*args, **kwargs)
File "/opt/conda/envs/py_3.8/lib/python3.8/site-packages/torch/_inductor/lowering.py", line 6245, in associative_scan
raise RuntimeError("Unable to generate code for associative_scan op")
torch._inductor.exc.LoweringException: RuntimeError: Unable to generate code for associative_scan op
```
NOTE: even "eager" backend fails
```
File "/opt/conda/envs/py_3.8/lib/python3.8/site-packages/torch/_higher_order_ops/associative_scan.py", line 338, in associative_scan_op_dense
raise NotImplementedError("associative_scan is not implemented for eager")
NotImplementedError: associative_scan is not implemented for eager
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135995
Approved by: https://github.com/malfet
This PR implements tracing of with contexts with TorchFunction modes which have the default enter/exit behavior (ie pushing/popping the mode)
Typically the bytecode for a context manager looks like this during a graph break:
1. graph call
2. enter context
3. unsupported code
4. exit context
5. resume call
resume fn structure:
1. enter context
2. jump
...
3. exit context
The issue with torch function modes is that side effects will replay any mutations to the torch function stack performed during tracing. So, we do not need to enter and exit around the unsupported code in the original function (doing so would result in a duplicate torch function mode entry during execution of the unsupported code), and we don't need to enter again in the resume function (the mode that was pushed from the side effects bytecode would still be on the stack).
So for torch function modes the structure of our output code is this:
1. graph call
2. mutate tf mode stack to replay mutations
4. unsupported code
5. on exception restore stack
6. resume function
Then our resume fn looks like this:
1. no-op enter torch function mode
2. jump
3. exit tf mode
To implement the no-op enter of the torch function mode I added torch function mode in polyfill which no-op enters, but normally exits. This is needed because we still want to trace the with context in the resume function, and exit properly (the exit instructions will still be in the function, so we need to generate instructions to set up the context).
Separately from the bytecode, dynamo also tracks contexts on the block stack, which is how the SETUP_* instructions are implemented. Naturally at a graph break, we exit these block stacks to properly reset the contexts entirely, so that we can re-enter around the unsupported code soundly. However once again, in the torch function mode case, in the event of a graph we do not want to perform any exit side effects because we want to preserve the state of the mode stack as is so that we will properly update the stack with bytecode mentioned in the first section. If we exited here, dynamo would pop the mode off of the symbolic stack, and not update the true python torch function mode stack with the suffix bytecode. All in all, for torch function modes we enter exactly once, update the global torch function mode stack with side effects bytecode, re-read this stack when compiling the resume function, and exit exactly once in the resume function. This matches the semantics of eager exactly.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135422
Approved by: https://github.com/williamwen42
ghstack dependencies: #134732, #133137, #135443, #135444
In preparation for tracing through DeviceContext (defb515306/torch/utils/_device.py (L66))
This PR adds support for calling the setattr of thread local objects. These objects have a slots impl, and since this doesn't appear to have any side effects, we call this setattr impl when replaying mutations, since calling `object.__setattr__` on these objects results in a type error.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135443
Approved by: https://github.com/anijain2305
ghstack dependencies: #134732, #133137
For tracing cond/while in eager, we trace the HOP with the eager backend with metadata torchfunction mode enabled. HOPs disallow the mutation that occurs in this torch function mode, so it is not able to be traced. As a result, we use a custom backend which enters this mode for tracing these HOPs. Thanks to @ydwu4 for the help with implementing this
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134732
Approved by: https://github.com/ydwu4
This PR solves two problems with `sum()` support in NJT:
* `sum()` over a dim with `keepdim=True` returns the wrong shape (i.e. it'll keep the wrong dim). This is a long-standing bug from way back in #112519.
* Historically, we've only supported `sum()` over a dim and not a full reduction. This PR adds the full reduction form (forward only, backward still fails).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131945
Approved by: https://github.com/davidberard98, https://github.com/jananisriram
Summary:
Previously we only checked dtype and is_dynamic to decide if two quantization spec are equivalent
this may not work in some cases, e.g. when people use different qscheme or quant_min/quant_max
This PR added checks for other fields as well
Test Plan:
regression tests
Reviewers:
Subscribers:
Tasks:
Tags:
Differential Revision: [D62530974](https://our.internmc.facebook.com/intern/diff/D62530974)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135736
Approved by: https://github.com/sxu
there was a recent strange noise +5%, -5%.
using only compile time :
1) avoid gc time .
2) avoid other operations that are not what we try to measure by this. ==> less probable noise.
```
collecting compile time instruction count for sum_floordiv_regression
compile time instruction count for iteration 0 is 8899290248
compile time instruction count for iteration 1 is 1188830489
compile time instruction count for iteration 2 is 1180579615
compile time instruction count for iteration 3 is 1176263131
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135785
Approved by: https://github.com/avikchaudhuri, https://github.com/anijain2305