import collections import dataclasses import functools import inspect import itertools import operator import sys import types from typing import Dict, List import torch._C import torch._numpy as tnp from .. import config, polyfill, variables from ..bytecode_transformation import create_call_function, create_instruction from ..exc import unimplemented from ..guards import GuardBuilder, install_guard from ..source import AttrSource, GetItemSource, ODictGetItemSource, TypeSource from ..utils import ( check_constant_args, identity, is_tensor_base_attr_getter, proxy_args_kwargs, ) from .base import MutableLocal, VariableTracker from .dicts import DefaultDictVariable from .functions import ( NestedUserFunctionVariable, UserFunctionVariable, UserMethodVariable, ) from .user_defined import UserDefinedObjectVariable class SuperVariable(VariableTracker): def __init__(self, typevar, objvar=None, specialized=False, **kwargs): super().__init__(**kwargs) self.typevar = typevar self.objvar = objvar self.specialized = specialized # directly get attr from self.typevar if true def reconstruct(self, codegen): codegen(variables.BuiltinVariable(super)) codegen(self.typevar) if self.objvar is not None: codegen(self.objvar) return create_call_function(2, True) else: return create_call_function(1, True) def _resolved_getattr_and_source(self, tx, name): assert self.objvar, "1-arg super not implemented" if self.specialized: return getattr(self.typevar.as_python_constant(), name) search_type = self.typevar.as_python_constant() # We default to the python type of the object. However, if this is # a `type` or subclass of `type`, then the original object represents # the user defined type. type_to_use = self.objvar.python_type() type_to_use_source = ( TypeSource(self.objvar.source) if self.objvar.source else None ) if issubclass(type_to_use, type): type_to_use = self.objvar.value type_to_use_source = self.objvar.source source = None if self.objvar.source is not None: # Walk the mro tuple to find out the actual class where the # attribute resides. search_mro = type_to_use.__mro__ start_index = search_mro.index(search_type) + 1 for index in range(start_index, len(search_mro)): if hasattr(search_mro[index], name): # Equivalent of something like type(L['self']).__mro__[1].attr_name source = AttrSource( GetItemSource(AttrSource(type_to_use_source, "__mro__"), index), name, ) break # TODO(jansel): there is a small chance this could trigger user code, prevent that return getattr(super(search_type, type_to_use), name), source def var_getattr(self, tx, name: str) -> "VariableTracker": # Check if getattr is a constant. If not, delay the actual work by # wrapping the result in GetAttrVariable. Mostly super is called with a # method, so most of the work is delayed to call_function. # # We could have just implemented a const_getattr. However, super is # special when it comes to finding sources. Compared to other VTs, super # requires the attr name to walk the mro and find the actual source (and # not just AttrSource). value, source = self._resolved_getattr_and_source(self, name) if not variables.ConstantVariable.is_literal(value): return GetAttrVariable(self, name) if source: install_guard(source.make_guard(GuardBuilder.CONSTANT_MATCH)) return variables.ConstantVariable.create(value, source=source) return variables.ConstantVariable.create(value) def call_method( self, tx, name, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]", ) -> "VariableTracker": inner_fn, source = self._resolved_getattr_and_source(self, name) if inner_fn is object.__init__: return LambdaVariable(identity) elif inner_fn is torch.nn.Module.__init__: objvar = self.objvar from ..side_effects import AttributeMutationNew if ( isinstance(objvar, variables.UserDefinedObjectVariable) and isinstance(objvar.mutable_local, AttributeMutationNew) and not (args or kwargs) ): tx.output.side_effects.store_attr( objvar, "__call_nn_module_init", variables.ConstantVariable.create(True), ) return variables.ConstantVariable.create(None) else: unimplemented("super() nn.Module.__init__") elif isinstance(inner_fn, types.FunctionType): return variables.UserFunctionVariable( inner_fn, source=source ).call_function(tx, [self.objvar] + args, kwargs) elif isinstance(inner_fn, types.MethodType): return variables.UserMethodVariable( inner_fn.__func__, self.objvar, source=source ).call_function(tx, args, kwargs) elif ( inner_fn is collections.OrderedDict.__getitem__ and isinstance(self.objvar, variables.UserDefinedObjectVariable) and self.objvar.source and len(args) == 1 and len(kwargs) == 0 and args[0].is_python_constant() ): from .builder import VariableBuilder key = args[0].as_python_constant() return VariableBuilder(tx, ODictGetItemSource(self.objvar.source, key))( collections.OrderedDict.__getitem__(self.objvar.value, key) ) elif ( inner_fn in (collections.OrderedDict.__setitem__, object.__setattr__) and isinstance(self.objvar, variables.CustomizedDictVariable) and args and variables.ConstDictVariable.is_valid_key(args[0]) and self.objvar.mutable_local ): assert not kwargs and len(args) == 2 k = variables.ConstDictVariable.get_key(args[0]) tx.output.side_effects.mutation(self) self.objvar.items[k] = args[1] return variables.ConstantVariable.create(None) else: unimplemented(f"non-function or method super: {inner_fn}") class UnknownVariable(VariableTracker): """ It could be anything! """ class DelayGraphBreakVariable(UnknownVariable): """ Used to insert a dummy variable in the stack to do the graph break at CALL_FUNCTION. """ class ComptimeVariable(VariableTracker): """ This variable is special, it lets you execute arbitrary code at Dynamo compile time """ def reconstruct(self, codegen): raise NotImplementedError("comptime is special form") def var_getattr(self, tx, name: str) -> "VariableTracker": from ..comptime import comptime # To support the comptime.print_graph convenience accessors from .functions import UserFunctionVariable return UserFunctionVariable( getattr(comptime, name), source=AttrSource(self.source, name) ) def call_function( self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" ) -> "VariableTracker": from ..comptime import ComptimeContext # TODO: support an expression form as well assert not kwargs assert len(args) == 1 fn = args[0] if isinstance(fn, UserFunctionVariable): fn.get_function()(ComptimeContext(tx)) elif isinstance(fn, NestedUserFunctionVariable): # We have to manually bind the freevars ourselves code = fn.get_code() assert not fn.closure, ( "comptime function must not have free variables, " f"but these variables were free: {code.co_freevars}" ) func = types.FunctionType( code, fn.f_globals, fn.fn_name.as_python_constant(), tuple(fn.defaults.items) if fn.defaults else None, # We could automatically promote free variables into # ComptimeVar but this is confusing if you access # a free variable that we actually DO have the runtime # value for # tuple(make_cell(ComptimeVar(i)) for i in fn.closure.items) tuple(), ) func(ComptimeContext(tx)) else: raise RuntimeError(f"unsupported argument to comptime: {type(fn)}") return variables.ConstantVariable.create(None) class ClosureVariable(UnknownVariable): def __init__(self, name, **kwargs): super().__init__(**kwargs) self.name = name def reconstruct(self, codegen): return [codegen.create_load_closure(self.name)] # closure variable created by an inlined function class InlinedClosureVariable(UnknownVariable): def __init__(self, name, **kwargs): super().__init__(**kwargs) self.name = name def reconstruct(self, codegen): return [codegen.create_load_closure(self.name)] class NewCellVariable(VariableTracker): def __init__(self, **kwargs): super().__init__(**kwargs) class NewGlobalVariable(VariableTracker): def __init__(self, **kwargs): super().__init__(**kwargs) class InspectSignatureVariable(VariableTracker): """represents inspect.signature(...)""" @staticmethod def create(callable, **kwargs): if kwargs: unimplemented(f"inspect.signature with {kwargs}") return InspectSignatureVariable(callable) def __init__(self, inspected: VariableTracker, **kwargs): super().__init__(**kwargs) self.inspected = inspected def var_getattr(self, tx, name: str) -> "VariableTracker": if name == "parameters": return variables.ConstDictVariable( { name: InspectParameterVariable() for name in self.inspected.inspect_parameter_names() }, user_cls=dict, ) return super().var_getattr(tx, name) class InspectParameterVariable(VariableTracker): """This is not implemented, if used will graph break.""" pass def produce_trampoline_autograd_fwd(fn_cls): def trampoline_autograd_fwd(*args, **kwargs): return fn_cls.forward(*args, **kwargs) trampoline_autograd_fwd._origin = produce_trampoline_autograd_fwd return trampoline_autograd_fwd def produce_trampoline_autograd_bwd(fn_cls): def trampoline_autograd_bwd(*args, **kwargs): return fn_cls.backward(*args, **kwargs) trampoline_autograd_bwd._origin = produce_trampoline_autograd_bwd return trampoline_autograd_bwd def produce_trampoline_autograd_apply(fn_cls): def trampoline_autograd_apply(*args, **kwargs): return fn_cls.apply(*args, **kwargs) trampoline_autograd_apply._origin = produce_trampoline_autograd_apply return trampoline_autograd_apply class AutogradFunctionVariable(VariableTracker): """represents a torch.autograd.Function subclass""" def __init__(self, fn_cls, **kwargs): super().__init__(**kwargs) self.fn_cls = fn_cls def call_apply(self, tx, args, kwargs): requires_grad = False def visit(node): nonlocal requires_grad if isinstance(node, variables.TensorVariable): if node.requires_grad is not False: requires_grad = True if isinstance(node, variables.NNModuleVariable): if node.is_training(tx): requires_grad = True return node VariableTracker.apply(visit, (args, kwargs)) ctx = AutogradFunctionContextVariable.create(tx) args = [ctx, *args] if ( requires_grad and torch.is_grad_enabled() and config.capture_autograd_function ): # Note - this is the same check used in autograd/function.py, except inverted. # If we want to support functorch transforms here, we will need to enable this. if ( self.fn_cls.setup_context != torch.autograd.function._SingleLevelFunction.setup_context ): unimplemented( "NYI - autograd.Function with custom setup_context method" ) vjp_fn = self.fn_cls.vjp # type: ignore[attr-defined] if vjp_fn is not torch.autograd.Function.vjp: unimplemented("NYI - User defind vjp") jvp_fn = self.fn_cls.jvp # type: ignore[attr-defined] if jvp_fn is not torch.autograd.Function.jvp: unimplemented("NYI - User defind jvp") from .higher_order_ops import TorchHigherOrderOperatorVariable trampoline_autograd_apply = produce_trampoline_autograd_apply(self.fn_cls) trampoline_autograd_fwd = produce_trampoline_autograd_fwd(self.fn_cls) trampoline_autograd_bwd = produce_trampoline_autograd_bwd(self.fn_cls) # NOTE [On Tracing autograd.Function w/ grad] # The complex system described here revolves around the soundness evaluation of an autograd.Function in # PyTorch. The system follows a well-defined strategy for tracing, which involves three key steps: tracing # forward, tracing backward, and if both are sound the potential recording of an "apply" operation into the # graph.We trace forward, and evaluate soundness. Soundness, in this context, refers to the absence of side # effects, the avoidance of lifting new arguments into the trace, the production of a single tensor output, # and a limited input scope confined to contexts, tensors, and constants. If the forward trace is sound, # we install any guards accumulated from tracing. If not, we graph break. We trace backward, and evaluate # for soundness, same as forward, except with more strictness. We enable a strict mode on the tx, and # reject certain ops when running under this strict mode. If both the forward and backward traces are sound, # we write the autograd function’s apply into the graph. # For tracing forward and backward, we use UserFunctionVariable. Although it does not directly contribute # to soundness evaluation, it plus a GlobalSource makes sure we can produce valid guards, # and that we can inline properly here. Inlining is required in order to be able to ensure that the # soundness evaluation works as described above. module_source = AttrSource( tx.import_source(self.fn_cls.__module__), self.fn_cls.__name__ ) fwd_bwd_tracer = torch._dynamo.output_graph.SubgraphTracer( tx.output, parent=tx.output.current_tracer, source_target="autograd.Function", ) higher_order_autograd_fn = TorchHigherOrderOperatorVariable.make( trampoline_autograd_fwd, source=AttrSource(module_source, "forward"), fwd_bwd_tracer=fwd_bwd_tracer, ) speculated_fwd_result = higher_order_autograd_fn.call_function( tx, args, kwargs ) if isinstance(speculated_fwd_result, variables.TupleVariable): bwd_args = [ctx, *speculated_fwd_result.items] else: bwd_args = [ctx, speculated_fwd_result] TorchHigherOrderOperatorVariable.make( trampoline_autograd_bwd, source=AttrSource(module_source, "backward"), fwd_bwd_tracer=fwd_bwd_tracer, ).call_function(tx, bwd_args, {}) # If fwd and backward are sound, we want apply in the graph. # We don't want backward because we are tracing forwards. args = args[1:] return TorchHigherOrderOperatorVariable.make( trampoline_autograd_apply, fwd_bwd_tracer=None, ).call_function(tx, args, kwargs) if self.source: source = AttrSource(AttrSource(self.source, "__class__"), "forward") else: source = None fn = self.fn_cls.forward if isinstance(fn, types.FunctionType): return variables.UserFunctionVariable(fn, source=source).call_function( tx, args, kwargs ) elif isinstance(fn, types.MethodType): return variables.UserMethodVariable( fn.__func__, variables.UserDefinedClassVariable(self.fn_cls), source=source, ).call_function(tx, args, kwargs) else: unimplemented( f"non-function or method in subclass of torch.autograd.Function: {fn}" ) def call_function(self, tx, args, kwargs): return AutogradFunctionVariable(self.fn_cls) def call_method( self, tx, name, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]", ): from ..allowed_functions import is_user_defined_allowed from .builder import wrap_fx_proxy if name == "apply": if is_user_defined_allowed(self.fn_cls): trampoline_autograd_apply = produce_trampoline_autograd_apply( self.fn_cls ) return wrap_fx_proxy( tx=tx, proxy=tx.output.create_proxy( "call_function", trampoline_autograd_apply, *proxy_args_kwargs(args, kwargs), ), ) else: return self.call_apply(tx, args, kwargs) elif name == "backward": with tx.strict_translation_mode(): if isinstance(self.fn_cls.backward, types.FunctionType): backward = UserFunctionVariable(self.fn_cls.backward) elif isinstance(self.fn_cls.backward, types.MethodType): backward = UserMethodVariable( self.fn_cls.backward.__func__, variables.UserDefinedClassVariable(self.fn_cls), ) args = [backward.obj] + args else: unimplemented( f"backward is a non-function or method: {self.fn_cls.backward}" ) return tx.inline_call(tx, backward, args, kwargs) elif name == "forward": if isinstance(self.fn_cls.forward, types.FunctionType): forward = UserFunctionVariable(self.fn_cls.forward) elif isinstance(self.fn_cls.forward, types.MethodType): forward = UserMethodVariable( self.fn_cls.forward.__func__, variables.UserDefinedClassVariable(self.fn_cls), ) args = [forward.obj] + args else: unimplemented( f"forward is a non-function or method: {self.fn_cls.forward}" ) return tx.inline_call(tx, forward, args, kwargs) else: unimplemented(f"Unsupported method: {name}") @dataclasses.dataclass class SavedTensorBox: tensors: List[VariableTracker] = dataclasses.field(default_factory=list) class AutogradFunctionContextVariable(UserDefinedObjectVariable): """ Tracks an autograd.Function() context using mutation tracking in side_effects.py """ _nonvar_fields = { "proxy", "inference", *UserDefinedObjectVariable._nonvar_fields, } def __init__( self, value, value_type=None, inference=False, proxy=None, saved_tensors=None, **kwargs, ): super().__init__(value=value, value_type=value_type, **kwargs) self.inference = inference self.proxy = proxy self.saved_tensors = saved_tensors @staticmethod def create(tx): proxy = tx.output.create_proxy( "call_function", torch.autograd.function.FunctionCtx, tuple(), {} ) out = tx.output.side_effects.track_object_new( None, torch.autograd.function.FunctionCtx, functools.partial( AutogradFunctionContextVariable, inference=True, proxy=proxy, saved_tensors=SavedTensorBox(), ), {}, ) proxy.node.meta["example_value"] = out.value return out def as_proxy(self): if self.proxy is None: unimplemented("proxy not set") return self.proxy def call_method( self, tx, name, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]", ) -> "VariableTracker": if name != "save_for_backward": unimplemented(f"autograd.Function context method: {name}") if self.saved_tensors is None: unimplemented( "save_for_backward only supported on a newly constructed FunctionCtx" ) if not self.inference: assert self.source and not kwargs tx.output.side_effects.track_save_for_backward(self, args) for arg in args: self.saved_tensors.tensors.append(arg) return variables.ConstantVariable.create(None) def var_getattr(self, tx, name): if name == "save_for_backward": return LambdaVariable( lambda *args, **kwargs: self.call_method(tx, name, args, kwargs) ) if name == "saved_tensors" and self.saved_tensors is not None: return variables.TupleVariable(list(self.saved_tensors.tensors)) return super().var_getattr(tx, name) class LambdaVariable(VariableTracker): def __init__(self, fn, **kwargs): super().__init__(**kwargs) self.fn = fn def call_function( self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" ) -> "VariableTracker": return self.fn(*args, **kwargs) class GetAttrVariable(VariableTracker): def __init__(self, obj, name, **kwargs): super().__init__(**kwargs) assert isinstance(obj, VariableTracker) assert isinstance(name, str) self.obj = obj self.name = name def __str__(self): return f"{self.__class__.__name__}({self.obj}, {self.name})" @staticmethod def create_getattr_proxy(base_proxy: torch.fx.Proxy, attr): return getattr(base_proxy, attr) def as_proxy(self): return GetAttrVariable.create_getattr_proxy(self.obj.as_proxy(), self.name) def const_getattr(self, tx, name): if not isinstance(self.obj, variables.NNModuleVariable): raise NotImplementedError() step1 = tx.output.get_submodule(self.obj.module_key) if self.name not in step1.__dict__: raise NotImplementedError() step2 = inspect.getattr_static(step1, self.name) if name not in step2.__dict__: raise NotImplementedError() return inspect.getattr_static(step2, name) def reconstruct(self, codegen): codegen(self.obj) return codegen.create_load_attrs(self.name) def call_function( self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" ) -> "VariableTracker": return self.obj.call_method(tx, self.name, args, kwargs) class MethodWrapperVariable(VariableTracker): def __init__(self, method_wrapper, **kwargs): super().__init__(**kwargs) self.method_wrapper = method_wrapper def call_function( self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" ) -> "VariableTracker": if is_tensor_base_attr_getter(self.method_wrapper) and isinstance( args[0], variables.TensorVariable ): assert len(args) == 1 and len(kwargs) == 0 return args[0].var_getattr(tx, self.method_wrapper.__self__.__name__) super().call_function(tx, args, kwargs) def is_python_constant(self): return True def as_python_constant(self): return self.method_wrapper class GetSetDescriptorVariable(VariableTracker): def __init__(self, desc, **kwargs): super().__init__(**kwargs) self.desc = desc def var_getattr(self, tx, name): if name == "__get__" and self.source: from .builder import VariableBuilder return VariableBuilder(tx, AttrSource(self.source, "__get__"))( self.desc.__get__ ) else: return super().var_getattr(tx, name) def is_python_constant(self): return True def as_python_constant(self): return self.desc class PythonModuleVariable(VariableTracker): def __init__(self, value: types.ModuleType, **kwargs): super().__init__(**kwargs) self.value = value self.is_torch = self.value is torch or self.value.__name__.startswith("torch.") def python_type(self): return types.ModuleType def as_python_constant(self): return self.value def __repr__(self): return f"PythonModuleVariable({self.value})" def call_hasattr(self, tx, name): if self.is_torch: result = hasattr(self.value, name) return variables.ConstantVariable.create(result) return super().call_hasattr(tx, name) class SkipFilesVariable(VariableTracker): def __init__(self, value, reason=None, **kwargs): super().__init__(**kwargs) self.value = value self.reason = reason def python_type(self): return type(self.value) def as_python_constant(self): return self.value @classmethod def create_with_source(cls, value, source): install_guard(source.make_guard(GuardBuilder.FUNCTION_MATCH)) return cls( value, source=source, ) @staticmethod @functools.lru_cache(None) def fold_through_function_to_wrapper(): return { collections.namedtuple: variables.UserDefinedClassVariable, } def call_function( self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" ) -> "VariableTracker": from .builtin import BuiltinVariable if inspect.getattr_static(self.value, "_torchdynamo_disable", False): unimplemented(f"call torch._dynamo.disable() wrapped function {self.value}") # Allowlist a few popular classes(e.g, collections.OrderedDict) calls in skip files. elif self.value is collections.OrderedDict: return BuiltinVariable.call_custom_dict( tx, collections.OrderedDict, *args, **kwargs ) elif ( self.value is collections.defaultdict and len(args) <= 1 and DefaultDictVariable.is_supported_arg(args[0]) ): return DefaultDictVariable( {}, collections.defaultdict, args[0], mutable_local=MutableLocal(), ) # Fold through the functions(e.g, collections.namedtuple) # that inputs & outputs are all python constants elif ( self.value in self.fold_through_function_to_wrapper().keys() and check_constant_args(args, kwargs) ): value = self.value( *[x.as_python_constant() for x in args], **{k: v.as_python_constant() for k, v in kwargs.items()}, ) return self.fold_through_function_to_wrapper().get(self.value)( value, mutable_local=MutableLocal() ) elif ( self.value is itertools.product and not kwargs and all(arg.has_unpack_var_sequence(tx) for arg in args) ): seqs = [arg.unpack_var_sequence(tx) for arg in args] items = [] for item in itertools.product(*seqs): items.append(variables.TupleVariable(list(item))) return variables.ListIteratorVariable(items, mutable_local=MutableLocal()) elif ( self.value is itertools.chain and not kwargs and all(arg.has_unpack_var_sequence(tx) for arg in args) ): seqs = [arg.unpack_var_sequence(tx) for arg in args] items = [] for item in itertools.chain(*seqs): items.append(item) return variables.ListIteratorVariable(items, mutable_local=MutableLocal()) elif self.value is itertools.accumulate: from .builtin import BuiltinVariable if any(key not in ["initial", "func"] for key in kwargs.keys()): unimplemented( "Unsupported kwargs for itertools.accumulate: " f"{','.join(set(kwargs.keys()) - {'initial', 'func'})}" ) acc = kwargs.get("initial") if len(args) in [1, 2] and args[0].has_unpack_var_sequence(tx): seq = args[0].unpack_var_sequence(tx) if "func" in kwargs and len(args) == 1: func = kwargs["func"].call_function elif len(args) == 2: func = args[1].call_function elif len(args) == 1: # Default to operator.add func = BuiltinVariable(operator.add).call_function else: unimplemented( "itertools.accumulate can only accept one of: `func` kwarg, pos 2 arg" ) else: unimplemented("Unsupported arguments for itertools.accumulate") items = [] if acc is not None: items.append(acc) for item in seq: if acc is None: acc = item else: try: acc = func(tx, [acc, item], {}) except Exception: raise unimplemented( # noqa: TRY200 f"Unexpected failure in invoking function during accumulate. Failed running func {func}({item}{acc})" ) items.append(acc) return variables.ListIteratorVariable(items, mutable_local=MutableLocal()) elif ( self.value is itertools.combinations and not kwargs and len(args) == 2 and args[0].has_unpack_var_sequence(tx) and args[1].is_python_constant() ): iterable = args[0].unpack_var_sequence(tx) r = args[1].as_python_constant() items = [] for item in itertools.combinations(iterable, r): items.append(variables.TupleVariable(list(item))) return variables.ListIteratorVariable(items, mutable_local=MutableLocal()) elif self.value is itertools.groupby: if any(kw != "key" for kw in kwargs.keys()): unimplemented( "Unsupported kwargs for itertools.groupby: " f"{','.join(set(kwargs.keys()) - {'key'})}" ) def retrieve_const_key(key): if isinstance(key, variables.SymNodeVariable): return key.evaluate_expr() elif isinstance(key, variables.ConstantVariable): return key.as_python_constant() else: raise unimplemented( "Unsupported key type for itertools.groupby: " + str(type(key)) ) if len(args) == 1 and args[0].has_unpack_var_sequence(tx): seq = args[0].unpack_var_sequence(tx) keyfunc = ( ( lambda x: ( retrieve_const_key( kwargs.get("key").call_function(tx, [x], {}) ) ) ) if "key" in kwargs else None ) else: unimplemented("Unsupported arguments for itertools.groupby") result = [] try: for k, v in itertools.groupby(seq, key=keyfunc): result.append( variables.TupleVariable( [ variables.ConstantVariable.create(k) if variables.ConstantVariable.is_literal(k) else k, variables.ListIteratorVariable( list(v), mutable_local=MutableLocal() ), ], mutable_local=MutableLocal(), ) ) except Exception: raise unimplemented( # noqa: TRY200 "Unexpected failure when calling itertools.groupby" ) return variables.ListIteratorVariable(result, mutable_local=MutableLocal()) elif ( self.value is functools.wraps and not kwargs and len(args) == 1 and ( args[0].source is not None or args[0].can_reconstruct(tx.output.root_tx) ) ): def wraps(fn): if isinstance(fn, variables.NestedUserFunctionVariable): if args[0].source: reconstructible = args[0].source else: reconstructible = args[0] return fn.clone(wrapped_reconstructible=reconstructible) unimplemented(f"functools.wraps({fn})") return variables.LambdaVariable(wraps) elif self.value is collections.deque and not kwargs: if len(args) == 0: items = [] elif len(args) == 1 and args[0].has_unpack_var_sequence(tx): items = args[0].unpack_var_sequence(tx) else: unimplemented("deque() with more than 1 arg not supported") return variables.lists.DequeVariable(items, mutable_local=MutableLocal()) elif self.value is functools.partial: if not args: unimplemented("functools.partial malformed") # The first arg, a callable (the ctor below will assert on types) fn = args[0] rest_args = args[1:] # guards for the produced FunctoolsPartialVariable are installed in FunctoolsPartialVariable ctor from the # args and keywords return variables.functions.FunctoolsPartialVariable( fn, args=rest_args, keywords=kwargs ) elif self.value is itertools.repeat: if len(args) < 2: return variables.RepeatIteratorVariable( *args, mutable_local=MutableLocal() ) from .builder import SourcelessBuilder return tx.inline_user_function_return( SourcelessBuilder()(tx, polyfill.repeat), args, kwargs ) elif self.value is itertools.count: return variables.CountIteratorVariable(*args, mutable_local=MutableLocal()) elif self.value is itertools.cycle: return variables.CycleIteratorVariable(*args, mutable_local=MutableLocal()) else: try: path = inspect.getfile(self.value) except TypeError: path = f"Builtin {self.value.__name__}" msg = f"'skip function {self.value.__qualname__} in file {path}'" msg += f"', {self.reason}'" if self.reason else "" unimplemented(msg) def call_method( self, tx, name, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]", ) -> "VariableTracker": if ( self.value in {collections.OrderedDict, collections.defaultdict} and name == "fromkeys" ): from .builtin import BuiltinVariable return BuiltinVariable.call_custom_dict_fromkeys( tx, self.value, *args, **kwargs ) return super().call_method(tx, name, args, kwargs) class TypingVariable(VariableTracker): def __init__(self, value, **kwargs): super().__init__(**kwargs) self.value = value def call_method( self, tx, name, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]", ) -> "VariableTracker": if name == "__getitem__" and len(args) == 1: return variables.ConstantVariable.create( self.value[args[0].as_python_constant()], ) unimplemented("typing") def python_type(self): return type(self.value) def as_python_constant(self): return self.value @functools.lru_cache(maxsize=1) def get_np_to_tnp_map(): from ..utils import NP_TO_TNP_MODULE np_fn_to_tnp_fn = {} for np_mod, tnp_mod in NP_TO_TNP_MODULE.items(): for fn_name, tnp_fn in tnp_mod.__dict__.items(): if callable(tnp_fn): # some internal details do leak from tnp # which are not part of numpy API. if np_fn := getattr(np_mod, fn_name, None): np_fn_to_tnp_fn[np_fn] = tnp_fn return np_fn_to_tnp_fn class NumpyVariable(VariableTracker): """ Wrapper around `numpy.*`. Currently, is able to trace a small subset of numpy functions as well as numpy dtypes. """ def __init__(self, value, **kwargs): super().__init__(**kwargs) self.value = value def call_function( self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" ) -> "VariableTracker": if not config.trace_numpy: unimplemented(f"numpy.{self.value}()") from ..utils import numpy_to_tensor_wrapper from .tensor import NumpyNdarrayVariable # lookup method name in tnp. Things like np.dtype(float) are not supported yet. if self.value.__name__ == "dtype": unimplemented( f"numpy dtype function is not supported yet. Got type {type(self.value)}." ) else: # We are dealing with a callable. func = get_np_to_tnp_map().get(self.value) if func is None: unimplemented( f"Can't find numpy function {self.value} in torch._numpy. " " Please file an issue to request support for this function." ) if ( func.__module__ == "torch._numpy.random" and config.use_numpy_random_stream ): msg = f"delegate '{func.__qualname__}' to NumPy itself via " msg += f"confg.use_numpy_random_stream={config.use_numpy_random_stream}" unimplemented(msg) # TODO(larryliu0820): currently assuming all numpy.* functions are returning a ndarray that can be # wrapped by NumpyNdarrayVariable which is wrong! proxy = tx.output.create_proxy( "call_function", numpy_to_tensor_wrapper(func), *proxy_args_kwargs(args, kwargs), ) return NumpyNdarrayVariable.create(tx, proxy) def call_method( self, tx, name, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]", ) -> "VariableTracker": unimplemented("numpy") def python_type(self): return type(self.value) def as_python_constant(self): return self.value def as_proxy(self): # this handles numpy dtype attribute such as np.float32. TODO(larryliu0820): we should split NumpyVariable # into NumpyVariable for instances/objects and NumpyVariable for types. if config.trace_numpy and isinstance(self.value, type): # retrieve attribute str. E.g., "float32" if given np.float32 attr = self.value.__name__ # get tnp equivalent tnp_dtype = tnp.dtype(attr) # returning a string here because we are assuming all `dtype` kwargs for numpy # functions can take an equivalent string and the behavior of the function would # be the same as taking a numpy dtype. return tnp_dtype.name return super().as_proxy() # Used to keep track of NULLs pushed on the stack for Python 3.11 function calls class NullVariable(VariableTracker): def __init__(self, **kwargs): super().__init__(**kwargs) def __str__(self): return "NullVariable" def reconstruct(self, codegen): if sys.version_info < (3, 11): unimplemented("cannot reconstruct NullVariable in < Python 3.11") return [create_instruction("PUSH_NULL")] class DeletedVariable(VariableTracker): """Marker used to implement delattr()""" class StringFormatVariable(VariableTracker): """ Represents a call to str.format(), we delay calling format until after the graph. """ _nonvar_fields = {"format_string", *VariableTracker._nonvar_fields} @classmethod def create(cls, format_string, sym_args, sym_kwargs): if all( x.is_python_constant() for x in itertools.chain(sym_args, sym_kwargs.values()) ): return variables.ConstantVariable.create( format_string.format( *[v.as_python_constant() for v in sym_args], **{k: v.as_python_constant() for k, v in sym_kwargs.items()}, ) ) return cls(format_string, list(sym_args), dict(sym_kwargs)) def __init__(self, format_string, sym_args, sym_kwargs, **kwargs): super().__init__(**kwargs) assert isinstance(format_string, str) self.format_string = format_string self.sym_args = sym_args self.sym_kwargs = sym_kwargs def __repr__(self): return f"{self.__class__.__name__}({self.format_string!r}, {self.sym_args!r}, {self.sym_kwargs!r})" def reconstruct(self, codegen): if sys.version_info >= (3, 11): codegen.append_output(create_instruction("PUSH_NULL")) codegen.append_output(codegen.create_load_const(self.format_string)) codegen.append_output(codegen.create_load_attr("format")) codegen.extend_output( variables.TupleVariable(self.sym_args).reconstruct(codegen) ) codegen.extend_output( variables.ConstDictVariable(self.sym_kwargs, user_cls=dict).reconstruct( codegen ) ) codegen.append_output(create_instruction("CALL_FUNCTION_EX", arg=1)) return []