mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Enable some sensible flake8-simplify rules. Mainly wanted to enable the SIM101, and `yield from` SIM103 checks. @kit1980 since you wanted to be tagged on this CI check. Enabling this check also helped flag one logical bug so it's definitely beneficial (also fixed in this PR). Pull Request resolved: https://github.com/pytorch/pytorch/pull/97984 Approved by: https://github.com/ezyang
813 lines
27 KiB
Python
813 lines
27 KiB
Python
import collections
|
||
import functools
|
||
import inspect
|
||
import sys
|
||
import types
|
||
from typing import Dict, List
|
||
|
||
import torch._C
|
||
from torch._guards import Guard, GuardSource
|
||
|
||
from .. import variables
|
||
from ..bytecode_transformation import create_call_function, create_instruction
|
||
from ..exc import unimplemented
|
||
from ..guards import GuardBuilder
|
||
from ..source import AttrSource
|
||
from ..utils import check_constant_args, identity, proxy_args_kwargs
|
||
from .base import MutableLocal, VariableTracker
|
||
from .functions import (
|
||
NestedUserFunctionVariable,
|
||
UserFunctionVariable,
|
||
UserMethodVariable,
|
||
WrappedUserFunctionVariable,
|
||
WrappedUserMethodVariable,
|
||
)
|
||
|
||
|
||
class SuperVariable(VariableTracker):
|
||
def __init__(self, typevar, objvar=None, specialized=False, **kwargs):
|
||
super().__init__(**kwargs)
|
||
self.typevar = typevar
|
||
self.objvar = objvar
|
||
self.specialized = specialized # directly get attr from self.typevar if true
|
||
|
||
def reconstruct(self, codegen):
|
||
codegen(variables.BuiltinVariable(super))
|
||
codegen(self.typevar)
|
||
if self.objvar is not None:
|
||
codegen(self.objvar)
|
||
return create_call_function(2, True)
|
||
else:
|
||
return create_call_function(1, True)
|
||
|
||
def const_getattr(self, tx, name):
|
||
assert self.objvar, "1-arg super not implemented"
|
||
if self.specialized:
|
||
return getattr(self.typevar.as_python_constant(), name)
|
||
search_type = self.typevar.as_python_constant()
|
||
|
||
# We default to the python type of the object. However, if this is
|
||
# a `type` or subclass of `type`, then the original object represents
|
||
# the user defined type.
|
||
type_to_use = self.objvar.python_type()
|
||
if issubclass(type_to_use, type):
|
||
type_to_use = self.objvar.value
|
||
|
||
# TODO(jansel): there is a small chance this could trigger user code, prevent that
|
||
return getattr(super(search_type, type_to_use), name)
|
||
|
||
def call_method(
|
||
self,
|
||
tx,
|
||
name,
|
||
args: "List[VariableTracker]",
|
||
kwargs: "Dict[str, VariableTracker]",
|
||
) -> "VariableTracker":
|
||
options = VariableTracker.propagate(
|
||
self, args, kwargs.values(), self.objvar, self.typevar
|
||
)
|
||
inner_fn = self.const_getattr(self, name)
|
||
source = None if self.source is None else AttrSource(self.source, name)
|
||
if inner_fn is object.__init__:
|
||
return LambdaVariable(identity, **options)
|
||
elif isinstance(inner_fn, types.FunctionType):
|
||
return variables.UserFunctionVariable(
|
||
inner_fn, source=source, **options
|
||
).call_function(tx, [self.objvar] + args, kwargs)
|
||
elif isinstance(inner_fn, types.MethodType):
|
||
return variables.UserMethodVariable(
|
||
inner_fn.__func__, self.objvar, source=source, **options
|
||
).call_function(tx, args, kwargs)
|
||
else:
|
||
unimplemented(f"non-function or method super: {inner_fn}")
|
||
|
||
|
||
class UnknownVariable(VariableTracker):
|
||
"""
|
||
It could be anything!
|
||
"""
|
||
|
||
|
||
class ComptimeVariable(VariableTracker):
|
||
"""
|
||
This variable is special, it lets you execute arbitrary code at
|
||
Dynamo compile time
|
||
"""
|
||
|
||
def reconstruct(self, codegen):
|
||
raise NotImplementedError("comptime is special form")
|
||
|
||
def var_getattr(self, tx, name: str) -> "VariableTracker":
|
||
from ..comptime import comptime
|
||
|
||
# To support the comptime.print_graph convenience accessors
|
||
from .functions import UserFunctionVariable
|
||
|
||
return UserFunctionVariable(
|
||
getattr(comptime, name), source=AttrSource(self.source, name)
|
||
)
|
||
|
||
def call_function(
|
||
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
||
) -> "VariableTracker":
|
||
from ..comptime import ComptimeContext
|
||
|
||
# TODO: support an expression form as well
|
||
|
||
assert not kwargs
|
||
assert len(args) == 1
|
||
fn = args[0]
|
||
if isinstance(fn, UserFunctionVariable):
|
||
fn.get_function()(ComptimeContext(tx))
|
||
elif isinstance(fn, NestedUserFunctionVariable):
|
||
# We have to manually bind the freevars ourselves
|
||
code = fn.get_code()
|
||
assert not fn.closure, (
|
||
"comptime function must not have free variables, "
|
||
f"but these variables were free: {code.co_freevars}"
|
||
)
|
||
func = types.FunctionType(
|
||
code,
|
||
fn.f_globals,
|
||
fn.fn_name.as_python_constant(),
|
||
tuple(fn.defaults.items) if fn.defaults else None,
|
||
# We could automatically promote free variables into
|
||
# ComptimeVar but this is confusing if you access
|
||
# a free variable that we actually DO have the runtime
|
||
# value for
|
||
# tuple(make_cell(ComptimeVar(i)) for i in fn.closure.items)
|
||
tuple(),
|
||
)
|
||
func(ComptimeContext(tx))
|
||
else:
|
||
raise RuntimeError(f"unsupported argument to comptime: {type(fn)}")
|
||
|
||
return variables.ConstantVariable(None)
|
||
|
||
|
||
class ClosureVariable(UnknownVariable):
|
||
def __init__(self, name, **kwargs):
|
||
super().__init__(**kwargs)
|
||
self.name = name
|
||
|
||
def reconstruct(self, codegen):
|
||
return [codegen.create_load_closure(self.name)]
|
||
|
||
|
||
class NewCellVariable(VariableTracker):
|
||
def __init__(self, **kwargs):
|
||
super().__init__(**kwargs)
|
||
|
||
|
||
class NewGlobalVariable(VariableTracker):
|
||
def __init__(self, **kwargs):
|
||
super().__init__(**kwargs)
|
||
|
||
|
||
class ContextWrappingVariable(VariableTracker):
|
||
def __init__(self, target_values, initial_values=None, **kwargs):
|
||
super().__init__(**kwargs)
|
||
self.target_values = target_values
|
||
self.initial_values = initial_values
|
||
self.recursively_contains = (
|
||
set()
|
||
) # This var doesn't contain any child vars and doesn't support clone() properly,
|
||
# so don't populate this automatically
|
||
|
||
def enter(self, tx):
|
||
self._call_func(tx, self.target_values)
|
||
return variables.ConstantVariable(None, **VariableTracker.propagate(self))
|
||
|
||
def exit(self, tx, *args):
|
||
self._call_func(tx, self.initial_values)
|
||
return variables.ConstantVariable(None, **VariableTracker.propagate(self))
|
||
|
||
def reconstruct(self, codegen):
|
||
attr_source = AttrSource(
|
||
codegen.tx.import_source(self.module_name()), self.fn_name()
|
||
)
|
||
return attr_source.reconstruct(codegen)
|
||
|
||
def module_name(self):
|
||
raise NotImplementedError("module_name called on base")
|
||
|
||
def fn_name(self):
|
||
raise NotImplementedError("fn_name called on base")
|
||
|
||
def call_function(
|
||
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
||
) -> "VariableTracker":
|
||
assert len(args) == 1
|
||
if isinstance(args[0], NestedUserFunctionVariable):
|
||
args[0] = UserFunctionVariable(args[0].get_function())
|
||
assert isinstance(args[0], (UserMethodVariable, UserFunctionVariable))
|
||
|
||
if isinstance(args[0], UserMethodVariable):
|
||
return WrappedUserMethodVariable(args[0], self)
|
||
|
||
if isinstance(args[0], UserFunctionVariable):
|
||
return WrappedUserFunctionVariable(args[0], self)
|
||
|
||
|
||
class GradModeVariable(ContextWrappingVariable):
|
||
"""represents torch.{no_grad,enable_grad,set_grad_mode}()"""
|
||
|
||
_guards_singleton = {Guard("", GuardSource.GLOBAL, GuardBuilder.GRAD_MODE)}
|
||
|
||
@staticmethod
|
||
def create(tx, target_value, **kwargs):
|
||
var = GradModeVariable(
|
||
target_values=[target_value],
|
||
initial_values=[torch.is_grad_enabled()],
|
||
**kwargs,
|
||
)
|
||
var._call_func(tx, [target_value])
|
||
return var
|
||
|
||
def __init__(self, target_values, initial_values=None, **kwargs):
|
||
super().__init__(
|
||
target_values=target_values, initial_values=initial_values, **kwargs
|
||
)
|
||
self.guards = self.guards | self._guards_singleton
|
||
|
||
def enter(self, tx):
|
||
return variables.ConstantVariable(None, **VariableTracker.propagate(self))
|
||
|
||
def _call_func(self, tx, values):
|
||
assert len(values) == 1
|
||
value = values[0]
|
||
tx.output.create_node(
|
||
"call_function", torch._C._set_grad_enabled, (value,), {}
|
||
),
|
||
torch._C._set_grad_enabled(value)
|
||
|
||
def module_name(self):
|
||
return "torch"
|
||
|
||
def fn_name(self):
|
||
return "set_grad_enabled"
|
||
|
||
|
||
class AutocastModeVariable(ContextWrappingVariable):
|
||
@staticmethod
|
||
def create(target_values, kwargs):
|
||
# device_type : str,
|
||
# dtype : Optional[_dtype] = None,
|
||
# enabled : bool = True,
|
||
# cache_enabled : Optional[bool] = None):cache_enabled
|
||
bound_args = inspect.signature(torch.autocast).bind(*target_values, **kwargs)
|
||
bound_args.apply_defaults()
|
||
target_values = []
|
||
kwargs.clear()
|
||
|
||
for key in ["device_type", "dtype", "enabled", "cache_enabled"]:
|
||
arg = bound_args.arguments[key]
|
||
if isinstance(arg, VariableTracker):
|
||
target_values.append(bound_args.arguments[key].as_python_constant())
|
||
else:
|
||
target_values.append(bound_args.arguments[key])
|
||
|
||
var = AutocastModeVariable(target_values, initial_values=None, **kwargs)
|
||
return var
|
||
|
||
def __init__(self, target_values, initial_values=None, **kwargs):
|
||
mode = kwargs.pop("mode", None)
|
||
super().__init__(
|
||
target_values=target_values, initial_values=initial_values, **kwargs
|
||
)
|
||
self.target_values = target_values
|
||
self.mode = mode
|
||
|
||
def exit(self, tx, *args):
|
||
self.mode = (
|
||
exit_functional_autocast(self.mode[0]),
|
||
tx.output.create_node(
|
||
"call_function", exit_functional_autocast, (self.mode[1],), {}
|
||
),
|
||
)
|
||
|
||
def enter(self, tx):
|
||
self.mode = (
|
||
enter_functional_autocast(*self.target_values),
|
||
tx.output.create_node(
|
||
"call_function", enter_functional_autocast, (*self.target_values,), {}
|
||
),
|
||
)
|
||
|
||
def module_name(self):
|
||
return "torch.amp.autocast_mode"
|
||
|
||
def fn_name(self):
|
||
return "autocast"
|
||
|
||
|
||
def enter_functional_autocast(*vals):
|
||
mode = torch.amp.autocast(*vals)
|
||
mode.__enter__()
|
||
return mode
|
||
|
||
|
||
def exit_functional_autocast(mode):
|
||
mode.__exit__(None, None, None)
|
||
|
||
|
||
class NullContextVariable(ContextWrappingVariable):
|
||
"""
|
||
This class represents Python contextlib.nullcontext.
|
||
It's used as a placeholder for other context managers that Dynamo doesn't
|
||
support yet, e.g, torch.autograd.profiler.record_function.
|
||
"""
|
||
|
||
def __init__(self, target_values=None, **kwargs):
|
||
super().__init__(target_values=target_values, **kwargs)
|
||
|
||
def enter(self, tx):
|
||
return variables.ConstantVariable(None, **VariableTracker.propagate(self))
|
||
|
||
def exit(self, tx, *args):
|
||
return variables.ConstantVariable(None, **VariableTracker.propagate(self))
|
||
|
||
def module_name(self):
|
||
return "contextlib"
|
||
|
||
def fn_name(self):
|
||
return "nullcontext"
|
||
|
||
|
||
class CUDAStreamContextVariable(ContextWrappingVariable):
|
||
@staticmethod
|
||
def create(tx, target_value, **kwargs):
|
||
from .builder import wrap_fx_proxy_cls
|
||
|
||
current_stream = wrap_fx_proxy_cls(
|
||
CUDAStreamVariable,
|
||
tx,
|
||
tx.output.create_proxy(
|
||
"call_function",
|
||
torch.cuda.current_stream,
|
||
(None,),
|
||
{},
|
||
),
|
||
)
|
||
return CUDAStreamContextVariable(
|
||
target_values=[target_value],
|
||
initial_values=[current_stream],
|
||
**kwargs,
|
||
)
|
||
|
||
def __init__(self, target_values, initial_values=None, **kwargs):
|
||
super().__init__(
|
||
target_values=target_values, initial_values=initial_values, **kwargs
|
||
)
|
||
|
||
def enter(self, tx):
|
||
# CUDA stream generated inside of traced function
|
||
if self.target_values[0].as_proxy() is not None:
|
||
tx.output.create_proxy(
|
||
"call_function",
|
||
torch.cuda.set_stream,
|
||
(self.target_values[0].as_proxy(),),
|
||
{},
|
||
)
|
||
# CUDA stream passed from outside of traced function
|
||
else:
|
||
stream = self.target_values[0].value
|
||
tx.output.create_proxy(
|
||
"call_function",
|
||
torch._C._cuda_setStream,
|
||
(stream.stream_id, stream.device_index, stream.device_type),
|
||
{},
|
||
)
|
||
torch.cuda.set_stream(self.target_values[0].value)
|
||
|
||
def exit(self, tx, *args):
|
||
tx.output.create_proxy(
|
||
"call_function",
|
||
torch.cuda.set_stream,
|
||
(self.initial_values[0].as_proxy(),),
|
||
{},
|
||
)
|
||
torch.cuda.set_stream(self.initial_values[0].value)
|
||
|
||
def module_name(self):
|
||
return "torch.cuda"
|
||
|
||
def fn_name(self):
|
||
return "stream"
|
||
|
||
|
||
class CUDAStreamVariable(VariableTracker):
|
||
def __init__(self, proxy, value, **kwargs):
|
||
if proxy is not None and "example_value" in proxy.node.meta:
|
||
assert proxy.node.meta["example_value"] == value
|
||
super().__init__(**kwargs)
|
||
self.proxy = proxy
|
||
self.value = value
|
||
|
||
def call_method(
|
||
self,
|
||
tx,
|
||
name,
|
||
args: "List[VariableTracker]",
|
||
kwargs: "Dict[str, VariableTracker]",
|
||
) -> "VariableTracker":
|
||
unimplemented("cuda stream")
|
||
|
||
def as_proxy(self):
|
||
return self.proxy
|
||
|
||
|
||
class WithExitFunctionVariable(VariableTracker):
|
||
def __init__(self, ctx: ContextWrappingVariable, target, **kwargs):
|
||
super().__init__(**kwargs)
|
||
assert isinstance(ctx, ContextWrappingVariable)
|
||
self.ctx = ctx
|
||
self.target = target
|
||
|
||
def call_function(
|
||
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
||
) -> "VariableTracker":
|
||
assert not kwargs
|
||
return self.ctx.exit(tx, *args)
|
||
|
||
def reconstruct(self, codegen):
|
||
# Note here we reconstruct the context manager rather than the
|
||
# exit function. The handler generated by BlockStackEntry
|
||
# will re-enter the context in the resume function.
|
||
output = AttrSource(
|
||
codegen.tx.import_source(self.ctx.module_name()), self.ctx.fn_name()
|
||
).reconstruct(codegen)
|
||
|
||
if codegen.tx.output.partial_convert:
|
||
loads = [codegen.create_load_const(val) for val in self.ctx.target_values]
|
||
output.extend(loads)
|
||
output.extend(
|
||
[
|
||
*create_call_function(len(loads), True),
|
||
create_instruction("SETUP_WITH", target=self.target),
|
||
create_instruction("POP_TOP"),
|
||
]
|
||
)
|
||
return output
|
||
|
||
|
||
class InspectSignatureVariable(VariableTracker):
|
||
"""represents inspect.signature(...)"""
|
||
|
||
@staticmethod
|
||
def create(callable, **kwargs):
|
||
if kwargs:
|
||
unimplemented(f"inspect.signature with {kwargs}")
|
||
return InspectSignatureVariable(callable)
|
||
|
||
def __init__(self, inspected, **kwargs):
|
||
super().__init__(**kwargs)
|
||
self.inspected = inspected
|
||
|
||
|
||
class AutogradFunctionVariable(VariableTracker):
|
||
"""represents a torch.autograd.Function subclass"""
|
||
|
||
def __init__(self, fn_cls, **kwargs):
|
||
super().__init__(**kwargs)
|
||
self.fn_cls = fn_cls
|
||
|
||
def call_apply(self, tx, args, kwargs):
|
||
requires_grad = False
|
||
|
||
def visit(node):
|
||
nonlocal requires_grad
|
||
if isinstance(node, variables.TensorVariable):
|
||
if node.requires_grad is not False:
|
||
requires_grad = True
|
||
if isinstance(node, variables.NNModuleVariable):
|
||
if node.is_training(tx):
|
||
requires_grad = True
|
||
return node
|
||
|
||
VariableTracker.apply(visit, (args, kwargs))
|
||
|
||
if requires_grad and torch.is_grad_enabled():
|
||
# TODO(jansel): handle this in training mode
|
||
unimplemented("autograd.Function with requires_grad")
|
||
|
||
args = [BlackHoleVariable()] + list(args)
|
||
options = VariableTracker.propagate(self, args, kwargs.values())
|
||
options["source"] = AttrSource(AttrSource(self.source, "__class__"), "forward")
|
||
fn = self.fn_cls.forward
|
||
if isinstance(fn, types.FunctionType):
|
||
return variables.UserFunctionVariable(fn, **options).call_function(
|
||
tx, args, kwargs
|
||
)
|
||
elif isinstance(fn, types.MethodType):
|
||
return variables.UserMethodVariable(
|
||
fn.__func__, variables.UserDefinedClassVariable(self.fn_cls), **options
|
||
).call_function(tx, args, kwargs)
|
||
else:
|
||
unimplemented(
|
||
f"non-function or method in subclass of torch.autograd.Function: {fn}"
|
||
)
|
||
|
||
def call_function(self, tx, args, kwargs):
|
||
options = VariableTracker.propagate(self, args, kwargs.values())
|
||
return AutogradFunctionVariable(self.fn_cls, source=self.source, **options)
|
||
|
||
|
||
class BlackHoleVariable(VariableTracker):
|
||
"""A autograd.function context that just ignores everything (for forward extraction)"""
|
||
|
||
def call_method(
|
||
self,
|
||
tx,
|
||
name,
|
||
args: "List[VariableTracker]",
|
||
kwargs: "Dict[str, VariableTracker]",
|
||
) -> "VariableTracker":
|
||
assert name in ("__setattr__", "save_for_backward"), name
|
||
return variables.ConstantVariable(
|
||
None, **VariableTracker.propagate(self, args, kwargs.values())
|
||
)
|
||
|
||
|
||
class AutogradFunctionContextVariable(VariableTracker):
|
||
"""
|
||
A autograd.function context used after graph break in forward.
|
||
Any call method on this context object will be graph break.
|
||
The is different from BlackHoleVariable which is only used in inference mode.
|
||
"""
|
||
|
||
pass
|
||
|
||
|
||
class LambdaVariable(VariableTracker):
|
||
def __init__(self, fn, **kwargs):
|
||
super().__init__(**kwargs)
|
||
self.fn = fn
|
||
|
||
def call_function(
|
||
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
||
) -> "VariableTracker":
|
||
return self.fn(*args, **kwargs).add_options(self)
|
||
|
||
|
||
class GetAttrVariable(VariableTracker):
|
||
def __init__(self, obj, name, **kwargs):
|
||
super().__init__(**kwargs)
|
||
assert isinstance(obj, VariableTracker)
|
||
assert isinstance(name, str)
|
||
self.obj = obj
|
||
self.name = name
|
||
|
||
def __str__(self):
|
||
return f"{self.__class__.__name__}({self.obj}, {self.name})"
|
||
|
||
@staticmethod
|
||
def create_getattr_proxy(base_proxy: torch.fx.Proxy, attr):
|
||
return getattr(base_proxy, attr)
|
||
|
||
def as_proxy(self):
|
||
return GetAttrVariable.create_getattr_proxy(self.obj.as_proxy(), self.name)
|
||
|
||
def const_getattr(self, tx, name):
|
||
if not isinstance(self.obj, variables.NNModuleVariable):
|
||
raise NotImplementedError()
|
||
step1 = tx.output.get_submodule(self.obj.module_key)
|
||
if self.name not in step1.__dict__:
|
||
raise NotImplementedError()
|
||
step2 = inspect.getattr_static(step1, self.name)
|
||
if name not in step2.__dict__:
|
||
raise NotImplementedError()
|
||
return inspect.getattr_static(step2, name)
|
||
|
||
def reconstruct(self, codegen):
|
||
codegen(self.obj)
|
||
return codegen.create_load_attrs(self.name)
|
||
|
||
def call_function(
|
||
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
||
) -> "VariableTracker":
|
||
from .builder import wrap_fx_proxy
|
||
|
||
# This variable is True when it corresponds to user code such as
|
||
#
|
||
# super().__torch_function__(...)
|
||
#
|
||
# and the super().__torch_function__ attribute resolves
|
||
# to torch.Tensor.__torch_function__.
|
||
is_original_tensor_torch_function = (
|
||
self.name == "__torch_function__"
|
||
and isinstance(self.obj, SuperVariable)
|
||
# for now, only support one level of inheritance
|
||
and len(self.obj.objvar.value.__mro__) > 1
|
||
and self.obj.objvar.value.__mro__[1] == torch.Tensor
|
||
)
|
||
if is_original_tensor_torch_function:
|
||
# Instead of tracing inside torch.Tensor.__torch_function__,
|
||
# record the `call_function` or `call_method` call into the graph.
|
||
from . import TorchVariable
|
||
|
||
original_torch_or_getattr_variable = args[0]
|
||
new_args = args[2].items
|
||
new_kwargs = args[3].items
|
||
options = VariableTracker.propagate(self, new_args, new_kwargs.values())
|
||
# Disable __torch_function__ here to prevent the clone of the
|
||
# example tensor from going into the override.
|
||
with torch._C.DisableTorchFunctionSubclass():
|
||
if isinstance(args[0], TorchVariable):
|
||
return wrap_fx_proxy(
|
||
tx=tx,
|
||
proxy=tx.output.create_proxy(
|
||
"call_function",
|
||
original_torch_or_getattr_variable.value,
|
||
*proxy_args_kwargs(new_args, new_kwargs),
|
||
),
|
||
**options,
|
||
)
|
||
elif isinstance(args[0], GetAttrVariable):
|
||
return wrap_fx_proxy(
|
||
tx=tx,
|
||
proxy=tx.output.create_proxy(
|
||
"call_method",
|
||
original_torch_or_getattr_variable.name,
|
||
*proxy_args_kwargs(new_args, new_kwargs),
|
||
),
|
||
**options,
|
||
)
|
||
else:
|
||
unimplemented(
|
||
f"GetAttrVariable.call_function original __torch_function__ {args}"
|
||
)
|
||
|
||
if isinstance(self.obj, AutogradFunctionVariable) and self.name == "apply":
|
||
return self.obj.call_apply(tx, args, kwargs).add_options(self)
|
||
# calling parent class‘s non classmethod from child class
|
||
# https://github.com/pytorch/pytorch/issues/90558
|
||
elif (
|
||
isinstance(self.obj, variables.UserDefinedClassVariable)
|
||
and len(args) > 0
|
||
and issubclass(args[0].python_type(), self.obj.value)
|
||
):
|
||
return SuperVariable(self.obj, args[0], True).call_method(
|
||
tx, self.name, args[1:], kwargs
|
||
)
|
||
return self.obj.call_method(tx, self.name, args, kwargs).add_options(self)
|
||
|
||
def call_method(
|
||
self,
|
||
tx,
|
||
name,
|
||
args: "List[VariableTracker]",
|
||
kwargs: "Dict[str, VariableTracker]",
|
||
) -> "VariableTracker":
|
||
if (
|
||
name == "__len__"
|
||
and isinstance(self.obj, InspectSignatureVariable)
|
||
and self.name == "parameters"
|
||
):
|
||
return variables.ConstantVariable(
|
||
self.obj.inspected.num_parameters(),
|
||
**VariableTracker.propagate(self, self.obj, self.obj.inspected),
|
||
)
|
||
return super().call_method(tx, name, args, kwargs)
|
||
|
||
|
||
class PythonModuleVariable(VariableTracker):
|
||
def __init__(self, value: types.ModuleType, **kwargs):
|
||
super().__init__(**kwargs)
|
||
self.value = value
|
||
|
||
def python_type(self):
|
||
return types.ModuleType
|
||
|
||
|
||
class SkipFilesVariable(VariableTracker):
|
||
def __init__(self, value, **kwargs):
|
||
super().__init__(**kwargs)
|
||
self.value = value
|
||
|
||
def python_type(self):
|
||
return type(self.value)
|
||
|
||
def as_python_constant(self):
|
||
return self.value
|
||
|
||
@staticmethod
|
||
@functools.lru_cache(None)
|
||
def fold_through_function_to_wrapper():
|
||
return {
|
||
collections.namedtuple: variables.UserDefinedClassVariable,
|
||
}
|
||
|
||
def call_function(
|
||
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
||
) -> "VariableTracker":
|
||
from .builtin import BuiltinVariable
|
||
|
||
options = VariableTracker.propagate(self, args, kwargs.values())
|
||
|
||
if inspect.getattr_static(self.value, "_torchdynamo_disable", False):
|
||
unimplemented(f"call torch._dynamo.disable() wrapped function {self.value}")
|
||
# Allowlist a few popular classes(e.g, collections.OrderedDict) calls in skip files.
|
||
elif self.value is collections.OrderedDict and (
|
||
len(args) == 0
|
||
or len(args) == 1
|
||
and BuiltinVariable.is_supported_call_dict_arg(tx, args[0])
|
||
):
|
||
return BuiltinVariable.call_dict_helper(
|
||
tx,
|
||
collections.OrderedDict,
|
||
None if len(args) == 0 else args[0],
|
||
**options,
|
||
)
|
||
# Fold through the functions(e.g, collections.namedtuple)
|
||
# that inputs & outputs are all python constants
|
||
elif (
|
||
self.value in self.fold_through_function_to_wrapper().keys()
|
||
and check_constant_args(args, kwargs)
|
||
):
|
||
value = self.value(
|
||
*[x.as_python_constant() for x in args],
|
||
**{k: v.as_python_constant() for k, v in kwargs.items()},
|
||
)
|
||
return self.fold_through_function_to_wrapper().get(self.value)(
|
||
value, mutable_local=MutableLocal(), **options
|
||
)
|
||
else:
|
||
try:
|
||
path = inspect.getfile(self.value)
|
||
except TypeError:
|
||
path = f"Builtin {self.value.__name__}"
|
||
unimplemented(
|
||
f"call_function {self.value.__qualname__} in skip_files {path}"
|
||
)
|
||
|
||
|
||
class TypingVariable(VariableTracker):
|
||
def __init__(self, value, **kwargs):
|
||
super().__init__(**kwargs)
|
||
self.value = value
|
||
|
||
def call_method(
|
||
self,
|
||
tx,
|
||
name,
|
||
args: "List[VariableTracker]",
|
||
kwargs: "Dict[str, VariableTracker]",
|
||
) -> "VariableTracker":
|
||
if name == "__getitem__" and len(args) == 1:
|
||
return variables.ConstantVariable(
|
||
self.value[args[0].as_python_constant()],
|
||
**VariableTracker.propagate(self, args),
|
||
)
|
||
unimplemented("typing")
|
||
|
||
def python_type(self):
|
||
return type(self.value)
|
||
|
||
def as_python_constant(self):
|
||
return self.value
|
||
|
||
|
||
class NumpyVariable(VariableTracker):
|
||
"""
|
||
Wrapper around `numpy.*` for better error messages.
|
||
"""
|
||
|
||
def __init__(self, value, **kwargs):
|
||
super().__init__(**kwargs)
|
||
self.value = value
|
||
|
||
def call_function(
|
||
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
||
) -> "VariableTracker":
|
||
unimplemented("numpy")
|
||
|
||
def call_method(
|
||
self,
|
||
tx,
|
||
name,
|
||
args: "List[VariableTracker]",
|
||
kwargs: "Dict[str, VariableTracker]",
|
||
) -> "VariableTracker":
|
||
unimplemented("numpy")
|
||
|
||
def python_type(self):
|
||
return type(self.value)
|
||
|
||
def as_python_constant(self):
|
||
return self.value
|
||
|
||
|
||
# Used to keep track of NULLs pushed on the stack for Python 3.11 function calls
|
||
class NullVariable(VariableTracker):
|
||
def __init__(self, **kwargs):
|
||
super(NullVariable, self).__init__(**kwargs)
|
||
|
||
def __str__(self):
|
||
return "NullVariable"
|
||
|
||
def reconstruct(self, codegen):
|
||
if sys.version_info < (3, 11):
|
||
unimplemented("cannot reconstruct NullVariable in < Python 3.11")
|
||
return [create_instruction("PUSH_NULL")]
|