Resolves#164972
- #164972
All `torch.utils._cxx_pytree` functions are based on `optree` functions with hardcoded `none_is_leaf=True` and `namespace="torch"`. This PR changes the polyfills to generic `optree` functions with those arguments unhardcoded. This means `torch.utils._cxx_pytree` functions are still traceable while the community `optree` usages can get dynamo support additionally.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165860
Approved by: https://github.com/Lucaskabela
Fixes#165911
- Add message to Attribute error so we see ` Developer debug context: raised exception AttributeError(["'Linear' object has no attribute 'w'"])` instead of just `Developer debug context: raised exception AttributeError([])`
- Add stack trace in `ObservedException` so we display the inner most error stack trace back to user code
Output:
```
/data/users/shangdiy/pytorch/torch/__init__.py:2641: UserWarning: You are calling torch.compile inside torch.export region. To capture an useful graph, we will implicitly switch to torch.compile(backend=eager)
warnings.warn(
Traceback (most recent call last):
File "/data/users/shangdiy/pytorch/torch/_dynamo/variables/user_defined.py", line 1385, in var_getattr
subobj = self._getattr_static(name)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/shangdiy/pytorch/torch/_dynamo/variables/user_defined.py", line 1256, in _getattr_static
subobj = type(self.value).__getattribute__(self.value, name)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'Linear' object has no attribute 'w'
During handling of the above exception, another exception occurred:
torch._dynamo.exc.ObservedAttributeError: 'Linear' object has no attribute 'w'
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/data/users/shangdiy/pytorch/test.py", line 34, in <module>
mod = torch._dynamo.functional_export._dynamo_graph_capture_for_export(Model())(x)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/shangdiy/pytorch/torch/_dynamo/functional_export.py", line 481, in inner
out = fullgraph_capture(
^^^^^^^^^^^^^^^^^^
File "/data/users/shangdiy/pytorch/torch/_dynamo/convert_frame.py", line 1053, in fullgraph_capture
return _fullgraph_capture_frame(
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/shangdiy/pytorch/torch/_dynamo/convert_frame.py", line 1115, in _fullgraph_capture_frame
raise e.with_traceback(None) from e.__cause__ # User compiler error
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch._dynamo.exc.Unsupported: Observed exception
Explanation: Dynamo found no exception handler at the top-level compiled function when encountering an exception. Exception will propagate outside the compiled region.
Hint: Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled.
Hint: It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues.
Developer debug context: raised exception AttributeError(["'Linear' object has no attribute 'w'"])
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0088.html
from user code:
File "/data/users/shangdiy/pytorch/torch/_dynamo/functional_export.py", line 171, in forward
res = self._export_root(*args, **kwargs)
File "/data/users/shangdiy/pytorch/test.py", line 31, in forward
weight = self.linear.w
Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165930
Approved by: https://github.com/anijain2305
Reapply of https://github.com/pytorch/pytorch/pull/163260
AOTI utils expect free function sometimes so adjust export API to handle that, haven't seen any methods getting exported. Some AOTI flows also require we populate dynamo_flat_name_to_original_fqn so i just copy how it is done in eval_frame.py. I also cleaned up how we get rid of export_root and fixed some overcomplicated nn_module_stack handling in export code. The logic is simpler now thanks to @anijain2305 .
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165582
Approved by: https://github.com/anijain2305
This PR enables all PIE rules on ruff, there are already some enabled rules from this family, the new added rules are
```
PIE796 Enum contains duplicate value: {value}
PIE808 Unnecessary start argument in range
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165814
Approved by: https://github.com/ezyang
This PR enables all PIE rules on ruff, there are already some enabled rules from this family, the new added rules are
```
PIE796 Enum contains duplicate value: {value}
PIE808 Unnecessary start argument in range
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165814
Approved by: https://github.com/ezyang
For a custom op
```
@torch.library.custom_op("my_lib::foo", mutates_args={})
def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y
```
ppl could call `torch.ops.my_lib.foo()` or directly call `foo()` in the `forward` of an `nn.Module`
These two calling conventions will lead to the same node in the output graph, but different stack traces.
When directly calling `foo()`, the displayed stack_trace in the graph will be
```
# File: .../pytorch/torch/_library/custom_ops.py:687 in __call__, code: return self._opoverload(*args, **kwargs)
```
This is not useful so we filter it out.
```
python test/functorch/test_aot_joint_with_descriptors.py -k test_custom_op_stack_trace
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165693
Approved by: https://github.com/SherlockNoMad, https://github.com/williamwen42
Three fixes:
1. When doing t[u0] +=1 if u0 is unbacked we could allocate a new unbacked symbol during the the indexing of t[u0] (when we fake trace setitem), namely because meta_select does allocate a new unbacked symbol for the storage offset when we do not know if u0>=0 or u0<0. but the output size/stride of setitem(), does not depend on that new symbol. it's self consumed in setitem so we shall ignore it.
2. Also when we trace through generalized_scatter the applications of the views could allocate unbacked symints
but those do not effect final output, we also shall ignore them.
3.Before accessing strides in lowering we shall materialize.
Address https://github.com/pytorch/pytorch/issues/114293 and https://github.com/pytorch/pytorch/issues/131911
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164341
Approved by: https://github.com/bobrenjc93
Eager AC/SAC reapplies the mutations (like global dict mutations) in the backward during the recomputation of forward. torch.compile has no easy way to reapply python mutations in the backward. But many users might be ok to skip reapplication of side effects in the backward. They can set this config flag to accept this eager and compile divergence.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165775
Approved by: https://github.com/zou3519
ghstack dependencies: #165734
Summary: This starts writing the compiler_config metadata into logger
Test Plan:
Modified existing test case to make sure this is not null.
(Also eyeballed what we're logging tomake sure it's reasonable
Reviewed By: masnesral
Differential Revision: D84014636
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165581
Approved by: https://github.com/masnesral
Summary:
This stores information on where fx graphs come from, which makes it
significantly easier to debug.
One outstanding question
1) I only stored the kernel stack traces, do we also want the node mappings?
Test Plan:
I wrote a explicit logging test which makes a module, fx traces it, compiles it, and makes sure the logging infomration shows up.
```
clr@devvm17763 ~/fbsource/fbcode/caffe2/test/dynamo
% buck2 test @//mode/opt fbcode//caffe2/test/dynamo:test_dynamo -- test_utils
File changed: fbsource//xplat/caffe2/test/dynamo/test_utils.py
File changed: fbcode//caffe2/test/dynamo/test_utils.py
Buck UI: https://www.internalfb.com/buck2/528dea32-2416-4a62-a1ec-39f3c0efdd2e
Test UI: https://www.internalfb.com/intern/testinfra/testrun/13229324015574003
Network: Up: 0B Down: 0B
Executing actions. Remaining 0/2
Command: test.
Time elapsed: 17.3s
Tests finished: Pass 16. Fail 0. Fatal 0. Skip 0. Build failure 0
```
Rollback Plan:
Differential Revision: D82037582
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162669
Approved by: https://github.com/yushangdi
Not sure what exactly we want to have in the message, but that's easy to adjust. I tried to find a reliable test to reproduce this message (happens only when a guard fails right after it's created), but I ended up mocking a `guard_manager.check` function to return `False` to trigger this behavior. I think that's fine, because any other case that we pick (like datetime.now()), we want to patch one day anyway, so every time we make the next patch, will need to chase for another repro test
@williamwen42
Fixes#164990
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165242
Approved by: https://github.com/williamwen42
Summary: If a function is wrapped with functools, we should not look at the wrapped function signature but rather the wrapper, since we need to construct the frame for the top level function here.
Test Plan: test_decorated_function_with_functools_wrap_aot
Differential Revision: D84626752
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165454
Approved by: https://github.com/yiming0416
Creates the fork/join stream ops. These ops are passthrough ops which mutate all of their args (without actually performing any computation on them) so that during functionalization, implicit dependencies are added on all of their args. This allows us to prevent reordering during our pre/post grad graph passes.
Make custom ops inplace
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162900
Approved by: https://github.com/anijain2305
ghstack dependencies: #163027, #162899, #163028
Stores streams in a global object look table that maps a dynamo selected index to objects. This index is generated during tracing, and at runtime, a helper function is called from the bytecode to populate this map.
This differs from the previous implementation that simply mapped IDs to the associated objects. This required specialization on the IDs of the specific objects, while this new approach does not.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162899
Approved by: https://github.com/anijain2305
ghstack dependencies: #163027
Verify the deterministic mode with torch.compile benchmark scripts.
Here is what my testing script does (pasted in the end):
- run a model in default mode, save it's result
- run the model again in default mode, but distort the benchmarking results. Compare it with the saved result.
- Do the above again in deterministic mode.
I tried to test a few modes
- BertForMaskedLM and GoogleFnet: I can repro the numeric change by distorting the benchnmark result in the default mode. The non-determinism is gone in the deterministic mode
- DistillGPT2: I can not repro the numeric change by distorting the benchmarking result in the default mode. It does not surprise me much. Reduction order change does not always cause numeric change.
```
model=GoogleFnet
export TORCHINDUCTOR_WRITE_ARE_DETERMINISTIC_ALGORITHMS_ENABLED=0
export TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 # disable autotune cache
export TORCHINDUCTOR_FX_GRAPH_REMOTE_CACHE=0
export TORCHINDUCTOR_FX_GRAPH_CACHE=0
export TORCHINDUCTOR_CACHE_DIR=/tmp/torchinductor_shunting/
export TORCHINDUCTOR_BENCHMARK_KERNEL=1
export TORCHINDUCTOR_UNIQUE_KERNEL_NAMES=1
export INDUCTOR_TEST_DISABLE_FRESH_CACHE=1
# Non deterministic mode
# --float32 rather than --amp to make it easier to repro non-deterministic
echo "Save results for non-deterministic mode"
python benchmarks/dynamo/huggingface.py --backend inductor --float32 --accuracy --only $model --training --disable-cudagraphs --save-model-outputs-to=/tmp/saved-non-deterministic.pkl
echo "Compare results with distorted benchmarking in non-deterministic mode"
TORCHINDUCTOR_DISTORT_BENCHMARKING_RESULT=inverse python benchmarks/dynamo/huggingface.py --backend inductor --float32 --accuracy --only $model --training --disable-cudagraphs --compare-model-outputs-with=/tmp/saved-non-deterministic.pkl
echo "Save results for deterministic mode"
TORCHINDUCTOR_DETERMINISTIC=1 python benchmarks/dynamo/huggingface.py --backend inductor --float32 --accuracy --only $model --training --disable-cudagraphs --save-model-outputs-to=/tmp/saved-deterministic.pkl
echo "Compare results with distorted benchmarking in deterministic mode"
TORCHINDUCTOR_DETERMINISTIC=1 TORCHINDUCTOR_DISTORT_BENCHMARKING_RESULT=inverse python benchmarks/dynamo/huggingface.py --backend inductor --float32 --accuracy --only $model --training --disable-cudagraphs --compare-model-outputs-with=/tmp/saved-deterministic.pkl
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164904
Approved by: https://github.com/jansel, https://github.com/v0i0
https://github.com/pytorch/pytorch/issues/162858 The issue described the feature implemented.
This adds to the existing graph break log with the latest 20 (or viable user frame) bytecode instructions. The scenario is when the graph_break happens without errors. It happens during the case when user calling torch._dynamo.graph_break().
Meanwhile, in the testing, one can find that the generated frame based on step() is not deterministic as sometimes it reached the maximum amount, sometimes it generated the less than that. The bytecode generation is python version dependent. Thus, the testing plan excludes the bytecode output but generated the total bytecode line count.
This is a helpful process to understand bytecode transformation, symbolic convert, and convert frame. It is a helpful task to provide hands-on experience with dynamo workflow.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164422
Approved by: https://github.com/williamwen42, https://github.com/mlazos
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This is follow-up of #165037. It generally recommended to use `is/is not` to compare types. Therefore this series of changes apply this suggestion in the code base, and it aims to finally enabling related linter checks.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165142
Approved by: https://github.com/albanD
It generally recommended to use `is/is not` to compare types. Therefore this series of changes apply this suggestion in the code base, and it aims to finally enabling related linter checks.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165037
Approved by: https://github.com/mlazos
Reviewed GPT5 summary:
**Summary / Goal**
Improve error reporting when local_map subgraph input/output counts mismatch placement info.
**Details**
- Adds descriptive runtime error messages.
**Motivation**
Helps debug local_map misalignments.
```python
AssertionError: Expecting 2 inputs to local_map function based on placements, but found 1. If the count matches for eager, Dynamo may have flattened inputs to the function or found additional tensors used via closures. Please adjust the input placements to match what the traced graph sees:
class GraphModule(torch.nn.Module):
def forward(self, l_args_0_: "f32[8, 8, 16]"):
# File: /home/xmfan/core/a/pytorch/test/higher_order_ops/test_local_map.py:523 in mismatch_input, code: return x + scalar, scalar
child: "f32[8, 8, 16]" = l_args_0_ + 10; l_args_0_ = None
return (child,)
.
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164321
Approved by: https://github.com/ezyang, https://github.com/mlazos
ghstack dependencies: #164296
Verify the deterministic mode with torch.compile benchmark scripts.
Here is what my testing script does (pasted in the end):
- run a model in default mode, save it's result
- run the model again in default mode, but distort the benchmarking results. Compare it with the saved result.
- Do the above again in deterministic mode.
I tried to test a few modes
- BertForMaskedLM and GoogleFnet: I can repro the numeric change by distorting the benchnmark result in the default mode. The non-determinism is gone in the deterministic mode
- DistillGPT2: I can not repro the numeric change by distorting the benchmarking result in the default mode. It does not surprise me much. Reduction order change does not always cause numeric change.
```
model=GoogleFnet
export TORCHINDUCTOR_WRITE_ARE_DETERMINISTIC_ALGORITHMS_ENABLED=0
export TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 # disable autotune cache
export TORCHINDUCTOR_FX_GRAPH_REMOTE_CACHE=0
export TORCHINDUCTOR_FX_GRAPH_CACHE=0
export TORCHINDUCTOR_CACHE_DIR=/tmp/torchinductor_shunting/
export TORCHINDUCTOR_BENCHMARK_KERNEL=1
export TORCHINDUCTOR_UNIQUE_KERNEL_NAMES=1
export INDUCTOR_TEST_DISABLE_FRESH_CACHE=1
# Non deterministic mode
# --float32 rather than --amp to make it easier to repro non-deterministic
echo "Save results for non-deterministic mode"
python benchmarks/dynamo/huggingface.py --backend inductor --float32 --accuracy --only $model --training --disable-cudagraphs --save-model-outputs-to=/tmp/saved-non-deterministic.pkl
echo "Compare results with distorted benchmarking in non-deterministic mode"
TORCHINDUCTOR_DISTORT_BENCHMARKING_RESULT=inverse python benchmarks/dynamo/huggingface.py --backend inductor --float32 --accuracy --only $model --training --disable-cudagraphs --compare-model-outputs-with=/tmp/saved-non-deterministic.pkl
echo "Save results for deterministic mode"
TORCHINDUCTOR_DETERMINISTIC=1 python benchmarks/dynamo/huggingface.py --backend inductor --float32 --accuracy --only $model --training --disable-cudagraphs --save-model-outputs-to=/tmp/saved-deterministic.pkl
echo "Compare results with distorted benchmarking in deterministic mode"
TORCHINDUCTOR_DETERMINISTIC=1 TORCHINDUCTOR_DISTORT_BENCHMARKING_RESULT=inverse python benchmarks/dynamo/huggingface.py --backend inductor --float32 --accuracy --only $model --training --disable-cudagraphs --compare-model-outputs-with=/tmp/saved-deterministic.pkl
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164904
Approved by: https://github.com/jansel, https://github.com/v0i0
ghstack dependencies: #164801, #164532
In https://github.com/pytorch/pytorch/pull/106824, export decided to slow-path for MultiHeadAttention module (look into the PR description as to why). But that PR eventually caused a divergence between Dynamo and export.
Today, strict-export does not inline into builtin modules (like MultiHeadAttention), and therefore make_fx sees the original nn.Module and takes the slow path. But compile inlines into the nn module, and at this time the condition `_is_make_fx_tracing` is False. As a result, Dynamo takes a fast path, resulting in a different op being called.
This divergence is undesirable. There are 2 ways to fix it
1) Make export take the fast path - As explained in the https://github.com/pytorch/pytorch/pull/106824 , this might be difficult. So, we go to (2)
2) Make compile as well take the slow path - This is easy to implement. The con here is that Pytorch eager and compile will use different operators, which can cause numerics issues etc.
Since (2) is easy to do, we will follow this path. We are tracking the issue in https://github.com/pytorch/pytorch/issues/164062
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164721
Approved by: https://github.com/avikchaudhuri, https://github.com/tugsbayasgalan
Builds on top of https://github.com/pytorch/pytorch/pull/163673 and https://github.com/pytorch/pytorch/pull/164174. This will be used in the followup PRs to apply regional inductor compilation.
The existing implementation let Dynamo trace into the `torch.fx.traceback.annotate`, but thats not what we want. We want Dynamo to essentially run the torch.fx.traceback.annotate function in eager, so that every Fx node created in Dynamo Fx graph has the custom meta node.
What does not work?
* We still have to set the context manager `torch.fx.traceback.preserve_node_meta()` in the user code because CI was unhappy. This can be fixed but with some perseverance.
* This does not work with graph breaks yet. But we can solve that problem, if needed, in a separate PR.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164678
Approved by: https://github.com/SherlockNoMad, https://github.com/jansel, https://github.com/xmfan
Turns out codegen'ing a nested step graph break is significantly more complicated than first thought. The optimized function should actually do:
- call graph/load values/do side effects etc.
- call into the leaf's resume function, but skipped (this essentially step graph break function for just the leaf function)
- call into all the other resume functions, traced.
This PR also adds `torch._dynamo.step_unsupported()`, which can be used for internal testing purposes to better test step graph break handling.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162737
Approved by: https://github.com/Lucaskabela
ghstack dependencies: #160601
This is needed because if we codegen cells for nested frames AFTER side effects, then reconstruction could get messed up. From below:
>The added test case demonstrates the reconstruction failure if we kept cell codegen at the original place (only happens with nested graph breaks since we reconstruct nested frame cells from VariableTracker rather than directly using LOAD_CLOSURE).
>At a high level, what happened before this change was that side_effects was pruning the cells (I don't recall exactly why this happens), and because cells were codegen'd after the side effects were applied, we were unable to properly reconstruct the cell. The error I was seeing was a list/tuple IndexError.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160601
Approved by: https://github.com/mlazos
## Summary
- add a CuBLASReductionOption enum so the CUDA context can track reduced-precision and split-K options
- extend the Python bindings, backend helpers, and docs to accept an optional allow_splitk argument for fp16/bf16 matmul controls
- update cuBLAS/cuBLASLt call sites plus dynamo guards and tests to respect the new combinations
## Testing
- python test/test_cuda.py TestCuda.test_cublas_allow_fp16_reduced_precision_reduction_get_set -v *(fails: ModuleNotFoundError: No module named 'psutil')*
------
https://chatgpt.com/codex/tasks/task_e_68e404623178832f8a3e1d34e1e175da
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164766
Approved by: https://github.com/malfet, https://github.com/albanD
Adding source_get_cache also to AOT compile case. Since the guard manager loader code can be shared between AOT and caching, we added a new function load_guard_manager to avoid code duplication between two workflows, for loading guards.
Test Plan: test_guard_serialization.py
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164773
Approved by: https://github.com/yiming0416, https://github.com/dolpm
This PR is to temporarily unblock various experiments to re-use dynamo create fake mode. Note that this is still not what we want as the end state. The end state should look sth like:
```
out = fulllgraph_capture(mod, inputs)
fake_mode = out.backend_inputs.fake_mode
gm = out.module()
```
This doesn't work today because export requires we need to wrap the original module to setup a flat module to trace for easier handling of pytree. As a result, we would need to carry export specific flag in fullgraph_capture which seems not ideal.
Regardless, the end state is that we need to give downstream user a graph module and a fake mode in some form, so I think _dynamo_graph_capture_for_export returning a fake mode within graph module itself via gm.meta
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164730
Approved by: https://github.com/avikchaudhuri
Summary: Added a set of fixes triggered by fm training job. Overall the theme here is that we should get rid of saved objects as much as possible when they are not used in guard reconstruction. Sometimes for objects that cannot be saved (like local functions) we still try our best to save their closures.
Test Plan:
test_guard_serialization.py
test_lazy_awatiable.py
Differential Revision: D83766926
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164490
Approved by: https://github.com/jamesjwu
So this fixes at least two issues:
1) When we are invoking inductor backend, we apply pre-grad passes which try to find correct fake mode to use. In the nested case, we will run into clash when there is closure variable in the inductor region because non-strict would have fakified this variable before hand and inner torch.compile would have created a new fresh fake mode. This is not a problem in regular torch.compile because inner torch.compile gets ignored. I don't know if we are supposed to inherit fake mode from parent context in this case. But we can avoid this problem if we just default to eager backend which is fine in this case because the point of export is to capture aten operators. Going to inductor would mean we will lose inner torch.compile ops.
2) There is custom torch function modes in export that track number of torch fns executed and inner compile itself doesn't work because of guard failure as this mode state gets changed. I noticed torch.cond fixes this problem by carefully stashing the torch function mode and defer it in the backend. So the correct thing to do here is just re-use torch.cond implementation unconditionally.
So the things i did for fixing above were:
1) Always default to eager backend when compile is invoked inside export. I needed to make how torch.cond sets up the fresh tracing env into an util that can be shared.
2) The previous eager backend for torch.cond was wrong because the context managers didn't actually persist until the backend is invoked.
3) torch.cond used only disable TorchFunctionMetadata tf mode and stash it for later, but in fact, we should do both TorchFunctionMetadata and PreDispatchTorchFunctionMode.
With above fixes, we are able to export flex attention in export.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164171
Approved by: https://github.com/ydwu4
Actually we would like to not graph break even in the case of Dynamo. But there is a weird-unsolved bug with Kineto + Dynamo when there are distributed jobs that lead to NCCL timeouts. This bug is a rare edege case, but we have not been able to root cause it yet.
But for export, we do not anticipate JIT tracing in distributed job training and therefore this PR is safe for export.
Fixes #ISSUE_NUMBER
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164418
Approved by: https://github.com/StrongerXi, https://github.com/williamwen42
Summary:
Under circumstances it seems reasonable to return a callable directly without guard check when user use aot_compile on a function with single compilation result.
When having multiple entries (aot_compile_module), we should start enabling guard check to differetiate different compiled functions apart.
Test Plan: CI
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163432
Approved by: https://github.com/dolpm, https://github.com/mlazos
In comm-compute overlap we will have a graph with:
```
def foo(...):
ag = all_gather(...)
hiding_compute = mm(...)
wait(ag)
```
There is no explicit dependency between the hiding compute and the collectives, but we want to add implicit dependencies from wait->hiding_compute, and from hiding_compute->all_gather to preserve overlap.
Additionally, while bucketing, we will merge collective starts and collective waits together. In this case, we will want to treat the two nodes as a single subgraph - each node in the merged set will have the union of all deps in the set.
This pr adds `AugmentedGraphHelper` that adds the apis, and allows querying for dependency with this augmented graph.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163959
Approved by: https://github.com/v0i0, https://github.com/IvanKobzarev
ghstack dependencies: #163215, #163754
In comm-compute overlap we will have a graph with:
```
def foo(...):
ag = all_gather(...)
hiding_compute = mm(...)
wait(ag)
```
There is no explicit dependency between the hiding compute and the collectives, but we want to add implicit dependencies from wait->hiding_compute, and from hiding_compute->all_gather to preserve overlap.
Additionally, while bucketing, we will merge collective starts and collective waits together. In this case, we will want to treat the two nodes as a single subgraph - each node in the merged set will have the union of all deps in the set.
This pr adds `AugmentedGraphHelper` that adds the apis, and allows querying for dependency with this augmented graph.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163959
Approved by: https://github.com/v0i0, https://github.com/IvanKobzarev
ghstack dependencies: #163215, #163754