This PR implements tracing of with contexts with TorchFunction modes which have the default enter/exit behavior (ie pushing/popping the mode)
Typically the bytecode for a context manager looks like this during a graph break:
1. graph call
2. enter context
3. unsupported code
4. exit context
5. resume call
resume fn structure:
1. enter context
2. jump
...
3. exit context
The issue with torch function modes is that side effects will replay any mutations to the torch function stack performed during tracing. So, we do not need to enter and exit around the unsupported code in the original function (doing so would result in a duplicate torch function mode entry during execution of the unsupported code), and we don't need to enter again in the resume function (the mode that was pushed from the side effects bytecode would still be on the stack).
So for torch function modes the structure of our output code is this:
1. graph call
2. mutate tf mode stack to replay mutations
4. unsupported code
5. on exception restore stack
6. resume function
Then our resume fn looks like this:
1. no-op enter torch function mode
2. jump
3. exit tf mode
To implement the no-op enter of the torch function mode I added torch function mode in polyfill which no-op enters, but normally exits. This is needed because we still want to trace the with context in the resume function, and exit properly (the exit instructions will still be in the function, so we need to generate instructions to set up the context).
Separately from the bytecode, dynamo also tracks contexts on the block stack, which is how the SETUP_* instructions are implemented. Naturally at a graph break, we exit these block stacks to properly reset the contexts entirely, so that we can re-enter around the unsupported code soundly. However once again, in the torch function mode case, in the event of a graph we do not want to perform any exit side effects because we want to preserve the state of the mode stack as is so that we will properly update the stack with bytecode mentioned in the first section. If we exited here, dynamo would pop the mode off of the symbolic stack, and not update the true python torch function mode stack with the suffix bytecode. All in all, for torch function modes we enter exactly once, update the global torch function mode stack with side effects bytecode, re-read this stack when compiling the resume function, and exit exactly once in the resume function. This matches the semantics of eager exactly.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135422
Approved by: https://github.com/williamwen42
ghstack dependencies: #134732, #133137, #135443, #135444
This PR implements tracing of with contexts with TorchFunction modes which have the default enter/exit behavior (ie pushing/popping the mode)
Typically the bytecode for a context manager looks like this during a graph break:
1. graph call
2. enter context
3. unsupported code
4. exit context
5. resume call
resume fn structure:
1. enter context
2. jump
...
3. exit context
The issue with torch function modes is that side effects will replay any mutations to the torch function stack performed during tracing. So, we do not need to enter and exit around the unsupported code in the original function (doing so would result in a duplicate torch function mode entry during execution of the unsupported code), and we don't need to enter again in the resume function (the mode that was pushed from the side effects bytecode would still be on the stack).
So for torch function modes the structure of our output code is this:
1. graph call
2. mutate tf mode stack to replay mutations
4. unsupported code
5. on exception restore stack
6. resume function
Then our resume fn looks like this:
1. no-op enter torch function mode
2. jump
3. exit tf mode
To implement the no-op enter of the torch function mode I added torch function mode in polyfill which no-op enters, but normally exits. This is needed because we still want to trace the with context in the resume function, and exit properly (the exit instructions will still be in the function, so we need to generate instructions to set up the context).
Separately from the bytecode, dynamo also tracks contexts on the block stack, which is how the SETUP_* instructions are implemented. Naturally at a graph break, we exit these block stacks to properly reset the contexts entirely, so that we can re-enter around the unsupported code soundly. However once again, in the torch function mode case, in the event of a graph we do not want to perform any exit side effects because we want to preserve the state of the mode stack as is so that we will properly update the stack with bytecode mentioned in the first section. If we exited here, dynamo would pop the mode off of the symbolic stack, and not update the true python torch function mode stack with the suffix bytecode. All in all, for torch function modes we enter exactly once, update the global torch function mode stack with side effects bytecode, re-read this stack when compiling the resume function, and exit exactly once in the resume function. This matches the semantics of eager exactly.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135422
Approved by: https://github.com/williamwen42
ghstack dependencies: #134732, #133137, #135443, #135444
This PR implements tracing of with contexts with TorchFunction modes which have the default enter/exit behavior (ie pushing/popping the mode)
Typically the bytecode for a context manager looks like this during a graph break:
1. graph call
2. enter context
3. unsupported code
4. exit context
5. resume call
resume fn structure:
1. enter context
2. jump
...
3. exit context
The issue with torch function modes is that side effects will replay any mutations to the torch function stack performed during tracing. So, we do not need to enter and exit around the unsupported code in the original function (doing so would result in a duplicate torch function mode entry during execution of the unsupported code), and we don't need to enter again in the resume function (the mode that was pushed from the side effects bytecode would still be on the stack).
So for torch function modes the structure of our output code is this:
1. graph call
2. mutate tf mode stack to replay mutations
4. unsupported code
5. on exception restore stack
6. resume function
Then our resume fn looks like this:
1. no-op enter torch function mode
2. jump
3. exit tf mode
To implement the no-op enter of the torch function mode I added torch function mode in polyfill which no-op enters, but normally exits. This is needed because we still want to trace the with context in the resume function, and exit properly (the exit instructions will still be in the function, so we need to generate instructions to set up the context).
Separately from the bytecode, dynamo also tracks contexts on the block stack, which is how the SETUP_* instructions are implemented. Naturally at a graph break, we exit these block stacks to properly reset the contexts entirely, so that we can re-enter around the unsupported code soundly. However once again, in the torch function mode case, in the event of a graph we do not want to perform any exit side effects because we want to preserve the state of the mode stack as is so that we will properly update the stack with bytecode mentioned in the first section. If we exited here, dynamo would pop the mode off of the symbolic stack, and not update the true python torch function mode stack with the suffix bytecode. All in all, for torch function modes we enter exactly once, update the global torch function mode stack with side effects bytecode, re-read this stack when compiling the resume function, and exit exactly once in the resume function. This matches the semantics of eager exactly.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135422
Approved by: https://github.com/williamwen42
ghstack dependencies: #134732, #133137, #135443, #135444
This PR adds support `torch._C._push_on_torch_function_stack()` by updating `torch.py` to push onto the symbolic torch function mode stack when a push is encountered. The same side effects infra used in the previous PR is used to track the mutation of the torch function mode stack and add bytecode to update it if it is mutated.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133132
Approved by: https://github.com/williamwen42
ghstack dependencies: #133130, #133729, #133131
This PR adds support for tracing `torch._C._pop_torch_function_stack()` without graph breaking and in order to verify the state change also adds replay of mutations to the torch function mode stack via side_effects appending supplemental bytecode as we do for other python mutable objects.
Details:
To represent the torch function mode stack symbolically a deque field is added to the instruction translator. When the InstructionTranslator is initialized, all modes are read from the current torch function mode stack, and stashed in a global weak ref for later access (using existing sources) without needing to push/pop the python/cpp torch function mode stack.
During tracing, when `_pop_torch_function_stack` is encountered a value is popped from this deque and the variable tracker representing the mode is returned. To ensure the true torch function mode stack matches this state, `TorchFunctionModeStackVariable`, a singleton, is marked as mutated, this adds it to side effects, where during final codegen, side effects will codegen a call to a python helper which will update the python torch function mode stack.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133131
Approved by: https://github.com/jansel
ghstack dependencies: #133130, #133729
Need to revert due to internal hangs: S437700
This reverts commit b6c1490cc0.
Revert "[dynamo] implement IteratorVariable and polyfill fallbacks for enumerate (#131725)"
This reverts commit 2576dbbc35.
Revert "[dynamo] add itertools repeat/count bytecode reconstruction (#131716)"
This reverts commit 35b4de32fa.
Revert "[dynamo] add lazy IteratorVariable implementations for map and zip (#131413)"
This reverts commit 7d282d8755.
Fixes #ISSUE_NUMBER
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132528
Approved by: https://github.com/ZainRizvi
Need to revert due to internal hangs: S437700
This reverts commit b6c1490cc0.
Revert "[dynamo] implement IteratorVariable and polyfill fallbacks for enumerate (#131725)"
This reverts commit 2576dbbc35.
Revert "[dynamo] add itertools repeat/count bytecode reconstruction (#131716)"
This reverts commit 35b4de32fa.
Revert "[dynamo] add lazy IteratorVariable implementations for map and zip (#131413)"
This reverts commit 7d282d8755.
Fixes #ISSUE_NUMBER
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132528
Approved by: https://github.com/ZainRizvi
There are some substantive changes. Instead of recording the *next* instruction in the speculation log, I record the *current* instruction. I think this is more intuitive, we always call speculation at the beginning of executing an instruction, so logically, the entry is associated with the current instruction. (Note that self.instruction_pointer is next instruction, as conventionally we increment IP before calling speculate).
The cosmetic change is to also pass in the Instruction corresponding to the IP and print it, and beef up the error message, including notes about the previous instruction that was run before it failed (this is typically the critical instruction).
At time of submission, this test case triggered the error:
```
diff --git a/test/distributed/test_dynamo_distributed.py b/test/distributed/test_dynamo_distributed.py
index 5ade17856e1..60ef89be346 100644
--- a/test/distributed/test_dynamo_distributed.py
+++ b/test/distributed/test_dynamo_distributed.py
@@ -844,6 +844,39 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
for r in res[1:]:
self.assertEqual(res[0], r)
+ @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
+ @config.patch(enable_compiler_collectives=True)
+ def test_compiler_collectives_automatic_dynamic_speculation_divergence(self):
+ with _dynamo_dist_per_rank_init(self.rank, self.world_size):
+ torch._dynamo.utils.clear_compilation_metrics()
+
+ # TODO: This should be possible to do inside the function, but
+ device = f"cuda:{self.rank}"
+
+ @torch.compile()
+ def f(x, y):
+ zx = x.shape
+ zy = y.shape
+ return x.sum() + y.sum()
+
+ if self.rank == 0:
+ dataloader = [4, 4]
+ else:
+ dataloader = [3, 4]
+
+ for data in dataloader:
+ f(
+ torch.randn(data, device=self.rank),
+ torch.randn(data, device=self.rank),
+ )
+
+ metrics = torch._dynamo.utils.get_compilation_metrics()
+ # Number of compiles same on all nodes
+ res = [None] * self.world_size
+ torch.distributed.all_gather_object(res, len(metrics))
+ for r in res[1:]:
+ self.assertEqual(res[0], r)
+
@requires_nccl()
```
although I plan to fix this soon.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131982
Approved by: https://github.com/anijain2305, https://github.com/mlazos, https://github.com/jansel