Serialize BUILTIN_MATCH since they are all stored in __builtin__ dict.
Also fixed an issue that the wrong global scope is passed to CheckFunctionManager while loading guards. Previously we can always reuse the compile-time global scope for evaluating guards because the compile-time and runtime global scope are always the same.
For precompile, we need to serialize the compile-time global scope for loading only. We need to point the CheckFunctionManager to the new global scope after loading is finished for evaluating guards.
Differential Revision: [D77159313](https://our.internmc.facebook.com/intern/diff/D77159313/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157016
Approved by: https://github.com/jansel, https://github.com/jamesjwu
Currently, every time we construct a GLOBAL_STATE guard, we always create a fresh guard based on the current global state. For precompile, we want to create a GLOBAL_STATE guard always based on some external sources, e.g. serialized global states. This can also be applied with the normal case where we just pass in the global state guard from Python.
Differential Revision: [D77400988](https://our.internmc.facebook.com/intern/diff/D77400988/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157285
Approved by: https://github.com/jansel
Adding a per torch.compile() object CompilePackage which tracks dynamo artifact. CompilePackage is considered a low level component and should not be directly exposed to end users. It has the following interface:
1. `CompilePackage.__init__()` which optionally takes previously serialized dynamo states.
a. when `dynamo` argument is None, it will contruct a brand new CompilePackage object.
b. when `dynamo` argument is not None, it will load a pre-compiled dynamo state.
2. `package.save()` which dumps the dynamo states into _DynamoCacheEntry.
3. `package.install(backends)` which will handle all the side-effectful global scope updates with compiled functions and resume functions.
This diff focus on making the low level mechanism for precompile. It will be left to upper level interface to use these API to build more user-facing frontend.
Differential Revision: [D75956538](https://our.internmc.facebook.com/intern/diff/D75956538/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155118
Approved by: https://github.com/jamesjwu
Co-authored-by: James Wu <jjwu@meta.com>
vLLM profiler sets with_stack=True that shows the dict_getitem on the profiler, both inflating the numbers and confusing compile users. This PR keeps BINARY_SUBSCR for regular dicts, while using `dict.__getitem__` only for dict subclasses.
Using binary_subscr is little bit faster, but not enough to make any major latency improvements.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155727
Approved by: https://github.com/zou3519, https://github.com/StrongerXi, https://github.com/jansel
We observed that guard overhead at runtime using profiler traces was
higher than reported in this profiling function at the compile time.
After investigation, we found that f_locals are already in cache and
that was causing the guard overhead to be way smaller while profiling
during the compilation. To be more realistic, we flush the cache here.
Profiling the guard overhead during compilation (in addition to at
runtime) allows faster iteration time, and logging in tlparse and
internal databases.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154764
Approved by: https://github.com/zou3519, https://github.com/jansel, https://github.com/StrongerXi
We observed that guard overhead at runtime using profiler traces was
higher than reported in this profiling function at the compile time.
After investigation, we found that f_locals are already in cache and
that was causing the guard overhead to be way smaller while profiling
during the compilation. To be more realistic, we flush the cache here.
Profiling the guard overhead during compilation (in addition to at
runtime) allows faster iteration time, and logging in tlparse and
internal databases.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154764
Approved by: https://github.com/zou3519, https://github.com/jansel, https://github.com/StrongerXi
ghstack dependencies: #154769
Summary:
Today when guard serialization fails, dynamo will raise an internal error like:
```
torch._dynamo.exc.InternalTorchDynamoError: RuntimeError: CLOSURE_MATCH guard cannot be serialized.
```
Adding a dedicated PackageError type to surface the error more clearly.
Test Plan: CI
Differential Revision: D75452124
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154430
Approved by: https://github.com/jamesjwu, https://github.com/jansel
Summary: Prune unused local objects from serialized local scope if they are not used in guard reconstruction. This is helpful when a user program takes things like local callable functions or the function call is recursive.
Test Plan:
test/dynamo/test_guard_serialization.py -k test_function_locals
Before pruning locals:
```
state = GuardsState(output_graph=OutputGraphGuardsState(local_scope={'x': tensor([ 0.0461, 0.4024, -1.0115]), 'g': <function ...aints=None, _guards=<torch._guards.GuardsSet object at 0x7fbccc7e9fc0>, _aotautograd_guards=[]), shape_code_parts=None)
def pickle_guards_state(state: GuardsState) -> bytes:
buf = io.BytesIO()
pickler = GuardsStatePickler(buf)
try:
pickler.dump(state)
except AttributeError as e:
> raise torch._dynamo.exc.PackageError(str(e)) from e
E torch._dynamo.exc.PackageError: Can't pickle local object 'TestGuardSerialization.test_function_locals.<locals>.foo'
```
After the diff
```
Tests finished: Pass 1. Fail 0. Fatal 0. Skip 0. Build failure 0
```
Differential Revision: D75452123
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154431
Approved by: https://github.com/jansel
This PR adds support for SimpleFSDP's composability with Tensor Parallel + torch.compile.
`_StridedShard` is used in SimpleFSDP/FSDP2 to support correct distributed checkpointing when FSDP+TP is applied. Previously, `_StridedShard` is not guarded by torch.compile. This PR adds `_StridedShard` as an additional placement type to be guarded by torch.compile.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152286
Approved by: https://github.com/bdhirsh
Fixes#153777
CSE is an optimization and shouldn't block a compile if it hits recursion depth limits. Unfortunately we can't write this iteratively due to a dependency on `ast.unparse` which necessarily needs to do recursion. This PR catches opts out of CSE when we hit recursion depth errors.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154039
Approved by: https://github.com/Microve
https://github.com/pytorch/pytorch/issues/148222
Goal:
At the moment autograd saved tensors hooks are run in eager after compiled forward.
They are executed at the same time for all saved tensors.
Hooks can be used to reduce amout of memory used for saved tensors, doing quantization or offloading to cpu.
This is suboptimal for optimization of peak memory.
Better solution will be to put the hooks in the graph, as close as possible to the last usage of the tensor.
To get user specified autograd saved tensors hooks in the graph.
Logic:
UX:
If user specifies with torch.autograd.graph.saved_tensors_hooks(pack_gm, unpack_gm).
Where pack_gm and unpack_gm are torch.fx.GraphModule.
Then AotAutograd will retrace those graph modules, doing decompositions and functionalization in aot_autograd, inlining the result graphs in forward epilogue and backward prologue.
User may want to use control logic in the hooks, for example applying quantization only for specific dtypes and sizes.
This is also possible, user can put it into torch.fx.wrap function and use symbolic trace to make a GraphModule.
In that case AotAutograd cahing will work only in case when user explicitly set to the torch.fx.wrap call_function node "user_cache_hash" metadata.
If this metadata set - then aot_autograd cache can use saved cache artifact.
If metadata is not set - then cache is bypassed.
Dynamo:
Dynamo traces pack and unpack hooks and installs them as subgraph and explicitly adds to the output_graph. (As those subgraphs are not used and will not be copied in the result by default).
The complexity here is that at this moment we do not have example of inputs for the hooks.
We trace pack_hook with some Tensor from the inputs.
The result subgraphs are added to the hashing of AotAutograd Cache.
In AotAutograd we retrace the graph with the true saved tensors coming from partitioner.
Backwards Compatibility:
As current hooks are executed in eager mode and not all of them will be traceable - we only try to put in the graph hooks, explicitly marked by user with annotation (@_inlineable_saved_tensors_hooks).
For other hooks or if compiled autograd is enabled - keep the same logic.
Recompilations:
Hooks are guarded with lambda guard matching function id to cause recompilation if user reruns compiled function.
Aot_autograd:
After partitioner prepared forward and backward module - we trace prepared at Dynamo graphs for pack and unpack hooks and inline them in epilogue of forward and prologue of backward. Forward outputs and backward inputs are changed, transparently for user.
We do not try to put it close the last usage etc., relying on inductor to do this optimization.
```
INFO: TRACED GRAPH
===== Forward graph pre saved_tensors_hooks inlining 3 =====
/data/users/ivankobzarev/a/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", primals_3: "f32[s0, s1][s1, 1]cuda:0"):
# File: /data/users/ivankobzarev/a/pytorch/test/functorch/test_aotdispatch.py:6660 in simple_fn, code: x = x + 1
add: "f32[s0, s1][s1, 1]cuda:0" = torch.ops.aten.add.Tensor(primals_3, 1); primals_3 = None
# File: /data/users/ivankobzarev/a/pytorch/test/functorch/test_aotdispatch.py:6661 in simple_fn, code: x = SAF.apply(x)
view: "f32[s0, s1][s1, 1]cuda:0" = torch.ops.aten.view.default(add, [primals_1, primals_2])
return (view, add, primals_1, primals_2)
INFO: TRACED GRAPH
===== Backward graph pre saved_tensors_hooks inlining 3 =====
/data/users/ivankobzarev/a/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", primals_3: "f32[s0, s1][s1, 1]cuda:0"):
# File: /data/users/ivankobzarev/a/pytorch/test/functorch/test_aotdispatch.py:6660 in simple_fn, code: x = x + 1
add: "f32[s0, s1][s1, 1]cuda:0" = torch.ops.aten.add.Tensor(primals_3, 1); primals_3 = None
# File: /data/users/ivankobzarev/a/pytorch/test/functorch/test_aotdispatch.py:6661 in simple_fn, code: x = SAF.apply(x)
view: "f32[s0, s1][s1, 1]cuda:0" = torch.ops.aten.view.default(add, [primals_1, primals_2])
return (view, add, primals_1, primals_2)
INFO: TRACED GRAPH
===== saved_tensors_pack_hook add 3 =====
/data/users/ivankobzarev/a/pytorch/torch/fx/_lazy_graph_module.py class pack_float8(torch.nn.Module):
def forward(self, x_1: "f32[s0, s1][s1, 1]cuda:0"):
# No stacktrace found for following nodes
_to_copy: "f8e4m3fn[s0, s1][s1, 1]cuda:0" = torch.ops.aten._to_copy.default(x_1, dtype = torch.float8_e4m3fn); x_1 = None
return (torch.float32, _to_copy)
INFO: TRACED GRAPH
===== saved_tensors_unpack_hook add 3 =====
<eval_with_key>.22 from /data/users/ivankobzarev/a/pytorch/torch/fx/experimental/proxy_tensor.py:1225 in wrapped class pack_float8(torch.nn.Module):
def forward(self, x_1: "f32[s0, s1][s1, 1]cuda:0"):
# No stacktrace found for following nodes
_to_copy: "f8e4m3fn[s0, s1][s1, 1]cuda:0" = torch.ops.aten._to_copy.default(x_1, dtype = torch.float8_e4m3fn); x_1 = None
return (torch.float32, _to_copy)
INFO: TRACED GRAPH
===== Forward graph 3 =====
/data/users/ivankobzarev/a/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", primals_3: "f32[s0, s1][s1, 1]cuda:0"):
# File: /data/users/ivankobzarev/a/pytorch/test/functorch/test_aotdispatch.py:6660 in simple_fn, code: x = x + 1
add: "f32[s0, s1][s1, 1]cuda:0" = torch.ops.aten.add.Tensor(primals_3, 1); primals_3 = None
# No stacktrace found for following nodes
_to_copy: "f8e4m3fn[s0, s1][s1, 1]cuda:0" = torch.ops.aten._to_copy.default(add, dtype = torch.float8_e4m3fn)
# File: /data/users/ivankobzarev/a/pytorch/test/functorch/test_aotdispatch.py:6661 in simple_fn, code: x = SAF.apply(x)
view: "f32[s0, s1][s1, 1]cuda:0" = torch.ops.aten.view.default(add, [primals_1, primals_2]); add = None
return (view, _to_copy, primals_1, primals_2)
INFO: TRACED GRAPH
===== Backward graph 3 =====
<eval_with_key>.21 class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", add_packed_2: "f8e4m3fn[s0, s1][s1, 1]cuda:0", tangents_1: "f32[s0, s1][s1, 1]cuda:0"):
# No stacktrace found for following nodes
_to_copy: "f32[s0, s1][s1, 1]cuda:0" = torch.ops.aten._to_copy.default(add_packed_2, dtype = torch.float32); add_packed_2 = None
# File: /data/users/ivankobzarev/a/pytorch/test/functorch/test_aotdispatch.py:6661 in simple_fn, code: x = SAF.apply(x)
add_7: "f32[s0, s1][s1, 1]cuda:0" = torch.ops.aten.add.Tensor(tangents_1, _to_copy); tangents_1 = _to_copy = None
return (None, None, add_7)
```
Differential Revision: [D72187044](https://our.internmc.facebook.com/intern/diff/D72187044)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150032
Approved by: https://github.com/bdhirsh
Basically adds native _IntWrapper support to dynamo. Here's my process of trying to make symint input support work on dynamo, and how I ended up with this approach [(doc)](https://docs.google.com/document/d/1GvNRQd8BnxlMay_hrEVgEta6VUeUW_hcFeRuB7q1nDY/edit?tab=t.0).
What I did was, before passing inputs to dynamo.export, I first wrap them with a class, `_IntWrapper`. When processing dynamic shapes, I will then add the corresponding dynamic shape specification to the `dynamism` field stored on the `_IntWrapper`. If there is no dynamism specified, then this will get unwrapped back to an integer. When dynamo tracing, when we encounter an `_IntWrapper`, we will convert this to a symint if the dynamism was specified as `Dim.DYNAMIC/AUTO`. Dynamo will then trace a graph that contains symint inputs, which will get passed to AOTAutograd and so on.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152677
Approved by: https://github.com/pianpwk
This PR updates `GuardsStatePickler.reducer_override()` in `torch/_dynamo/guards.py` to handle reconstruction of traceable wrapper subclasses. It's intended to work recursively and handle any level of subclass instance nesting (e.g. subclass instances that contain subclass instances, etc.)
This PR tests the guard on several traceable wrapper tensor subclasses:
* `LocalSubclass`: used to ensure the correct error message is thrown when the subclass is not defined globally
* `torch.testing._internal.two_tensor.TwoTensor`: defines None for its extra metadata
* `SubclassWithMeta`: stores non-trivial extra metadata
* `SubclassWithCustomMetadataGuard`: stores non-trivial extra metadata and defines a custom `__metadata_guard__` classmethod
* `SubclassWithSubclassInnerTensors`: used to test recursiveness; this subclass contains subclass inner tensor components
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152626
Approved by: https://github.com/jansel
Make Functorch interpreters serializable most of the time, so that we can save the guards on functorch states.
## Test Cases:
0. torch.compile() without functorch layers present. Guard should fail with any layer being pushed.
1. torch.compile() nested in vmap.
2. torch.compile() nested in grad.
3. torch.compile() nested in jvp + vmap
4. torch.compile() nested functionalize
5. torch.compile() nested in vmap + grad
Differential Revision: [D74008787](https://our.internmc.facebook.com/intern/diff/D74008787/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152616
Approved by: https://github.com/zou3519
ghstack dependencies: #152615
This is a proof-of-concept of how we could serialize a guard and deserialize it back from the bytes.
The main behavioral change introduced in this diff is on CheckFunctionManager:
```
check_fn_manager = CheckFunctionManager(code, output_graph, guards_serialization_mode="save")
guards_state: bytes = check_fn_manager.guards_state
```
Once `guards_serialization_mode` is set to `save`, CheckFunctionManager will return an addtional `bytes` object called `guards_state` which should contain all the information needed for deserializing guards later.
When we load back guards state, we will set `guards_serialization_mode` is set to `load`:
```
output_graph_state = pickle.loads(guards_state)
check_fn_manager = CheckFunctionManager(code, output_graph_state, guards_serialization_mode="load")
```
# TENSOR_MATCH
Since we have many types of guards to support, we will break the work into small diffs instead of a single diff to support every guards.
We kick off the work from TENSOR_MATCH from this diff.
# Testing
For each type of guard we will test it like the following:
1. Use guard_filter_fn to select 1 type of guard each time.
2. Call InstructionTranslator directly on an example function to get OutputGraph and CheckFunctionManager (reference guard manager)
3. Serialize->deserialize the output graph state and re-build the guards with a new CheckFunctionManager (loaded guard manager)
4. Throw a set of example inputs to both reference and loaded guard manager to see if their behavior match.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151318
Approved by: https://github.com/jansel, https://github.com/anijain2305
This is a proof-of-concept of how we could serialize a guard and deserialize it back from the bytes.
The main behavioral change introduced in this diff is on CheckFunctionManager:
```
check_fn_manager = CheckFunctionManager(code, output_graph, guards_serialization_mode="save")
guards_state: bytes = check_fn_manager.guards_state
```
Once `guards_serialization_mode` is set to `save`, CheckFunctionManager will return an addtional `bytes` object called `guards_state` which should contain all the information needed for deserializing guards later.
When we load back guards state, we will set `guards_serialization_mode` is set to `load`:
```
output_graph_state = pickle.loads(guards_state)
check_fn_manager = CheckFunctionManager(code, output_graph_state, guards_serialization_mode="load")
```
# TENSOR_MATCH
Since we have many types of guards to support, we will break the work into small diffs instead of a single diff to support every guards.
We kick off the work from TENSOR_MATCH from this diff.
# Testing
For each type of guard we will test it like the following:
1. Use guard_filter_fn to select 1 type of guard each time.
2. Call InstructionTranslator directly on an example function to get OutputGraph and CheckFunctionManager (reference guard manager)
3. Serialize->deserialize the output graph state and re-build the guards with a new CheckFunctionManager (loaded guard manager)
4. Throw a set of example inputs to both reference and loaded guard manager to see if their behavior match.
Differential Revision: [D72987485](https://our.internmc.facebook.com/intern/diff/D72987485/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151318
Approved by: https://github.com/jansel, https://github.com/anijain2305