mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
This PR adds support `torch._C._push_on_torch_function_stack()` by updating `torch.py` to push onto the symbolic torch function mode stack when a push is encountered. The same side effects infra used in the previous PR is used to track the mutation of the torch function mode stack and add bytecode to update it if it is mutated. Pull Request resolved: https://github.com/pytorch/pytorch/pull/133132 Approved by: https://github.com/williamwen42 ghstack dependencies: #133130, #133729, #133131
292 lines
8.7 KiB
Python
292 lines
8.7 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
from unittest.mock import patch
|
|
|
|
import torch
|
|
import torch._dynamo.test_case
|
|
import torch._dynamo.testing
|
|
from torch._C import (
|
|
_len_torch_function_stack,
|
|
_pop_torch_function_stack,
|
|
_push_on_torch_function_stack,
|
|
)
|
|
from torch.overrides import _get_current_function_mode_stack, BaseTorchFunctionMode
|
|
from torch.utils._device import DeviceContext
|
|
from torch.utils._python_dispatch import TorchDispatchMode
|
|
|
|
|
|
class TorchDispatchModeTests(torch._dynamo.test_case.TestCase):
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
super().setUpClass()
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
super().tearDownClass()
|
|
|
|
def test_skip_torch_dispatch_modes(self):
|
|
class RewriteAddToMul(TorchDispatchMode):
|
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
|
if func is torch.ops.aten.add.Tensor:
|
|
func = torch.ops.aten.mul.Tensor
|
|
return func(*args, **kwargs)
|
|
|
|
def fn(x):
|
|
return x + x
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
|
|
x = torch.tensor([3.0])
|
|
with RewriteAddToMul():
|
|
eager_res = fn(x)
|
|
compiled_res = torch._dynamo.optimize(cnt)(fn)(x)
|
|
|
|
self.assertEqual(eager_res, compiled_res)
|
|
self.assertEqual(cnt.frame_count, 0)
|
|
|
|
|
|
class TorchFunctionModeTests(torch._dynamo.test_case.TestCase):
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
cls.default_device_old = torch.get_default_device()
|
|
super().setUpClass()
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
torch.set_default_device(cls.default_device_old)
|
|
super().tearDownClass()
|
|
|
|
def setUp(self):
|
|
torch.set_default_device(None)
|
|
|
|
def tearDown(self):
|
|
torch.set_default_device(None)
|
|
|
|
def _run_torch_function_mode_guard_test(self):
|
|
class TestMode1(BaseTorchFunctionMode):
|
|
pass
|
|
|
|
class TestMode2(BaseTorchFunctionMode):
|
|
pass
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
|
|
@torch.compile(backend=cnt.__call__)
|
|
def fn(x):
|
|
return x + 1
|
|
|
|
inp = torch.ones(2, 2)
|
|
fn(inp)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
with TestMode1():
|
|
fn(inp)
|
|
self.assertEqual(cnt.frame_count, 2)
|
|
|
|
with TestMode1(), TestMode2():
|
|
fn(inp)
|
|
self.assertEqual(cnt.frame_count, 3)
|
|
|
|
with TestMode2(), TestMode1():
|
|
fn(inp)
|
|
self.assertEqual(cnt.frame_count, 4)
|
|
|
|
with TestMode1():
|
|
fn(inp)
|
|
self.assertEqual(cnt.frame_count, 4)
|
|
|
|
def _run_ignored_mode_types_test(self):
|
|
class IgnoredMode(BaseTorchFunctionMode):
|
|
pass
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
|
|
@torch.compile(backend=cnt.__call__, fullgraph=True)
|
|
def fn(x):
|
|
return x + 1
|
|
|
|
inp = torch.ones(2, 2)
|
|
|
|
with patch(
|
|
"torch._dynamo.variables.torch_function.IGNORED_MODES", {IgnoredMode}
|
|
):
|
|
# initial compile
|
|
fn(inp)
|
|
|
|
# no recompile, mode ignored
|
|
# note: the ref stack is length 0, and the stack we are checking against has length 2
|
|
# we want to check both ref stack len > runtime stack, and ref stack len < runtime stack
|
|
with IgnoredMode(), IgnoredMode():
|
|
fn(inp)
|
|
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
# recompile due to new mode on the stack
|
|
with BaseTorchFunctionMode(), BaseTorchFunctionMode(), BaseTorchFunctionMode():
|
|
fn(inp)
|
|
|
|
self.assertEqual(cnt.frame_count, 2)
|
|
|
|
# recompile
|
|
# tests both ref stack len > runtime stack len for the above guard check
|
|
# and ref stack len < runtime stack len for the initial zero mode case
|
|
with BaseTorchFunctionMode(), IgnoredMode(), BaseTorchFunctionMode():
|
|
fn(inp)
|
|
|
|
self.assertEqual(cnt.frame_count, 3)
|
|
|
|
# no recompile
|
|
with IgnoredMode(), IgnoredMode(), BaseTorchFunctionMode(), BaseTorchFunctionMode():
|
|
fn(inp)
|
|
|
|
self.assertEqual(cnt.frame_count, 3)
|
|
|
|
# This is tricky, basically the ignored modes are baked into the guard
|
|
# IgnoredMode will be ignored forever by that guard.
|
|
# This is okay since we don't expect to be modifying IGNORED_MODES
|
|
# in the middle of execution except for the purposes of testing.
|
|
torch._dynamo.reset()
|
|
|
|
with IgnoredMode():
|
|
fn(inp)
|
|
|
|
self.assertEqual(cnt.frame_count, 4)
|
|
|
|
@torch._dynamo.config.patch("enable_cpp_guard_manager", False)
|
|
def test_torch_function_mode_guards_ignored_types_py(self):
|
|
self._run_ignored_mode_types_test()
|
|
|
|
def test_torch_function_mode_guards_ignored_types_cpp(self):
|
|
self._run_ignored_mode_types_test()
|
|
|
|
@torch._dynamo.config.patch("enable_cpp_guard_manager", False)
|
|
def test_torch_function_mode_guards_py(self):
|
|
self._run_torch_function_mode_guard_test()
|
|
|
|
def test_torch_function_mode_guards_cpp(self):
|
|
self._run_torch_function_mode_guard_test()
|
|
|
|
def test_stack_state_mutation_default_device(self):
|
|
m = BaseTorchFunctionMode()
|
|
m1 = BaseTorchFunctionMode()
|
|
with m, m1:
|
|
|
|
@torch.compile(fullgraph=True)
|
|
def fn(x):
|
|
torch.set_default_device("cpu")
|
|
_pop_torch_function_stack()
|
|
|
|
fn(torch.ones(2, 2))
|
|
_push_on_torch_function_stack(m1)
|
|
|
|
stack = _get_current_function_mode_stack()
|
|
self.assertIsInstance(stack[0], DeviceContext)
|
|
self.assertEqual(stack[0].device, torch.device("cpu"))
|
|
self.assertIs(stack[1], m)
|
|
self.assertIs(stack[2], m1)
|
|
|
|
def test_stack_state_clear_default_device(self):
|
|
@torch.compile(fullgraph=True)
|
|
def fn(x):
|
|
torch.set_default_device(None)
|
|
return x + 1
|
|
|
|
fn(torch.ones(2, 2))
|
|
stack = _get_current_function_mode_stack()
|
|
self.assertEqual(len(stack), 0)
|
|
|
|
m = BaseTorchFunctionMode()
|
|
m1 = BaseTorchFunctionMode()
|
|
|
|
# Stack populated, add device
|
|
with m, m1:
|
|
|
|
@torch.compile(fullgraph=True)
|
|
def fn(x):
|
|
torch.set_default_device("cpu")
|
|
torch.set_default_device(None)
|
|
torch.set_default_device("cpu")
|
|
return x + 1
|
|
|
|
fn(torch.ones(2, 2))
|
|
stack = _get_current_function_mode_stack()
|
|
self.assertEqual(stack[0].device, torch.device("cpu"))
|
|
self.assertIs(stack[1], m)
|
|
self.assertIs(stack[2], m1)
|
|
|
|
# Stack populated, remove device
|
|
torch.set_default_device("cpu")
|
|
with m, m1:
|
|
|
|
@torch.compile(fullgraph=True)
|
|
def fn(x):
|
|
torch.set_default_device(None)
|
|
return x + 1
|
|
|
|
fn(torch.ones(2, 2))
|
|
stack = _get_current_function_mode_stack()
|
|
self.assertIs(stack[0], m)
|
|
self.assertIs(stack[1], m1)
|
|
|
|
@torch.compile(fullgraph=True)
|
|
def fn(x):
|
|
torch.set_default_device("cpu")
|
|
torch.set_default_device("cpu")
|
|
return x + 1
|
|
|
|
fn(torch.ones(2, 2))
|
|
stack = _get_current_function_mode_stack()
|
|
self.assertEqual(stack[0].device, torch.device("cpu"))
|
|
torch.set_default_device(None)
|
|
|
|
def test_pop_torch_function_mode(self):
|
|
m = BaseTorchFunctionMode()
|
|
with m:
|
|
|
|
@torch.compile(fullgraph=True)
|
|
def fn(x):
|
|
_pop_torch_function_stack()
|
|
return x + 1
|
|
|
|
fn(torch.ones(2, 2))
|
|
|
|
self.assertEqual(_len_torch_function_stack(), 0)
|
|
# reset stack so __exit__ doesn't crash
|
|
_push_on_torch_function_stack(m)
|
|
|
|
self.assertEqual(_len_torch_function_stack(), 0)
|
|
|
|
def test_error_empty_stack_pop_torch_function_mode(self):
|
|
@torch.compile(fullgraph=True)
|
|
def fn(x):
|
|
_pop_torch_function_stack()
|
|
return x + 1
|
|
|
|
self.assertRaisesRegex(
|
|
torch._dynamo.exc.Unsupported,
|
|
"Popping from an empty torch function mode stack",
|
|
lambda: fn(torch.ones(2, 2)),
|
|
)
|
|
|
|
def test_push_torch_function_mode(self):
|
|
m = BaseTorchFunctionMode()
|
|
with m:
|
|
|
|
@torch.compile(fullgraph=True)
|
|
def fn(x, m):
|
|
_push_on_torch_function_stack(m)
|
|
return x + 1
|
|
|
|
fn(torch.ones(2, 2), m)
|
|
|
|
self.assertEqual(_len_torch_function_stack(), 2)
|
|
# reset stack state
|
|
_pop_torch_function_stack()
|
|
|
|
self.assertEqual(_len_torch_function_stack(), 0)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|