mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
The original motivation for MYPYINDUCTOR was a faster type checking configuration that only checked a subset of files. With the removal of `follow_imports = ignore`, we are now able to use dmypy to do fast incremental typechecking, eliminating the need for this. Perhaps erroneously, when I tee'ed up this PR I elected to delete the `follow_imports = skip` designations in the mypy-inductor.ini. This lead to a number of extra type error suppressions that I manually edited. You will need to review. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/118432 Approved by: https://github.com/Skylion007 ghstack dependencies: #118414, #118418
713 lines
24 KiB
Python
713 lines
24 KiB
Python
# mypy: ignore-errors
|
|
|
|
import dataclasses
|
|
import inspect
|
|
from typing import Callable, Dict, List, Optional
|
|
|
|
import torch._C
|
|
from torch._guards import Guard
|
|
|
|
from .. import variables
|
|
from ..bytecode_transformation import create_call_function, create_instruction
|
|
from ..device_interface import get_interface_for_device
|
|
from ..exc import unimplemented, Unsupported
|
|
from ..guards import GuardBuilder, install_guard
|
|
from ..source import AttrSource, GlobalStateSource
|
|
from .base import VariableTracker
|
|
from .functions import (
|
|
NestedUserFunctionVariable,
|
|
UserFunctionVariable,
|
|
UserMethodVariable,
|
|
WrappedUserFunctionVariable,
|
|
WrappedUserMethodVariable,
|
|
)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class ContextMangerState:
|
|
"""
|
|
Mutating `self` in VariableTracker is not allowed because we copy
|
|
them. This is a mutable container pointed to by context managers
|
|
that won't get copied, so it is safe to mutate.
|
|
"""
|
|
|
|
cleanup_fn: Optional[Callable] = None
|
|
proxy: Optional[torch.fx.Proxy] = None
|
|
|
|
def cleanup(self):
|
|
if self.cleanup_fn is not None:
|
|
self.cleanup_fn()
|
|
self.cleanup_fn = None
|
|
|
|
def cleanup_assert(self):
|
|
assert self.cleanup_fn, "multiple exits?"
|
|
self.cleanup()
|
|
|
|
|
|
class ContextWrappingVariable(VariableTracker):
|
|
_nonvar_fields = {
|
|
"cm_obj",
|
|
"target_values",
|
|
"initial_values",
|
|
"state",
|
|
*VariableTracker._nonvar_fields,
|
|
}
|
|
|
|
def __init__(self, target_values, initial_values=None, *, state=None, **kwargs):
|
|
super().__init__(**kwargs)
|
|
self.target_values = target_values
|
|
self.initial_values = initial_values
|
|
self.state = ContextMangerState() if state is None else state
|
|
|
|
def enter(self, tx):
|
|
self._call_func(tx, self.target_values)
|
|
self.set_cleanup_hook(tx)
|
|
return variables.ConstantVariable.create(None)
|
|
|
|
def set_cleanup_hook(self, tx, fn=None):
|
|
if fn is None:
|
|
|
|
def fn():
|
|
self._call_func(tx, self.initial_values)
|
|
|
|
self.state.cleanup_fn = fn
|
|
tx.output.add_cleanup_hook(self.state.cleanup)
|
|
|
|
def exit(self, tx, *args):
|
|
self.state.cleanup_assert()
|
|
return variables.ConstantVariable.create(None)
|
|
|
|
def reconstruct(self, codegen):
|
|
attr_source = AttrSource(
|
|
codegen.tx.import_source(self.module_name()), self.fn_name()
|
|
)
|
|
return attr_source.reconstruct(codegen)
|
|
|
|
def module_name(self):
|
|
raise NotImplementedError("module_name called on base")
|
|
|
|
def fn_name(self):
|
|
raise NotImplementedError("fn_name called on base")
|
|
|
|
def call_function(
|
|
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
|
) -> "VariableTracker":
|
|
assert len(args) == 1
|
|
if isinstance(args[0], NestedUserFunctionVariable):
|
|
args[0] = UserFunctionVariable(args[0].get_function())
|
|
assert isinstance(args[0], (UserMethodVariable, UserFunctionVariable))
|
|
|
|
if isinstance(args[0], UserMethodVariable):
|
|
return WrappedUserMethodVariable(args[0], self)
|
|
|
|
if isinstance(args[0], UserFunctionVariable):
|
|
return WrappedUserFunctionVariable(args[0], self)
|
|
|
|
|
|
class GenericContextWrappingVariable(ContextWrappingVariable):
|
|
def __init__(self, target_values, initial_values=None, *, cm_obj=None, **kwargs):
|
|
assert cm_obj is not None
|
|
super().__init__(
|
|
target_values=target_values, initial_values=initial_values, **kwargs
|
|
)
|
|
self.cm_obj = cm_obj
|
|
|
|
def enter(self, tx):
|
|
source = None if self.source is None else AttrSource(self.source, "__enter__")
|
|
try:
|
|
return variables.UserMethodVariable(
|
|
self.cm_obj.__enter__.__func__,
|
|
variables.UserDefinedObjectVariable(self.cm_obj),
|
|
source=source,
|
|
).call_function(tx, [], {})
|
|
except Unsupported as e:
|
|
raise unimplemented(
|
|
f"Unsupported context manager {self.cm_obj}'s __enter__ function"
|
|
) from e
|
|
|
|
def exit(self, tx, *args):
|
|
source = None if self.source is None else AttrSource(self.source, "__exit__")
|
|
try:
|
|
x = variables.UserMethodVariable(
|
|
self.cm_obj.__exit__.__func__,
|
|
variables.UserDefinedObjectVariable(self.cm_obj),
|
|
source=source,
|
|
).call_function(
|
|
tx,
|
|
[
|
|
variables.ConstantVariable.create(None),
|
|
variables.ConstantVariable.create(None),
|
|
variables.ConstantVariable.create(None),
|
|
],
|
|
{},
|
|
)
|
|
except Unsupported as e:
|
|
raise unimplemented(
|
|
f"Unsupported context manager {self.cm_obj}'s __exit__ function"
|
|
) from e
|
|
|
|
tx.generic_context_manager_depth -= 1
|
|
return x
|
|
|
|
|
|
class VmapIncrementNestingCtxManagerVariable(ContextWrappingVariable):
|
|
"""represents torch VMap increment/decrement nesting"""
|
|
|
|
# A guard is needed as the vmap level is baked into the torch FX graph
|
|
# generated. This is fine if vmap is only called from within the function
|
|
# being compiled. But the FX graph may be invalid in the case of a vmap
|
|
# call from eager that calls the compiled function, as the vmap levels
|
|
# may be different.
|
|
_guards_singleton = Guard(
|
|
GlobalStateSource(), GuardBuilder.FUNCTORCH_CURRENT_LEVEL_MATCH
|
|
)
|
|
|
|
@staticmethod
|
|
def create(tx, target_values, **kwargs):
|
|
var = VmapIncrementNestingCtxManagerVariable(
|
|
target_values=target_values,
|
|
initial_values=None,
|
|
**kwargs,
|
|
)
|
|
return var
|
|
|
|
def enter(self, tx):
|
|
install_guard(self._guards_singleton)
|
|
batch_size, randomness = self.target_values
|
|
vmap_level = torch._C._functorch._vmap_increment_nesting(batch_size, randomness)
|
|
self.set_cleanup_hook(tx, lambda: torch._C._functorch._vmap_decrement_nesting())
|
|
self.state.proxy = tx.output.create_node(
|
|
"call_function",
|
|
torch._C._functorch._vmap_increment_nesting,
|
|
(batch_size, randomness),
|
|
{},
|
|
)
|
|
return variables.ConstantVariable.create(vmap_level)
|
|
|
|
def exit(self, tx, *args):
|
|
self.state.cleanup()
|
|
tx.output.create_node(
|
|
"call_function", torch._C._functorch._vmap_decrement_nesting, (), {}
|
|
)
|
|
return variables.ConstantVariable.create(None)
|
|
|
|
|
|
class GradModeVariable(ContextWrappingVariable):
|
|
"""represents torch.{no_grad,enable_grad,set_grad_mode}()"""
|
|
|
|
_guards_singleton = Guard(GlobalStateSource(), GuardBuilder.GRAD_MODE)
|
|
|
|
@staticmethod
|
|
def create(tx, target_value, initialized=False, **kwargs):
|
|
var = GradModeVariable(
|
|
target_values=[target_value],
|
|
initial_values=[torch.is_grad_enabled()],
|
|
**kwargs,
|
|
)
|
|
if initialized:
|
|
var._call_func(tx, var.target_values)
|
|
return var
|
|
|
|
def __init__(self, target_values, initial_values=None, initialized=True, **kwargs):
|
|
super().__init__(
|
|
target_values=target_values, initial_values=initial_values, **kwargs
|
|
)
|
|
install_guard(self._guards_singleton)
|
|
|
|
def enter(self, tx):
|
|
self._call_func(tx, self.target_values)
|
|
return variables.ConstantVariable.create(None)
|
|
|
|
def exit(self, tx, *args):
|
|
self._call_func(tx, self.initial_values)
|
|
return variables.ConstantVariable.create(None)
|
|
|
|
def call_function(
|
|
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
|
):
|
|
self._call_func(tx, self.initial_values) # undo eager initialization
|
|
return super().call_function(tx, args, kwargs)
|
|
|
|
def _call_func(self, tx, values):
|
|
assert len(values) == 1
|
|
value = values[0]
|
|
# Coalesce grad mode mutations
|
|
if torch.is_grad_enabled() != value:
|
|
tx.output.create_node(
|
|
"call_function", torch._C._set_grad_enabled, (value,), {}
|
|
)
|
|
torch._C._set_grad_enabled(value)
|
|
|
|
def module_name(self):
|
|
return "torch"
|
|
|
|
def fn_name(self):
|
|
return "set_grad_enabled"
|
|
|
|
|
|
class InferenceModeVariable(ContextWrappingVariable):
|
|
@staticmethod
|
|
def create(tx, target_values, **kwargs):
|
|
var = InferenceModeVariable(
|
|
target_values, initial_values=torch.is_inference_mode_enabled(), **kwargs
|
|
)
|
|
return var
|
|
|
|
def __init__(
|
|
self,
|
|
target_values,
|
|
initial_values=None,
|
|
**kwargs,
|
|
):
|
|
if initial_values is None:
|
|
# This must be called here since function defaults are evaluated at import time
|
|
initial_values = torch.is_inference_mode_enabled()
|
|
super().__init__(
|
|
target_values=target_values, initial_values=initial_values, **kwargs
|
|
)
|
|
self.target_values = target_values
|
|
|
|
def exit(self, tx, *args):
|
|
self.state.cleanup_assert()
|
|
tx.output.create_node(
|
|
"call_function",
|
|
torch.autograd.grad_mode._exit_inference_mode,
|
|
(self.state.proxy,),
|
|
{},
|
|
)
|
|
|
|
def enter(self, tx):
|
|
ctx = torch.autograd.grad_mode._enter_inference_mode(self.target_values)
|
|
self.set_cleanup_hook(
|
|
tx, lambda: torch.autograd.grad_mode._exit_inference_mode(ctx)
|
|
)
|
|
self.state.proxy = tx.output.create_node(
|
|
"call_function",
|
|
torch.autograd.grad_mode._enter_inference_mode,
|
|
(self.target_values,),
|
|
{},
|
|
)
|
|
|
|
def module_name(self):
|
|
return "torch.inference_mode"
|
|
|
|
def fn_name(self):
|
|
return "inference_mode"
|
|
|
|
|
|
class TorchFunctionDisableVariable(ContextWrappingVariable):
|
|
"""represents whether torch function overrides are enabled or not"""
|
|
|
|
_guards_singleton = Guard(GlobalStateSource(), GuardBuilder.TORCH_FUNCTION_STATE)
|
|
|
|
@staticmethod
|
|
def create(tx, **kwargs):
|
|
var = TorchFunctionDisableVariable(
|
|
target_values=[False],
|
|
initial_values=[tx.output.torch_function_enabled],
|
|
**kwargs,
|
|
)
|
|
# mlazos: I think this is here to make sure we don't reinvoke on clone()
|
|
var._call_func(tx, [False])
|
|
var.set_cleanup_hook(tx)
|
|
return var
|
|
|
|
def __init__(self, target_values, initial_values=None, **kwargs):
|
|
super().__init__(
|
|
target_values=target_values, initial_values=initial_values, **kwargs
|
|
)
|
|
install_guard(self._guards_singleton)
|
|
|
|
def enter(self, tx):
|
|
return variables.ConstantVariable.create(None)
|
|
|
|
def _call_func(self, tx, values):
|
|
assert len(values) == 1
|
|
tx.output.set_torch_function_state(values[0])
|
|
|
|
|
|
class DeterministicAlgorithmsVariable(ContextWrappingVariable):
|
|
"""represents torch.{are_deterministic_algorithms_enabled,use_deterministic_algorithms}()"""
|
|
|
|
_guards_singleton = Guard(
|
|
GlobalStateSource(), GuardBuilder.DETERMINISTIC_ALGORITHMS
|
|
)
|
|
|
|
@staticmethod
|
|
def create(tx, target_value, **kwargs):
|
|
var = DeterministicAlgorithmsVariable(
|
|
target_values=[target_value],
|
|
initial_values=[torch.are_deterministic_algorithms_enabled()],
|
|
**kwargs,
|
|
)
|
|
var._call_func(tx, [target_value])
|
|
var.set_cleanup_hook(tx)
|
|
return var
|
|
|
|
def __init__(self, target_values, initial_values=None, **kwargs):
|
|
super().__init__(
|
|
target_values=target_values, initial_values=initial_values, **kwargs
|
|
)
|
|
install_guard(self._guards_singleton)
|
|
|
|
def enter(self, tx):
|
|
return variables.ConstantVariable.create(None)
|
|
|
|
def _call_func(self, tx, values):
|
|
assert len(values) == 1
|
|
value = values[0]
|
|
tx.output.create_node(
|
|
"call_function", torch._C._set_deterministic_algorithms, (value,), {}
|
|
),
|
|
torch._C._set_deterministic_algorithms(value)
|
|
|
|
def module_name(self):
|
|
return "torch"
|
|
|
|
def fn_name(self):
|
|
return "use_deterministic_algorithms"
|
|
|
|
|
|
class DisabledSavedTensorsHooksVariable(ContextWrappingVariable):
|
|
"""represents torch.autograd.graph.disable_saved_tensors_hook."""
|
|
|
|
@staticmethod
|
|
def create(tx, target_value, **kwargs):
|
|
var = DisabledSavedTensorsHooksVariable(
|
|
target_values=[target_value],
|
|
initial_values=[
|
|
torch._C._autograd._saved_tensors_hooks_get_disabled_error_message()
|
|
],
|
|
**kwargs,
|
|
)
|
|
var._call_func(tx, [target_value])
|
|
var.set_cleanup_hook(tx)
|
|
return var
|
|
|
|
def __init__(self, target_values, initial_values=None, **kwargs):
|
|
super().__init__(
|
|
target_values=target_values, initial_values=initial_values, **kwargs
|
|
)
|
|
|
|
def enter(self, tx):
|
|
return variables.ConstantVariable.create(None)
|
|
|
|
def _call_func(self, tx, values):
|
|
assert len(values) == 1
|
|
value = values[0]
|
|
if value is not None:
|
|
# Disable `saved_tensors_hooks` with message (`value`)
|
|
# OR
|
|
# we are exiting this context and restoring the previous message.
|
|
tx.output.create_node(
|
|
"call_function",
|
|
torch._C._autograd._saved_tensors_hooks_disable,
|
|
(value,),
|
|
{},
|
|
)
|
|
torch._C._autograd._saved_tensors_hooks_disable(value)
|
|
else:
|
|
# We are exiting this context and if prev_message was None, we re-enable `saved_tensors_hooks`.
|
|
tx.output.create_node(
|
|
"call_function", torch._C._autograd._saved_tensors_hooks_enable, (), {}
|
|
)
|
|
torch._C._autograd._saved_tensors_hooks_enable()
|
|
|
|
def module_name(self):
|
|
return "torch.autograd.graph"
|
|
|
|
def fn_name(self):
|
|
return "disable_saved_tensors_hooks"
|
|
|
|
|
|
class AutocastModeVariable(ContextWrappingVariable):
|
|
@staticmethod
|
|
def create(func, args, kwargs):
|
|
assert func in [
|
|
torch.amp.autocast_mode.autocast,
|
|
torch.cuda.amp.autocast,
|
|
torch.cpu.amp.autocast,
|
|
]
|
|
# device_type : str,
|
|
# dtype : Optional[_dtype] = None,
|
|
# enabled : bool = True,
|
|
# cache_enabled : Optional[bool] = None):cache_enabled
|
|
bound_args = inspect.signature(func).bind(*args, **kwargs)
|
|
bound_args.apply_defaults()
|
|
target_values = []
|
|
kwargs.clear()
|
|
|
|
for key in ["device_type", "dtype", "enabled", "cache_enabled"]:
|
|
if key == "device_type" and func in [
|
|
torch.cuda.amp.autocast,
|
|
torch.cpu.amp.autocast,
|
|
]:
|
|
arg = "cuda" if func is torch.cuda.amp.autocast else "cpu"
|
|
else:
|
|
arg = bound_args.arguments[key]
|
|
if isinstance(arg, VariableTracker):
|
|
target_values.append(arg.as_python_constant())
|
|
else:
|
|
target_values.append(arg)
|
|
|
|
var = AutocastModeVariable(target_values, initial_values=None, **kwargs)
|
|
return var
|
|
|
|
def __init__(self, target_values, initial_values=None, **kwargs):
|
|
super().__init__(
|
|
target_values=target_values, initial_values=initial_values, **kwargs
|
|
)
|
|
self.target_values = target_values
|
|
|
|
def exit(self, tx, *args):
|
|
self.state.cleanup_assert()
|
|
tx.output.create_node(
|
|
"call_function", torch.amp._exit_autocast, (self.state.proxy,), {}
|
|
)
|
|
|
|
def enter(self, tx):
|
|
ctx = torch.amp._enter_autocast(*self.target_values)
|
|
self.set_cleanup_hook(tx, lambda: torch.amp._exit_autocast(ctx))
|
|
self.state.proxy = tx.output.create_node(
|
|
"call_function", torch.amp._enter_autocast, (*self.target_values,), {}
|
|
)
|
|
|
|
def module_name(self):
|
|
return "torch.amp.autocast_mode"
|
|
|
|
def fn_name(self):
|
|
return "autocast"
|
|
|
|
|
|
class NullContextVariable(ContextWrappingVariable):
|
|
"""
|
|
This class represents Python contextlib.nullcontext.
|
|
It's used as a placeholder for other context managers that Dynamo doesn't
|
|
support yet, e.g, torch.autograd.profiler.record_function.
|
|
"""
|
|
|
|
def __init__(self, target_values=None, **kwargs):
|
|
super().__init__(target_values=target_values, **kwargs)
|
|
|
|
def enter(self, tx):
|
|
return variables.ConstantVariable.create(None)
|
|
|
|
def exit(self, tx, *args):
|
|
return variables.ConstantVariable.create(None)
|
|
|
|
def module_name(self):
|
|
return "contextlib"
|
|
|
|
def fn_name(self):
|
|
return "nullcontext"
|
|
|
|
|
|
class StreamContextVariable(ContextWrappingVariable):
|
|
@staticmethod
|
|
def create(tx, target_value, **kwargs):
|
|
from .builder import wrap_fx_proxy_cls
|
|
|
|
current_stream_method = get_interface_for_device(
|
|
target_value.device
|
|
).current_stream
|
|
current_stream = wrap_fx_proxy_cls(
|
|
StreamVariable,
|
|
tx,
|
|
tx.output.create_proxy(
|
|
"call_function",
|
|
current_stream_method,
|
|
(None,),
|
|
{},
|
|
),
|
|
)
|
|
return StreamContextVariable(
|
|
target_values=[target_value],
|
|
initial_values=[current_stream],
|
|
device=target_value.device,
|
|
**kwargs,
|
|
)
|
|
|
|
def __init__(self, target_values, device, initial_values=None, **kwargs):
|
|
super().__init__(
|
|
target_values=target_values, initial_values=initial_values, **kwargs
|
|
)
|
|
self.device = device
|
|
self.set_stream = get_interface_for_device(self.device).set_stream
|
|
self.set_stream_id = get_interface_for_device(self.device)._set_stream_by_id
|
|
|
|
def enter(self, tx):
|
|
# stream generated inside the traced function
|
|
if self.target_values[0].as_proxy() is not None:
|
|
tx.output.create_proxy(
|
|
"call_function",
|
|
self.set_stream,
|
|
(self.target_values[0].as_proxy(),),
|
|
{},
|
|
)
|
|
# stream passed from outside the traced function
|
|
else:
|
|
stream = self.target_values[0].value
|
|
tx.output.create_proxy(
|
|
"call_function",
|
|
self.set_stream_id,
|
|
(stream.stream_id, stream.device_index, stream.device_type),
|
|
{},
|
|
)
|
|
self.set_stream(self.target_values[0].value)
|
|
self.set_cleanup_hook(tx, lambda: self.set_stream(self.initial_values[0].value))
|
|
|
|
def exit(self, tx, *args):
|
|
tx.output.create_proxy(
|
|
"call_function",
|
|
self.set_stream,
|
|
(self.initial_values[0].as_proxy(),),
|
|
{},
|
|
)
|
|
self.state.cleanup_assert()
|
|
|
|
|
|
class StreamVariable(VariableTracker):
|
|
def __init__(self, proxy, value, device, **kwargs):
|
|
if proxy is not None and "example_value" in proxy.node.meta:
|
|
assert proxy.node.meta["example_value"] == value
|
|
assert (
|
|
value.device.type == device.type
|
|
), "stream value is not equal to the passed device"
|
|
super().__init__(**kwargs)
|
|
self.proxy = proxy
|
|
self.value = value
|
|
self.device = device
|
|
|
|
def call_method(
|
|
self,
|
|
tx,
|
|
name,
|
|
args: "List[VariableTracker]",
|
|
kwargs: "Dict[str, VariableTracker]",
|
|
) -> "VariableTracker":
|
|
assert hasattr(self.value, name), f"no stream method found named {name}"
|
|
assert name in [
|
|
"wait_stream",
|
|
"synchronize",
|
|
"query",
|
|
"record_event",
|
|
"wait_event",
|
|
], f" unsupported stream method {name}"
|
|
|
|
from ..utils import proxy_args_kwargs
|
|
from .builder import wrap_fx_proxy_cls
|
|
|
|
if name in ("wait_stream", "synchronize", "wait_event"):
|
|
tx.output.create_proxy(
|
|
"call_method", name, *proxy_args_kwargs([self] + args, kwargs)
|
|
)
|
|
return variables.ConstantVariable(None)
|
|
elif name == "query":
|
|
return wrap_fx_proxy_cls(
|
|
target_cls=variables.ConstantVariable,
|
|
tx=tx,
|
|
proxy=tx.output.create_proxy(
|
|
"call_method", name, *proxy_args_kwargs([self] + args, kwargs)
|
|
),
|
|
)
|
|
elif name == "record_event":
|
|
return wrap_fx_proxy_cls(
|
|
target_cls=EventVariable,
|
|
tx=tx,
|
|
proxy=tx.output.create_proxy(
|
|
"call_method", name, *proxy_args_kwargs([self] + args, kwargs)
|
|
),
|
|
)
|
|
else:
|
|
unimplemented(self.device + " stream method " + name + " unsupported")
|
|
|
|
def as_proxy(self):
|
|
return self.proxy
|
|
|
|
def reconstruct(self, codegen):
|
|
# If we got here, this stream is fully subsumed by the graph - this means it is
|
|
# not an input or global
|
|
assert not self.source
|
|
# Since we just proved that - for other such structures, like lists and dicts, reconstruction
|
|
# is fine and sound according to dynamo principles of treating collectives. However,
|
|
# streams are special in that we want to preserve the identity of the stream as the same as in the graph
|
|
# Normally, we would do this via codegen for the proxy mapping to an output - we cannot do this yet, as we do not
|
|
# yet have a plan for how we want to handle the case where the stream is used as an input or an output. Pending
|
|
# design, to unblock current work, we lift the stream into a global and then codegen bytecode to load it from there.
|
|
prefix = f"_stream_{self.device}"
|
|
name = codegen.tx.output.install_global_by_id(prefix, self.value)
|
|
|
|
return [codegen.create_load_global(name, push_null=False, add=True)]
|
|
|
|
|
|
class EventVariable(VariableTracker):
|
|
def __init__(self, proxy, value, **kwargs):
|
|
if proxy is not None and "example_value" in proxy.node.meta:
|
|
assert proxy.node.meta["example_value"] == value
|
|
super().__init__(**kwargs)
|
|
self.proxy = proxy
|
|
self.value = value
|
|
|
|
def call_method(
|
|
self,
|
|
tx,
|
|
name,
|
|
args: "List[VariableTracker]",
|
|
kwargs: "Dict[str, VariableTracker]",
|
|
) -> "VariableTracker":
|
|
from ..utils import proxy_args_kwargs
|
|
from .builder import wrap_fx_proxy_cls
|
|
|
|
if name in ("wait", "record", "synchronize"):
|
|
tx.output.create_proxy(
|
|
"call_method", name, *proxy_args_kwargs([self] + args, kwargs)
|
|
)
|
|
return variables.ConstantVariable(None)
|
|
elif name == "query":
|
|
return wrap_fx_proxy_cls(
|
|
target_cls=variables.ConstantVariable,
|
|
tx=tx,
|
|
proxy=tx.output.create_proxy(
|
|
"call_method", name, *proxy_args_kwargs([self] + args, kwargs)
|
|
),
|
|
)
|
|
else:
|
|
unimplemented(f"event method {name} unsupported")
|
|
|
|
def as_proxy(self):
|
|
return self.proxy
|
|
|
|
|
|
class WithExitFunctionVariable(VariableTracker):
|
|
def __init__(self, ctx: ContextWrappingVariable, target, **kwargs):
|
|
super().__init__(**kwargs)
|
|
assert isinstance(ctx, ContextWrappingVariable)
|
|
self.ctx = ctx
|
|
self.target = target
|
|
|
|
def call_function(
|
|
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
|
) -> "VariableTracker":
|
|
assert not kwargs
|
|
return self.ctx.exit(tx, *args)
|
|
|
|
def reconstruct(self, codegen):
|
|
# Note here we reconstruct the context manager rather than the
|
|
# exit function. The handler generated by BlockStackEntry
|
|
# will re-enter the context in the resume function.
|
|
output = AttrSource(
|
|
codegen.tx.import_source(self.ctx.module_name()), self.ctx.fn_name()
|
|
).reconstruct(codegen)
|
|
|
|
if codegen.tx.output.partial_convert:
|
|
loads = [codegen.create_load_const(val) for val in self.ctx.target_values]
|
|
output.extend(loads)
|
|
output.extend(
|
|
[
|
|
*create_call_function(len(loads), True),
|
|
create_instruction("SETUP_WITH", target=self.target),
|
|
create_instruction("POP_TOP"),
|
|
]
|
|
)
|
|
return output
|