mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
We want to get to a point where most UserErrors link to exportdb examples. This PR makes passing case names non-optional to make this intent clearer and encourage developers who raise UserErrors to make or point to examples that make fixing such errors more obvious for users. In addition, sometimes there are multiple examples that are relevant to an error. Thus this PR also enables passing multiple case names. Retry of #110733 which was reverted due to a landrace. Differential Revision: [D50087148](https://our.internmc.facebook.com/intern/diff/D50087148/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/110878 Approved by: https://github.com/gmagogsfm, https://github.com/tugsbayasgalan
1531 lines
56 KiB
Python
1531 lines
56 KiB
Python
import functools
|
|
import inspect
|
|
import itertools
|
|
import logging
|
|
import math
|
|
import operator
|
|
import types
|
|
from typing import Dict, List
|
|
|
|
import torch
|
|
from torch import sym_float, sym_int
|
|
|
|
from .. import config, polyfill, variables
|
|
from ..allowed_functions import is_allowed
|
|
from ..exc import (
|
|
AttributeMutationError,
|
|
unimplemented,
|
|
Unsupported,
|
|
UserError,
|
|
UserErrorType,
|
|
)
|
|
from ..guards import GuardBuilder
|
|
from ..replay_record import DummyModule
|
|
from ..source import AttrSource, GetItemSource, is_constant_source, TypeSource
|
|
from ..utils import (
|
|
build_checkpoint_variable,
|
|
check_constant_args,
|
|
check_numpy_ndarray_args,
|
|
check_unspec_python_args,
|
|
get_fake_value,
|
|
guard_if_dyn,
|
|
is_utils_checkpoint,
|
|
istype,
|
|
numpy_operator_wrapper,
|
|
proxy_args_kwargs,
|
|
specialize_args_kwargs,
|
|
)
|
|
from .base import MutableLocal, typestr, VariableTracker
|
|
from .constant import ConstantVariable, EnumVariable
|
|
from .dicts import ConstDictVariable
|
|
from .lists import (
|
|
BaseListVariable,
|
|
ListIteratorVariable,
|
|
ListVariable,
|
|
SetVariable,
|
|
SizeVariable,
|
|
TupleIteratorVariable,
|
|
TupleVariable,
|
|
)
|
|
from .tensor import FakeItemVariable, SymNodeVariable, UnspecializedPythonVariable
|
|
from .user_defined import UserDefinedVariable
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
class BuiltinVariable(VariableTracker):
|
|
@staticmethod
|
|
@functools.lru_cache(None)
|
|
def _constant_fold_functions():
|
|
fns = {
|
|
abs,
|
|
all,
|
|
any,
|
|
bool,
|
|
callable,
|
|
chr,
|
|
divmod,
|
|
float,
|
|
int,
|
|
len,
|
|
max,
|
|
min,
|
|
ord,
|
|
pow,
|
|
repr,
|
|
round,
|
|
str,
|
|
str.format,
|
|
sum,
|
|
type,
|
|
operator.pos,
|
|
operator.neg,
|
|
operator.not_,
|
|
operator.invert,
|
|
operator.pow,
|
|
operator.mul,
|
|
operator.matmul,
|
|
operator.floordiv,
|
|
operator.truediv,
|
|
operator.mod,
|
|
operator.add,
|
|
operator.sub,
|
|
operator.getitem,
|
|
operator.lshift,
|
|
operator.rshift,
|
|
operator.and_,
|
|
operator.or_,
|
|
operator.xor,
|
|
operator.ipow,
|
|
operator.imul,
|
|
operator.imatmul,
|
|
operator.ifloordiv,
|
|
operator.itruediv,
|
|
operator.imod,
|
|
operator.iadd,
|
|
operator.isub,
|
|
operator.ilshift,
|
|
operator.irshift,
|
|
operator.iand,
|
|
operator.ixor,
|
|
operator.ior,
|
|
operator.index,
|
|
}
|
|
fns.update(x for x in math.__dict__.values() if isinstance(x, type(math.sqrt)))
|
|
return fns
|
|
|
|
def can_constant_fold_through(self):
|
|
return self.fn in self._constant_fold_functions()
|
|
|
|
@staticmethod
|
|
@functools.lru_cache(None)
|
|
def _fx_graph_functions():
|
|
fns = {
|
|
operator.pos,
|
|
operator.neg,
|
|
operator.not_,
|
|
operator.invert,
|
|
operator.pow,
|
|
operator.mul,
|
|
operator.matmul,
|
|
operator.floordiv,
|
|
operator.truediv,
|
|
operator.mod,
|
|
operator.add,
|
|
operator.sub,
|
|
operator.getitem,
|
|
operator.lshift,
|
|
operator.rshift,
|
|
operator.and_,
|
|
operator.or_,
|
|
operator.xor,
|
|
operator.ipow,
|
|
operator.imul,
|
|
operator.imatmul,
|
|
operator.ifloordiv,
|
|
operator.itruediv,
|
|
operator.imod,
|
|
operator.iadd,
|
|
operator.isub,
|
|
operator.ilshift,
|
|
operator.irshift,
|
|
operator.iand,
|
|
operator.ixor,
|
|
operator.ior,
|
|
}
|
|
return fns
|
|
|
|
@staticmethod
|
|
@functools.lru_cache(None)
|
|
def _binops():
|
|
# function -> ([forward name, reverse name, in-place name], in-place op)
|
|
fns = {
|
|
operator.add: (["__add__", "__radd__", "__iadd__"], operator.iadd),
|
|
operator.sub: (["__sub__", "__rsub__", "__isub__"], operator.isub),
|
|
operator.mul: (["__mul__", "__rmul__", "__imul__"], operator.imul),
|
|
operator.truediv: (
|
|
["__truediv__", "__rtruediv__", "__itruediv__"],
|
|
operator.itruediv,
|
|
),
|
|
operator.floordiv: (
|
|
["__floordiv__", "__rfloordiv__", "__ifloordiv__"],
|
|
operator.ifloordiv,
|
|
),
|
|
operator.mod: (["__mod__", "__rmod__", "__imod__"], operator.imod),
|
|
pow: (["__pow__", "__rpow__", "__ipow__"], operator.ipow),
|
|
operator.pow: (["__pow__", "__rpow__", "__ipow__"], operator.ipow),
|
|
operator.lshift: (
|
|
["__lshift__", "__rlshift__", "__ilshift__"],
|
|
operator.ilshift,
|
|
),
|
|
operator.rshift: (
|
|
["__rshift__", "__rrshift__", "__irshift__"],
|
|
operator.irshift,
|
|
),
|
|
# NB: The follow binary operators are not supported for now, since the
|
|
# corresponding magic methods aren't defined on SymInt / SymFloat:
|
|
# operator.matmul
|
|
# divmod
|
|
# operator.and_
|
|
# operator.or_
|
|
# operator.xor
|
|
}
|
|
return fns
|
|
|
|
@staticmethod
|
|
@functools.lru_cache(None)
|
|
def _binop_handlers():
|
|
# Multiple dispatch mechanism defining custom binop behavior for certain type
|
|
# combinations. Handlers are attempted in order, and will be used if the type checks
|
|
# match. They are expected to have the signature:
|
|
# fn(tx, arg0: VariableTracker, arg1: VariableTracker, options) -> VariableTracker
|
|
|
|
# Override table contains: op_fn -> [list of handlers]
|
|
op_handlers = {}
|
|
for (
|
|
op,
|
|
(magic_method_names, in_place_op),
|
|
) in BuiltinVariable._binops().items():
|
|
op_handlers[op] = []
|
|
op_handlers[in_place_op] = []
|
|
|
|
forward_name, reverse_name, inplace_name = magic_method_names
|
|
|
|
# User-defined args (highest precedence)
|
|
def user_defined_handler(
|
|
tx,
|
|
a,
|
|
b,
|
|
options,
|
|
forward_name=forward_name,
|
|
reverse_name=reverse_name,
|
|
):
|
|
# Manually handle reversing logic if needed (e.g. call __radd__)
|
|
|
|
# TODO: If we expand this to handle tensor args, we need to manually
|
|
# handle cases like this:
|
|
#
|
|
# class A(int):
|
|
# def __radd__(self, other):
|
|
# print("woof")
|
|
# torch.randn(3) + A(3)
|
|
#
|
|
# In this example, A.__radd__() is not called -> nothing is printed, because
|
|
# Tensor.__add__ only does a subtype test against int, ignoring the subclass.
|
|
# To be fully correct, we should not call A.__radd__() here, and there may be
|
|
# other cases to reason about and add exceptions for.
|
|
if isinstance(a, UserDefinedVariable):
|
|
return a.call_method(tx, forward_name, [b], {})
|
|
else:
|
|
return b.call_method(tx, reverse_name, [a], {})
|
|
|
|
op_handlers[op].append(
|
|
((UserDefinedVariable, VariableTracker), user_defined_handler)
|
|
)
|
|
op_handlers[op].append(
|
|
((VariableTracker, UserDefinedVariable), user_defined_handler)
|
|
)
|
|
|
|
def user_defined_inplace_handler(
|
|
tx, a, b, options, forward_name=inplace_name
|
|
):
|
|
return a.call_method(tx, forward_name, [b], {})
|
|
|
|
op_handlers[in_place_op].append(
|
|
((UserDefinedVariable, VariableTracker), user_defined_inplace_handler)
|
|
)
|
|
op_handlers[in_place_op].append(
|
|
((VariableTracker, UserDefinedVariable), user_defined_inplace_handler)
|
|
)
|
|
|
|
# Dynamic shape args
|
|
def dynamic_handler(tx, a, b, options, fn=op):
|
|
from .builder import wrap_fx_proxy
|
|
|
|
return wrap_fx_proxy(
|
|
tx,
|
|
tx.output.create_proxy(
|
|
"call_function", fn, *proxy_args_kwargs([a, b], {})
|
|
),
|
|
**options,
|
|
)
|
|
|
|
op_handlers[op].append(
|
|
((SymNodeVariable, VariableTracker), dynamic_handler)
|
|
)
|
|
op_handlers[op].append(
|
|
((VariableTracker, SymNodeVariable), dynamic_handler)
|
|
)
|
|
|
|
# NB: Prefer out-of-place op when calling in-place op to generate valid graph
|
|
op_handlers[in_place_op].append(
|
|
((SymNodeVariable, VariableTracker), dynamic_handler)
|
|
)
|
|
op_handlers[in_place_op].append(
|
|
((VariableTracker, SymNodeVariable), dynamic_handler)
|
|
)
|
|
|
|
# Special cases - lower precedence but still prefer these over constant folding
|
|
|
|
# List-like addition (e.g. [1, 2] + [3, 4])
|
|
def tuple_add_handler(tx, a, b, options):
|
|
return TupleVariable(a.items + list(b.unpack_var_sequence(tx)), **options)
|
|
|
|
def size_add_handler(tx, a, b, options):
|
|
return SizeVariable(a.items + list(b.unpack_var_sequence(tx)), **options)
|
|
|
|
list_like_addition_handlers = [
|
|
# NB: Prefer the tuple-specific logic over base logic because of
|
|
# some SizeVariable weirdness. Specifically, the tuple-specific logic
|
|
# drops the subclass type (e.g. SizeVariable) and returns TupleVariables.
|
|
(
|
|
(SizeVariable, SizeVariable),
|
|
size_add_handler,
|
|
),
|
|
(
|
|
(TupleVariable, TupleVariable),
|
|
tuple_add_handler,
|
|
),
|
|
(
|
|
(TupleVariable, ConstantVariable),
|
|
tuple_add_handler,
|
|
),
|
|
(
|
|
(ConstantVariable, TupleVariable),
|
|
lambda tx, a, b, options: TupleVariable(
|
|
list(a.unpack_var_sequence(tx)) + b.items, **options
|
|
),
|
|
),
|
|
(
|
|
(BaseListVariable, BaseListVariable),
|
|
lambda tx, a, b, options: type(a)(a.items + b.items, **options),
|
|
),
|
|
]
|
|
op_handlers[operator.add].extend(list_like_addition_handlers)
|
|
|
|
def list_iadd_handler(tx, a, b, options):
|
|
if not a.mutable_local or not b.has_unpack_var_sequence(tx):
|
|
# Handler doesn't apply
|
|
return None
|
|
|
|
return tx.replace_all(
|
|
a,
|
|
ListVariable(
|
|
list(a.items) + list(b.unpack_var_sequence(tx)),
|
|
regen_guards=False,
|
|
**options,
|
|
),
|
|
)
|
|
|
|
list_like_iadd_handlers = [
|
|
(
|
|
(ListVariable, VariableTracker),
|
|
list_iadd_handler,
|
|
),
|
|
(
|
|
(TupleVariable, TupleVariable),
|
|
tuple_add_handler,
|
|
),
|
|
(
|
|
(TupleVariable, ConstantVariable),
|
|
tuple_add_handler,
|
|
),
|
|
]
|
|
op_handlers[operator.iadd].extend(list_like_iadd_handlers)
|
|
|
|
# List-like expansion (e.g. [1, 2, 3] * 3)
|
|
def expand_list_like(tx, lst, const, options):
|
|
return lst.__class__(
|
|
items=lst.items * const.as_python_constant(),
|
|
mutable_local=MutableLocal(),
|
|
**options,
|
|
)
|
|
|
|
list_like_expansion_handlers = [
|
|
((ListVariable, ConstantVariable), expand_list_like),
|
|
((TupleVariable, ConstantVariable), expand_list_like),
|
|
(
|
|
(ConstantVariable, ListVariable),
|
|
lambda tx, a, b, options: expand_list_like(tx, b, a, options),
|
|
),
|
|
(
|
|
(ConstantVariable, TupleVariable),
|
|
lambda tx, a, b, options: expand_list_like(tx, b, a, options),
|
|
),
|
|
]
|
|
op_handlers[operator.mul].extend(list_like_expansion_handlers)
|
|
|
|
return op_handlers
|
|
|
|
@staticmethod
|
|
def _find_binop_handler(op, a, b):
|
|
handlers = BuiltinVariable._binop_handlers()
|
|
if op not in handlers:
|
|
return None
|
|
|
|
# Return first handler that matches the type checks
|
|
for (type1, type2), handler in handlers[op]:
|
|
if isinstance(a, type1) and isinstance(b, type2):
|
|
return handler
|
|
|
|
return None
|
|
|
|
def can_insert_in_graph(self):
|
|
return self.fn in self._fx_graph_functions()
|
|
|
|
def __init__(self, fn, **kwargs):
|
|
super().__init__(**kwargs)
|
|
self.fn = fn
|
|
|
|
def __str__(self):
|
|
if self.fn is None:
|
|
name = "None"
|
|
else:
|
|
name = self.fn.__name__
|
|
|
|
return f"{self.__class__.__name__}({name})"
|
|
|
|
def python_type(self):
|
|
return type(self.fn)
|
|
|
|
def as_python_constant(self):
|
|
return self.fn
|
|
|
|
def as_proxy(self):
|
|
DTYPE = {
|
|
bool: torch.bool,
|
|
int: torch.int64,
|
|
float: torch.float64,
|
|
}
|
|
if self.fn in DTYPE:
|
|
return DTYPE[self.fn]
|
|
return super().as_proxy()
|
|
|
|
def reconstruct(self, codegen):
|
|
name = self.fn.__name__
|
|
assert self.fn.__module__ == "builtins"
|
|
assert name not in codegen.tx.f_globals, "shadowed global"
|
|
return [codegen.create_load_global(name, False, add=True)]
|
|
|
|
def constant_args(self, *args, **kwargs):
|
|
return check_constant_args(args, kwargs)
|
|
|
|
def tensor_args(self, *args, **kwargs):
|
|
return any(
|
|
isinstance(i, variables.TensorVariable)
|
|
for i in itertools.chain(args, kwargs.values())
|
|
) and not any(
|
|
isinstance(i, variables.GetAttrVariable)
|
|
for i in itertools.chain(args, kwargs.values())
|
|
)
|
|
|
|
def unspec_python_args(self, *args, **kwargs):
|
|
return check_unspec_python_args(args, kwargs)
|
|
|
|
@staticmethod
|
|
def unwrap_unspec_args_kwargs(args, kwargs):
|
|
unwrapped_args = []
|
|
unwrapped_kwargs = {}
|
|
for x in args:
|
|
if isinstance(
|
|
x,
|
|
(variables.UnspecializedPythonVariable,),
|
|
):
|
|
unwrapped_args.append(x.raw_value)
|
|
else:
|
|
unwrapped_args.append(x.as_python_constant())
|
|
for k, v in kwargs:
|
|
if isinstance(
|
|
x,
|
|
(variables.UnspecializedPythonVariable,),
|
|
):
|
|
unwrapped_kwargs.update({k: v.raw_value})
|
|
else:
|
|
unwrapped_kwargs.update({k: v.as_python_constant()})
|
|
return unwrapped_args, unwrapped_kwargs
|
|
|
|
def call_function(
|
|
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
|
) -> "VariableTracker":
|
|
from . import UserFunctionVariable
|
|
from .builder import wrap_fx_proxy, wrap_fx_proxy_cls
|
|
|
|
constant_args = check_constant_args(args, kwargs)
|
|
tensor_args = self.tensor_args(*args, **kwargs)
|
|
unspec_python_args = self.unspec_python_args(*args, **kwargs)
|
|
options = VariableTracker.propagate(self, args, kwargs.values())
|
|
has_constant_handler = self.can_constant_fold_through() and (
|
|
constant_args or unspec_python_args
|
|
)
|
|
assert isinstance(args, (list, tuple))
|
|
assert isinstance(kwargs, dict)
|
|
|
|
# args[0] is list and args[1] is unspec
|
|
if self.fn is operator.getitem and not isinstance(
|
|
args[0], variables.TensorVariable
|
|
):
|
|
tensor_args = False
|
|
args, kwargs = specialize_args_kwargs(tx, args, kwargs)
|
|
|
|
if (
|
|
self.can_insert_in_graph()
|
|
and tensor_args
|
|
and not (
|
|
self.fn is operator.getitem
|
|
and isinstance(args[0], ConstDictVariable)
|
|
and isinstance(args[1], variables.TensorVariable)
|
|
)
|
|
):
|
|
try:
|
|
fn = self.fn
|
|
if self.fn is operator.iadd and isinstance(
|
|
args[0], variables.ConstantVariable
|
|
):
|
|
# Work around weird bug in hf_T5
|
|
fn, args = operator.add, [args[1], args[0]]
|
|
|
|
if self.fn is operator.getitem and isinstance(args[1], SymNodeVariable):
|
|
# Standard indexing will force specialization due to
|
|
# __index__. Rewrite as a regular torch op which will
|
|
# trace fine
|
|
fn, args = torch.select, [
|
|
args[0],
|
|
variables.ConstantVariable.create(0),
|
|
args[1],
|
|
]
|
|
|
|
# Interaction between ndarray and tensors:
|
|
# We prefer the tensor op whenever there are tensors involved
|
|
if check_numpy_ndarray_args(args, kwargs) and not any(
|
|
type(arg) == variables.TensorVariable for arg in args
|
|
):
|
|
proxy = tx.output.create_proxy(
|
|
"call_function",
|
|
numpy_operator_wrapper(self.fn),
|
|
*proxy_args_kwargs(args, kwargs),
|
|
)
|
|
|
|
return wrap_fx_proxy_cls(
|
|
variables.NumpyNdarrayVariable, tx, proxy, **options
|
|
)
|
|
|
|
proxy = tx.output.create_proxy(
|
|
"call_function",
|
|
fn,
|
|
*proxy_args_kwargs(args, kwargs),
|
|
)
|
|
if any(isinstance(arg, FakeItemVariable) for arg in args):
|
|
return wrap_fx_proxy_cls(
|
|
FakeItemVariable,
|
|
tx,
|
|
proxy,
|
|
**options,
|
|
)
|
|
elif self.unspec_python_args(*args, **kwargs):
|
|
_args, _kwargs = self.unwrap_unspec_args_kwargs(args, kwargs)
|
|
raw_value = self.fn(*_args, **_kwargs)
|
|
|
|
need_unwrap = any(
|
|
x.need_unwrap
|
|
for x in itertools.chain(args, kwargs.values())
|
|
if isinstance(x, variables.UnspecializedPythonVariable)
|
|
)
|
|
|
|
return wrap_fx_proxy_cls(
|
|
UnspecializedPythonVariable,
|
|
tx,
|
|
proxy,
|
|
raw_value=raw_value,
|
|
need_unwrap=need_unwrap,
|
|
**options,
|
|
)
|
|
elif all(isinstance(x, SymNodeVariable) for x in args):
|
|
return SymNodeVariable.create(tx, proxy, None, **options)
|
|
else:
|
|
# Work around for vision_maskrcnn due to precision difference
|
|
# specialize the dividend when float divide by tensor
|
|
if self.fn is operator.truediv and isinstance(
|
|
args[0], variables.UnspecializedPythonVariable
|
|
):
|
|
args[0] = args[0].convert_to_constant(tx)
|
|
return wrap_fx_proxy(tx, proxy, **options)
|
|
|
|
except NotImplementedError:
|
|
unimplemented(f"partial tensor op: {self} {args} {kwargs}")
|
|
|
|
# Handle cases like int(torch.seed())
|
|
# Also handle sym_float to sym_int cases
|
|
if self.fn in (int, float) and isinstance(args[0], SymNodeVariable):
|
|
fn_ = sym_int if self.fn is int else sym_float
|
|
out = wrap_fx_proxy(
|
|
tx=tx,
|
|
proxy=tx.output.create_proxy(
|
|
"call_function",
|
|
fn_,
|
|
(args[0].as_proxy(),),
|
|
{},
|
|
),
|
|
**options,
|
|
)
|
|
return out
|
|
|
|
# Handle `str` on a user defined function
|
|
if self.fn == str and args and isinstance(args[0], (UserFunctionVariable)):
|
|
return variables.ConstantVariable.create(value=str(args[0].fn))
|
|
|
|
# Handle binary ops (e.g. __add__ / __radd__, __iadd__, etc.)
|
|
# NB: Tensor args are handled above and not here
|
|
if len(kwargs) == 0 and len(args) == 2:
|
|
# Try to find a handler for the arg types; otherwise, fall through to constant handler
|
|
binop_handler = BuiltinVariable._find_binop_handler(
|
|
self.fn, args[0], args[1]
|
|
)
|
|
if binop_handler:
|
|
res = binop_handler(tx, args[0], args[1], options)
|
|
if res is not None:
|
|
return res
|
|
|
|
handler = getattr(self, f"call_{self.fn.__name__}", None)
|
|
if handler:
|
|
try:
|
|
inspect.signature(handler).bind(tx, *args, **kwargs)
|
|
except TypeError as exc:
|
|
if not has_constant_handler:
|
|
log.warning(
|
|
"incorrect arg count %s %s and no constant handler",
|
|
handler,
|
|
exc,
|
|
)
|
|
handler = None
|
|
|
|
if handler:
|
|
try:
|
|
result = handler(tx, *args, **kwargs)
|
|
if result is not None:
|
|
return result.add_options(options)
|
|
except Unsupported as exc:
|
|
if not has_constant_handler:
|
|
raise
|
|
# Actually, we will handle this just fine
|
|
exc.remove_from_stats()
|
|
|
|
if has_constant_handler:
|
|
args, kwargs = specialize_args_kwargs(tx, args, kwargs)
|
|
# constant fold
|
|
return variables.ConstantVariable.create(
|
|
self.as_python_constant()(
|
|
*[x.as_python_constant() for x in args],
|
|
**{k: v.as_python_constant() for k, v in kwargs.items()},
|
|
),
|
|
**options,
|
|
)
|
|
|
|
if self.fn is round:
|
|
if len(args) > 0 and isinstance(args[0], SymNodeVariable):
|
|
raise UserError(
|
|
UserErrorType.STANDARD_LIBRARY,
|
|
"Calling round() on symbolic value is not supported. "
|
|
"You can use floor() to implement this functionality",
|
|
case_names=["dynamic_shape_round"],
|
|
)
|
|
return super().call_function(tx, args, kwargs)
|
|
|
|
def _call_min_max(self, tx, *args):
|
|
if len(args) == 1 and args[0].has_unpack_var_sequence(tx):
|
|
# expand iterable
|
|
items = args[0].unpack_var_sequence(tx)
|
|
return self._call_min_max_seq(tx, items)
|
|
elif len(args) == 2:
|
|
return self._call_min_max_binary(tx, args[0], args[1])
|
|
elif len(args) > 2:
|
|
return self._call_min_max_seq(tx, args)
|
|
|
|
def _call_min_max_seq(self, tx, items):
|
|
assert len(items) > 0
|
|
if len(items) == 1:
|
|
return items[0]
|
|
|
|
return functools.reduce(functools.partial(self._call_min_max_binary, tx), items)
|
|
|
|
def _call_min_max_binary(self, tx, a, b):
|
|
if self.tensor_args(a, b):
|
|
if not isinstance(a, variables.TensorVariable):
|
|
a, b = b, a
|
|
assert isinstance(a, variables.TensorVariable)
|
|
|
|
# result of an item call is a scalar convert to a tensor
|
|
if isinstance(a, FakeItemVariable):
|
|
a = variables.TorchVariable(torch.tensor).call_function(tx, [a], {})
|
|
|
|
# Dynamic input does not get resolved, rather, gets stored as call_function
|
|
if isinstance(a, SymNodeVariable) or isinstance(b, SymNodeVariable):
|
|
from .builder import wrap_fx_proxy_cls
|
|
|
|
return wrap_fx_proxy_cls(
|
|
type(a),
|
|
tx=tx,
|
|
proxy=tx.output.create_proxy(
|
|
"call_function",
|
|
self.fn,
|
|
*proxy_args_kwargs([a, b], {}),
|
|
),
|
|
**VariableTracker.propagate(self, [a, b]),
|
|
)
|
|
|
|
# convert min/max to torch ops
|
|
if b.is_python_constant():
|
|
if isinstance(a, variables.NumpyNdarrayVariable):
|
|
import numpy as np
|
|
|
|
fn = variables.NumpyVariable(np.clip)
|
|
else:
|
|
fn = variables.TorchVariable(torch.clamp)
|
|
kwargs = {"min": b} if (self.fn is max) else {"max": b}
|
|
result = fn.call_function(tx, [a], kwargs)
|
|
else:
|
|
if isinstance(a, variables.NumpyNdarrayVariable):
|
|
import numpy as np
|
|
|
|
fn = {max: np.maximum, min: np.minimum}[self.fn]
|
|
fn = variables.NumpyVariable(fn)
|
|
else:
|
|
fn = {max: torch.maximum, min: torch.minimum}[self.fn]
|
|
fn = variables.TorchVariable(fn)
|
|
result = fn.call_function(tx, [a, b], {})
|
|
|
|
# return unspec if both a, b are unspec or const
|
|
if all(
|
|
isinstance(
|
|
i,
|
|
(
|
|
variables.UnspecializedPythonVariable,
|
|
variables.ConstantVariable,
|
|
),
|
|
)
|
|
for i in [a, b]
|
|
):
|
|
if any(isinstance(val, FakeItemVariable) for val in [a, b]):
|
|
return variables.FakeItemVariable.from_tensor_variable(result)
|
|
|
|
if b.is_python_constant():
|
|
raw_b = b.as_python_constant()
|
|
else:
|
|
raw_b = b.raw_value
|
|
if self.fn is max:
|
|
raw_res = max(a.raw_value, raw_b)
|
|
else:
|
|
raw_res = min(a.raw_value, raw_b)
|
|
|
|
need_unwrap = any(
|
|
x.need_unwrap
|
|
for x in [a, b]
|
|
if isinstance(x, variables.UnspecializedPythonVariable)
|
|
)
|
|
return variables.UnspecializedPythonVariable.from_tensor_variable(
|
|
result, raw_res, need_unwrap
|
|
)
|
|
# otherwise return tensor
|
|
else:
|
|
return result
|
|
elif isinstance(a, variables.ConstantVariable) and isinstance(
|
|
b, variables.ConstantVariable
|
|
):
|
|
if self.fn is max:
|
|
return variables.ConstantVariable.create(max(a.value, b.value))
|
|
else:
|
|
return variables.ConstantVariable.create(min(a.value, b.value))
|
|
elif isinstance(a, SymNodeVariable) or isinstance(b, SymNodeVariable):
|
|
proxy = tx.output.create_proxy(
|
|
"call_function", self.fn, *proxy_args_kwargs([a, b], {})
|
|
)
|
|
return SymNodeVariable.create(tx, proxy, None)
|
|
else:
|
|
unimplemented(f"unsupported min / max over args {str(a)}, {str(b)}")
|
|
|
|
call_min = _call_min_max
|
|
call_max = _call_min_max
|
|
|
|
def call_abs(self, tx, arg: "VariableTracker"):
|
|
# Call arg.__abs__()
|
|
abs_method = BuiltinVariable(getattr).call_function(
|
|
tx, [arg, ConstantVariable.create("__abs__")], {}
|
|
)
|
|
return abs_method.call_function(tx, [], {})
|
|
|
|
def call_range(self, tx, *args):
|
|
if self.unspec_python_args(*args) or self.constant_args(*args):
|
|
args, _ = specialize_args_kwargs(tx, args, {})
|
|
return variables.RangeVariable(args)
|
|
elif self._dynamic_args(*args):
|
|
args = [
|
|
variables.ConstantVariable.create(guard_if_dyn(arg)) for arg in args
|
|
]
|
|
return variables.RangeVariable(args)
|
|
# None no-ops this handler and lets the driving function proceed
|
|
return None
|
|
|
|
def _dynamic_args(self, *args, **kwargs):
|
|
return any(isinstance(x, SymNodeVariable) for x in args) or any(
|
|
isinstance(x, SymNodeVariable) for x in kwargs.values()
|
|
)
|
|
|
|
def call_slice(self, tx, *args):
|
|
return variables.SliceVariable(args)
|
|
|
|
def _dyn_proxy(self, tx, *args, **kwargs):
|
|
from .builder import wrap_fx_proxy
|
|
|
|
options = VariableTracker.propagate(self, args, kwargs.values())
|
|
return wrap_fx_proxy(
|
|
tx,
|
|
tx.output.create_proxy(
|
|
"call_function", self.fn, *proxy_args_kwargs(args, kwargs)
|
|
),
|
|
**options,
|
|
)
|
|
|
|
def _call_iter_tuple_list(self, tx, obj=None, *args, **kwargs):
|
|
if self._dynamic_args(*args, **kwargs):
|
|
return self._dyn_proxy(tx, *args, **kwargs)
|
|
cls = variables.BaseListVariable.cls_for(self.fn)
|
|
if obj is None:
|
|
if cls is SetVariable:
|
|
return cls(
|
|
[],
|
|
mutable_local=MutableLocal(),
|
|
)
|
|
else:
|
|
return cls(
|
|
[],
|
|
mutable_local=MutableLocal(),
|
|
)
|
|
elif obj.has_unpack_var_sequence(tx):
|
|
guards = set()
|
|
if obj.source and not is_constant_source(obj.source):
|
|
if isinstance(obj, TupleIteratorVariable):
|
|
guards.add(obj.source.make_guard(GuardBuilder.TUPLE_ITERATOR_LEN))
|
|
else:
|
|
guards.add(obj.source.make_guard(GuardBuilder.LIST_LENGTH))
|
|
if cls is SetVariable:
|
|
return cls(
|
|
list(obj.unpack_var_sequence(tx)),
|
|
mutable_local=MutableLocal(),
|
|
guards=guards,
|
|
).add_options(self, obj)
|
|
|
|
return cls(
|
|
list(obj.unpack_var_sequence(tx)),
|
|
mutable_local=MutableLocal(),
|
|
guards=guards,
|
|
).add_options(self, obj)
|
|
|
|
call_iter = _call_iter_tuple_list
|
|
call_tuple = _call_iter_tuple_list
|
|
call_list = _call_iter_tuple_list
|
|
call_set = _call_iter_tuple_list
|
|
|
|
@staticmethod
|
|
def is_supported_call_dict_arg(tx, arg):
|
|
return (
|
|
arg is None
|
|
or isinstance(arg, ConstDictVariable)
|
|
or (
|
|
isinstance(
|
|
arg,
|
|
(
|
|
ListVariable,
|
|
TupleVariable,
|
|
ListIteratorVariable,
|
|
),
|
|
)
|
|
and all(
|
|
isinstance(x, (ListVariable, TupleVariable))
|
|
and isinstance(
|
|
x.unpack_var_sequence(tx)[0], (ConstantVariable, EnumVariable)
|
|
)
|
|
for x in arg.unpack_var_sequence(tx)
|
|
)
|
|
)
|
|
)
|
|
|
|
def call_callable(self, tx, arg):
|
|
from .functions import BaseUserFunctionVariable
|
|
|
|
if isinstance(
|
|
arg, (variables.UserDefinedClassVariable, BaseUserFunctionVariable)
|
|
):
|
|
return variables.ConstantVariable.create(True).add_options(arg)
|
|
|
|
@staticmethod
|
|
def call_dict_helper(tx, user_cls, arg, **options):
|
|
if arg is None or isinstance(arg, dict):
|
|
return ConstDictVariable(
|
|
arg if arg is not None else {}, user_cls, mutable_local=MutableLocal()
|
|
).add_options(options)
|
|
elif isinstance(arg, variables.ConstDictVariable):
|
|
return arg.clone(
|
|
user_cls=user_cls, mutable_local=MutableLocal()
|
|
).add_options(options)
|
|
elif isinstance(
|
|
arg,
|
|
(
|
|
ListVariable,
|
|
TupleVariable,
|
|
ListIteratorVariable,
|
|
),
|
|
):
|
|
items = user_cls()
|
|
for x in arg.unpack_var_sequence(tx):
|
|
k = x.unpack_var_sequence(tx)[0].as_python_constant()
|
|
v = x.unpack_var_sequence(tx)[1]
|
|
items.update({k: v})
|
|
return ConstDictVariable(
|
|
items, user_cls, mutable_local=MutableLocal()
|
|
).add_options(options)
|
|
else:
|
|
raise AssertionError("call_dict_helper with illegal arg")
|
|
|
|
def call_cast(self, _, *args, **kwargs):
|
|
if len(args) == 2:
|
|
return args[1]
|
|
|
|
unimplemented(f"unsupported args to builtin cast(): {args} {kwargs}")
|
|
|
|
def call_dict(self, tx, *args, **kwargs):
|
|
if not (args or kwargs):
|
|
return self.call_dict_helper(tx, dict, None)
|
|
elif (
|
|
not kwargs
|
|
and len(args) == 1
|
|
and self.is_supported_call_dict_arg(tx, args[0])
|
|
):
|
|
return self.call_dict_helper(tx, dict, args[0])
|
|
elif not args and kwargs:
|
|
return variables.ConstDictVariable(
|
|
dict(kwargs), user_cls=dict, mutable_local=MutableLocal()
|
|
)
|
|
else:
|
|
unimplemented(f"dict(): {args} {kwargs}")
|
|
|
|
def call_zip(self, tx, *args):
|
|
options = VariableTracker.propagate(self, args)
|
|
if all(x.has_unpack_var_sequence(tx) for x in args):
|
|
items = [
|
|
variables.TupleVariable(list(item), **options)
|
|
for item in zip(*[arg.unpack_var_sequence(tx) for arg in args])
|
|
]
|
|
return variables.TupleVariable(items, **options)
|
|
|
|
def call_enumerate(self, tx, *args):
|
|
options = VariableTracker.propagate(self, args)
|
|
if len(args) == 1:
|
|
start = 0
|
|
else:
|
|
assert len(args) == 2
|
|
assert isinstance(args[1], variables.ConstantVariable)
|
|
start = args[1].as_python_constant()
|
|
if args[0].has_unpack_var_sequence(tx):
|
|
items = [
|
|
variables.TupleVariable(
|
|
[variables.ConstantVariable.create(idx, **options), var],
|
|
**options,
|
|
)
|
|
for idx, var in enumerate(args[0].unpack_var_sequence(tx), start)
|
|
]
|
|
return variables.TupleVariable(items, **options)
|
|
|
|
def call_len(self, tx, *args, **kwargs):
|
|
return args[0].call_method(tx, "__len__", args[1:], kwargs)
|
|
|
|
def call_getitem(self, tx, *args, **kwargs):
|
|
if self.unspec_python_args(*args, **kwargs):
|
|
args, kwargs = specialize_args_kwargs(tx, args, kwargs)
|
|
return args[0].call_method(tx, "__getitem__", args[1:], kwargs)
|
|
|
|
def call_isinstance(self, tx, arg, isinstance_type):
|
|
arg_type = arg.python_type()
|
|
|
|
isinstance_type = isinstance_type.as_python_constant()
|
|
|
|
if isinstance(arg, variables.TensorVariable) and arg.dtype is not None:
|
|
return variables.ConstantVariable.create(
|
|
arg.call_isinstance(isinstance_type)
|
|
)
|
|
# UserDefinedObject with C extensions can have torch.Tensor attributes,
|
|
# so break graph.
|
|
if isinstance(arg, variables.UserDefinedObjectVariable) and isinstance(
|
|
arg.value, types.MemberDescriptorType
|
|
):
|
|
unimplemented(
|
|
f"isinstance called on UserDefinedClass {arg} {isinstance_type}"
|
|
)
|
|
# handle __instancecheck__ defined in user class
|
|
if (
|
|
isinstance(arg, variables.UserDefinedObjectVariable)
|
|
and "__instancecheck__" in isinstance_type.__class__.__dict__
|
|
):
|
|
return variables.ConstantVariable.create(
|
|
isinstance_type.__class__.__instancecheck__(isinstance_type, arg.value)
|
|
)
|
|
|
|
try:
|
|
val = issubclass(arg_type, isinstance_type)
|
|
except TypeError:
|
|
val = arg_type is isinstance_type
|
|
return variables.ConstantVariable.create(val)
|
|
|
|
def call_super(self, tx, a, b):
|
|
return variables.SuperVariable(a, b)
|
|
|
|
def call_next(self, tx, arg):
|
|
if isinstance(arg, variables.ListIteratorVariable):
|
|
val, next_iter = arg.next_variables()
|
|
tx.replace_all(arg, next_iter)
|
|
return val
|
|
elif isinstance(arg, variables.BaseListVariable):
|
|
return arg.items[0].add_options(self, arg)
|
|
|
|
def call_hasattr(self, tx, obj, attr):
|
|
if attr.is_python_constant():
|
|
name = attr.as_python_constant()
|
|
return obj.call_hasattr(tx, name).add_options(self, obj, attr)
|
|
|
|
def call_map(self, tx, fn, seq):
|
|
if seq.has_unpack_var_sequence(tx):
|
|
items = [fn.call_function(tx, [x], {}) for x in seq.unpack_var_sequence(tx)]
|
|
return variables.TupleVariable(items).add_options(self, fn, seq)
|
|
|
|
def call_sum(self, tx, seq, **kwargs):
|
|
# Special case for sum on tuple of floats and ints
|
|
if (
|
|
isinstance(seq, (variables.ListVariable, variables.TupleVariable))
|
|
and all(
|
|
isinstance(x, variables.ConstantVariable)
|
|
and isinstance(x.value, (int, float))
|
|
for x in seq.items
|
|
)
|
|
and not kwargs
|
|
):
|
|
new_list = [x.value for x in seq.items]
|
|
return variables.ConstantVariable.create(sum(new_list))
|
|
if seq.has_unpack_var_sequence(tx):
|
|
start = kwargs.pop(
|
|
"start", variables.ConstantVariable.create(0)
|
|
).as_python_constant()
|
|
assert not kwargs
|
|
items = seq.unpack_var_sequence(tx)[start:]
|
|
return BuiltinVariable(functools.reduce).call_function(
|
|
tx,
|
|
[
|
|
BuiltinVariable(operator.add),
|
|
variables.TupleVariable(items),
|
|
variables.ConstantVariable.create(0).add_options(self, seq),
|
|
],
|
|
{},
|
|
)
|
|
|
|
def call_reduce(self, tx, function, iterable, initializer=None):
|
|
if iterable.has_unpack_var_sequence(tx):
|
|
items = iterable.unpack_var_sequence(tx)
|
|
if initializer is None:
|
|
value, items = items[0], items[1:]
|
|
else:
|
|
value = initializer
|
|
for element in items:
|
|
value = function.call_function(tx, [value, element], {})
|
|
return value
|
|
|
|
def call_getattr(
|
|
self, tx, obj: VariableTracker, name_var: VariableTracker, default=None
|
|
):
|
|
from . import (
|
|
ConstantVariable,
|
|
GetAttrVariable,
|
|
PythonModuleVariable,
|
|
TorchVariable,
|
|
UserFunctionVariable,
|
|
)
|
|
from .builder import SourcelessBuilder, VariableBuilder
|
|
|
|
options = VariableTracker.propagate(self, obj, name_var)
|
|
guards = options["guards"]
|
|
name = name_var.as_python_constant()
|
|
|
|
if not name_var.is_python_constant():
|
|
unimplemented("non-const getattr() name")
|
|
|
|
if tx.output.side_effects.is_attribute_mutation(obj):
|
|
try:
|
|
# re-read a pending side effect?
|
|
return tx.output.side_effects.load_attr(obj, name).add_options(options)
|
|
except KeyError:
|
|
pass
|
|
|
|
if default is not None:
|
|
hasattr_var = self.call_hasattr(tx, obj, name_var)
|
|
guards.update(hasattr_var.guards)
|
|
assert hasattr_var.as_python_constant() in (True, False)
|
|
if not hasattr_var.as_python_constant():
|
|
return default.add_guards(guards)
|
|
|
|
if obj.source:
|
|
source = AttrSource(obj.source, name)
|
|
options["source"] = source
|
|
else:
|
|
source = None
|
|
|
|
if name == "__bases__":
|
|
try:
|
|
value = obj.as_python_constant()
|
|
if isinstance(value, type):
|
|
bases = value.__bases__
|
|
if source is not None:
|
|
tuple_args = [
|
|
VariableBuilder(tx, GetItemSource(source, i))(b)
|
|
for i, b in enumerate(bases)
|
|
]
|
|
elif len(bases) == 1 and (
|
|
bases[0] is object or bases[0] is torch._C.TensorBase
|
|
):
|
|
tuple_args = [SourcelessBuilder()(tx, bases[0])]
|
|
else:
|
|
unimplemented(f"unexpected sourceless type bases: {bases}")
|
|
|
|
return variables.TupleVariable(tuple_args, **options)
|
|
except NotImplementedError:
|
|
pass
|
|
|
|
if isinstance(obj, variables.NNModuleVariable):
|
|
return obj.var_getattr(tx, name).add_options(options)
|
|
elif isinstance(obj, variables.TensorVariable) and name == "grad":
|
|
if source:
|
|
# We are going to be raising this tensor as grapharg. So, ensure
|
|
# that we have real grad value instead of fake tensor value.
|
|
# Walk through the inputs of the subgraph and find if we already
|
|
# have the original tensor stored in the graphargs.
|
|
for grapharg in tx.output.graphargs:
|
|
if grapharg.source == source.base:
|
|
example_value = grapharg.example.grad
|
|
return VariableBuilder(tx, source)(example_value).add_options(
|
|
options
|
|
)
|
|
unimplemented("tensor grad")
|
|
else:
|
|
unimplemented("tensor grad")
|
|
elif isinstance(
|
|
obj,
|
|
(
|
|
variables.TensorVariable,
|
|
variables.NamedTupleVariable,
|
|
variables.ConstantVariable,
|
|
variables.UserDefinedClassVariable,
|
|
variables.UserDefinedObjectVariable,
|
|
),
|
|
):
|
|
try:
|
|
return (
|
|
obj.var_getattr(tx, name).clone(source=source).add_options(options)
|
|
)
|
|
except NotImplementedError:
|
|
return GetAttrVariable(obj, name, **options)
|
|
elif isinstance(obj, TorchVariable):
|
|
member = getattr(obj.value, name)
|
|
if is_utils_checkpoint(member):
|
|
options["source"] = source
|
|
return build_checkpoint_variable(**options)
|
|
elif is_allowed(member):
|
|
return TorchVariable(member, **options)
|
|
elif ConstantVariable.is_literal(member):
|
|
return ConstantVariable.create(member, **options)
|
|
else:
|
|
return VariableBuilder(tx, source)(member).add_guards(guards)
|
|
elif isinstance(obj, (PythonModuleVariable, DummyModule)):
|
|
member = obj.value.__dict__[name]
|
|
|
|
if config.replay_record_enabled:
|
|
tx.exec_recorder.record_module_access(obj.value, name, member)
|
|
|
|
return VariableBuilder(tx, source)(member).add_guards(guards)
|
|
elif istype(obj, UserFunctionVariable) and name in ("__name__", "__module__"):
|
|
return ConstantVariable.create(
|
|
getattr(obj.fn, name), **VariableTracker.propagate(obj)
|
|
)
|
|
else:
|
|
try:
|
|
return (
|
|
obj.var_getattr(tx, name).clone(source=source).add_options(options)
|
|
)
|
|
except NotImplementedError:
|
|
return GetAttrVariable(obj, name, **options)
|
|
|
|
def call_setattr(
|
|
self, tx, obj: VariableTracker, name_var: VariableTracker, val: VariableTracker
|
|
):
|
|
from .distributed import PlacementVariable
|
|
|
|
if isinstance(
|
|
obj,
|
|
(
|
|
variables.DataClassVariable,
|
|
variables.CustomizedDictVariable,
|
|
PlacementVariable,
|
|
),
|
|
):
|
|
return obj.call_method(tx, "__setattr__", [name_var, val], {})
|
|
elif (
|
|
tx.output.side_effects.is_attribute_mutation(obj)
|
|
and name_var.is_python_constant()
|
|
):
|
|
tx.output.side_effects.store_attr(obj, name_var.as_python_constant(), val)
|
|
return val.add_options(self, obj, name_var)
|
|
elif isinstance(obj, variables.UserDefinedObjectVariable):
|
|
unimplemented(
|
|
f"setattr(UserDefinedObjectVariable) {type(obj.value).__setattr__}"
|
|
)
|
|
elif isinstance(obj, variables.NNModuleVariable):
|
|
if not tx.output.is_root_tracer():
|
|
raise AttributeMutationError(
|
|
"Can't inplace modify module params/buffers inside HigherOrderOp"
|
|
)
|
|
if name_var.is_python_constant() and isinstance(
|
|
val, variables.TensorVariable
|
|
):
|
|
assigning_fake_val = get_fake_value(val.as_proxy().node, tx)
|
|
|
|
try:
|
|
getattr_var = obj.var_getattr(tx, name_var.as_python_constant())
|
|
except AttributeError:
|
|
getattr_var = None
|
|
|
|
if isinstance(getattr_var, variables.TensorVariable):
|
|
# get_fake_val will return a real tensor here because it's an attribute on the module (get_attr node)
|
|
existing_attr = get_fake_value(getattr_var.as_proxy().node, tx)
|
|
existing_fake_attr = (
|
|
variables.builder.wrap_to_fake_tensor_and_record(
|
|
existing_attr, tx, source=getattr_var.source, is_tensor=True
|
|
)
|
|
)
|
|
|
|
# same tensor identiy, setattr is a no-op
|
|
mod_setattr = inspect.getattr_static(obj.module_type, "__setattr__")
|
|
if (
|
|
existing_fake_attr is assigning_fake_val
|
|
and mod_setattr is torch.nn.Module.__setattr__
|
|
):
|
|
return getattr_var
|
|
|
|
obj.convert_to_unspecialized(tx)
|
|
# FIXME (tmanlaibaatar) this is utter hack to unblock HuggingFace export
|
|
# Export generally doesn't want to allow mutations on objects directly,
|
|
# but we don't have good way to do this rn. For now, we make it an undefined
|
|
# behaviour and just set attributes directly on the PretrainedConfig object
|
|
# for now.
|
|
elif isinstance(obj, variables.dicts.HFPretrainedConfigVariable) and tx.export:
|
|
if name_var.is_python_constant() and isinstance(
|
|
val, variables.ConstantVariable
|
|
):
|
|
setattr(
|
|
obj.obj, name_var.as_python_constant(), val.as_python_constant()
|
|
)
|
|
return ConstantVariable(None)
|
|
|
|
def call_delattr(self, tx, obj: VariableTracker, name_var: VariableTracker):
|
|
return self.call_setattr(tx, obj, name_var, variables.DeletedVariable())
|
|
|
|
def call_type(self, tx, obj: VariableTracker):
|
|
from .builder import VariableBuilder
|
|
|
|
try:
|
|
py_type = obj.python_type()
|
|
except NotImplementedError:
|
|
py_type = None
|
|
|
|
if istype(obj, variables.TupleVariable):
|
|
return BuiltinVariable(py_type).add_options(self, obj)
|
|
|
|
if py_type is not None and obj.source:
|
|
return VariableBuilder(tx, TypeSource(obj.source))(py_type).add_options(
|
|
self, obj
|
|
)
|
|
|
|
if py_type is not None:
|
|
return ConstantVariable.create(py_type)
|
|
|
|
raise UserError(
|
|
UserErrorType.ANTI_PATTERN,
|
|
f"Can't call type() on generated custom object {obj}. "
|
|
"Please use __class__ instead",
|
|
case_names=["type_reflection_method"],
|
|
)
|
|
|
|
def call_reversed(self, tx, obj: VariableTracker):
|
|
if obj.has_unpack_var_sequence(tx):
|
|
items = list(reversed(obj.unpack_var_sequence(tx)))
|
|
return variables.TupleVariable(
|
|
items, **VariableTracker.propagate(self, obj)
|
|
)
|
|
|
|
def call_sorted(self, tx, obj: VariableTracker, **kwargs):
|
|
if (
|
|
obj.has_unpack_var_sequence(tx)
|
|
and not isinstance(obj, variables.TensorVariable)
|
|
and all(x.is_python_constant() for x in obj.unpack_var_sequence(tx))
|
|
):
|
|
function = kwargs.pop("key", None)
|
|
reverse = kwargs.pop(
|
|
"reverse", ConstantVariable.create(False)
|
|
).as_python_constant()
|
|
assert len(kwargs) == 0
|
|
if function:
|
|
items = sorted(
|
|
obj.unpack_var_sequence(tx),
|
|
key=lambda x: function.call_function(
|
|
tx, [x], {}
|
|
).as_python_constant(),
|
|
reverse=reverse,
|
|
)
|
|
else:
|
|
items = sorted(
|
|
obj.unpack_var_sequence(tx),
|
|
key=lambda x: x.as_python_constant(),
|
|
reverse=reverse,
|
|
)
|
|
return variables.ListVariable(items, **VariableTracker.propagate(self, obj))
|
|
|
|
def call_chain(self, tx, *args):
|
|
if all(obj.has_unpack_var_sequence(tx) for obj in args):
|
|
items = []
|
|
for obj in args:
|
|
items.extend(obj.unpack_var_sequence(tx))
|
|
return variables.TupleVariable(
|
|
items, **VariableTracker.propagate(self, *args)
|
|
)
|
|
|
|
def call_islice(self, tx, iterable, *args):
|
|
if iterable.has_unpack_var_sequence(tx) and all(
|
|
x.is_python_constant() for x in args
|
|
):
|
|
const_args = [x.as_python_constant() for x in args]
|
|
items = iterable.unpack_var_sequence(tx)
|
|
items = list(itertools.islice(items, *const_args))
|
|
return variables.TupleVariable(
|
|
items, **VariableTracker.propagate(self, iterable, *args)
|
|
)
|
|
|
|
# neg is a constant fold function, so we only get here if constant fold is not valid
|
|
def call_neg(self, tx, a):
|
|
if isinstance(a, SymNodeVariable):
|
|
return SymNodeVariable.create(
|
|
tx,
|
|
(operator.neg)(a.as_proxy()),
|
|
sym_num=None,
|
|
)
|
|
# None no-ops this handler and lets the driving function proceed
|
|
return None
|
|
|
|
def call_id(self, tx, *args):
|
|
if len(args) > 0 and isinstance(args[0], variables.NNModuleVariable):
|
|
nn_mod_variable = args[0]
|
|
mod = tx.output.get_submodule(nn_mod_variable.module_key)
|
|
return variables.ConstantVariable.create(id(mod))
|
|
else:
|
|
unimplemented(f"call_id with args {args}")
|
|
|
|
def _comparison(self, tx, left, right):
|
|
"""
|
|
Used to implement comparison operators for different types.
|
|
For example, list1 < list2 is implemented differently from tensor1 < tensor2
|
|
"""
|
|
from . import (
|
|
BaseListVariable,
|
|
ConstantVariable,
|
|
NNModuleVariable,
|
|
TensorVariable,
|
|
UserDefinedObjectVariable,
|
|
UserFunctionVariable,
|
|
)
|
|
from .lists import SizeVariable
|
|
from .tensor import (
|
|
supported_const_comparison_ops,
|
|
supported_tensor_comparison_ops,
|
|
)
|
|
|
|
op = self.fn
|
|
|
|
def _unimplemented():
|
|
unimplemented(f"comparison {typestr(left)} {op} {typestr(right)}")
|
|
|
|
if (
|
|
all(
|
|
isinstance(x, (NNModuleVariable, ConstantVariable))
|
|
for x in [left, right]
|
|
)
|
|
and op in supported_const_comparison_ops.values()
|
|
):
|
|
left = (
|
|
tx.output.get_submodule(left.module_key)
|
|
if isinstance(left, NNModuleVariable)
|
|
else left.as_python_constant()
|
|
)
|
|
right = (
|
|
tx.output.get_submodule(right.module_key)
|
|
if isinstance(right, NNModuleVariable)
|
|
else right.as_python_constant()
|
|
)
|
|
return ConstantVariable.create(op(left, right))
|
|
|
|
if isinstance(left, UserFunctionVariable):
|
|
if op not in supported_const_comparison_ops.values():
|
|
_unimplemented()
|
|
if not isinstance(right, UserFunctionVariable):
|
|
_unimplemented()
|
|
return ConstantVariable.create(op(left.fn, right.fn))
|
|
|
|
# Note, we have a rare BaseListVariable subtype mismatch with valid comparison
|
|
# x = torch.randn([3, 3])
|
|
# x.size() == (3, 3) # True
|
|
# (3, 3) == x.size() # True
|
|
if isinstance(left, (SizeVariable, TupleVariable)) and isinstance(
|
|
right, (TupleVariable, SizeVariable)
|
|
):
|
|
return BaseListVariable.list_compare(tx, op, left, right)
|
|
|
|
if isinstance(left, BaseListVariable):
|
|
if not type(left) == type(right): # Mismatch in BaseListVariable subclasses
|
|
_unimplemented()
|
|
return BaseListVariable.list_compare(tx, op, left, right)
|
|
|
|
if isinstance(left, SetVariable):
|
|
if not type(left) == type(right): # Mismatch in BaseListVariable subclasses
|
|
_unimplemented()
|
|
return ConstantVariable.create(
|
|
op(left._underlying_items, right._underlying_items)
|
|
)
|
|
|
|
if isinstance(left, TensorVariable):
|
|
from .builder import wrap_fx_proxy_cls
|
|
|
|
if op not in supported_tensor_comparison_ops.values():
|
|
_unimplemented()
|
|
if (
|
|
isinstance(right, TensorVariable)
|
|
and (left.size and right.size) is not None
|
|
and left.size != right.size
|
|
):
|
|
try:
|
|
torch.broadcast_shapes(left.size, right.size)
|
|
except RuntimeError:
|
|
# not broadcastable, can't be compared
|
|
_unimplemented()
|
|
return wrap_fx_proxy_cls(
|
|
type(left), # handle Ndarrays and Tensors
|
|
tx,
|
|
op(left.as_proxy(), right.as_proxy()),
|
|
)
|
|
|
|
if isinstance(left, SymNodeVariable) or isinstance(right, SymNodeVariable):
|
|
if op not in supported_tensor_comparison_ops.values():
|
|
_unimplemented()
|
|
|
|
return SymNodeVariable.create(
|
|
tx,
|
|
op(left.as_proxy(), right.as_proxy()),
|
|
sym_num=None,
|
|
)
|
|
|
|
if isinstance(left, ConstantVariable) and isinstance(right, ConstantVariable):
|
|
return ConstantVariable.create(op(left.value, right.value))
|
|
|
|
if isinstance(left, UserDefinedObjectVariable) and isinstance(
|
|
right, UserDefinedObjectVariable
|
|
):
|
|
return ConstantVariable.create(op(left.value, right.value))
|
|
|
|
if op.__name__ == "is_":
|
|
# If the two objects are of different type, we can safely return False
|
|
if type(left) is not type(right):
|
|
return ConstantVariable.create(False)
|
|
|
|
_unimplemented()
|
|
|
|
# and_ is a constant fold function, so we only get here if constant fold is not valid
|
|
def call_and_(self, tx, a, b):
|
|
if isinstance(a, (SymNodeVariable, ConstantVariable)) and isinstance(
|
|
b, (SymNodeVariable, ConstantVariable)
|
|
):
|
|
return SymNodeVariable.create(
|
|
tx,
|
|
tx.output.create_proxy(
|
|
"call_function", operator.and_, *proxy_args_kwargs([a, b], {})
|
|
),
|
|
sym_num=None,
|
|
)
|
|
# None no-ops this handler and lets the driving function proceed
|
|
return None
|
|
|
|
# or_ is a constant fold function, so we only get here if constant fold is not valid
|
|
def call_or_(self, tx, a, b):
|
|
if isinstance(a, (SymNodeVariable, ConstantVariable)) and isinstance(
|
|
b, (SymNodeVariable, ConstantVariable)
|
|
):
|
|
return SymNodeVariable.create(
|
|
tx,
|
|
tx.output.create_proxy(
|
|
"call_function", operator.or_, *proxy_args_kwargs([a, b], {})
|
|
),
|
|
sym_num=None,
|
|
)
|
|
# None no-ops this handler and lets the driving function proceed
|
|
return None
|
|
|
|
def call_not_(self, tx, a):
|
|
if isinstance(a, SymNodeVariable):
|
|
return SymNodeVariable.create(
|
|
tx,
|
|
tx.output.create_proxy(
|
|
"call_function", operator.not_, *proxy_args_kwargs([a], {})
|
|
),
|
|
sym_num=None,
|
|
)
|
|
|
|
if isinstance(a, ListVariable):
|
|
return ConstantVariable.create(len(a.items) == 0).add_options(self, a)
|
|
|
|
return None
|
|
|
|
call_eq = _comparison
|
|
call_gt = _comparison
|
|
call_lt = _comparison
|
|
call_ge = _comparison
|
|
call_le = _comparison
|
|
call_ne = _comparison
|
|
call_is_ = _comparison
|
|
call_is_not = _comparison
|
|
|
|
def call_all(self, tx, *args, **kwargs):
|
|
from .builder import SourcelessBuilder
|
|
|
|
return tx.inline_user_function_return(
|
|
SourcelessBuilder()(tx, polyfill.all), args, kwargs
|
|
)
|