import functools import inspect import itertools import logging import math import operator import types from typing import Dict, List import numpy as np import torch from .. import config, variables from ..allowed_functions import is_allowed from ..exc import unimplemented, Unsupported from ..guards import GuardBuilder from ..replay_record import DummyModule from ..source import AttrSource, is_constant_source, TypeSource from ..utils import ( check_constant_args, check_unspec_python_args, istype, proxy_args_kwargs, specialize_args_kwargs, ) from .base import MutableLocal, VariableTracker from .dicts import ConstDictVariable from .tensor import DynamicShapeVariable, FakeItemVariable log = logging.getLogger(__name__) class BuiltinVariable(VariableTracker): @staticmethod @functools.lru_cache(None) def _constant_fold_functions(): fns = { abs, all, any, bool, callable, chr, dict, divmod, float, int, len, list, max, min, ord, pow, repr, round, set, str, str.format, sum, tuple, type, operator.pos, operator.neg, operator.not_, operator.invert, operator.pow, operator.mul, operator.matmul, operator.floordiv, operator.truediv, operator.mod, operator.add, operator.sub, operator.getitem, operator.lshift, operator.rshift, operator.and_, operator.or_, operator.xor, operator.ipow, operator.imul, operator.imatmul, operator.ifloordiv, operator.itruediv, operator.imod, operator.iadd, operator.isub, operator.ilshift, operator.irshift, operator.iand, operator.ixor, operator.ior, operator.index, } fns.update(x for x in math.__dict__.values() if isinstance(x, type(math.sqrt))) return fns def can_constant_fold_through(self): return self.fn in self._constant_fold_functions() @staticmethod @functools.lru_cache(None) def _fx_graph_functions(): fns = { operator.pos, operator.neg, operator.not_, operator.invert, operator.pow, operator.mul, operator.matmul, operator.floordiv, operator.truediv, operator.mod, operator.add, operator.sub, operator.getitem, operator.lshift, operator.rshift, operator.and_, operator.or_, operator.xor, operator.ipow, operator.imul, operator.imatmul, operator.ifloordiv, operator.itruediv, operator.imod, operator.iadd, operator.isub, operator.ilshift, operator.irshift, operator.iand, operator.ixor, operator.ior, } return fns def can_insert_in_graph(self): return self.fn in self._fx_graph_functions() def __init__(self, fn, **kwargs): super(BuiltinVariable, self).__init__(**kwargs) self.fn = fn def __str__(self): if self.fn is None: name = "None" else: name = self.fn.__name__ return f"{self.__class__.__name__}({name})" def python_type(self): return type(self.fn) def as_python_constant(self): return self.fn def reconstruct(self, codegen): name = self.fn.__name__ assert self.fn.__module__ == "builtins" assert name not in codegen.tx.f_globals, "shadowed global" return [codegen.create_load_global(name, add=True)] def constant_args(self, *args, **kwargs): return check_constant_args(args, kwargs) def tensor_args(self, *args, **kwargs): return any( isinstance(i, variables.TensorVariable) for i in itertools.chain(args, kwargs.values()) ) and not any( isinstance(i, variables.GetAttrVariable) for i in itertools.chain(args, kwargs.values()) ) def unspec_numpy_args(self, *args, **kwargs): return all( isinstance( i, ( variables.UnspecializedNumpyVariable, variables.UnspecializedPythonVariable, variables.ConstantVariable, ), ) for i in itertools.chain(args, kwargs.values()) ) and any( isinstance(x, variables.UnspecializedNumpyVariable) for x in itertools.chain(args, kwargs.values()) ) def unspec_python_args(self, *args, **kwargs): return check_unspec_python_args(args, kwargs) @staticmethod def unwrap_unspec_args_kwargs(args, kwargs): unwrapped_args = [] unwrapped_kwargs = {} for x in args: if isinstance( x, ( variables.UnspecializedNumpyVariable, variables.UnspecializedPythonVariable, ), ): unwrapped_args.append(x.raw_value) else: unwrapped_args.append(x.as_python_constant()) for k, v in kwargs: if isinstance( x, ( variables.UnspecializedNumpyVariable, variables.UnspecializedPythonVariable, ), ): unwrapped_kwargs.update({k: v.raw_value}) else: unwrapped_kwargs.update({k: v.as_python_constant()}) return unwrapped_args, unwrapped_kwargs def call_function( self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" ) -> "VariableTracker": constant_args = check_constant_args(args, kwargs) tensor_args = self.tensor_args(*args, **kwargs) unspec_python_args = self.unspec_python_args(*args, **kwargs) options = VariableTracker.propagate(self, args, kwargs.values()) has_constant_handler = self.can_constant_fold_through() and ( constant_args or unspec_python_args ) assert isinstance(args, list) assert isinstance(kwargs, dict) if ( self.fn is operator.getitem and len(args) == 2 and isinstance(args[1], variables.TensorVariable) and args[1].dtype == torch.bool and not config.dynamic_shapes ): unimplemented("dynamic Tensor.__getitem__(bool[])") # args[0] is list and args[1] is unspec if self.fn is operator.getitem and not isinstance( args[0], variables.TensorVariable ): tensor_args = False args, kwargs = specialize_args_kwargs(tx, args, kwargs) if ( self.can_insert_in_graph() and tensor_args and not ( self.fn is operator.getitem and isinstance(args[0], ConstDictVariable) and isinstance(args[1], variables.TensorVariable) ) ): try: fn = self.fn if self.fn is operator.iadd and isinstance( args[0], variables.ConstantVariable ): # Work around weird bug in hf_T5 fn, args = operator.add, [args[1], args[0]] proxy = tx.output.create_proxy( "call_function", fn, *proxy_args_kwargs(args, kwargs), current_tx=tx ) if any([isinstance(arg, FakeItemVariable) for arg in args]): return variables.FakeItemVariable.create( tx, proxy, **options, ) elif self.unspec_numpy_args(*args, **kwargs): _args, _kwargs = self.unwrap_unspec_args_kwargs(args, kwargs) raw_value = self.fn(*_args, **_kwargs) return variables.UnspecializedNumpyVariable.create( tx, proxy, raw_value=raw_value, **options, ) elif self.unspec_python_args(*args, **kwargs): _args, _kwargs = self.unwrap_unspec_args_kwargs(args, kwargs) raw_value = self.fn(*_args, **_kwargs) need_unwrap = any( x.need_unwrap for x in itertools.chain(args, kwargs.values()) if isinstance(x, variables.UnspecializedPythonVariable) ) return variables.UnspecializedPythonVariable.create( tx, proxy, raw_value=raw_value, need_unwrap=need_unwrap, **options, ) else: # Work around for vision_maskrcnn due to precision difference # specialize the dividend when float divide by tensor if self.fn is operator.truediv and isinstance( args[0], variables.UnspecializedPythonVariable ): args[0] = args[0].convert_to_constant(tx) return variables.TensorVariable.create(tx, proxy, **options) except NotImplementedError: unimplemented(f"partial tensor op: {self} {args} {kwargs}") # Handle cases like int(torch.seed()) if self.fn is int and isinstance(args[0], DynamicShapeVariable): return args[0] handler = getattr(self, f"call_{self.fn.__name__}", None) if handler: try: inspect.signature(handler).bind(tx, *args, **kwargs) except TypeError as exc: log.warning(f"incorrect arg count {handler} {exc}") handler = None if handler: try: result = handler(tx, *args, **kwargs) if result is not None: return result.add_options(options) except Unsupported as exc: if not has_constant_handler: raise # Actually, we will handle this just fine exc.remove_from_stats() if has_constant_handler: args, kwargs = specialize_args_kwargs(tx, args, kwargs) # constant fold return variables.ConstantVariable( self.as_python_constant()( *[x.as_python_constant() for x in args], **{k: v.as_python_constant() for k, v in kwargs.items()}, ), **options, ) return super().call_function(tx, args, kwargs) def _call_min_max(self, tx, a, b): if self.tensor_args(a, b): if not isinstance(a, variables.TensorVariable): a, b = b, a assert isinstance(a, variables.TensorVariable) # 1. result of an item call is a scalar convert to a tensor # 2. dynamic shape should be resolved to tensor if isinstance(a, (FakeItemVariable, DynamicShapeVariable)): a = variables.TorchVariable(torch.tensor).call_function(tx, [a], {}) # convert min/max to torch ops if b.is_python_constant(): kwargs = {"min": b} if (self.fn is max) else {"max": b} result = variables.TorchVariable(torch.clamp).call_function( tx, [a], kwargs ) else: fn = {max: torch.maximum, min: torch.minimum}[self.fn] result = variables.TorchVariable(fn).call_function(tx, [a, b], {}) # return unspec if both a, b are unspec or const if all( isinstance( i, ( variables.UnspecializedNumpyVariable, variables.UnspecializedPythonVariable, variables.ConstantVariable, ), ) for i in [a, b] ): if any([isinstance(val, FakeItemVariable) for val in [a, b]]): return variables.FakeItemVariable.from_tensor_variable(result) if b.is_python_constant(): raw_b = b.as_python_constant() else: raw_b = b.raw_value if self.fn is max: raw_res = max(a.raw_value, raw_b) else: raw_res = min(a.raw_value, raw_b) if isinstance(raw_res, np.number): return variables.UnspecializedNumpyVariable.from_tensor_variable( result, raw_res ) else: need_unwrap = any( x.need_unwrap for x in [a, b] if isinstance(x, variables.UnspecializedPythonVariable) ) return variables.UnspecializedPythonVariable.from_tensor_variable( result, raw_res, need_unwrap ) # otherwise return tensor else: return result elif isinstance(a, variables.ConstantVariable) and isinstance( b, variables.ConstantVariable ): if self.fn is max: return variables.ConstantVariable(max(a.value, b.value)) else: return variables.ConstantVariable(min(a.value, b.value)) else: unimplemented(f"unsupported min / max over args {str(a)}, {str(b)}") call_min = _call_min_max call_max = _call_min_max def call_range(self, tx, *args, **kwargs): if self.unspec_python_args(*args, **kwargs) or self.constant_args( *args, **kwargs ): args, kwargs = specialize_args_kwargs(tx, args, kwargs) return variables.RangeVariable( value=range( *[x.value for x in args], **{k: v.value for k, v in kwargs.items()}, ), ) def call_slice(self, tx, *args): return variables.SliceVariable(args) def _call_iter_tuple_list(self, tx, obj=None): cls = variables.BaseListVariable.cls_for(self.fn) if obj is None: return cls( [], mutable_local=MutableLocal(), ) elif obj.has_unpack_var_sequence(tx): guards = set() if obj.source and not is_constant_source(obj.source): guards.add(obj.source.make_guard(GuardBuilder.LIST_LENGTH)) return cls( list(obj.unpack_var_sequence(tx)), mutable_local=MutableLocal(), guards=guards, ).add_options(self, obj) call_iter = _call_iter_tuple_list call_tuple = _call_iter_tuple_list call_list = _call_iter_tuple_list def call_dict(self, tx, arg): if isinstance(arg, variables.ConstDictVariable): return arg.clone(mutable_local=MutableLocal()) def call_zip(self, tx, *args): options = VariableTracker.propagate(self, args) if all(x.has_unpack_var_sequence(tx) for x in args): items = [ variables.TupleVariable(list(item), **options) for item in zip(*[arg.unpack_var_sequence(tx) for arg in args]) ] return variables.TupleVariable(items, **options) def call_enumerate(self, tx, *args): options = VariableTracker.propagate(self, args) if len(args) == 1: start = 0 else: assert len(args) == 2 assert isinstance(args[1], variables.ConstantVariable) start = args[1].as_python_constant() if args[0].has_unpack_var_sequence(tx): items = [ variables.TupleVariable( [variables.ConstantVariable(idx, **options), var], **options, ) for idx, var in enumerate(args[0].unpack_var_sequence(tx), start) ] return variables.TupleVariable(items, **options) def call_mul(self, tx, a, b): if isinstance( a, (variables.ListVariable, variables.TupleVariable) ) and isinstance(b, variables.ConstantVariable): return a.__class__( items=a.items * b.as_python_constant(), mutable_local=MutableLocal() ).add_options(self, a, b) elif isinstance( b, (variables.ListVariable, variables.TupleVariable) ) and isinstance(a, variables.ConstantVariable): return b.__class__( items=b.items * a.as_python_constant(), mutable_local=MutableLocal() ).add_options(self, a, b) else: return a.call_method(tx, "__mul__", [b], {}) def call_len(self, tx, *args, **kwargs): return args[0].call_method(tx, "__len__", args[1:], kwargs) def call_add(self, tx, *args, **kwargs): return args[0].call_method(tx, "__add__", args[1:], kwargs) def call_sub(self, tx, *args, **kwargs): return args[0].call_method(tx, "__sub__", args[1:], kwargs) def call_truediv(self, tx, *args, **kwargs): return args[0].call_method(tx, "__truediv__", args[1:], kwargs) def call_floordiv(self, tx, *args, **kwargs): return args[0].call_method(tx, "__floordiv__", args[1:], kwargs) def call_iadd(self, tx, *args, **kwargs): return args[0].call_method(tx, "__iadd__", args[1:], kwargs) def call_getitem(self, tx, *args, **kwargs): if self.unspec_python_args(*args, **kwargs): args, kwargs = specialize_args_kwargs(tx, args, kwargs) return args[0].call_method(tx, "__getitem__", args[1:], kwargs) def call_isinstance(self, tx, arg, isinstance_type): arg_type = arg.python_type() isinstance_type = isinstance_type.as_python_constant() if isinstance(arg, variables.TensorVariable) and arg.dtype is not None: return variables.ConstantVariable(arg.call_isinstance(isinstance_type)) # UserDefinedObject with C extensions can have torch.Tensor attributes, # so break graph. if isinstance(arg, variables.UserDefinedObjectVariable) and isinstance( arg.value, types.MemberDescriptorType ): unimplemented( f"isinstance called on UserDefinedClass {arg} {isinstance_type}" ) try: val = issubclass(arg_type, isinstance_type) except TypeError: val = arg_type is isinstance_type return variables.ConstantVariable(val) def call_super(self, tx, a, b): return variables.SuperVariable(a, b) def call_next(self, tx, arg): if isinstance(arg, variables.ListIteratorVariable): val, next_iter = arg.next_variables() tx.replace_all(arg, next_iter) return val elif isinstance(arg, variables.BaseListVariable): return arg.items[0].add_options(self, arg) def call_hasattr(self, tx, obj, attr): if attr.is_python_constant(): name = attr.as_python_constant() return obj.call_hasattr(tx, name).add_options(self, obj, attr) def call_map(self, tx, fn, seq): if seq.has_unpack_var_sequence(tx): items = [fn.call_function(tx, [x], {}) for x in seq.unpack_var_sequence(tx)] return variables.TupleVariable(items).add_options(self, fn, seq) def call_sum(self, tx, seq, **kwargs): # Special case for sum on tuple of floats and ints if ( isinstance(seq, (variables.ListVariable, variables.TupleVariable)) and all( [ isinstance(x, variables.ConstantVariable) and isinstance(x.value, (int, float)) for x in seq.items ] ) and not kwargs ): new_list = [x.value for x in seq.items] return variables.ConstantVariable(sum(new_list)) if seq.has_unpack_var_sequence(tx): start = kwargs.pop( "start", variables.ConstantVariable(0) ).as_python_constant() assert not kwargs items = seq.unpack_var_sequence(tx)[start:] return BuiltinVariable(functools.reduce).call_function( tx, [ BuiltinVariable(operator.add), variables.TupleVariable(items), variables.ConstantVariable(0).add_options(self, seq), ], {}, ) def call_reduce(self, tx, function, iterable, initializer=None): if iterable.has_unpack_var_sequence(tx): items = iterable.unpack_var_sequence(tx) if initializer is None: value, items = items[0], items[1:] else: value = initializer for element in items: value = function.call_function(tx, [value, element], {}) return value def call_getattr( self, tx, obj: VariableTracker, name_var: VariableTracker, default=None ): from . import ( ConstantVariable, GetAttrVariable, PythonModuleVariable, TorchVariable, UserFunctionVariable, ) from .builder import VariableBuilder options = VariableTracker.propagate(self, obj, name_var) guards = options["guards"] name = name_var.as_python_constant() if not name_var.is_python_constant(): unimplemented("non-const getattr() name") if tx.output.side_effects.is_attribute_mutation(obj): try: # re-read a pending side effect? return tx.output.side_effects.load_attr(obj, name).add_options(options) except KeyError: pass if default is not None: hasattr_var = self.call_hasattr(tx, obj, name_var) guards.update(hasattr_var.guards) assert hasattr_var.as_python_constant() in (True, False) if not hasattr_var.as_python_constant(): return default.add_guards(guards) if obj.source: source = AttrSource(obj.source, name) options["source"] = source else: source = None if isinstance(obj, variables.NNModuleVariable): return obj.var_getattr(tx, name).add_options(options) elif isinstance(obj, variables.TensorVariable) and name == "grad": if source: # We are going to be raising this tensor as grapharg. So, ensure # that we have real grad value instead of fake tensor value. # Walk through the inputs of the subgraph and find if we already # have the original tensor stored in the graphargs. for grapharg in tx.output.graphargs: if grapharg.source == source.base: example_value = grapharg.example.grad return VariableBuilder(tx, source)(example_value).add_options( options ) unimplemented("tensor grad") else: unimplemented("tensor grad") elif isinstance( obj, ( variables.TensorVariable, variables.NamedTupleVariable, variables.ConstantVariable, variables.UserDefinedClassVariable, variables.UserDefinedObjectVariable, ), ): try: return ( obj.var_getattr(tx, name).clone(source=source).add_options(options) ) except NotImplementedError: return GetAttrVariable(obj, name, **options) elif isinstance(obj, TorchVariable): member = getattr(obj.value, name) if is_allowed(member): return TorchVariable(member, **options) elif ConstantVariable.is_literal(member): return ConstantVariable(member, **options) else: return VariableBuilder(tx, source)(member).add_guards(guards) elif isinstance(obj, (PythonModuleVariable, DummyModule)): member = obj.value.__dict__[name] if config.replay_record_enabled: tx.exec_recorder.record_module_access(obj.value, name, member) return VariableBuilder(tx, source)(member).add_guards(guards) elif istype(obj, UserFunctionVariable) and name in ("__name__", "__module__"): return ConstantVariable( getattr(obj.fn, name), **VariableTracker.propagate(obj) ) else: try: return ( obj.var_getattr(tx, name).clone(source=source).add_options(options) ) except NotImplementedError: return GetAttrVariable(obj, name, **options) def call_setattr( self, tx, obj: VariableTracker, name_var: VariableTracker, val: VariableTracker ): if isinstance(obj, (variables.BlackHoleVariable, variables.DataClassVariable)): return obj.call_method(tx, "__setattr__", [name_var, val], {}) elif ( tx.output.side_effects.is_attribute_mutation(obj) and name_var.is_python_constant() ): tx.output.side_effects.store_attr(obj, name_var.as_python_constant(), val) return val.add_options(self, obj, name_var) elif isinstance(obj, variables.UserDefinedObjectVariable): unimplemented( f"setattr(UserDefinedObjectVariable) {type(obj.value).__setattr__}" ) elif isinstance(obj, variables.NNModuleVariable): obj.convert_to_unspecialized(tx) def call_type(self, tx, obj: VariableTracker): from .builder import VariableBuilder try: py_type = obj.python_type() except NotImplementedError: py_type = None if istype(obj, variables.TupleVariable): return BuiltinVariable(py_type).add_options(self, obj) if py_type is not None and obj.source: return VariableBuilder(tx, TypeSource(obj.source))(py_type).add_options( self, obj ) unimplemented(f"type({obj})") def call_reversed(self, tx, obj: VariableTracker): if obj.has_unpack_var_sequence(tx): items = list(reversed(obj.unpack_var_sequence(tx))) return variables.TupleVariable( items, **VariableTracker.propagate(self, obj) ) def call_chain(self, tx, *args): if all(obj.has_unpack_var_sequence(tx) for obj in args): items = [] for obj in args: items.extend(obj.unpack_var_sequence(tx)) return variables.TupleVariable( items, **VariableTracker.propagate(self, *args) ) def call_islice(self, tx, iterable, *args): if iterable.has_unpack_var_sequence(tx) and all( x.is_python_constant() for x in args ): const_args = [x.as_python_constant() for x in args] items = iterable.unpack_var_sequence(tx) items = list(itertools.islice(items, *const_args)) return variables.TupleVariable( items, **VariableTracker.propagate(self, iterable, *args) ) def call_id(self, tx, *args): if len(args) > 0 and isinstance(args[0], variables.NNModuleVariable): nn_mod_variable = args[0] mod = tx.output.get_submodule(nn_mod_variable.module_key) return variables.ConstantVariable(id(mod)) else: unimplemented(f"call_id with args {args}")