mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo][copmile-time] Handle builtins first in LOAD_GLOBAL (#153458)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153458 Approved by: https://github.com/jansel
This commit is contained in:
parent
33a5179269
commit
c797f1285c
|
|
@ -116,7 +116,7 @@ from .utils import (
|
|||
proxy_args_kwargs,
|
||||
)
|
||||
from .variables.base import typestr, ValueMutationNew, VariableTracker
|
||||
from .variables.builder import FrameStateSizeEntry, wrap_fx_proxy
|
||||
from .variables.builder import FrameStateSizeEntry, VariableBuilder, wrap_fx_proxy
|
||||
from .variables.builtin import BuiltinVariable
|
||||
from .variables.constant import ConstantVariable
|
||||
from .variables.ctx_manager import (
|
||||
|
|
@ -1482,16 +1482,15 @@ class InstructionTranslatorBase(
|
|||
assert name in self.f_builtins
|
||||
self.exec_recorder.builtins[name] = self.f_builtins[name]
|
||||
|
||||
if name not in self.f_globals:
|
||||
return self.load_builtin(inst)
|
||||
|
||||
if name in self.symbolic_globals:
|
||||
variable = self.output.side_effects[self.symbolic_globals[name]]
|
||||
self.push(self.output.side_effects.load_global(variable, name))
|
||||
return
|
||||
|
||||
try:
|
||||
value = self.f_globals[name]
|
||||
except KeyError:
|
||||
return self.load_builtin(inst)
|
||||
|
||||
value = self.f_globals[name]
|
||||
self.push(VariableTracker.build(self, value, GlobalSource(name)))
|
||||
|
||||
@functools.cached_property
|
||||
|
|
@ -4031,9 +4030,8 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
|||
) # type: ignore[assignment]
|
||||
else:
|
||||
fglobals_value = _import_module(module_name)
|
||||
fglobals_vt = VariableTracker.build(self, fglobals_value, module_source)
|
||||
# realize the VT because we are going to send this to side effects
|
||||
fglobals_vt = fglobals_vt.realize()
|
||||
# Dont use lazy vt because we will do a setattr afterwards
|
||||
fglobals_vt = VariableBuilder(self, module_source)(fglobals_value)
|
||||
global_source = AttrSource(module_source, name)
|
||||
else:
|
||||
globals_name = self.output.install_global_by_id(
|
||||
|
|
@ -4041,29 +4039,26 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
|||
)
|
||||
globals_source = GlobalSource(globals_name)
|
||||
fglobals_value = self.f_globals # type: ignore[assignment]
|
||||
fglobals_vt = VariableTracker.build(self, fglobals_value, globals_source)
|
||||
# realize the VT because we are going to send this to side effects
|
||||
fglobals_vt = fglobals_vt.realize()
|
||||
# Dont use lazy vt because we will do a setattr afterwards
|
||||
fglobals_vt = VariableBuilder(self, globals_source)(fglobals_value)
|
||||
global_source = DictGetItemSource(globals_source, name) # type: ignore[assignment]
|
||||
return fglobals_value, fglobals_vt, global_source
|
||||
|
||||
def _load_global(self, inst):
|
||||
name = inst.argval
|
||||
if name not in self.f_globals:
|
||||
return self.load_builtin(inst)
|
||||
|
||||
if self.output.global_scope is self.f_globals:
|
||||
# If the global scope matches that of the root frame, use handler in
|
||||
# root frame instruction translator, to enforce consistency.
|
||||
super()._load_global(inst)
|
||||
else:
|
||||
name = inst.argval
|
||||
|
||||
_, fglobals_vt, global_source = self.get_globals_source_and_value(name)
|
||||
if self.output.side_effects.has_pending_mutation_of_attr(fglobals_vt, name):
|
||||
self.push(self.output.side_effects.load_attr(fglobals_vt, name))
|
||||
else:
|
||||
try:
|
||||
value = self.f_globals[name]
|
||||
except KeyError:
|
||||
return self.load_builtin(inst)
|
||||
|
||||
value = self.f_globals[name]
|
||||
self.push(VariableTracker.build(self, value, global_source))
|
||||
|
||||
def STORE_GLOBAL(self, inst):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user