mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ca] make torch.compile API respect ambient disable contexts (#155473)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155473 Approved by: https://github.com/jansel
This commit is contained in:
parent
be124a61a4
commit
87b002b6fb
|
|
@ -99,6 +99,7 @@ def reset():
|
||||||
torch._logging.set_logs(compiled_autograd_verbose=False)
|
torch._logging.set_logs(compiled_autograd_verbose=False)
|
||||||
config.compiled_autograd = False
|
config.compiled_autograd = False
|
||||||
compiled_autograd.reset()
|
compiled_autograd.reset()
|
||||||
|
torch._dynamo.utils.counters.clear()
|
||||||
|
|
||||||
|
|
||||||
class TestCompiledAutograd(TestCase):
|
class TestCompiledAutograd(TestCase):
|
||||||
|
|
@ -706,6 +707,44 @@ main()
|
||||||
self.assertEqual(expected, actual)
|
self.assertEqual(expected, actual)
|
||||||
self.assertEqual(counters["compiled_autograd"]["captures"], 2)
|
self.assertEqual(counters["compiled_autograd"]["captures"], 2)
|
||||||
|
|
||||||
|
@parametrize("api", ("compile", "optimize"))
|
||||||
|
@parametrize("backend", ("eager", "aot_eager", "inductor"))
|
||||||
|
def test_compile_api_disable(self, api, backend):
|
||||||
|
def wrap(fn, backend):
|
||||||
|
if api == "compile":
|
||||||
|
return torch.compile(fn, backend=backend)
|
||||||
|
elif api == "optimize":
|
||||||
|
return torch._dynamo.optimize(backend)(fn)
|
||||||
|
|
||||||
|
def fn(model, inputs):
|
||||||
|
res = []
|
||||||
|
for inp in inputs:
|
||||||
|
result = model(inp).sum()
|
||||||
|
result.backward()
|
||||||
|
res.append(model[0].weight.grad)
|
||||||
|
res.append(model[0].bias.grad)
|
||||||
|
model.zero_grad()
|
||||||
|
return res
|
||||||
|
|
||||||
|
torch.manual_seed(123)
|
||||||
|
model = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(4, 4),
|
||||||
|
torch.nn.Sigmoid(),
|
||||||
|
)
|
||||||
|
inputs = [
|
||||||
|
torch.randn([1, 4]),
|
||||||
|
torch.randn([2, 4]),
|
||||||
|
torch.randn([3, 4]),
|
||||||
|
]
|
||||||
|
|
||||||
|
expected = fn(model, inputs)
|
||||||
|
with config.patch(compiled_autograd=True):
|
||||||
|
compiled_fn = wrap(fn, backend)
|
||||||
|
with torch._dynamo.compiled_autograd._disable():
|
||||||
|
actual = compiled_fn(model, inputs)
|
||||||
|
self.assertEqual(expected, actual)
|
||||||
|
self.assertTrue("compiled_autograd" not in counters)
|
||||||
|
|
||||||
@parametrize("backend", ("eager", "aot_eager", "inductor"))
|
@parametrize("backend", ("eager", "aot_eager", "inductor"))
|
||||||
def test_optimize_assert(self, backend):
|
def test_optimize_assert(self, backend):
|
||||||
# can be merged into the test above once we support
|
# can be merged into the test above once we support
|
||||||
|
|
|
||||||
|
|
@ -1373,9 +1373,11 @@ compiled_autograd_enabled_force_eager = False
|
||||||
# global flag to check if we are processing graphs produced from a compiled autograd graph
|
# global flag to check if we are processing graphs produced from a compiled autograd graph
|
||||||
in_compiled_autograd_region = False
|
in_compiled_autograd_region = False
|
||||||
|
|
||||||
|
active_disable_ctx = False
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def _enable(compiler_fn, dynamic: bool = True):
|
def _enable(compiler_fn, dynamic: bool = True, ignore_active_disable_ctx=True):
|
||||||
# The entrypoint to enable CA.
|
# The entrypoint to enable CA.
|
||||||
# It is recommended to enable via `torch._dynamo.config.compiled_autograd = True` rather
|
# It is recommended to enable via `torch._dynamo.config.compiled_autograd = True` rather
|
||||||
# than using this context manager directly. If you are torch.compiling the corresponding
|
# than using this context manager directly. If you are torch.compiling the corresponding
|
||||||
|
|
@ -1396,6 +1398,9 @@ def _enable(compiler_fn, dynamic: bool = True):
|
||||||
# - dynamic: Whether compiled autograd will treat tensors in the autograd graph (params, activations) as dynamic.
|
# - dynamic: Whether compiled autograd will treat tensors in the autograd graph (params, activations) as dynamic.
|
||||||
# This doesn't affect the dynamic configuration of the compilation wrapper.
|
# This doesn't affect the dynamic configuration of the compilation wrapper.
|
||||||
|
|
||||||
|
if not ignore_active_disable_ctx and active_disable_ctx:
|
||||||
|
yield
|
||||||
|
else:
|
||||||
if dynamic:
|
if dynamic:
|
||||||
assert type(dynamic) is bool
|
assert type(dynamic) is bool
|
||||||
|
|
||||||
|
|
@ -1444,11 +1449,15 @@ def _disable():
|
||||||
) = torch._C._dynamo.compiled_autograd.set_autograd_compiler(None, False)
|
) = torch._C._dynamo.compiled_autograd.set_autograd_compiler(None, False)
|
||||||
global compiled_autograd_enabled
|
global compiled_autograd_enabled
|
||||||
compiled_autograd_enabled = False
|
compiled_autograd_enabled = False
|
||||||
|
global active_disable_ctx
|
||||||
|
if not active_disable_ctx:
|
||||||
|
active_disable_ctx = True
|
||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
if prior_compiler:
|
if prior_compiler:
|
||||||
compiled_autograd_enabled = True
|
compiled_autograd_enabled = True
|
||||||
|
active_disable_ctx = False
|
||||||
torch._C._dynamo.compiled_autograd.set_autograd_compiler(
|
torch._C._dynamo.compiled_autograd.set_autograd_compiler(
|
||||||
prior_compiler, prior_dynamic
|
prior_compiler, prior_dynamic
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -816,7 +816,7 @@ class OptimizeContext(_TorchDynamoContext):
|
||||||
assert rebuild_ctx is not None
|
assert rebuild_ctx is not None
|
||||||
compiler_fn = rebuild_ctx()
|
compiler_fn = rebuild_ctx()
|
||||||
ctx = torch._dynamo.compiled_autograd._enable(
|
ctx = torch._dynamo.compiled_autograd._enable(
|
||||||
compiler_fn, dynamic=_dynamic
|
compiler_fn, dynamic=_dynamic, ignore_active_disable_ctx=False
|
||||||
)
|
)
|
||||||
ctx.__enter__()
|
ctx.__enter__()
|
||||||
return functools.partial(ctx.__exit__, None, None, None)
|
return functools.partial(ctx.__exit__, None, None, None)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user