[Dynamo] Test intermediate tf mode construction (#133134)

Ensures that constructing a torch function mode in the middle of a function is supported.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133134
Approved by: https://github.com/williamwen42
ghstack dependencies: #133130, #133729, #133131, #133132, #133133
This commit is contained in:
Michael Lazos 2024-08-19 15:10:57 -07:00 committed by PyTorch MergeBot
parent 626acaeb16
commit d97ca968cd

View File

@ -297,6 +297,19 @@ class TorchFunctionModeTests(torch._dynamo.test_case.TestCase):
self.assertEqual(res, torch.ones(2, 2) + 1)
self.assertEqual(_len_torch_function_stack(), 1)
def test_intermedate_torch_function_mode_construction_mutation(self):
class TestMode(BaseTorchFunctionMode):
def __init__(self, x):
self.x = x
@torch.compile(fullgraph=True)
def fn(x):
z = TestMode(2)
z.y = 2
return x + 1, z
fn(torch.ones(2, 2))
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests