mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
copy_graphstate is called a ton, this makes copy_graphstate a lot faster, helps with https://github.com/pytorch/torchdynamo/issues/1803 tag each graph node with a timestamp, when checkpointing store the timestamp, when restoring remove nodes older than the timestamp stored in the state. This essentially has the same behavior as the original impl, just doesn't copy the whole graph. Pull Request resolved: https://github.com/pytorch/pytorch/pull/89232 Approved by: https://github.com/jansel
706 lines
24 KiB
Python
706 lines
24 KiB
Python
import inspect
|
|
import sys
|
|
import types
|
|
from typing import Dict, List
|
|
|
|
import torch._C
|
|
|
|
from .. import config, variables
|
|
from ..bytecode_transformation import create_instruction
|
|
from ..exc import unimplemented
|
|
from ..guards import Guard, GuardBuilder, GuardSource
|
|
from ..source import AttrSource
|
|
from ..utils import identity, proxy_args_kwargs
|
|
from .base import VariableTracker
|
|
from .functions import (
|
|
UserFunctionVariable,
|
|
UserMethodVariable,
|
|
WrappedUserFunctionVariable,
|
|
WrappedUserMethodVariable,
|
|
)
|
|
|
|
|
|
class SuperVariable(VariableTracker):
|
|
def __init__(self, typevar, objvar=None, **kwargs):
|
|
super(SuperVariable, self).__init__(**kwargs)
|
|
self.typevar = typevar
|
|
self.objvar = objvar
|
|
|
|
def reconstruct(self, codegen):
|
|
codegen(variables.BuiltinVariable(super))
|
|
codegen(self.typevar)
|
|
if self.objvar is not None:
|
|
codegen(self.objvar)
|
|
return [create_instruction("CALL_FUNCTION", 2)]
|
|
else:
|
|
return [create_instruction("CALL_FUNCTION", 1)]
|
|
|
|
def const_getattr(self, tx, name):
|
|
assert self.objvar, "1-arg super not implemented"
|
|
search_type = self.typevar.as_python_constant()
|
|
|
|
# We default to the python type of the object. However,
|
|
# 1. If this is a `type`, then the original object represents the user
|
|
# defined type.
|
|
# 2. If this is `torch._C._TensorMeta`, the original object is the user
|
|
# defined type of a custom tensor subclass.
|
|
# TODO(future PR): figure out how to do this in a less hacky way
|
|
type_to_use = self.objvar.python_type()
|
|
if type_to_use is type or type_to_use is torch._C._TensorMeta:
|
|
type_to_use = self.objvar.value
|
|
|
|
# TODO(jansel): there is a small chance this could trigger user code, prevent that
|
|
return getattr(super(search_type, type_to_use), name)
|
|
|
|
def call_method(
|
|
self,
|
|
tx,
|
|
name,
|
|
args: "List[VariableTracker]",
|
|
kwargs: "Dict[str, VariableTracker]",
|
|
) -> "VariableTracker":
|
|
options = VariableTracker.propagate(
|
|
self, args, kwargs.values(), self.objvar, self.typevar
|
|
)
|
|
inner_fn = self.const_getattr(self, name)
|
|
if inner_fn is object.__init__:
|
|
return LambdaVariable(identity, **options)
|
|
elif isinstance(inner_fn, types.FunctionType):
|
|
return variables.UserFunctionVariable(inner_fn, **options).call_function(
|
|
tx, [self.objvar] + args, kwargs
|
|
)
|
|
elif isinstance(inner_fn, types.MethodType):
|
|
return variables.UserMethodVariable(
|
|
inner_fn.__func__, self.objvar, **options
|
|
).call_function(tx, args, kwargs)
|
|
else:
|
|
unimplemented(f"non-function or method super: {inner_fn}")
|
|
|
|
|
|
class UnknownVariable(VariableTracker):
|
|
"""
|
|
It could be anything!
|
|
"""
|
|
|
|
|
|
class ClosureVariable(UnknownVariable):
|
|
def __init__(self, name, **kwargs):
|
|
super(ClosureVariable, self).__init__(**kwargs)
|
|
self.name = name
|
|
|
|
def reconstruct(self, codegen):
|
|
return [codegen.create_load_closure(self.name)]
|
|
|
|
|
|
class NewCellVariable(VariableTracker):
|
|
def __init__(self, **kwargs):
|
|
super(NewCellVariable, self).__init__(**kwargs)
|
|
|
|
|
|
class NewGlobalVariable(VariableTracker):
|
|
def __init__(self, **kwargs):
|
|
super(NewGlobalVariable, self).__init__(**kwargs)
|
|
|
|
|
|
class ContextWrappingVariable(VariableTracker):
|
|
def __init__(self, target_values, initial_values=None, **kwargs):
|
|
super(ContextWrappingVariable, self).__init__(**kwargs)
|
|
self.target_values = target_values
|
|
self.initial_values = initial_values
|
|
self.recursively_contains = (
|
|
set()
|
|
) # This var doesn't contain any child vars and doesn't support clone() properly,
|
|
# so don't populate this automatically
|
|
|
|
def enter(self, tx):
|
|
self._call_func(tx, self.target_values)
|
|
return variables.ConstantVariable(None, **VariableTracker.propagate(self))
|
|
|
|
def exit(self, tx, *args):
|
|
self._call_func(tx, self.initial_values)
|
|
return variables.ConstantVariable(None, **VariableTracker.propagate(self))
|
|
|
|
def module_name(self):
|
|
return "torch"
|
|
|
|
def reconstruct(self, codegen, target_inst=None):
|
|
"""
|
|
Generate following Python Bytecode, with a `torch._C._set_grad_enable` call
|
|
Python 3.8
|
|
0 LOAD_GLOBAL 0 (torch)
|
|
2 LOAD_ATTR 1 (_C)
|
|
4 LOAD_METHOD 2 (_set_grad_enable)
|
|
6 LOAD_CONST 1 (False)
|
|
8 CALL_METHOD 1
|
|
10 POP_TOP
|
|
|
|
12 SETUP_FINALLY 10 (to 24)
|
|
|
|
14 LOAD_GLOBAL 3 (user_inst)
|
|
16 CALL_FUNCTION 0
|
|
18 POP_TOP
|
|
20 POP_BLOCK
|
|
22 BEGIN_FINALLY
|
|
|
|
24 LOAD_GLOBAL 0 (torch)
|
|
26 LOAD_ATTR 1 (_C)
|
|
28 LOAD_METHOD 2 (_set_grad_enable)
|
|
30 LOAD_CONST 2 (True)
|
|
32 CALL_METHOD 1
|
|
34 POP_TOP
|
|
36 END_FINALLY
|
|
38 LOAD_CONST 0 (None)
|
|
40 RETURN_VALUE
|
|
|
|
Instructions 0-10 and 24-34 call torch._C.set_grad_enable(True/False)
|
|
|
|
Python 3.9, 3.10
|
|
0 LOAD_GLOBAL 0 (torch)
|
|
2 LOAD_ATTR 1 (_C)
|
|
4 LOAD_METHOD 2 (_set_grad_enable)
|
|
6 LOAD_CONST 1 (False)
|
|
8 CALL_METHOD 1
|
|
10 POP_TOP
|
|
|
|
12 SETUP_FINALLY 22 (to 36)
|
|
|
|
14 LOAD_GLOBAL 3 (user_inst)
|
|
16 CALL_FUNCTION 0
|
|
18 POP_TOP
|
|
20 POP_BLOCK
|
|
|
|
22 LOAD_GLOBAL 0 (torch)
|
|
24 LOAD_ATTR 1 (_C)
|
|
26 LOAD_METHOD 2 (_set_grad_enable)
|
|
28 LOAD_CONST 2 (True)
|
|
30 CALL_METHOD 1
|
|
32 POP_TOP
|
|
|
|
34 JUMP_FORWARD 14 (to 50)
|
|
|
|
36 LOAD_GLOBAL 0 (torch)
|
|
38 LOAD_ATTR 1 (_C)
|
|
40 LOAD_METHOD 2 (_set_grad_enable)
|
|
42 LOAD_CONST 2 (True)
|
|
44 CALL_METHOD 1
|
|
46 POP_TOP
|
|
48 RERAISE
|
|
|
|
50 LOAD_CONST 0 (None)
|
|
52 RETURN_VALUE
|
|
|
|
"""
|
|
if self.target_values == self.initial_values:
|
|
return ([], [])
|
|
|
|
def set_context_insts(values):
|
|
global_torch_source = codegen.tx.import_source("torch")
|
|
attr_source = AttrSource(global_torch_source, self._func_name())
|
|
load_set_context_enabling_insts = attr_source.reconstruct(codegen)
|
|
|
|
loads = [codegen.create_load_const(val) for val in values]
|
|
|
|
return [
|
|
*load_set_context_enabling_insts,
|
|
*loads,
|
|
create_instruction("CALL_FUNCTION", len(values)),
|
|
create_instruction("POP_TOP"),
|
|
]
|
|
|
|
init_block = set_context_insts(self.target_values)
|
|
finally_block = set_context_insts(self.initial_values)
|
|
setup_final_inst = create_instruction("SETUP_FINALLY", target=finally_block[0])
|
|
prologue = init_block + [setup_final_inst]
|
|
|
|
# Generate the epilogue - starts with 20 POP_BLOCK and ends at 34 POP_TOP
|
|
if sys.version_info < (3, 9):
|
|
# Generate the prologue that ends with setup_finally
|
|
epilogue = [
|
|
create_instruction("POP_BLOCK"),
|
|
codegen.create_begin_finally(),
|
|
*finally_block,
|
|
create_instruction("END_FINALLY"),
|
|
]
|
|
else:
|
|
except_block = set_context_insts(self.initial_values)
|
|
epilogue = [
|
|
create_instruction("POP_BLOCK"),
|
|
*except_block,
|
|
create_instruction("JUMP_FORWARD", target=target_inst),
|
|
*finally_block,
|
|
create_instruction("RERAISE"),
|
|
]
|
|
|
|
return (prologue, epilogue)
|
|
|
|
def _call_func(self, tx, initial_values):
|
|
raise NotImplementedError("_call_func called on base")
|
|
|
|
def _func_name(self):
|
|
raise NotImplementedError("_func_name called on base")
|
|
|
|
def call_function(
|
|
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
|
) -> "VariableTracker":
|
|
assert len(args) == 1
|
|
assert isinstance(args[0], UserMethodVariable) or isinstance(
|
|
args[0], UserFunctionVariable
|
|
)
|
|
|
|
if isinstance(args[0], UserMethodVariable):
|
|
return WrappedUserMethodVariable(args[0], self)
|
|
|
|
if isinstance(args[0], UserFunctionVariable):
|
|
return WrappedUserFunctionVariable(args[0], self)
|
|
|
|
|
|
class GradModeVariable(ContextWrappingVariable):
|
|
"""represents torch.{no_grad,enable_grad,set_grad_mode}()"""
|
|
|
|
_guards_singleton = {Guard("", GuardSource.GLOBAL, GuardBuilder.GRAD_MODE)}
|
|
|
|
@staticmethod
|
|
def create(tx, target_value, **kwargs):
|
|
var = GradModeVariable(
|
|
target_values=[target_value],
|
|
initial_values=[torch.is_grad_enabled()],
|
|
**kwargs,
|
|
)
|
|
var._call_func(tx, [target_value])
|
|
return var
|
|
|
|
def __init__(self, target_values, initial_values=None, **kwargs):
|
|
super(GradModeVariable, self).__init__(
|
|
target_values=target_values, initial_values=initial_values, **kwargs
|
|
)
|
|
self.guards = self.guards | self._guards_singleton
|
|
|
|
def enter(self, tx):
|
|
return variables.ConstantVariable(None, **VariableTracker.propagate(self))
|
|
|
|
def _call_func(self, tx, values):
|
|
assert len(values) == 1
|
|
value = values[0]
|
|
tx.output.create_node(
|
|
"call_function", torch._C._set_grad_enabled, (value,), {}
|
|
),
|
|
torch._C._set_grad_enabled(value)
|
|
|
|
def _func_name(self):
|
|
return "_C._set_grad_enabled"
|
|
|
|
def fn_name(self):
|
|
if self.target_values[0]:
|
|
return "enable_grad"
|
|
else:
|
|
return "no_grad"
|
|
|
|
|
|
class AutocastModeVariable(ContextWrappingVariable):
|
|
@staticmethod
|
|
def create(target_values, kwargs):
|
|
values = target_values
|
|
# device_type : str,
|
|
# dtype : Optional[_dtype] = None,
|
|
# enabled : bool = True,
|
|
# cache_enabled : Optional[bool] = None):cache_enabled
|
|
assert "device_type" in kwargs
|
|
values.append(kwargs["device_type"])
|
|
del kwargs["device_type"]
|
|
|
|
if "dtype" in kwargs:
|
|
values.append(kwargs["dtype"])
|
|
del kwargs["dtype"]
|
|
else:
|
|
values.append(variables.ConstantVariable(None))
|
|
|
|
if "enabled" in kwargs:
|
|
values.append(kwargs["enabled"])
|
|
del kwargs["enabled"]
|
|
else:
|
|
values.append(variables.ConstantVariable(True))
|
|
|
|
if "cache_enabled" in kwargs:
|
|
values.append(kwargs["cache_enabled"])
|
|
del kwargs["cache_enabled"]
|
|
else:
|
|
values.append(variables.ConstantVariable(None))
|
|
|
|
var = AutocastModeVariable(target_values, initial_values=None, **kwargs)
|
|
return var
|
|
|
|
def __init__(self, target_values, initial_values=None, **kwargs):
|
|
super(AutocastModeVariable, self).__init__(
|
|
target_values=target_values, initial_values=initial_values, **kwargs
|
|
)
|
|
self.target_values = [val.as_python_constant() for val in target_values]
|
|
self.mode = None
|
|
|
|
def exit(self, tx, *args):
|
|
tx.output.create_node(
|
|
"call_function", exit_functional_autocast, (self.mode,), {}
|
|
)
|
|
|
|
def enter(self, tx):
|
|
self.mode = tx.output.create_node(
|
|
"call_function", enter_functional_autocast, (*self.target_values,), {}
|
|
)
|
|
|
|
def _func_name(self):
|
|
return "torch.amp.autocast_mode.autocast"
|
|
|
|
def fn_name(self):
|
|
return "torch.amp.autocast_mode.autocast"
|
|
|
|
|
|
def enter_functional_autocast(*vals):
|
|
mode = torch.amp.autocast(*vals)
|
|
mode.__enter__()
|
|
return mode
|
|
|
|
|
|
def exit_functional_autocast(mode):
|
|
mode.__exit__(None, None, None)
|
|
|
|
|
|
class NullContextVariable(ContextWrappingVariable):
|
|
"""
|
|
This class represents Python contextlib.nullcontext.
|
|
It's used as a placeholder for other context managers that Dynamo doesn't
|
|
support yet, e.g, torch.autograd.profiler.record_function.
|
|
"""
|
|
|
|
def __init__(self, target_values=None, **kwargs):
|
|
super(NullContextVariable, self).__init__(target_values=target_values, **kwargs)
|
|
|
|
def enter(self, tx):
|
|
return variables.ConstantVariable(None, **VariableTracker.propagate(self))
|
|
|
|
def exit(self, tx, *args):
|
|
return variables.ConstantVariable(None, **VariableTracker.propagate(self))
|
|
|
|
def module_name(self):
|
|
return "contextlib"
|
|
|
|
def fn_name(self):
|
|
return "nullcontext"
|
|
|
|
|
|
class WithExitFunctionVariable(VariableTracker):
|
|
def __init__(self, ctx: VariableTracker, target, **kwargs):
|
|
super(WithExitFunctionVariable, self).__init__(**kwargs)
|
|
self.ctx = ctx
|
|
self.target = target
|
|
|
|
def call_function(
|
|
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
|
) -> "VariableTracker":
|
|
assert not kwargs
|
|
return self.ctx.exit(tx, *args)
|
|
|
|
def reconstruct(self, codegen):
|
|
# Note here we reconstruct the context manager rather than the
|
|
# exit function. The handler generated by BlockStackEntry
|
|
# will re-enter the context in the resume function.
|
|
output = AttrSource(
|
|
codegen.tx.import_source(self.ctx.module_name()), self.ctx.fn_name()
|
|
).reconstruct(codegen)
|
|
|
|
if codegen.tx.output.partial_convert:
|
|
output.extend(
|
|
[
|
|
create_instruction("CALL_FUNCTION", 0),
|
|
create_instruction("SETUP_WITH", target=self.target),
|
|
create_instruction("POP_TOP"),
|
|
]
|
|
)
|
|
return output
|
|
|
|
|
|
class InspectSignatureVariable(VariableTracker):
|
|
"""represents inspect.signature(...)"""
|
|
|
|
@staticmethod
|
|
def create(callable, **kwargs):
|
|
if kwargs:
|
|
unimplemented(f"inspect.signature with {kwargs}")
|
|
return InspectSignatureVariable(callable)
|
|
|
|
def __init__(self, inspected, **kwargs):
|
|
super(InspectSignatureVariable, self).__init__(**kwargs)
|
|
self.inspected = inspected
|
|
|
|
|
|
class AutogradFunctionVariable(VariableTracker):
|
|
"""represents a torch.autograd.Function subclass"""
|
|
|
|
def __init__(self, fn_cls, **kwargs):
|
|
super().__init__(**kwargs)
|
|
self.fn_cls = fn_cls
|
|
|
|
def call_apply(self, tx, args, kwargs):
|
|
requires_grad = False
|
|
|
|
def visit(node):
|
|
nonlocal requires_grad
|
|
if isinstance(node, variables.TensorVariable):
|
|
if node.requires_grad is not False:
|
|
requires_grad = True
|
|
if isinstance(node, variables.NNModuleVariable):
|
|
if node.is_training(tx):
|
|
requires_grad = True
|
|
return node
|
|
|
|
VariableTracker.apply(visit, (args, kwargs))
|
|
|
|
if requires_grad and torch.is_grad_enabled():
|
|
# TODO(jansel): handle this in training mode
|
|
unimplemented("autograd.Function with requires_grad")
|
|
|
|
args = [BlackHoleVariable()] + list(args)
|
|
options = VariableTracker.propagate(self, args, kwargs.values())
|
|
fn = self.fn_cls.forward
|
|
if isinstance(fn, types.FunctionType):
|
|
return variables.UserFunctionVariable(fn, **options).call_function(
|
|
tx, args, kwargs
|
|
)
|
|
elif isinstance(fn, types.MethodType):
|
|
return variables.UserMethodVariable(
|
|
fn.__func__, variables.UserDefinedClassVariable(self.fn_cls), **options
|
|
).call_function(tx, args, kwargs)
|
|
else:
|
|
unimplemented(
|
|
f"non-function or method in subclass of torch.autograd.Function: {fn}"
|
|
)
|
|
|
|
def call_function(self, tx, args, kwargs):
|
|
options = VariableTracker.propagate(self, args, kwargs.values())
|
|
return AutogradFunctionVariable(self.fn_cls, **options)
|
|
|
|
|
|
class BlackHoleVariable(VariableTracker):
|
|
"""A autograd.function context that just ignores everything (for forward extraction)"""
|
|
|
|
def call_method(
|
|
self,
|
|
tx,
|
|
name,
|
|
args: "List[VariableTracker]",
|
|
kwargs: "Dict[str, VariableTracker]",
|
|
) -> "VariableTracker":
|
|
assert name in ("__setattr__", "save_for_backward"), name
|
|
return variables.ConstantVariable(
|
|
None, **VariableTracker.propagate(self, args, kwargs.values())
|
|
)
|
|
|
|
|
|
class LambdaVariable(VariableTracker):
|
|
def __init__(self, fn, **kwargs):
|
|
super(LambdaVariable, self).__init__(**kwargs)
|
|
self.fn = fn
|
|
|
|
def call_function(
|
|
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
|
) -> "VariableTracker":
|
|
return self.fn(*args, **kwargs).add_options(self)
|
|
|
|
|
|
class GetAttrVariable(VariableTracker):
|
|
def __init__(self, obj, name, **kwargs):
|
|
super(GetAttrVariable, self).__init__(**kwargs)
|
|
assert isinstance(obj, VariableTracker)
|
|
assert isinstance(name, str)
|
|
self.obj = obj
|
|
self.name = name
|
|
|
|
def __str__(self):
|
|
return f"{self.__class__.__name__}({self.obj}, {self.name})"
|
|
|
|
def as_proxy(self):
|
|
return getattr(self.obj.as_proxy(), self.name)
|
|
|
|
def const_getattr(self, tx, name):
|
|
if not isinstance(self.obj, variables.NNModuleVariable):
|
|
raise NotImplementedError()
|
|
step1 = tx.output.get_submodule(self.obj.module_key)
|
|
if self.name not in step1.__dict__:
|
|
raise NotImplementedError()
|
|
step2 = inspect.getattr_static(step1, self.name)
|
|
if name not in step2.__dict__:
|
|
raise NotImplementedError()
|
|
return inspect.getattr_static(step2, name)
|
|
|
|
def reconstruct(self, codegen):
|
|
codegen(self.obj)
|
|
return codegen.create_load_attrs(self.name)
|
|
|
|
def call_function(
|
|
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
|
) -> "VariableTracker":
|
|
from .builder import wrap_fx_proxy
|
|
|
|
# This variable is True when it corresponds to user code such as
|
|
#
|
|
# super().__torch_function__(...)
|
|
#
|
|
# and the super().__torch_function__ attribute resolves
|
|
# to torch.Tensor.__torch_function__.
|
|
is_original_tensor_torch_function = (
|
|
self.name == "__torch_function__"
|
|
and isinstance(self.obj, SuperVariable)
|
|
# for now, only support one level of inheritance
|
|
and len(self.obj.objvar.value.__mro__) > 1
|
|
and self.obj.objvar.value.__mro__[1] == torch.Tensor
|
|
)
|
|
if is_original_tensor_torch_function:
|
|
# Instead of tracing inside torch.Tensor.__torch_function__,
|
|
# record the `call_function` or `call_method` call into the graph.
|
|
from . import TorchVariable
|
|
|
|
original_torch_or_getattr_variable = args[0]
|
|
new_args = args[2].items
|
|
new_kwargs = args[3].items
|
|
options = VariableTracker.propagate(self, new_args, new_kwargs.values())
|
|
# Disable __torch_function__ here to prevent the clone of the
|
|
# example tensor from going into the override.
|
|
with torch._C.DisableTorchFunction():
|
|
if isinstance(args[0], TorchVariable):
|
|
return wrap_fx_proxy(
|
|
tx=tx,
|
|
proxy=tx.output.create_proxy(
|
|
"call_function",
|
|
original_torch_or_getattr_variable.value,
|
|
*proxy_args_kwargs(new_args, new_kwargs),
|
|
current_tx=tx,
|
|
),
|
|
**options,
|
|
)
|
|
elif isinstance(args[0], GetAttrVariable):
|
|
return wrap_fx_proxy(
|
|
tx=tx,
|
|
proxy=tx.output.create_proxy(
|
|
"call_method",
|
|
original_torch_or_getattr_variable.name,
|
|
*proxy_args_kwargs(new_args, new_kwargs),
|
|
current_tx=tx,
|
|
),
|
|
**options,
|
|
)
|
|
else:
|
|
unimplemented(
|
|
f"GetAttrVariable.call_function original __torch_function__ {args}"
|
|
)
|
|
|
|
if isinstance(self.obj, AutogradFunctionVariable) and self.name == "apply":
|
|
return self.obj.call_apply(tx, args, kwargs).add_options(self)
|
|
return self.obj.call_method(tx, self.name, args, kwargs).add_options(self)
|
|
|
|
def call_method(
|
|
self,
|
|
tx,
|
|
name,
|
|
args: "List[VariableTracker]",
|
|
kwargs: "Dict[str, VariableTracker]",
|
|
) -> "VariableTracker":
|
|
if (
|
|
name == "__len__"
|
|
and isinstance(self.obj, InspectSignatureVariable)
|
|
and self.name == "parameters"
|
|
):
|
|
return variables.ConstantVariable(
|
|
self.obj.inspected.num_parameters(),
|
|
**VariableTracker.propagate(self, self.obj, self.obj.inspected),
|
|
)
|
|
return super(GetAttrVariable, self).call_method(tx, name, args, kwargs)
|
|
|
|
|
|
class PythonModuleVariable(VariableTracker):
|
|
def __init__(self, value: types.ModuleType, **kwargs):
|
|
super(PythonModuleVariable, self).__init__(**kwargs)
|
|
self.value = value
|
|
|
|
def python_type(self):
|
|
return types.ModuleType
|
|
|
|
|
|
class SkipFilesVariable(VariableTracker):
|
|
def __init__(self, value, **kwargs):
|
|
super().__init__(**kwargs)
|
|
self.value = value
|
|
|
|
def python_type(self):
|
|
return type(self.value)
|
|
|
|
def as_python_constant(self):
|
|
return self.value
|
|
|
|
def call_function(
|
|
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
|
) -> "VariableTracker":
|
|
if inspect.getattr_static(self.value, "_torchdynamo_disable", False):
|
|
unimplemented(
|
|
f"call {config.dynamo_import}.disable() wrapped function {self.value}"
|
|
)
|
|
else:
|
|
try:
|
|
path = inspect.getfile(self.value)
|
|
except TypeError:
|
|
path = f"Builtin {self.value.__name__}"
|
|
unimplemented("call_function in skip_files " + path)
|
|
|
|
|
|
class TypingVariable(VariableTracker):
|
|
def __init__(self, value, **kwargs):
|
|
super().__init__(**kwargs)
|
|
self.value = value
|
|
|
|
def call_method(
|
|
self,
|
|
tx,
|
|
name,
|
|
args: "List[VariableTracker]",
|
|
kwargs: "Dict[str, VariableTracker]",
|
|
) -> "VariableTracker":
|
|
if name == "__getitem__" and len(args) == 1:
|
|
return variables.ConstantVariable(
|
|
self.value[args[0].as_python_constant()],
|
|
**VariableTracker.propagate(self, args),
|
|
)
|
|
unimplemented("typing")
|
|
|
|
def python_type(self):
|
|
return type(self.value)
|
|
|
|
def as_python_constant(self):
|
|
return self.value
|
|
|
|
|
|
class NumpyVariable(VariableTracker):
|
|
"""
|
|
Wrapper around `numpy.*` for better error messages.
|
|
"""
|
|
|
|
def __init__(self, value, **kwargs):
|
|
super().__init__(**kwargs)
|
|
self.value = value
|
|
|
|
def call_function(
|
|
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
|
) -> "VariableTracker":
|
|
unimplemented("numpy")
|
|
|
|
def call_method(
|
|
self,
|
|
tx,
|
|
name,
|
|
args: "List[VariableTracker]",
|
|
kwargs: "Dict[str, VariableTracker]",
|
|
) -> "VariableTracker":
|
|
unimplemented("numpy")
|
|
|
|
def python_type(self):
|
|
return type(self.value)
|
|
|
|
def as_python_constant(self):
|
|
return self.value
|