[dynamo] reland map/zip iterator related changes (#135074)

Differential Revision: [D62211019](https://our.internmc.facebook.com/intern/diff/D62211019)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135074
Approved by: https://github.com/jansel, https://github.com/anijain2305, https://github.com/mlazos
This commit is contained in:
William Wen 2024-09-03 16:54:04 -07:00 committed by PyTorch MergeBot
parent 22e1fb6faa
commit a4030e37be
12 changed files with 554 additions and 78 deletions

View File

@ -239,6 +239,22 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
v = v + x v = v + x
return v return v
def test_itertools_reconstruct(self):
def fn(a):
it1 = itertools.repeat(1)
it2 = itertools.count(2)
for _ in range(3):
a += next(it1)
a += next(it2)
return it1, it2, a
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
i1, i2, a = fn(torch.ones(3, 3))
it1, it2, b = opt_fn(torch.ones(3, 3))
self.assertEqual(next(i1), next(it1))
self.assertEqual(next(i2), next(it2))
self.assertEqual(a, b)
@make_test @make_test
def test_obj_eq(a, b): def test_obj_eq(a, b):
v = a + b v = a + b
@ -507,8 +523,7 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
empty = collections.deque() empty = collections.deque()
d.extend(empty) d.extend(empty)
# dynamo same() util doesn't support deque so just return a list return d
return list(d)
@make_test @make_test
def test_slice1(a): def test_slice1(a):
@ -3115,6 +3130,199 @@ class GraphModule(torch.nn.Module):
fn(arr, np.s_[..., 1], np.array([3, 3])), np.array([[1, 3], [2, 3]]) fn(arr, np.s_[..., 1], np.array([3, 3])), np.array([[1, 3], [2, 3]])
) )
def test_map_return(self):
def fn(a, b):
return map(lambda x: x + 1, [a, b])
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
m = opt_fn(torch.randn(3, 3), torch.randn(3, 3))
self.assertIsInstance(m, map)
@make_test
def test_map_max(a, b):
return max(map(lambda x: x.sum(), [a, b]))
# max(map(...)) graph breaks
@unittest.expectedFailure
@make_test
def test_map_max_const(a):
return max(map(lambda x: x, [1, 2, 3])), a + 1
@make_test
def test_map_list(a, b):
return list(map(lambda x: x + 1, [a, b]))
@make_test
def test_map_tuple(a, b):
return tuple(map(lambda x: x + 1, [a, b]))
@make_test
def test_map_iter(a, b):
it = iter(map(lambda x: x + 1, [a, b]))
return next(it)
@make_test
def test_map_zip_dict(a):
d = dict(
zip(
map(lambda x: x + 1, [0, 1, 2]),
[map(lambda x: x - 1, [y]) for y in [3, 4, 5]],
)
)
return list(d[3])[0], a + 1 # noqa: RUF015
@make_test
def test_map_dict_fromkeys(a):
return dict.fromkeys(map(lambda x: x + 1, [0, 1])), a + 1
@make_test
def test_map_set(a):
return set(map(lambda x: x + 1, [0, 1])), a + 1
# test_map_sum defined earlier
@make_test
def test_map_reduce(a, b):
return functools.reduce(lambda x, y: x + y, map(lambda x: x + 1, [a, b]))
@make_test
def test_map_sorted(a):
return sorted(map(lambda x: x + 1, [0, 4, 3, 1, 2])), a + 1
@make_test
def test_map_list_extend(a, b, c):
l = [a]
l.extend(map(lambda x: x + 1, [b, c]))
return l
@make_test
def test_map_list_slice_assign(a, b, c, d, e):
l = [a, b, c]
l[1:2] = map(lambda x: x + 1, [d, e])
return l
@make_test
def test_map_deque_extendleft(a, b, c):
d = collections.deque([a])
d.extendleft(map(lambda x: x + 1, [b, c]))
return d
@make_test
def test_map_str_join(a):
return "".join(map(lambda x: x, ["a", "b", "c"])), a + 1
def test_map_with_graph_break(self):
def f(a):
a += 1
def g(x):
nonlocal a
a += 1
return x + 1
m = map(g, [1, 2, 3, 4, 5])
a += next(m) # won't graph break
torch._dynamo.graph_break()
a += next(m) # will graph break
return a
cnts = torch._dynamo.testing.CompileCounter()
opt_f = torch.compile(f, backend=cnts)
self.assertEqual(f(torch.ones(3, 3)), opt_f(torch.ones(3, 3)))
self.assertEqual(cnts.frame_count, 3)
def test_map_reconstruct(self):
def fn(a):
return map(lambda x: x[0] + x[1], zip([1, 2, 3], [1, 2, 3])), a + 1
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
m = opt_fn(torch.ones(3, 3))[0]
self.assertIsInstance(m, map)
self.assertEqual(list(m), list(fn(torch.ones(3, 3))[0]))
def test_zip_reconstruct(self):
def fn(a):
return zip([1, 2, 3], map(lambda x: x + 1, [1, 2, 3])), a + 1
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
m = opt_fn(torch.ones(3, 3))[0]
self.assertIsInstance(m, zip)
self.assertEqual(list(m), list(fn(torch.ones(3, 3))[0]))
@make_test
def test_map_partial_unpack(a, b):
y = 1
def f(x):
nonlocal y
y += 1
return x
l = list(zip([a, b], map(f, [1, 2, 3, 4])))
return a + y
@make_test
def test_map_call_function_ex(a, b):
def f(x, y):
return x + y
return f(*map(lambda x: x + 1, [a, b]))
@make_test
def test_map_unpack_twice(a, b):
m = map(lambda x: x + 1, [a, b])
l1 = list(m)
l2 = list(m)
return l1, l2
@make_test
def test_enumerate(a, b):
return list(enumerate([a, b], start=1)), a + 1
@make_test
def test_map_enumerate(a, b):
return list(enumerate(map(lambda x: x + 1, [a, b]), start=1)), a + 1
@make_test
def test_map_infinite(a, b):
return list(map(lambda x, y: x + y, [a, b], itertools.count(3)))
@make_test
def test_map_unpack_vars(a, b):
x, y = map(lambda x: x + 1, [a, b])
return x + y
def test_enumerate_custom(self):
class MyClass:
def __iter__(self):
self.a = 1
return self
def __next__(self):
if self.a > 3:
raise StopIteration
self.a += 1
return self.a
def fn(x):
for i, it in enumerate(MyClass()):
x += i + it
return x
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
self.assertEqual(fn(torch.ones(3, 3)), opt_fn(torch.ones(3, 3)))
def test_enumerate_reconstruct(self):
def fn(a, b):
return enumerate([a, b], start=1)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
inps = (torch.randn(3, 3), torch.randn(3, 3))
it1 = fn(*inps)
it2 = opt_fn(*inps)
self.assertIsInstance(it2, enumerate)
self.assertEqual(list(it1), list(it2))
def udf_mul(x, y): def udf_mul(x, y):
return x * y return x * y
@ -3670,10 +3878,16 @@ class DefaultsTests(torch._dynamo.test_case.TestCase):
with self.assertRaisesRegex(torch._dynamo.exc.UserError, "zip()"): with self.assertRaisesRegex(torch._dynamo.exc.UserError, "zip()"):
nopython_fn(x, ys[:1], zs) nopython_fn(x, ys[:1], zs)
with self.assertRaisesRegex(torch._dynamo.exc.UserError, "zip()"):
nopython_fn(x, ys, zs[:1])
# Should cause fallback if allow graph break # Should cause fallback if allow graph break
with self.assertRaisesRegex(ValueError, "zip()"): with self.assertRaisesRegex(ValueError, "zip()"):
opt_fn(x, ys[:1], zs) opt_fn(x, ys[:1], zs)
with self.assertRaisesRegex(ValueError, "zip()"):
opt_fn(x, ys, zs[:1])
def test_fn_with_attr(self): def test_fn_with_attr(self):
def fn(x): def fn(x):
if fn.pred: if fn.pred:

View File

@ -5476,15 +5476,17 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
return x, y return x, y
def g(x, y): def g(x, y):
return tuple(map(f, x, y)) return map(f, x, y)
opt_g = torch.compile(g, fullgraph=True, backend="eager") opt_g = torch.compile(g, fullgraph=True, backend="eager")
inps = gen_inps(3, 3) inps = gen_inps(3, 3)
self.assertEqual(g(*inps), opt_g(*inps)) self.assertEqual(type(g(*inps)), type(opt_g(*inps)))
self.assertEqual(tuple(g(*inps)), tuple(opt_g(*inps)))
inps = gen_inps(3, 5) inps = gen_inps(3, 5)
self.assertEqual(g(*inps), opt_g(*inps)) self.assertEqual(type(g(*inps)), type(opt_g(*inps)))
self.assertEqual(tuple(g(*inps)), tuple(opt_g(*inps)))
def test_staticmethod_allow_in_graph(self): def test_staticmethod_allow_in_graph(self):
class MyClass: class MyClass:

View File

@ -1663,8 +1663,8 @@ class InstructionTranslatorBase(
if not isinstance( if not isinstance(
argsvars, BaseListVariable argsvars, BaseListVariable
) and argsvars.has_unpack_var_sequence(self): ) and argsvars.has_force_unpack_var_sequence(self):
argsvars = TupleVariable(argsvars.unpack_var_sequence(self)) argsvars = TupleVariable(argsvars.force_unpack_var_sequence(self))
# Unpack for cases like fn(**obj) where obj is a map # Unpack for cases like fn(**obj) where obj is a map
if isinstance(kwargsvars, UserDefinedObjectVariable): if isinstance(kwargsvars, UserDefinedObjectVariable):
@ -1833,7 +1833,7 @@ class InstructionTranslatorBase(
items = [] items = []
for seq in seqs: for seq in seqs:
try: try:
items.extend(seq.unpack_var_sequence(self)) items.extend(seq.force_unpack_var_sequence(self))
except NotImplementedError: except NotImplementedError:
unimplemented(f"BUILD_LIST_UNPACK {seq}") unimplemented(f"BUILD_LIST_UNPACK {seq}")
self.push(cls(items, mutable_local=MutableLocal())) self.push(cls(items, mutable_local=MutableLocal()))
@ -1871,7 +1871,7 @@ class InstructionTranslatorBase(
assert isinstance(keys, TupleVariable) assert isinstance(keys, TupleVariable)
assert keys.is_python_constant() assert keys.is_python_constant()
keys = keys.unpack_var_sequence(self) keys = keys.force_unpack_var_sequence(self)
assert len(keys) == len(values) assert len(keys) == len(values)
self.push( self.push(
@ -1961,8 +1961,8 @@ class InstructionTranslatorBase(
# x, y = a.shape # x, y = a.shape
proxy = getattr(seq.obj.as_proxy(), seq.name) proxy = getattr(seq.obj.as_proxy(), seq.name)
val = [wrap_fx_proxy(self, proxy[i]) for i in range(inst.argval)] val = [wrap_fx_proxy(self, proxy[i]) for i in range(inst.argval)]
elif seq.has_unpack_var_sequence(self): elif seq.has_force_unpack_var_sequence(self):
val = seq.unpack_var_sequence(self) val = seq.force_unpack_var_sequence(self)
else: else:
unimplemented(f"UNPACK_SEQUENCE {seq}") unimplemented(f"UNPACK_SEQUENCE {seq}")
if len(val) != inst.argval: if len(val) != inst.argval:
@ -1975,8 +1975,8 @@ class InstructionTranslatorBase(
prefix = inst.argval & 0xFF # low byte prefix = inst.argval & 0xFF # low byte
suffix = inst.argval >> 8 # high byte suffix = inst.argval >> 8 # high byte
seq = self.pop() seq = self.pop()
if seq.has_unpack_var_sequence(self): if seq.has_force_unpack_var_sequence(self):
vals = list(seq.unpack_var_sequence(self)) vals = list(seq.force_unpack_var_sequence(self))
assert len(vals) >= prefix + suffix assert len(vals) >= prefix + suffix
vals_prefix = vals[:prefix] vals_prefix = vals[:prefix]
vals_list = vals[prefix : len(vals) - suffix] vals_list = vals[prefix : len(vals) - suffix]
@ -2400,7 +2400,7 @@ class InstructionTranslatorBase(
self.UNARY_POSITIVE(inst) self.UNARY_POSITIVE(inst)
elif inst.argval == 6: elif inst.argval == 6:
# INTRINSIC_LIST_TO_TUPLE # INTRINSIC_LIST_TO_TUPLE
self.push(TupleVariable(self.pop().unpack_var_sequence(self))) self.push(TupleVariable(self.pop().force_unpack_var_sequence(self)))
else: else:
unimplemented(f"missing CALL_INTRINSIC_1 operand {inst.argval}") unimplemented(f"missing CALL_INTRINSIC_1 operand {inst.argval}")

View File

@ -1608,8 +1608,12 @@ def same(
"""Check correctness to see if ref and res match""" """Check correctness to see if ref and res match"""
if fp64_ref is None: if fp64_ref is None:
fp64_ref = ref fp64_ref = ref
if isinstance(ref, (list, tuple, torch.nn.ParameterList, torch.Size)): if isinstance(
assert isinstance(res, (list, tuple)), f"type mismatch {type(ref)} {type(res)}" ref, (list, tuple, collections.deque, torch.nn.ParameterList, torch.Size)
):
assert isinstance(
res, (list, tuple, collections.deque)
), f"type mismatch {type(ref)} {type(res)}"
if len(ref) != len(res): if len(ref) != len(res):
log_error("Length mismatch") log_error("Length mismatch")
return False return False

View File

@ -46,7 +46,9 @@ from .iter import (
CycleIteratorVariable, CycleIteratorVariable,
IteratorVariable, IteratorVariable,
ItertoolsVariable, ItertoolsVariable,
MapVariable,
RepeatIteratorVariable, RepeatIteratorVariable,
ZipVariable,
) )
from .lazy import LazyVariableTracker from .lazy import LazyVariableTracker
from .lists import ( from .lists import (

View File

@ -289,6 +289,15 @@ class VariableTracker(metaclass=VariableTrackerMeta):
def unpack_var_sequence(self, tx) -> List["VariableTracker"]: def unpack_var_sequence(self, tx) -> List["VariableTracker"]:
raise NotImplementedError raise NotImplementedError
def force_unpack_var_sequence(self, tx) -> List["VariableTracker"]:
# like unpack_var_sequence, but should only be used when it is
# safe to eagerly (vs. lazily) unpack this variable.
# e.g. map(f, x) is normally evaluated lazily but sometimes
# we want to force eager unpacking, e.g. when converting to a list.
# NOTE: this method is allowed to mutate the VariableTracker, so
# it should only be called once.
return self.unpack_var_sequence(tx)
def has_unpack_var_sequence(self, tx) -> bool: def has_unpack_var_sequence(self, tx) -> bool:
try: try:
self.unpack_var_sequence(tx) self.unpack_var_sequence(tx)
@ -296,6 +305,10 @@ class VariableTracker(metaclass=VariableTrackerMeta):
except NotImplementedError: except NotImplementedError:
return False return False
# NB: don't call force_unpack_var_sequence, especially if it mutates!
def has_force_unpack_var_sequence(self, tx) -> bool:
return self.has_unpack_var_sequence(tx)
def inspect_parameter_names(self) -> List[str]: def inspect_parameter_names(self) -> List[str]:
unimplemented(f"inspect_parameter_names: {self}") unimplemented(f"inspect_parameter_names: {self}")

View File

@ -1058,9 +1058,8 @@ class BuiltinVariable(VariableTracker):
return tx.inline_user_function_return(user_func_variable, [arg], {}) return tx.inline_user_function_return(user_func_variable, [arg], {})
def _call_min_max(self, tx: "InstructionTranslator", *args): def _call_min_max(self, tx: "InstructionTranslator", *args):
if len(args) == 1 and args[0].has_unpack_var_sequence(tx): if len(args) == 1 and args[0].has_force_unpack_var_sequence(tx):
# expand iterable items = args[0].force_unpack_var_sequence(tx)
items = args[0].unpack_var_sequence(tx)
return self._call_min_max_seq(tx, items) return self._call_min_max_seq(tx, items)
elif len(args) == 2: elif len(args) == 2:
return self._call_min_max_binary(tx, args[0], args[1]) return self._call_min_max_binary(tx, args[0], args[1])
@ -1075,6 +1074,10 @@ class BuiltinVariable(VariableTracker):
return functools.reduce(functools.partial(self._call_min_max_binary, tx), items) return functools.reduce(functools.partial(self._call_min_max_binary, tx), items)
def _call_min_max_binary(self, tx: "InstructionTranslator", a, b): def _call_min_max_binary(self, tx: "InstructionTranslator", a, b):
if a is None or b is None:
# a or b could be none if we reduce and _call_min_max_binary failed
# to return something
return
if self.tensor_args(a, b): if self.tensor_args(a, b):
if not isinstance(a, variables.TensorVariable): if not isinstance(a, variables.TensorVariable):
a, b = b, a a, b = b, a
@ -1223,17 +1226,15 @@ class BuiltinVariable(VariableTracker):
), ),
) )
# NOTE must handle IteratorVariable separately!
def _call_iter_tuple_list( def _call_iter_tuple_list(
self, tx: "InstructionTranslator", obj=None, *args, **kwargs self, tx: "InstructionTranslator", obj=None, *args, **kwargs
): ):
assert not isinstance(obj, variables.IteratorVariable)
if self._dynamic_args(*args, **kwargs): if self._dynamic_args(*args, **kwargs):
return self._dyn_proxy(tx, *args, **kwargs) return self._dyn_proxy(tx, *args, **kwargs)
if isinstance(obj, variables.IteratorVariable):
# For non-list iterators, we will guard on vars that
# determine the control flow
return obj
cls = variables.BaseListVariable.cls_for(self.fn) cls = variables.BaseListVariable.cls_for(self.fn)
if obj is None: if obj is None:
return cls( return cls(
@ -1261,7 +1262,20 @@ class BuiltinVariable(VariableTracker):
mutable_local=MutableLocal(), mutable_local=MutableLocal(),
) )
def _call_tuple_list(self, tx, obj=None, *args, **kwargs):
if isinstance(obj, variables.IteratorVariable):
cls = variables.BaseListVariable.cls_for(self.fn)
return cls(
list(obj.force_unpack_var_sequence(tx)),
mutable_local=MutableLocal(),
)
else:
return self._call_iter_tuple_list(tx, obj, *args, **kwargs)
def call_iter(self, tx: "InstructionTranslator", obj, *args, **kwargs): def call_iter(self, tx: "InstructionTranslator", obj, *args, **kwargs):
if isinstance(obj, variables.IteratorVariable):
ret = obj
else:
# Handle the case where we are iterating over a tuple, list or iterator # Handle the case where we are iterating over a tuple, list or iterator
ret = self._call_iter_tuple_list(tx, obj, *args, **kwargs) ret = self._call_iter_tuple_list(tx, obj, *args, **kwargs)
@ -1272,8 +1286,8 @@ class BuiltinVariable(VariableTracker):
return obj.call_method(tx, "__iter__", args, kwargs) return obj.call_method(tx, "__iter__", args, kwargs)
return ret return ret
call_tuple = _call_iter_tuple_list call_tuple = _call_tuple_list
call_list = _call_iter_tuple_list call_list = _call_tuple_list
def call_callable(self, tx: "InstructionTranslator", arg): def call_callable(self, tx: "InstructionTranslator", arg):
from .functions import BaseUserFunctionVariable from .functions import BaseUserFunctionVariable
@ -1331,10 +1345,12 @@ class BuiltinVariable(VariableTracker):
ListVariable, ListVariable,
TupleVariable, TupleVariable,
ListIteratorVariable, ListIteratorVariable,
variables.IteratorVariable,
), ),
): ):
items = dict( items = dict(
x.unpack_var_sequence(tx) for x in arg.unpack_var_sequence(tx) x.force_unpack_var_sequence(tx)
for x in arg.force_unpack_var_sequence(tx)
) )
return ConstDictVariable(items, user_cls, mutable_local=MutableLocal()) return ConstDictVariable(items, user_cls, mutable_local=MutableLocal())
elif isinstance(arg, variables.MutableMappingVariable): elif isinstance(arg, variables.MutableMappingVariable):
@ -1391,10 +1407,9 @@ class BuiltinVariable(VariableTracker):
return DictVariableType( return DictVariableType(
dict.fromkeys(arg, value), user_cls, mutable_local=MutableLocal() dict.fromkeys(arg, value), user_cls, mutable_local=MutableLocal()
) )
elif arg.has_unpack_var_sequence(tx) and all( elif arg.has_force_unpack_var_sequence(tx):
is_hashable(v) for v in arg.unpack_var_sequence(tx) keys = arg.force_unpack_var_sequence(tx)
): if all(is_hashable(v) for v in keys):
keys = arg.unpack_var_sequence(tx)
return DictVariableType( return DictVariableType(
dict.fromkeys(keys, value), user_cls, mutable_local=MutableLocal() dict.fromkeys(keys, value), user_cls, mutable_local=MutableLocal()
) )
@ -1409,8 +1424,8 @@ class BuiltinVariable(VariableTracker):
arg = args[0] arg = args[0]
if isinstance(arg, variables.SetVariable): if isinstance(arg, variables.SetVariable):
return arg.clone(mutable_local=MutableLocal()) return arg.clone(mutable_local=MutableLocal())
elif arg.has_unpack_var_sequence(tx): elif arg.has_force_unpack_var_sequence(tx):
items = arg.unpack_var_sequence(tx) items = arg.force_unpack_var_sequence(tx)
return SetVariable(items, mutable_local=MutableLocal()) return SetVariable(items, mutable_local=MutableLocal())
elif isinstance(arg, variables.UserDefinedObjectVariable) and isinstance( elif isinstance(arg, variables.UserDefinedObjectVariable) and isinstance(
arg.value, KeysView arg.value, KeysView
@ -1443,16 +1458,12 @@ class BuiltinVariable(VariableTracker):
def call_zip(self, tx: "InstructionTranslator", *args, **kwargs): def call_zip(self, tx: "InstructionTranslator", *args, **kwargs):
if kwargs: if kwargs:
assert len(kwargs) == 1 and "strict" in kwargs assert len(kwargs) == 1 and "strict" in kwargs
if all(x.has_unpack_var_sequence(tx) for x in args): strict = kwargs.pop("strict", False)
unpacked = [arg.unpack_var_sequence(tx) for arg in args] args = [
if kwargs.pop("strict", False) and len(unpacked) > 0: arg.unpack_var_sequence(tx) if arg.has_unpack_var_sequence(tx) else arg
if not all(len(u) == len(unpacked[0]) for u in unpacked): for arg in args
raise UserError( ]
ValueError, return variables.ZipVariable(args, strict=strict, mutable_local=MutableLocal())
"zip() has one argument of len differing from others",
)
items = [variables.TupleVariable(list(item)) for item in zip(*unpacked)]
return variables.TupleVariable(items)
def call_len(self, tx: "InstructionTranslator", *args, **kwargs): def call_len(self, tx: "InstructionTranslator", *args, **kwargs):
return args[0].call_method(tx, "__len__", args[1:], kwargs) return args[0].call_method(tx, "__len__", args[1:], kwargs)
@ -1553,10 +1564,11 @@ class BuiltinVariable(VariableTracker):
return obj.call_hasattr(tx, name) return obj.call_hasattr(tx, name)
def call_map(self, tx: "InstructionTranslator", fn, *seqs): def call_map(self, tx: "InstructionTranslator", fn, *seqs):
if all(seq.has_unpack_var_sequence(tx) for seq in seqs): seqs = [
unpacked = [seq.unpack_var_sequence(tx) for seq in seqs] seq.unpack_var_sequence(tx) if seq.has_unpack_var_sequence(tx) else seq
items = [fn.call_function(tx, list(args), {}) for args in zip(*unpacked)] for seq in seqs
return variables.TupleVariable(items) ]
return variables.MapVariable(fn, seqs, mutable_local=MutableLocal())
def call_filter(self, tx: "InstructionTranslator", fn, seq): def call_filter(self, tx: "InstructionTranslator", fn, seq):
if seq.has_unpack_var_sequence(tx): if seq.has_unpack_var_sequence(tx):
@ -1589,10 +1601,10 @@ class BuiltinVariable(VariableTracker):
return variables.ConstantVariable.create( return variables.ConstantVariable.create(
sum((x.value for x in seq.items), start=start.value), sum((x.value for x in seq.items), start=start.value),
) )
if seq.has_unpack_var_sequence(tx): if seq.has_force_unpack_var_sequence(tx):
if start is self._SENTINEL: if start is self._SENTINEL:
start = variables.ConstantVariable.create(0) start = variables.ConstantVariable.create(0)
items = seq.unpack_var_sequence(tx) items = seq.force_unpack_var_sequence(tx)
return BuiltinVariable(functools.reduce).call_function( return BuiltinVariable(functools.reduce).call_function(
tx, tx,
[ [
@ -1606,8 +1618,8 @@ class BuiltinVariable(VariableTracker):
def call_reduce( def call_reduce(
self, tx: "InstructionTranslator", function, iterable, initial=_SENTINEL self, tx: "InstructionTranslator", function, iterable, initial=_SENTINEL
): ):
if iterable.has_unpack_var_sequence(tx): if iterable.has_force_unpack_var_sequence(tx):
items = iterable.unpack_var_sequence(tx) items = iterable.force_unpack_var_sequence(tx)
if initial is self._SENTINEL: if initial is self._SENTINEL:
value, items = items[0], items[1:] value, items = items[0], items[1:]
else: else:
@ -1903,11 +1915,12 @@ class BuiltinVariable(VariableTracker):
return variables.TupleVariable(items) return variables.TupleVariable(items)
def call_sorted(self, tx: "InstructionTranslator", obj: VariableTracker, **kwargs): def call_sorted(self, tx: "InstructionTranslator", obj: VariableTracker, **kwargs):
if ( if obj.has_force_unpack_var_sequence(tx) and not isinstance(
obj.has_unpack_var_sequence(tx) obj, variables.TensorVariable
and not isinstance(obj, variables.TensorVariable)
and all(x.is_python_constant() for x in obj.unpack_var_sequence(tx))
): ):
unpacked = obj.force_unpack_var_sequence(tx)
if not all(x.is_python_constant() for x in unpacked):
return
function = kwargs.pop("key", None) function = kwargs.pop("key", None)
reverse = kwargs.pop( reverse = kwargs.pop(
"reverse", ConstantVariable.create(False) "reverse", ConstantVariable.create(False)
@ -1915,7 +1928,7 @@ class BuiltinVariable(VariableTracker):
assert len(kwargs) == 0 assert len(kwargs) == 0
if function: if function:
items = sorted( items = sorted(
obj.unpack_var_sequence(tx), unpacked,
key=lambda x: function.call_function( key=lambda x: function.call_function(
tx, [x], {} tx, [x], {}
).as_python_constant(), ).as_python_constant(),
@ -1923,7 +1936,7 @@ class BuiltinVariable(VariableTracker):
) )
else: else:
items = sorted( items = sorted(
obj.unpack_var_sequence(tx), unpacked,
key=lambda x: x.as_python_constant(), key=lambda x: x.as_python_constant(),
reverse=reverse, reverse=reverse,
) )

View File

@ -145,6 +145,14 @@ class ConstantVariable(VariableTracker):
return variables.BuiltinVariable(str.format).call_function( return variables.BuiltinVariable(str.format).call_function(
tx, [self, *args], kwargs tx, [self, *args], kwargs
) )
elif name == "join" and istype(self.value, str):
assert len(args) == 1 and len(kwargs) == 0
arg_unpacked = args[0].force_unpack_var_sequence(tx)
try:
arg_const = [x.as_python_constant() for x in arg_unpacked]
return ConstantVariable.create(self.value.join(arg_const))
except NotImplementedError:
return super().call_method(tx, name, args, kwargs)
if any(isinstance(x, SymNodeVariable) for x in args): if any(isinstance(x, SymNodeVariable) for x in args):
# Promote to SymNodeVariable for operations involving dynamic shapes. # Promote to SymNodeVariable for operations involving dynamic shapes.

View File

@ -314,6 +314,7 @@ class ConstDictVariable(VariableTracker):
ListVariable, ListVariable,
TupleVariable, TupleVariable,
ListIteratorVariable, ListIteratorVariable,
variables.IteratorVariable,
UserDefinedObjectVariable, UserDefinedObjectVariable,
), ),
) )

View File

@ -2,14 +2,17 @@
import itertools import itertools
import operator import operator
from typing import Dict, List, Optional, TYPE_CHECKING import sys
from typing import Dict, List, Optional, TYPE_CHECKING, Union
from .. import polyfills, variables from .. import polyfills, variables
from ..bytecode_transformation import create_call_function, create_instruction
from ..exc import ( from ..exc import (
handle_observed_exception, handle_observed_exception,
ObservedUserStopIteration, ObservedUserStopIteration,
raise_observed_exception, raise_observed_exception,
unimplemented, unimplemented,
UserError,
) )
from .base import MutableLocal, VariableTracker from .base import MutableLocal, VariableTracker
from .constant import ConstantVariable from .constant import ConstantVariable
@ -197,6 +200,25 @@ class IteratorVariable(VariableTracker):
def next_variable(self, tx): def next_variable(self, tx):
unimplemented("abstract method, must implement") unimplemented("abstract method, must implement")
# NOTE: only call when unpacking this iterator safely done eagerly!
# Normally, iterators are accessed lazily.
# Example of safe eager unpacking: list(map(f, seq))
# Example of unsafe eager unpacking: list(islice(map(f, seq), 5))
def force_unpack_var_sequence(self, tx) -> List[VariableTracker]:
result = []
while True:
try:
result.append(self.next_variable(tx))
except ObservedUserStopIteration:
handle_observed_exception(tx)
break
return result
# don't call force_unpack_var_sequence since it can mutate
# IteratorVariable state!
def has_force_unpack_var_sequence(self, tx) -> bool:
return True
class RepeatIteratorVariable(IteratorVariable): class RepeatIteratorVariable(IteratorVariable):
def __init__(self, item: VariableTracker, **kwargs) -> None: def __init__(self, item: VariableTracker, **kwargs) -> None:
@ -207,6 +229,18 @@ class RepeatIteratorVariable(IteratorVariable):
def next_variable(self, tx): def next_variable(self, tx):
return self.item return self.item
def reconstruct(self, codegen):
codegen.add_push_null(
lambda: codegen.extend_output(
[
codegen.create_load_python_module(itertools),
codegen.create_load_attr("repeat"),
]
)
)
codegen(self.item)
codegen.extend_output(create_call_function(1, False))
class CountIteratorVariable(IteratorVariable): class CountIteratorVariable(IteratorVariable):
def __init__(self, item: int = 0, step: int = 1, **kwargs) -> None: def __init__(self, item: int = 0, step: int = 1, **kwargs) -> None:
@ -220,10 +254,23 @@ class CountIteratorVariable(IteratorVariable):
def next_variable(self, tx): def next_variable(self, tx):
assert self.mutable_local assert self.mutable_local
old_item = self.item
tx.output.side_effects.mutation(self) tx.output.side_effects.mutation(self)
next_item = self.item.call_method(tx, "__add__", [self.step], {}) self.item = self.item.call_method(tx, "__add__", [self.step], {})
self.item = next_item return old_item
return self.item
def reconstruct(self, codegen):
codegen.add_push_null(
lambda: codegen.extend_output(
[
codegen.create_load_python_module(itertools),
codegen.create_load_attr("count"),
]
)
)
codegen(self.item)
codegen(self.step)
codegen.extend_output(create_call_function(2, False))
class CycleIteratorVariable(IteratorVariable): class CycleIteratorVariable(IteratorVariable):
@ -269,3 +316,160 @@ class CycleIteratorVariable(IteratorVariable):
return self.item return self.item
else: else:
raise_observed_exception(StopIteration, tx, self) raise_observed_exception(StopIteration, tx, self)
class ZipVariable(IteratorVariable):
"""
Represents zip(*iterables)
"""
_nonvar_fields = {
"index",
"strict",
*IteratorVariable._nonvar_fields,
}
def __init__(
self,
iterables: List[Union[List[VariableTracker], VariableTracker]],
strict: bool = False,
**kwargs,
) -> None:
super().__init__(**kwargs)
assert isinstance(iterables, list)
# can be list[Variable] or VariableTracker (with next_variable implemented)
self.iterables = iterables
self.index = 0
self.strict = strict
def python_type(self):
return zip
def has_unpack_var_sequence(self, tx) -> bool:
return all(
isinstance(it, list) or it.has_unpack_var_sequence(tx)
for it in self.iterables
)
def unpack_var_sequence(self, tx) -> List["VariableTracker"]:
assert self.has_unpack_var_sequence(tx)
iterables = []
for it in self.iterables:
if isinstance(it, list):
iterables.append(it[self.index :])
else:
iterables.append(it.unpack_var_sequence(tx))
kwargs = {"strict": self.strict} if self.strict else {}
zipped = zip(*iterables, **kwargs)
return [variables.TupleVariable(list(var)) for var in zipped]
def next_variable(self, tx):
assert self.mutable_local
old_index = self.index
args = []
def get_item(it):
if isinstance(it, list):
if old_index >= len(it):
raise_observed_exception(StopIteration, tx, self)
return it[old_index]
else:
return it.next_variable(tx)
try:
for idx, it in enumerate(self.iterables):
args.append(get_item(it))
except ObservedUserStopIteration:
if self.strict:
if idx == 0:
# all other iterables should be exhausted
for it in self.iterables:
try:
get_item(it)
except ObservedUserStopIteration:
handle_observed_exception(tx)
continue
# no ObservedUserStopIteration - fall through to UserError
break
else:
# all iterables exhausted, raise original error
raise
handle_observed_exception(tx)
raise UserError(
ValueError,
"zip() has one argument of len differing from others",
) from None
raise
tx.output.side_effects.mutation(self)
self.index += 1
return variables.TupleVariable(args)
def reconstruct_items(self, codegen):
for it in self.iterables:
if isinstance(it, list):
remaining_items = it[self.index :]
codegen.foreach(remaining_items)
codegen.append_output(
create_instruction("BUILD_TUPLE", arg=len(remaining_items))
)
else:
codegen(it)
def reconstruct(self, codegen):
codegen.add_push_null(
lambda: codegen.load_import_from("builtins", "zip"), call_function_ex=True
)
self.reconstruct_items(codegen)
codegen.append_output(
create_instruction("BUILD_TUPLE", arg=len(self.iterables))
)
if sys.version_info >= (3, 10):
codegen.extend_output(
[
codegen.create_load_const("strict"),
codegen.create_load_const(self.strict),
create_instruction("BUILD_MAP", arg=1),
create_instruction("CALL_FUNCTION_EX", arg=1),
]
)
else:
codegen.append_output(create_instruction("CALL_FUNCTION_EX", arg=0))
class MapVariable(ZipVariable):
"""
Represents map(fn, *iterables)
"""
def __init__(
self,
fn: VariableTracker,
iterables: List[Union[List[VariableTracker], VariableTracker]],
**kwargs,
) -> None:
super().__init__(iterables, **kwargs)
self.fn = fn
def python_type(self):
return map
def has_unpack_var_sequence(self, tx) -> bool:
return False
def next_variable(self, tx):
args = super().next_variable(tx)
return self.fn.call_function(tx, args.items, {})
def reconstruct(self, codegen):
codegen.add_push_null(
lambda: codegen.load_import_from("builtins", "map"), call_function_ex=True
)
codegen(self.fn)
self.reconstruct_items(codegen)
codegen.extend_output(
[
create_instruction("BUILD_TUPLE", arg=len(self.iterables) + 1),
create_instruction("CALL_FUNCTION_EX", arg=0),
]
)

View File

@ -29,6 +29,7 @@ from ..utils import (
from .base import MutableLocal, VariableTracker from .base import MutableLocal, VariableTracker
from .constant import ConstantVariable from .constant import ConstantVariable
from .functions import UserFunctionVariable, UserMethodVariable from .functions import UserFunctionVariable, UserMethodVariable
from .iter import IteratorVariable
if TYPE_CHECKING: if TYPE_CHECKING:
@ -334,11 +335,11 @@ class CommonListMethodsVariable(BaseListVariable):
name == "extend" name == "extend"
and self.mutable_local and self.mutable_local
and args and args
and args[0].has_unpack_var_sequence(tx) and args[0].has_force_unpack_var_sequence(tx)
): ):
assert not kwargs assert not kwargs
(arg,) = args (arg,) = args
seq = arg.unpack_var_sequence(tx) seq = arg.force_unpack_var_sequence(tx)
tx.output.side_effects.mutation(self) tx.output.side_effects.mutation(self)
self.items.extend(seq) self.items.extend(seq)
return ConstantVariable.create(None) return ConstantVariable.create(None)
@ -422,11 +423,13 @@ class ListVariable(CommonListMethodsVariable):
key, value = args key, value = args
tx.output.side_effects.mutation(self) tx.output.side_effects.mutation(self)
if isinstance(key, SliceVariable): if isinstance(key, SliceVariable):
if not value.has_unpack_var_sequence(tx): if not value.has_force_unpack_var_sequence(tx):
unimplemented( unimplemented(
f"Missing dynamo support for expanding {value} into a list for slice assignment." f"Missing dynamo support for expanding {value} into a list for slice assignment."
) )
self.items[key.as_python_constant()] = value.unpack_var_sequence(tx) self.items[key.as_python_constant()] = value.force_unpack_var_sequence(
tx
)
else: else:
self.items[key.as_python_constant()] = value self.items[key.as_python_constant()] = value
return ConstantVariable.create(None) return ConstantVariable.create(None)
@ -464,7 +467,12 @@ class DequeVariable(CommonListMethodsVariable):
) )
) )
codegen.foreach(self.items) codegen.foreach(self.items)
codegen.extend_output(create_call_function(len(self.items), False)) codegen.extend_output(
[
create_instruction("BUILD_LIST", arg=len(self.items)),
*create_call_function(1, False),
]
)
def call_method( def call_method(
self, self,
@ -487,11 +495,15 @@ class DequeVariable(CommonListMethodsVariable):
tx.output.side_effects.mutation(self) tx.output.side_effects.mutation(self)
self.items[key.as_python_constant()] = value self.items[key.as_python_constant()] = value
return ConstantVariable.create(None) return ConstantVariable.create(None)
elif name == "extendleft" and self.mutable_local: elif (
name == "extendleft"
and self.mutable_local
and args[0].has_force_unpack_var_sequence(tx)
):
assert not kwargs assert not kwargs
(arg,) = args (arg,) = args
prefix = arg.unpack_var_sequence(tx) prefix = arg.force_unpack_var_sequence(tx)
prefix.reverse() prefix.reverse()
tx.output.side_effects.mutation(self) tx.output.side_effects.mutation(self)
self.items = prefix + list(self.items) self.items = prefix + list(self.items)
@ -802,10 +814,10 @@ class SliceVariable(BaseListVariable):
return self.items[fields.index(name)] return self.items[fields.index(name)]
class ListIteratorVariable(VariableTracker): class ListIteratorVariable(IteratorVariable):
_nonvar_fields = { _nonvar_fields = {
"index", "index",
*VariableTracker._nonvar_fields, *IteratorVariable._nonvar_fields,
} }
def __init__(self, items, index: int = 0, **kwargs) -> None: def __init__(self, items, index: int = 0, **kwargs) -> None:
@ -856,6 +868,9 @@ class ListIteratorVariable(VariableTracker):
def unpack_var_sequence(self, tx): def unpack_var_sequence(self, tx):
return list(self.items[self.index :]) return list(self.items[self.index :])
def force_unpack_var_sequence(self, tx) -> List[VariableTracker]:
return self.unpack_var_sequence(tx)
def reconstruct(self, codegen): def reconstruct(self, codegen):
remaining_items = self.items[self.index :] remaining_items = self.items[self.index :]
codegen.foreach(remaining_items) codegen.foreach(remaining_items)

View File

@ -379,8 +379,8 @@ class UserDefinedClassVariable(UserDefinedVariable):
elif self.value is collections.deque and not kwargs: elif self.value is collections.deque and not kwargs:
if len(args) == 0: if len(args) == 0:
items = [] items = []
elif len(args) == 1 and args[0].has_unpack_var_sequence(tx): elif len(args) == 1 and args[0].has_force_unpack_var_sequence(tx):
items = args[0].unpack_var_sequence(tx) items = args[0].force_unpack_var_sequence(tx)
else: else:
unimplemented("deque() with more than 1 arg not supported") unimplemented("deque() with more than 1 arg not supported")
return variables.lists.DequeVariable(items, mutable_local=MutableLocal()) return variables.lists.DequeVariable(items, mutable_local=MutableLocal())
@ -749,7 +749,7 @@ class UserDefinedObjectVariable(UserDefinedVariable):
assert not (args or kwargs) assert not (args or kwargs)
items = [] items = []
keys = self.call_method(tx, "keys", [], {}) keys = self.call_method(tx, "keys", [], {})
for key in keys.unpack_var_sequence(tx): for key in keys.force_unpack_var_sequence(tx):
items.append( items.append(
TupleVariable( TupleVariable(
[key, self.odict_getitem(tx, key)], [key, self.odict_getitem(tx, key)],