pytorch/torch/_dynamo/variables/misc.py
Xuehai Pan 2e8ac5ea93 [dynamo] support dict.fromkeys() / OrderedDict.fromkeys() / defaultdict.fromkeys() (#115010)
Add support for `dict.fromkeys`, `OrderedDict.fromkeys`, and `defaultdict.fromkeys`.

Fixes #114963

- #114963

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115010
Approved by: https://github.com/jansel
2023-12-04 01:49:59 +00:00

1133 lines
41 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import collections
import dataclasses
import functools
import inspect
import itertools
import operator
import sys
import types
from typing import Dict, List
import torch._C
import torch._numpy as tnp
from .. import config, polyfill, variables
from ..bytecode_transformation import create_call_function, create_instruction
from ..exc import unimplemented
from ..guards import GuardBuilder, install_guard
from ..source import AttrSource, GetItemSource, ODictGetItemSource, TypeSource
from ..utils import (
check_constant_args,
identity,
is_tensor_base_attr_getter,
proxy_args_kwargs,
)
from .base import MutableLocal, VariableTracker
from .dicts import DefaultDictVariable
from .functions import (
NestedUserFunctionVariable,
UserFunctionVariable,
UserMethodVariable,
)
from .user_defined import UserDefinedObjectVariable
class SuperVariable(VariableTracker):
def __init__(self, typevar, objvar=None, specialized=False, **kwargs):
super().__init__(**kwargs)
self.typevar = typevar
self.objvar = objvar
self.specialized = specialized # directly get attr from self.typevar if true
def reconstruct(self, codegen):
codegen(variables.BuiltinVariable(super))
codegen(self.typevar)
if self.objvar is not None:
codegen(self.objvar)
return create_call_function(2, True)
else:
return create_call_function(1, True)
def _resolved_getattr_and_source(self, tx, name):
assert self.objvar, "1-arg super not implemented"
if self.specialized:
return getattr(self.typevar.as_python_constant(), name)
search_type = self.typevar.as_python_constant()
# We default to the python type of the object. However, if this is
# a `type` or subclass of `type`, then the original object represents
# the user defined type.
type_to_use = self.objvar.python_type()
type_to_use_source = (
TypeSource(self.objvar.source) if self.objvar.source else None
)
if issubclass(type_to_use, type):
type_to_use = self.objvar.value
type_to_use_source = self.objvar.source
source = None
if self.objvar.source is not None:
# Walk the mro tuple to find out the actual class where the
# attribute resides.
search_mro = type_to_use.__mro__
start_index = search_mro.index(search_type) + 1
for index in range(start_index, len(search_mro)):
if hasattr(search_mro[index], name):
# Equivalent of something like type(L['self']).__mro__[1].attr_name
source = AttrSource(
GetItemSource(AttrSource(type_to_use_source, "__mro__"), index),
name,
)
break
# TODO(jansel): there is a small chance this could trigger user code, prevent that
return getattr(super(search_type, type_to_use), name), source
def var_getattr(self, tx, name: str) -> "VariableTracker":
# Check if getattr is a constant. If not, delay the actual work by
# wrapping the result in GetAttrVariable. Mostly super is called with a
# method, so most of the work is delayed to call_function.
#
# We could have just implemented a const_getattr. However, super is
# special when it comes to finding sources. Compared to other VTs, super
# requires the attr name to walk the mro and find the actual source (and
# not just AttrSource).
value, source = self._resolved_getattr_and_source(self, name)
if not variables.ConstantVariable.is_literal(value):
return GetAttrVariable(self, name)
if source:
install_guard(source.make_guard(GuardBuilder.CONSTANT_MATCH))
return variables.ConstantVariable.create(value, source=source)
return variables.ConstantVariable.create(value)
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
inner_fn, source = self._resolved_getattr_and_source(self, name)
if inner_fn is object.__init__:
return LambdaVariable(identity)
elif inner_fn is torch.nn.Module.__init__:
objvar = self.objvar
from ..side_effects import AttributeMutationNew
if (
isinstance(objvar, variables.UserDefinedObjectVariable)
and isinstance(objvar.mutable_local, AttributeMutationNew)
and not (args or kwargs)
):
tx.output.side_effects.store_attr(
objvar,
"__call_nn_module_init",
variables.ConstantVariable.create(True),
)
return variables.ConstantVariable.create(None)
else:
unimplemented("super() nn.Module.__init__")
elif isinstance(inner_fn, types.FunctionType):
return variables.UserFunctionVariable(
inner_fn, source=source
).call_function(tx, [self.objvar] + args, kwargs)
elif isinstance(inner_fn, types.MethodType):
return variables.UserMethodVariable(
inner_fn.__func__, self.objvar, source=source
).call_function(tx, args, kwargs)
elif (
inner_fn is collections.OrderedDict.__getitem__
and isinstance(self.objvar, variables.UserDefinedObjectVariable)
and self.objvar.source
and len(args) == 1
and len(kwargs) == 0
and args[0].is_python_constant()
):
from .builder import VariableBuilder
key = args[0].as_python_constant()
return VariableBuilder(tx, ODictGetItemSource(self.objvar.source, key))(
collections.OrderedDict.__getitem__(self.objvar.value, key)
)
elif (
inner_fn in (collections.OrderedDict.__setitem__, object.__setattr__)
and isinstance(self.objvar, variables.CustomizedDictVariable)
and args
and variables.ConstDictVariable.is_valid_key(args[0])
and self.objvar.mutable_local
):
assert not kwargs and len(args) == 2
k = variables.ConstDictVariable.get_key(args[0])
newval = dict(self.objvar.items)
newval[k] = args[1]
return tx.replace_all(
self.objvar,
self.objvar.modifed(newval),
)
else:
unimplemented(f"non-function or method super: {inner_fn}")
class UnknownVariable(VariableTracker):
"""
It could be anything!
"""
class DelayGraphBreakVariable(UnknownVariable):
"""
Used to insert a dummy variable in the stack to do the graph break at CALL_FUNCTION.
"""
class ComptimeVariable(VariableTracker):
"""
This variable is special, it lets you execute arbitrary code at
Dynamo compile time
"""
def reconstruct(self, codegen):
raise NotImplementedError("comptime is special form")
def var_getattr(self, tx, name: str) -> "VariableTracker":
from ..comptime import comptime
# To support the comptime.print_graph convenience accessors
from .functions import UserFunctionVariable
return UserFunctionVariable(
getattr(comptime, name), source=AttrSource(self.source, name)
)
def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
from ..comptime import ComptimeContext
# TODO: support an expression form as well
assert not kwargs
assert len(args) == 1
fn = args[0]
if isinstance(fn, UserFunctionVariable):
fn.get_function()(ComptimeContext(tx))
elif isinstance(fn, NestedUserFunctionVariable):
# We have to manually bind the freevars ourselves
code = fn.get_code()
assert not fn.closure, (
"comptime function must not have free variables, "
f"but these variables were free: {code.co_freevars}"
)
func = types.FunctionType(
code,
fn.f_globals,
fn.fn_name.as_python_constant(),
tuple(fn.defaults.items) if fn.defaults else None,
# We could automatically promote free variables into
# ComptimeVar but this is confusing if you access
# a free variable that we actually DO have the runtime
# value for
# tuple(make_cell(ComptimeVar(i)) for i in fn.closure.items)
tuple(),
)
func(ComptimeContext(tx))
else:
raise RuntimeError(f"unsupported argument to comptime: {type(fn)}")
return variables.ConstantVariable.create(None)
class ClosureVariable(UnknownVariable):
def __init__(self, name, **kwargs):
super().__init__(**kwargs)
self.name = name
def reconstruct(self, codegen):
return [codegen.create_load_closure(self.name)]
# closure variable created by an inlined function
class InlinedClosureVariable(UnknownVariable):
def __init__(self, name, **kwargs):
super().__init__(**kwargs)
self.name = name
def reconstruct(self, codegen):
return [codegen.create_load_closure(self.name)]
class NewCellVariable(VariableTracker):
def __init__(self, **kwargs):
super().__init__(**kwargs)
class NewGlobalVariable(VariableTracker):
def __init__(self, **kwargs):
super().__init__(**kwargs)
class InspectSignatureVariable(VariableTracker):
"""represents inspect.signature(...)"""
@staticmethod
def create(callable, **kwargs):
if kwargs:
unimplemented(f"inspect.signature with {kwargs}")
return InspectSignatureVariable(callable)
def __init__(self, inspected, **kwargs):
super().__init__(**kwargs)
self.inspected = inspected
def produce_trampoline_autograd_fwd(fn_cls):
def trampoline_autograd_fwd(*args, **kwargs):
return fn_cls.forward(*args, **kwargs)
trampoline_autograd_fwd._origin = produce_trampoline_autograd_fwd
return trampoline_autograd_fwd
def produce_trampoline_autograd_bwd(fn_cls):
def trampoline_autograd_bwd(*args, **kwargs):
return fn_cls.backward(*args, **kwargs)
trampoline_autograd_bwd._origin = produce_trampoline_autograd_bwd
return trampoline_autograd_bwd
def produce_trampoline_autograd_apply(fn_cls):
def trampoline_autograd_apply(*args, **kwargs):
return fn_cls.apply(*args, **kwargs)
trampoline_autograd_apply._origin = produce_trampoline_autograd_apply
return trampoline_autograd_apply
class AutogradFunctionVariable(VariableTracker):
"""represents a torch.autograd.Function subclass"""
def __init__(self, fn_cls, **kwargs):
super().__init__(**kwargs)
self.fn_cls = fn_cls
def call_apply(self, tx, args, kwargs):
requires_grad = False
def visit(node):
nonlocal requires_grad
if isinstance(node, variables.TensorVariable):
if node.requires_grad is not False:
requires_grad = True
if isinstance(node, variables.NNModuleVariable):
if node.is_training(tx):
requires_grad = True
return node
VariableTracker.apply(visit, (args, kwargs))
ctx = AutogradFunctionContextVariable.create(tx)
args = [ctx, *args]
if (
requires_grad
and torch.is_grad_enabled()
and config.capture_autograd_function
):
# Note - this is the same check used in autograd/function.py, except inverted.
# If we want to support functorch transforms here, we will need to enable this.
if (
self.fn_cls.setup_context
!= torch.autograd.function._SingleLevelFunction.setup_context
):
unimplemented(
"NYI - autograd.Function with custom setup_context method"
)
vjp_fn = self.fn_cls.vjp # type: ignore[attr-defined]
if vjp_fn is not torch.autograd.Function.vjp:
unimplemented("NYI - User defind vjp")
jvp_fn = self.fn_cls.jvp # type: ignore[attr-defined]
if jvp_fn is not torch.autograd.Function.jvp:
unimplemented("NYI - User defind jvp")
from .higher_order_ops import (
safe_or_raise_always_restore,
TorchHigherOrderOperatorVariable,
)
trampoline_autograd_apply = produce_trampoline_autograd_apply(self.fn_cls)
trampoline_autograd_fwd = produce_trampoline_autograd_fwd(self.fn_cls)
trampoline_autograd_bwd = produce_trampoline_autograd_bwd(self.fn_cls)
# NOTE [On Tracing autograd.Function w/ grad]
# The complex system described here revolves around the soundness evaluation of an autograd.Function in
# PyTorch. The system follows a well-defined strategy for tracing, which involves three key steps: tracing
# forward, tracing backward, and if both are sound the potential recording of an "apply" operation into the
# graph.We trace forward, and evaluate soundness. Soundness, in this context, refers to the absence of side
# effects, the avoidance of lifting new arguments into the trace, the production of a single tensor output,
# and a limited input scope confined to contexts, tensors, and constants. If the forward trace is sound,
# we install any guards accumulated from tracing. If not, we graph break. We trace backward, and evaluate
# for soundness, same as forward, except with more strictness. We enable a strict mode on the tx, and
# reject certain ops when running under this strict mode. If the backward trace is sound, we discard the
# trace by restoring. Otherwise, we raise.
# if both the forward and backward traces are sound, we write the autograd functions apply into the graph.
# For tracing forward and backward, we use UserFunctionVariable. Although it does not directly contribute
# to soundness evaluation, it plus a GlobalSource makes sure we can produce valid guards,
# and that we can inline properly here. Inlining is required in order to be able to ensure that the
# soundness evaluation works as described above.
graph_checkpoint, checkpoint = tx.output.graph, tx.copy_graphstate()
module_source = AttrSource(
tx.import_source(self.fn_cls.__module__), self.fn_cls.__name__
)
fwd_bwd_tracer = torch._dynamo.output_graph.SubgraphTracer(
tx.output,
parent=tx.output.current_tracer,
source_target="autograd.Function",
)
higher_order_autograd_fn = TorchHigherOrderOperatorVariable.make(
trampoline_autograd_fwd,
source=AttrSource(module_source, "forward"),
fwd_bwd_tracer=fwd_bwd_tracer,
)
speculated_fwd_result = higher_order_autograd_fn.call_function(
tx, args, kwargs
)
if isinstance(speculated_fwd_result, variables.TupleVariable):
bwd_args = [ctx, *speculated_fwd_result.items]
else:
bwd_args = [ctx, speculated_fwd_result]
safe_or_raise_always_restore(
tx,
graph_checkpoint,
checkpoint,
TorchHigherOrderOperatorVariable.make(
trampoline_autograd_bwd,
source=AttrSource(module_source, "backward"),
fwd_bwd_tracer=fwd_bwd_tracer,
),
bwd_args,
)
# If fwd and backward are sound, we want apply in the graph.
# And we don't want backwards for the obvious reasons.
args = args[1:]
return TorchHigherOrderOperatorVariable.make(
trampoline_autograd_apply,
fwd_bwd_tracer=None,
).call_function(tx, args, kwargs)
if self.source:
source = AttrSource(AttrSource(self.source, "__class__"), "forward")
else:
source = None
fn = self.fn_cls.forward
if isinstance(fn, types.FunctionType):
return variables.UserFunctionVariable(fn, source=source).call_function(
tx, args, kwargs
)
elif isinstance(fn, types.MethodType):
return variables.UserMethodVariable(
fn.__func__,
variables.UserDefinedClassVariable(self.fn_cls),
source=source,
).call_function(tx, args, kwargs)
else:
unimplemented(
f"non-function or method in subclass of torch.autograd.Function: {fn}"
)
def call_function(self, tx, args, kwargs):
return AutogradFunctionVariable(self.fn_cls)
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
):
from ..allowed_functions import is_user_defined_allowed
from .builder import wrap_fx_proxy
if name == "apply":
if is_user_defined_allowed(self.fn_cls):
trampoline_autograd_apply = produce_trampoline_autograd_apply(
self.fn_cls
)
return wrap_fx_proxy(
tx=tx,
proxy=tx.output.create_proxy(
"call_function",
trampoline_autograd_apply,
*proxy_args_kwargs(args, kwargs),
),
)
else:
return self.call_apply(tx, args, kwargs)
elif name == "backward":
with tx.strict_translation_mode():
if isinstance(self.fn_cls.backward, types.FunctionType):
backward = UserFunctionVariable(self.fn_cls.backward)
elif isinstance(self.fn_cls.backward, types.MethodType):
backward = UserMethodVariable(
self.fn_cls.backward.__func__,
variables.UserDefinedClassVariable(self.fn_cls),
)
args = [backward.obj] + args
else:
unimplemented(
f"backward is a non-function or method: {self.fn_cls.backward}"
)
return tx.inline_call(tx, backward, args, kwargs)
elif name == "forward":
if isinstance(self.fn_cls.forward, types.FunctionType):
forward = UserFunctionVariable(self.fn_cls.forward)
elif isinstance(self.fn_cls.forward, types.MethodType):
forward = UserMethodVariable(
self.fn_cls.forward.__func__,
variables.UserDefinedClassVariable(self.fn_cls),
)
args = [forward.obj] + args
else:
unimplemented(
f"forward is a non-function or method: {self.fn_cls.forward}"
)
return tx.inline_call(tx, forward, args, kwargs)
else:
unimplemented(f"Unsupported method: {name}")
@dataclasses.dataclass
class SavedTensorBox:
tensors: List[VariableTracker] = dataclasses.field(default_factory=list)
class AutogradFunctionContextVariable(UserDefinedObjectVariable):
"""
Tracks an autograd.Function() context using mutation tracking in side_effects.py
"""
_nonvar_fields = {
"proxy",
"inference",
*UserDefinedObjectVariable._nonvar_fields,
}
def __init__(
self,
value,
value_type=None,
inference=False,
proxy=None,
saved_tensors=None,
**kwargs,
):
super().__init__(value=value, value_type=value_type, **kwargs)
self.inference = inference
self.proxy = proxy
self.saved_tensors = saved_tensors
@staticmethod
def create(tx):
proxy = tx.output.create_proxy(
"call_function", torch.autograd.function.FunctionCtx, tuple(), {}
)
out = tx.output.side_effects.track_object_new(
None,
torch.autograd.function.FunctionCtx,
functools.partial(
AutogradFunctionContextVariable,
inference=True,
proxy=proxy,
saved_tensors=SavedTensorBox(),
),
{},
)
proxy.node.meta["example_value"] = out.value
return out
def as_proxy(self):
if self.proxy is None:
unimplemented("proxy not set")
return self.proxy
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
if name != "save_for_backward":
unimplemented(f"autograd.Function context method: {name}")
if self.saved_tensors is None:
unimplemented(
"save_for_backward only supported on a newly constructed FunctionCtx"
)
if not self.inference:
assert self.source and not kwargs
tx.output.side_effects.track_save_for_backward(self, args)
for arg in args:
self.saved_tensors.tensors.append(arg)
return variables.ConstantVariable.create(None)
def var_getattr(self, tx, name):
if name == "save_for_backward":
return LambdaVariable(
lambda *args, **kwargs: self.call_method(tx, name, args, kwargs)
)
if name == "saved_tensors" and self.saved_tensors is not None:
return variables.TupleVariable(list(self.saved_tensors.tensors))
return super().var_getattr(tx, name)
class LambdaVariable(VariableTracker):
def __init__(self, fn, **kwargs):
super().__init__(**kwargs)
self.fn = fn
def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
return self.fn(*args, **kwargs)
class GetAttrVariable(VariableTracker):
def __init__(self, obj, name, **kwargs):
super().__init__(**kwargs)
assert isinstance(obj, VariableTracker)
assert isinstance(name, str)
self.obj = obj
self.name = name
def __str__(self):
return f"{self.__class__.__name__}({self.obj}, {self.name})"
@staticmethod
def create_getattr_proxy(base_proxy: torch.fx.Proxy, attr):
return getattr(base_proxy, attr)
def as_proxy(self):
return GetAttrVariable.create_getattr_proxy(self.obj.as_proxy(), self.name)
def const_getattr(self, tx, name):
if not isinstance(self.obj, variables.NNModuleVariable):
raise NotImplementedError()
step1 = tx.output.get_submodule(self.obj.module_key)
if self.name not in step1.__dict__:
raise NotImplementedError()
step2 = inspect.getattr_static(step1, self.name)
if name not in step2.__dict__:
raise NotImplementedError()
return inspect.getattr_static(step2, name)
def reconstruct(self, codegen):
codegen(self.obj)
return codegen.create_load_attrs(self.name)
def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
return self.obj.call_method(tx, self.name, args, kwargs)
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
if (
name == "__len__"
and isinstance(self.obj, InspectSignatureVariable)
and self.name == "parameters"
):
return variables.ConstantVariable.create(
self.obj.inspected.num_parameters(),
)
return super().call_method(tx, name, args, kwargs)
class MethodWrapperVariable(VariableTracker):
def __init__(self, method_wrapper, **kwargs):
super().__init__(**kwargs)
self.method_wrapper = method_wrapper
def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
if is_tensor_base_attr_getter(self.method_wrapper) and isinstance(
args[0], variables.TensorVariable
):
assert len(args) == 1 and len(kwargs) == 0
return args[0].var_getattr(tx, self.method_wrapper.__self__.__name__)
super().call_function(tx, args, kwargs)
def is_python_constant(self):
return True
def as_python_constant(self):
return self.method_wrapper
class GetSetDescriptorVariable(VariableTracker):
def __init__(self, desc, **kwargs):
super().__init__(**kwargs)
self.desc = desc
def var_getattr(self, tx, name):
if name == "__get__" and self.source:
from .builder import VariableBuilder
return VariableBuilder(tx, AttrSource(self.source, "__get__"))(
self.desc.__get__
)
else:
return super().var_getattr(tx, name)
def is_python_constant(self):
return True
def as_python_constant(self):
return self.desc
class PythonModuleVariable(VariableTracker):
def __init__(self, value: types.ModuleType, **kwargs):
super().__init__(**kwargs)
self.value = value
def python_type(self):
return types.ModuleType
class SkipFilesVariable(VariableTracker):
def __init__(self, value, reason=None, **kwargs):
super().__init__(**kwargs)
self.value = value
self.reason = reason
def python_type(self):
return type(self.value)
def as_python_constant(self):
return self.value
@classmethod
def create_with_source(cls, value, source):
install_guard(source.make_guard(GuardBuilder.FUNCTION_MATCH))
return cls(
value,
source=source,
)
@staticmethod
@functools.lru_cache(None)
def fold_through_function_to_wrapper():
return {
collections.namedtuple: variables.UserDefinedClassVariable,
}
def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
from .builtin import BuiltinVariable
if inspect.getattr_static(self.value, "_torchdynamo_disable", False):
unimplemented(f"call torch._dynamo.disable() wrapped function {self.value}")
# Allowlist a few popular classes(e.g, collections.OrderedDict) calls in skip files.
elif self.value is collections.OrderedDict:
return BuiltinVariable.call_custom_dict(
tx, collections.OrderedDict, *args, **kwargs
)
elif (
self.value is collections.defaultdict
and len(args) <= 1
and DefaultDictVariable.is_supported_arg(args[0])
):
return DefaultDictVariable(
{},
collections.defaultdict,
args[0],
mutable_local=MutableLocal(),
)
# Fold through the functions(e.g, collections.namedtuple)
# that inputs & outputs are all python constants
elif (
self.value in self.fold_through_function_to_wrapper().keys()
and check_constant_args(args, kwargs)
):
value = self.value(
*[x.as_python_constant() for x in args],
**{k: v.as_python_constant() for k, v in kwargs.items()},
)
return self.fold_through_function_to_wrapper().get(self.value)(
value, mutable_local=MutableLocal()
)
elif (
self.value is itertools.product
and not kwargs
and all(arg.has_unpack_var_sequence(tx) for arg in args)
):
seqs = [arg.unpack_var_sequence(tx) for arg in args]
items = []
for item in itertools.product(*seqs):
items.append(variables.TupleVariable(list(item)))
return variables.ListIteratorVariable(items, mutable_local=MutableLocal())
elif (
self.value is itertools.chain
and not kwargs
and all(arg.has_unpack_var_sequence(tx) for arg in args)
):
seqs = [arg.unpack_var_sequence(tx) for arg in args]
items = []
for item in itertools.chain(*seqs):
items.append(item)
return variables.ListIteratorVariable(items, mutable_local=MutableLocal())
elif self.value is itertools.accumulate:
from .builtin import BuiltinVariable
if any(key not in ["initial", "func"] for key in kwargs.keys()):
unimplemented(
"Unsupported kwargs for itertools.accumulate: "
f"{','.join(set(kwargs.keys()) - {'initial', 'func'})}"
)
acc = kwargs.get("initial")
if len(args) in [1, 2] and args[0].has_unpack_var_sequence(tx):
seq = args[0].unpack_var_sequence(tx)
if "func" in kwargs and len(args) == 1:
func = kwargs["func"].call_function
elif len(args) == 2:
func = args[1].call_function
elif len(args) == 1:
# Default to operator.add
func = BuiltinVariable(operator.add).call_function
else:
unimplemented(
"itertools.accumulate can only accept one of: `func` kwarg, pos 2 arg"
)
else:
unimplemented("Unsupported arguments for itertools.accumulate")
items = []
if acc is not None:
items.append(acc)
for item in seq:
if acc is None:
acc = item
else:
try:
acc = func(tx, [acc, item], {})
except Exception:
raise unimplemented( # noqa: TRY200
f"Unexpected failure in invoking function during accumulate. Failed running func {func}({item}{acc})"
)
items.append(acc)
return variables.ListIteratorVariable(items, mutable_local=MutableLocal())
elif (
self.value is itertools.combinations
and not kwargs
and len(args) == 2
and args[0].has_unpack_var_sequence(tx)
and args[1].is_python_constant()
):
iterable = args[0].unpack_var_sequence(tx)
r = args[1].as_python_constant()
items = []
for item in itertools.combinations(iterable, r):
items.append(variables.TupleVariable(list(item)))
return variables.ListIteratorVariable(items, mutable_local=MutableLocal())
elif self.value is itertools.groupby:
if any(kw != "key" for kw in kwargs.keys()):
unimplemented(
"Unsupported kwargs for itertools.groupby: "
f"{','.join(set(kwargs.keys()) - {'key'})}"
)
def retrieve_const_key(key):
if isinstance(key, variables.SymNodeVariable):
return key.evaluate_expr()
elif isinstance(key, variables.ConstantVariable):
return key.as_python_constant()
else:
raise unimplemented(
"Unsupported key type for itertools.groupby: " + str(type(key))
)
if len(args) == 1 and args[0].has_unpack_var_sequence(tx):
seq = args[0].unpack_var_sequence(tx)
keyfunc = (
(
lambda x: (
retrieve_const_key(
kwargs.get("key").call_function(tx, [x], {})
)
)
)
if "key" in kwargs
else None
)
else:
unimplemented("Unsupported arguments for itertools.groupby")
result = []
try:
for k, v in itertools.groupby(seq, key=keyfunc):
result.append(
variables.TupleVariable(
[
variables.ConstantVariable.create(k)
if variables.ConstantVariable.is_literal(k)
else k,
variables.ListIteratorVariable(
list(v), mutable_local=MutableLocal()
),
],
mutable_local=MutableLocal(),
)
)
except Exception:
raise unimplemented( # noqa: TRY200
"Unexpected failure when calling itertools.groupby"
)
return variables.ListIteratorVariable(result, mutable_local=MutableLocal())
elif (
self.value is functools.wraps
and not kwargs
and len(args) == 1
and (
args[0].source is not None or args[0].can_reconstruct(tx.output.root_tx)
)
):
def wraps(fn):
if isinstance(fn, variables.NestedUserFunctionVariable):
if args[0].source:
reconstructible = args[0].source
else:
reconstructible = args[0]
return fn.clone(wrapped_reconstructible=reconstructible)
unimplemented(f"functools.wraps({fn})")
return variables.LambdaVariable(wraps)
elif self.value is collections.deque and not kwargs:
if len(args) == 0:
items = []
elif len(args) == 1 and args[0].has_unpack_var_sequence(tx):
items = args[0].unpack_var_sequence(tx)
else:
unimplemented("deque() with more than 1 arg not supported")
return variables.lists.DequeVariable(items, mutable_local=MutableLocal())
elif self.value is functools.partial:
if not args:
unimplemented("functools.partial malformed")
# The first arg, a callable (the ctor below will assert on types)
fn = args[0]
rest_args = args[1:]
# guards for the produced FunctoolsPartialVariable are installed in FunctoolsPartialVariable ctor from the
# args and keywords
return variables.functions.FunctoolsPartialVariable(
fn, args=rest_args, keywords=kwargs
)
elif self.value is itertools.repeat:
if len(args) < 2:
return variables.RepeatIteratorVariable(
*args, mutable_local=MutableLocal()
)
from .builder import SourcelessBuilder
return tx.inline_user_function_return(
SourcelessBuilder()(tx, polyfill.repeat), args, kwargs
)
elif self.value is itertools.count:
return variables.CountIteratorVariable(*args, mutable_local=MutableLocal())
elif self.value is itertools.cycle:
return variables.CycleIteratorVariable(*args, mutable_local=MutableLocal())
else:
try:
path = inspect.getfile(self.value)
except TypeError:
path = f"Builtin {self.value.__name__}"
msg = f"'skip function {self.value.__qualname__} in file {path}'"
msg += f"', {self.reason}'" if self.reason else ""
unimplemented(msg)
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
if (
self.value in {collections.OrderedDict, collections.defaultdict}
and name == "fromkeys"
):
from .builtin import BuiltinVariable
return BuiltinVariable.call_custom_dict_fromkeys(
tx, self.value, *args, **kwargs
)
return super().call_method(tx, name, args, kwargs)
class TypingVariable(VariableTracker):
def __init__(self, value, **kwargs):
super().__init__(**kwargs)
self.value = value
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
if name == "__getitem__" and len(args) == 1:
return variables.ConstantVariable.create(
self.value[args[0].as_python_constant()],
)
unimplemented("typing")
def python_type(self):
return type(self.value)
def as_python_constant(self):
return self.value
@functools.lru_cache(maxsize=1)
def get_np_to_tnp_map():
from ..utils import NP_TO_TNP_MODULE
np_fn_to_tnp_fn = {}
for np_mod, tnp_mod in NP_TO_TNP_MODULE.items():
for fn_name, tnp_fn in tnp_mod.__dict__.items():
if callable(tnp_fn):
# some internal details do leak from tnp
# which are not part of numpy API.
if np_fn := getattr(np_mod, fn_name, None):
np_fn_to_tnp_fn[np_fn] = tnp_fn
return np_fn_to_tnp_fn
class NumpyVariable(VariableTracker):
"""
Wrapper around `numpy.*`. Currently, is able to trace a small subset of numpy functions as well as numpy dtypes.
"""
def __init__(self, value, **kwargs):
super().__init__(**kwargs)
self.value = value
def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
if not config.trace_numpy:
unimplemented(f"numpy.{self.value}()")
from ..utils import numpy_to_tensor_wrapper
from .tensor import NumpyNdarrayVariable
# lookup method name in tnp. Things like np.dtype(float) are not supported yet.
if self.value.__name__ == "dtype":
unimplemented(
f"numpy dtype function is not supported yet. Got type {type(self.value)}."
)
else: # We are dealing with a callable.
func = get_np_to_tnp_map().get(self.value)
if func is None:
unimplemented(
f"Can't find numpy function {self.value} in torch._numpy. "
" Please file an issue to request support for this function."
)
if (
func.__module__ == "torch._numpy.random"
and config.use_numpy_random_stream
):
msg = f"delegate '{func.__qualname__}' to NumPy itself via "
msg += f"confg.use_numpy_random_stream={config.use_numpy_random_stream}"
unimplemented(msg)
# TODO(larryliu0820): currently assuming all numpy.* functions are returning a ndarray that can be
# wrapped by NumpyNdarrayVariable which is wrong!
proxy = tx.output.create_proxy(
"call_function",
numpy_to_tensor_wrapper(func),
*proxy_args_kwargs(args, kwargs),
)
return NumpyNdarrayVariable.create(tx, proxy)
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
unimplemented("numpy")
def python_type(self):
return type(self.value)
def as_python_constant(self):
return self.value
def as_proxy(self):
# this handles numpy dtype attribute such as np.float32. TODO(larryliu0820): we should split NumpyVariable
# into NumpyVariable for instances/objects and NumpyVariable for types.
if config.trace_numpy and isinstance(self.value, type):
# retrieve attribute str. E.g., "float32" if given np.float32
attr = self.value.__name__
# get tnp equivalent
tnp_dtype = tnp.dtype(attr)
# returning a string here because we are assuming all `dtype` kwargs for numpy
# functions can take an equivalent string and the behavior of the function would
# be the same as taking a numpy dtype.
return tnp_dtype.name
return super().as_proxy()
# Used to keep track of NULLs pushed on the stack for Python 3.11 function calls
class NullVariable(VariableTracker):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def __str__(self):
return "NullVariable"
def reconstruct(self, codegen):
if sys.version_info < (3, 11):
unimplemented("cannot reconstruct NullVariable in < Python 3.11")
return [create_instruction("PUSH_NULL")]
class DeletedVariable(VariableTracker):
"""Marker used to implement delattr()"""