mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fixes #103613. A requirement for HigherOrderOperators is that after Dynamo capture, the body function should be functional (i.e. has no observable side effects). If the body function mutates a variable that is not local to the body, then we that should induce a graph break. This PR distinguish between MutableLocals created inside/outside body and adds relevant checks. (Design originally proposed by voznesenskym.) - We tag each mutable_local with an id that corresponds to where it came from. The mutable_local may represent an existing object that gets tracked by Dynamo or an object that is created while Dynamo is introspecting. - This id changes when we are introspecting the body of a HigherOrderOperator. - If Dynamo wants to perform a side effect using a mutable_local, we check its .scope field with the current scope id and raise Unsupported in the desired case (non-local mutation inside HigherOrderOperator body) - The id is a global thread_local variable. I can make this not a global variable, but it just takes some engineering time to thread a number through each of the various ways Dynamo can construct a mutable_local. Test Plan: - Add a bunch of new tests. Tests combinations of {global, nonlocal} x {number, Tensor, list, object, nn.Module} and asserts that HigherOrderOp falls back on those cases. Pull Request resolved: https://github.com/pytorch/pytorch/pull/104077 Approved by: https://github.com/voznesenskym, https://github.com/jansel
24 lines
610 B
Python
24 lines
610 B
Python
import contextlib
|
|
import threading
|
|
|
|
# Global variable to identify which SubgraphTracer we are in.
|
|
# It is sometimes difficult to find an InstructionTranslator to use.
|
|
_current_scope_id = threading.local()
|
|
|
|
|
|
def current_scope_id():
|
|
global _current_scope_id
|
|
if not hasattr(_current_scope_id, "value"):
|
|
_current_scope_id.value = 1
|
|
return _current_scope_id.value
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def enter_new_scope():
|
|
global _current_scope_id
|
|
try:
|
|
_current_scope_id.value = current_scope_id() + 1
|
|
yield
|
|
finally:
|
|
_current_scope_id.value = current_scope_id() - 1
|