import functools import inspect import itertools import logging import math import operator import types from typing import Dict, List import numpy as np import torch from torch import sym_float, sym_int from .. import config, variables from ..allowed_functions import is_allowed from ..exc import unimplemented, Unsupported from ..guards import GuardBuilder from ..replay_record import DummyModule from ..source import AttrSource, is_constant_source, TypeSource from ..utils import ( check_constant_args, check_unspec_python_args, istype, proxy_args_kwargs, specialize_args_kwargs, ) from .base import MutableLocal, VariableTracker from .dicts import ConstDictVariable from .tensor import DynamicShapeVariable, FakeItemVariable, UnspecializedPythonVariable log = logging.getLogger(__name__) class BuiltinVariable(VariableTracker): @staticmethod @functools.lru_cache(None) def _constant_fold_functions(): fns = { abs, all, any, bool, callable, chr, dict, divmod, float, int, len, list, max, min, ord, pow, repr, round, set, str, str.format, sum, tuple, type, operator.pos, operator.neg, operator.not_, operator.invert, operator.pow, operator.mul, operator.matmul, operator.floordiv, operator.truediv, operator.mod, operator.add, operator.sub, operator.getitem, operator.lshift, operator.rshift, operator.and_, operator.or_, operator.xor, operator.ipow, operator.imul, operator.imatmul, operator.ifloordiv, operator.itruediv, operator.imod, operator.iadd, operator.isub, operator.ilshift, operator.irshift, operator.iand, operator.ixor, operator.ior, operator.index, } fns.update(x for x in math.__dict__.values() if isinstance(x, type(math.sqrt))) return fns def can_constant_fold_through(self): return self.fn in self._constant_fold_functions() @staticmethod @functools.lru_cache(None) def _fx_graph_functions(): fns = { operator.pos, operator.neg, operator.not_, operator.invert, operator.pow, operator.mul, operator.matmul, operator.floordiv, operator.truediv, operator.mod, operator.add, operator.sub, operator.getitem, operator.lshift, operator.rshift, operator.and_, operator.or_, operator.xor, operator.ipow, operator.imul, operator.imatmul, operator.ifloordiv, operator.itruediv, operator.imod, operator.iadd, operator.isub, operator.ilshift, operator.irshift, operator.iand, operator.ixor, operator.ior, } return fns def can_insert_in_graph(self): return self.fn in self._fx_graph_functions() def __init__(self, fn, **kwargs): super(BuiltinVariable, self).__init__(**kwargs) self.fn = fn def __str__(self): if self.fn is None: name = "None" else: name = self.fn.__name__ return f"{self.__class__.__name__}({name})" def python_type(self): return type(self.fn) def as_python_constant(self): return self.fn def reconstruct(self, codegen): name = self.fn.__name__ assert self.fn.__module__ == "builtins" assert name not in codegen.tx.f_globals, "shadowed global" return [codegen.create_load_global(name, add=True)] def constant_args(self, *args, **kwargs): return check_constant_args(args, kwargs) def tensor_args(self, *args, **kwargs): return any( isinstance(i, variables.TensorVariable) for i in itertools.chain(args, kwargs.values()) ) and not any( isinstance(i, variables.GetAttrVariable) for i in itertools.chain(args, kwargs.values()) ) def unspec_numpy_args(self, *args, **kwargs): return all( isinstance( i, ( variables.UnspecializedNumpyVariable, variables.UnspecializedPythonVariable, variables.ConstantVariable, ), ) for i in itertools.chain(args, kwargs.values()) ) and any( isinstance(x, variables.UnspecializedNumpyVariable) for x in itertools.chain(args, kwargs.values()) ) def unspec_python_args(self, *args, **kwargs): return check_unspec_python_args(args, kwargs) @staticmethod def unwrap_unspec_args_kwargs(args, kwargs): unwrapped_args = [] unwrapped_kwargs = {} for x in args: if isinstance( x, ( variables.UnspecializedNumpyVariable, variables.UnspecializedPythonVariable, ), ): unwrapped_args.append(x.raw_value) else: unwrapped_args.append(x.as_python_constant()) for k, v in kwargs: if isinstance( x, ( variables.UnspecializedNumpyVariable, variables.UnspecializedPythonVariable, ), ): unwrapped_kwargs.update({k: v.raw_value}) else: unwrapped_kwargs.update({k: v.as_python_constant()}) return unwrapped_args, unwrapped_kwargs def call_function( self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" ) -> "VariableTracker": from .builder import wrap_fx_proxy, wrap_fx_proxy_cls constant_args = check_constant_args(args, kwargs) tensor_args = self.tensor_args(*args, **kwargs) unspec_python_args = self.unspec_python_args(*args, **kwargs) options = VariableTracker.propagate(self, args, kwargs.values()) has_constant_handler = self.can_constant_fold_through() and ( constant_args or unspec_python_args ) assert isinstance(args, (list, tuple)) assert isinstance(kwargs, dict) if ( self.fn is operator.getitem and len(args) == 2 and isinstance(args[1], variables.TensorVariable) and args[1].dtype == torch.bool and not config.dynamic_shapes ): unimplemented("dynamic Tensor.__getitem__(bool[])") # args[0] is list and args[1] is unspec if self.fn is operator.getitem and not isinstance( args[0], variables.TensorVariable ): tensor_args = False args, kwargs = specialize_args_kwargs(tx, args, kwargs) if ( self.can_insert_in_graph() and tensor_args and not ( self.fn is operator.getitem and isinstance(args[0], ConstDictVariable) and isinstance(args[1], variables.TensorVariable) ) ): try: fn = self.fn if self.fn is operator.iadd and isinstance( args[0], variables.ConstantVariable ): # Work around weird bug in hf_T5 fn, args = operator.add, [args[1], args[0]] proxy = tx.output.create_proxy( "call_function", fn, *proxy_args_kwargs(args, kwargs), ) if any([isinstance(arg, FakeItemVariable) for arg in args]): return wrap_fx_proxy_cls( FakeItemVariable, tx, proxy, **options, ) elif self.unspec_numpy_args(*args, **kwargs): _args, _kwargs = self.unwrap_unspec_args_kwargs(args, kwargs) raw_value = self.fn(*_args, **_kwargs) return wrap_fx_proxy_cls( variables.UnspecializedNumpyVariable, tx, proxy, raw_value=raw_value, **options, ) elif self.unspec_python_args(*args, **kwargs): _args, _kwargs = self.unwrap_unspec_args_kwargs(args, kwargs) raw_value = self.fn(*_args, **_kwargs) need_unwrap = any( x.need_unwrap for x in itertools.chain(args, kwargs.values()) if isinstance(x, variables.UnspecializedPythonVariable) ) return wrap_fx_proxy_cls( UnspecializedPythonVariable, tx, proxy, raw_value=raw_value, need_unwrap=need_unwrap, **options, ) else: # Work around for vision_maskrcnn due to precision difference # specialize the dividend when float divide by tensor if self.fn is operator.truediv and isinstance( args[0], variables.UnspecializedPythonVariable ): args[0] = args[0].convert_to_constant(tx) return wrap_fx_proxy(tx, proxy, **options) except NotImplementedError: unimplemented(f"partial tensor op: {self} {args} {kwargs}") # Handle cases like int(torch.seed()) # Also handle sym_float to sym_int cases if self.fn in (int, float) and isinstance(args[0], DynamicShapeVariable): fn_ = sym_int if self.fn is int else sym_float out = wrap_fx_proxy( tx=tx, proxy=tx.output.create_proxy( "call_function", fn_, (args[0].as_proxy(),), {}, ), **options, ) return out handler = getattr(self, f"call_{self.fn.__name__}", None) if handler: try: inspect.signature(handler).bind(tx, *args, **kwargs) except TypeError as exc: if not has_constant_handler: log.warning( f"incorrect arg count {handler} {exc} and no constant handler" ) handler = None if handler: try: result = handler(tx, *args, **kwargs) if result is not None: return result.add_options(options) except Unsupported as exc: if not has_constant_handler: raise # Actually, we will handle this just fine exc.remove_from_stats() if has_constant_handler: args, kwargs = specialize_args_kwargs(tx, args, kwargs) # constant fold return variables.ConstantVariable( self.as_python_constant()( *[x.as_python_constant() for x in args], **{k: v.as_python_constant() for k, v in kwargs.items()}, ), **options, ) return super().call_function(tx, args, kwargs) def _call_min_max(self, tx, a, b): if self.tensor_args(a, b): if not isinstance(a, variables.TensorVariable): a, b = b, a assert isinstance(a, variables.TensorVariable) # result of an item call is a scalar convert to a tensor if isinstance(a, FakeItemVariable): a = variables.TorchVariable(torch.tensor).call_function(tx, [a], {}) # Dynamic input does not get resolved, rather, gets stored as call_function if isinstance(a, DynamicShapeVariable): from .builder import wrap_fx_proxy return wrap_fx_proxy( tx=tx, proxy=tx.output.create_proxy( "call_function", self.fn, *proxy_args_kwargs([a, b], {}), ), **VariableTracker.propagate(self, [a, b]), ) # convert min/max to torch ops if b.is_python_constant(): kwargs = {"min": b} if (self.fn is max) else {"max": b} result = variables.TorchVariable(torch.clamp).call_function( tx, [a], kwargs ) else: fn = {max: torch.maximum, min: torch.minimum}[self.fn] result = variables.TorchVariable(fn).call_function(tx, [a, b], {}) # return unspec if both a, b are unspec or const if all( isinstance( i, ( variables.UnspecializedNumpyVariable, variables.UnspecializedPythonVariable, variables.ConstantVariable, ), ) for i in [a, b] ): if any([isinstance(val, FakeItemVariable) for val in [a, b]]): return variables.FakeItemVariable.from_tensor_variable(result) if b.is_python_constant(): raw_b = b.as_python_constant() else: raw_b = b.raw_value if self.fn is max: raw_res = max(a.raw_value, raw_b) else: raw_res = min(a.raw_value, raw_b) if isinstance(raw_res, np.number): return variables.UnspecializedNumpyVariable.from_tensor_variable( result, raw_res ) else: need_unwrap = any( x.need_unwrap for x in [a, b] if isinstance(x, variables.UnspecializedPythonVariable) ) return variables.UnspecializedPythonVariable.from_tensor_variable( result, raw_res, need_unwrap ) # otherwise return tensor else: return result elif isinstance(a, variables.ConstantVariable) and isinstance( b, variables.ConstantVariable ): if self.fn is max: return variables.ConstantVariable(max(a.value, b.value)) else: return variables.ConstantVariable(min(a.value, b.value)) elif isinstance(a, DynamicShapeVariable) or isinstance(b, DynamicShapeVariable): proxy = tx.output.create_proxy( "call_function", self.fn, *proxy_args_kwargs([a, b], {}) ) return DynamicShapeVariable.create(tx, proxy, None) else: unimplemented(f"unsupported min / max over args {str(a)}, {str(b)}") call_min = _call_min_max call_max = _call_min_max def call_range(self, tx, *args): if self.unspec_python_args(*args) or self.constant_args(*args): args, _ = specialize_args_kwargs(tx, args, {}) return variables.RangeVariable(args) elif self._dynamic_args(*args): def guard_if_dyn(arg): if isinstance(arg, DynamicShapeVariable): return arg.evaluate_expr(tx.output) return arg args = [variables.ConstantVariable(guard_if_dyn(arg)) for arg in args] return variables.RangeVariable(args) # None no-ops this handler and lets the driving function proceed return None def _dynamic_args(self, *args, **kwargs): return any([isinstance(x, DynamicShapeVariable) for x in args]) or any( [isinstance(x, DynamicShapeVariable) for x in kwargs.values()] ) def call_slice(self, tx, *args): return variables.SliceVariable(args) def _dyn_proxy(self, tx, *args, **kwargs): assert self._dynamic_args(*args, **kwargs) from .builder import wrap_fx_proxy options = VariableTracker.propagate(self, args, kwargs.values()) return wrap_fx_proxy( tx, tx.output.create_proxy( "call_function", self.fn, *proxy_args_kwargs(args, kwargs) ), **options, ) def call_mod(self, tx, *args, **kwargs): if self._dynamic_args(*args, **kwargs): return self._dyn_proxy(tx, *args, **kwargs) def _call_iter_tuple_list(self, tx, obj=None, *args, **kwargs): if self._dynamic_args(*args, **kwargs): return self._dyn_proxy(tx, *args, **kwargs) cls = variables.BaseListVariable.cls_for(self.fn) if obj is None: return cls( [], mutable_local=MutableLocal(), ) elif obj.has_unpack_var_sequence(tx): guards = set() if obj.source and not is_constant_source(obj.source): guards.add(obj.source.make_guard(GuardBuilder.LIST_LENGTH)) return cls( list(obj.unpack_var_sequence(tx)), mutable_local=MutableLocal(), guards=guards, ).add_options(self, obj) call_iter = _call_iter_tuple_list call_tuple = _call_iter_tuple_list call_list = _call_iter_tuple_list def call_dict(self, tx, arg): if isinstance(arg, variables.ConstDictVariable): return arg.clone(mutable_local=MutableLocal()) def call_zip(self, tx, *args): options = VariableTracker.propagate(self, args) if all(x.has_unpack_var_sequence(tx) for x in args): items = [ variables.TupleVariable(list(item), **options) for item in zip(*[arg.unpack_var_sequence(tx) for arg in args]) ] return variables.TupleVariable(items, **options) def call_enumerate(self, tx, *args): options = VariableTracker.propagate(self, args) if len(args) == 1: start = 0 else: assert len(args) == 2 assert isinstance(args[1], variables.ConstantVariable) start = args[1].as_python_constant() if args[0].has_unpack_var_sequence(tx): items = [ variables.TupleVariable( [variables.ConstantVariable(idx, **options), var], **options, ) for idx, var in enumerate(args[0].unpack_var_sequence(tx), start) ] return variables.TupleVariable(items, **options) def call_mul(self, tx, a, b): if isinstance( a, (variables.ListVariable, variables.TupleVariable) ) and isinstance(b, variables.ConstantVariable): return a.__class__( items=a.items * b.as_python_constant(), mutable_local=MutableLocal() ).add_options(self, a, b) elif isinstance( b, (variables.ListVariable, variables.TupleVariable) ) and isinstance(a, variables.ConstantVariable): return b.__class__( items=b.items * a.as_python_constant(), mutable_local=MutableLocal() ).add_options(self, a, b) # TODO this doesn't generalize in other builtin operators. elif isinstance(a, variables.ConstantVariable) and isinstance( b, DynamicShapeVariable ): return b.call_method(tx, "__rmul__", [a], {}) else: return a.call_method(tx, "__mul__", [b], {}) def call_len(self, tx, *args, **kwargs): return args[0].call_method(tx, "__len__", args[1:], kwargs) def call_add(self, tx, *args, **kwargs): return args[0].call_method(tx, "__add__", args[1:], kwargs) def call_sub(self, tx, *args, **kwargs): return args[0].call_method(tx, "__sub__", args[1:], kwargs) def call_truediv(self, tx, *args, **kwargs): return args[0].call_method(tx, "__truediv__", args[1:], kwargs) def call_floordiv(self, tx, *args, **kwargs): return args[0].call_method(tx, "__floordiv__", args[1:], kwargs) def call_iadd(self, tx, *args, **kwargs): return args[0].call_method(tx, "__iadd__", args[1:], kwargs) def call_getitem(self, tx, *args, **kwargs): if self.unspec_python_args(*args, **kwargs): args, kwargs = specialize_args_kwargs(tx, args, kwargs) return args[0].call_method(tx, "__getitem__", args[1:], kwargs) def call_isinstance(self, tx, arg, isinstance_type): arg_type = arg.python_type() isinstance_type = isinstance_type.as_python_constant() if isinstance(arg, variables.TensorVariable) and arg.dtype is not None: return variables.ConstantVariable(arg.call_isinstance(isinstance_type)) # UserDefinedObject with C extensions can have torch.Tensor attributes, # so break graph. if isinstance(arg, variables.UserDefinedObjectVariable) and isinstance( arg.value, types.MemberDescriptorType ): unimplemented( f"isinstance called on UserDefinedClass {arg} {isinstance_type}" ) try: val = issubclass(arg_type, isinstance_type) except TypeError: val = arg_type is isinstance_type return variables.ConstantVariable(val) def call_super(self, tx, a, b): return variables.SuperVariable(a, b) def call_next(self, tx, arg): if isinstance(arg, variables.ListIteratorVariable): val, next_iter = arg.next_variables() tx.replace_all(arg, next_iter) return val elif isinstance(arg, variables.BaseListVariable): return arg.items[0].add_options(self, arg) def call_hasattr(self, tx, obj, attr): if attr.is_python_constant(): name = attr.as_python_constant() return obj.call_hasattr(tx, name).add_options(self, obj, attr) def call_map(self, tx, fn, seq): if seq.has_unpack_var_sequence(tx): items = [fn.call_function(tx, [x], {}) for x in seq.unpack_var_sequence(tx)] return variables.TupleVariable(items).add_options(self, fn, seq) def call_sum(self, tx, seq, **kwargs): # Special case for sum on tuple of floats and ints if ( isinstance(seq, (variables.ListVariable, variables.TupleVariable)) and all( [ isinstance(x, variables.ConstantVariable) and isinstance(x.value, (int, float)) for x in seq.items ] ) and not kwargs ): new_list = [x.value for x in seq.items] return variables.ConstantVariable(sum(new_list)) if seq.has_unpack_var_sequence(tx): start = kwargs.pop( "start", variables.ConstantVariable(0) ).as_python_constant() assert not kwargs items = seq.unpack_var_sequence(tx)[start:] return BuiltinVariable(functools.reduce).call_function( tx, [ BuiltinVariable(operator.add), variables.TupleVariable(items), variables.ConstantVariable(0).add_options(self, seq), ], {}, ) def call_reduce(self, tx, function, iterable, initializer=None): if iterable.has_unpack_var_sequence(tx): items = iterable.unpack_var_sequence(tx) if initializer is None: value, items = items[0], items[1:] else: value = initializer for element in items: value = function.call_function(tx, [value, element], {}) return value def call_getattr( self, tx, obj: VariableTracker, name_var: VariableTracker, default=None ): from . import ( ConstantVariable, GetAttrVariable, PythonModuleVariable, TorchVariable, UserFunctionVariable, ) from .builder import VariableBuilder options = VariableTracker.propagate(self, obj, name_var) guards = options["guards"] name = name_var.as_python_constant() if not name_var.is_python_constant(): unimplemented("non-const getattr() name") if tx.output.side_effects.is_attribute_mutation(obj): try: # re-read a pending side effect? return tx.output.side_effects.load_attr(obj, name).add_options(options) except KeyError: pass if default is not None: hasattr_var = self.call_hasattr(tx, obj, name_var) guards.update(hasattr_var.guards) assert hasattr_var.as_python_constant() in (True, False) if not hasattr_var.as_python_constant(): return default.add_guards(guards) if obj.source: source = AttrSource(obj.source, name) options["source"] = source else: source = None if isinstance(obj, variables.NNModuleVariable): return obj.var_getattr(tx, name).add_options(options) elif isinstance(obj, variables.TensorVariable) and name == "grad": if source: # We are going to be raising this tensor as grapharg. So, ensure # that we have real grad value instead of fake tensor value. # Walk through the inputs of the subgraph and find if we already # have the original tensor stored in the graphargs. for grapharg in tx.output.graphargs: if grapharg.source == source.base: example_value = grapharg.example.grad return VariableBuilder(tx, source)(example_value).add_options( options ) unimplemented("tensor grad") else: unimplemented("tensor grad") elif isinstance( obj, ( variables.TensorVariable, variables.NamedTupleVariable, variables.ConstantVariable, variables.UserDefinedClassVariable, variables.UserDefinedObjectVariable, ), ): try: return ( obj.var_getattr(tx, name).clone(source=source).add_options(options) ) except NotImplementedError: return GetAttrVariable(obj, name, **options) elif isinstance(obj, TorchVariable): member = getattr(obj.value, name) if is_allowed(member): return TorchVariable(member, **options) elif ConstantVariable.is_literal(member): return ConstantVariable(member, **options) else: return VariableBuilder(tx, source)(member).add_guards(guards) elif isinstance(obj, (PythonModuleVariable, DummyModule)): member = obj.value.__dict__[name] if config.replay_record_enabled: tx.exec_recorder.record_module_access(obj.value, name, member) return VariableBuilder(tx, source)(member).add_guards(guards) elif istype(obj, UserFunctionVariable) and name in ("__name__", "__module__"): return ConstantVariable( getattr(obj.fn, name), **VariableTracker.propagate(obj) ) else: try: return ( obj.var_getattr(tx, name).clone(source=source).add_options(options) ) except NotImplementedError: return GetAttrVariable(obj, name, **options) def call_setattr( self, tx, obj: VariableTracker, name_var: VariableTracker, val: VariableTracker ): if isinstance(obj, (variables.BlackHoleVariable, variables.DataClassVariable)): return obj.call_method(tx, "__setattr__", [name_var, val], {}) elif ( tx.output.side_effects.is_attribute_mutation(obj) and name_var.is_python_constant() ): tx.output.side_effects.store_attr(obj, name_var.as_python_constant(), val) return val.add_options(self, obj, name_var) elif isinstance(obj, variables.UserDefinedObjectVariable): unimplemented( f"setattr(UserDefinedObjectVariable) {type(obj.value).__setattr__}" ) elif isinstance(obj, variables.NNModuleVariable): obj.convert_to_unspecialized(tx) def call_type(self, tx, obj: VariableTracker): from .builder import VariableBuilder try: py_type = obj.python_type() except NotImplementedError: py_type = None if istype(obj, variables.TupleVariable): return BuiltinVariable(py_type).add_options(self, obj) if py_type is not None and obj.source: return VariableBuilder(tx, TypeSource(obj.source))(py_type).add_options( self, obj ) unimplemented(f"type({obj})") def call_reversed(self, tx, obj: VariableTracker): if obj.has_unpack_var_sequence(tx): items = list(reversed(obj.unpack_var_sequence(tx))) return variables.TupleVariable( items, **VariableTracker.propagate(self, obj) ) def call_chain(self, tx, *args): if all(obj.has_unpack_var_sequence(tx) for obj in args): items = [] for obj in args: items.extend(obj.unpack_var_sequence(tx)) return variables.TupleVariable( items, **VariableTracker.propagate(self, *args) ) def call_islice(self, tx, iterable, *args): if iterable.has_unpack_var_sequence(tx) and all( x.is_python_constant() for x in args ): const_args = [x.as_python_constant() for x in args] items = iterable.unpack_var_sequence(tx) items = list(itertools.islice(items, *const_args)) return variables.TupleVariable( items, **VariableTracker.propagate(self, iterable, *args) ) def call_id(self, tx, *args): if len(args) > 0 and isinstance(args[0], variables.NNModuleVariable): nn_mod_variable = args[0] mod = tx.output.get_submodule(nn_mod_variable.module_key) return variables.ConstantVariable(id(mod)) else: unimplemented(f"call_id with args {args}")