mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[dynamo] reland map/zip iterator related changes (#135074)
Differential Revision: [D62211019](https://our.internmc.facebook.com/intern/diff/D62211019) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135074 Approved by: https://github.com/jansel, https://github.com/anijain2305, https://github.com/mlazos
This commit is contained in:
parent
22e1fb6faa
commit
a4030e37be
|
|
@ -239,6 +239,22 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
|
||||||
v = v + x
|
v = v + x
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
def test_itertools_reconstruct(self):
|
||||||
|
def fn(a):
|
||||||
|
it1 = itertools.repeat(1)
|
||||||
|
it2 = itertools.count(2)
|
||||||
|
for _ in range(3):
|
||||||
|
a += next(it1)
|
||||||
|
a += next(it2)
|
||||||
|
return it1, it2, a
|
||||||
|
|
||||||
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
||||||
|
i1, i2, a = fn(torch.ones(3, 3))
|
||||||
|
it1, it2, b = opt_fn(torch.ones(3, 3))
|
||||||
|
self.assertEqual(next(i1), next(it1))
|
||||||
|
self.assertEqual(next(i2), next(it2))
|
||||||
|
self.assertEqual(a, b)
|
||||||
|
|
||||||
@make_test
|
@make_test
|
||||||
def test_obj_eq(a, b):
|
def test_obj_eq(a, b):
|
||||||
v = a + b
|
v = a + b
|
||||||
|
|
@ -507,8 +523,7 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
|
||||||
empty = collections.deque()
|
empty = collections.deque()
|
||||||
d.extend(empty)
|
d.extend(empty)
|
||||||
|
|
||||||
# dynamo same() util doesn't support deque so just return a list
|
return d
|
||||||
return list(d)
|
|
||||||
|
|
||||||
@make_test
|
@make_test
|
||||||
def test_slice1(a):
|
def test_slice1(a):
|
||||||
|
|
@ -3115,6 +3130,199 @@ class GraphModule(torch.nn.Module):
|
||||||
fn(arr, np.s_[..., 1], np.array([3, 3])), np.array([[1, 3], [2, 3]])
|
fn(arr, np.s_[..., 1], np.array([3, 3])), np.array([[1, 3], [2, 3]])
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_map_return(self):
|
||||||
|
def fn(a, b):
|
||||||
|
return map(lambda x: x + 1, [a, b])
|
||||||
|
|
||||||
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
||||||
|
m = opt_fn(torch.randn(3, 3), torch.randn(3, 3))
|
||||||
|
self.assertIsInstance(m, map)
|
||||||
|
|
||||||
|
@make_test
|
||||||
|
def test_map_max(a, b):
|
||||||
|
return max(map(lambda x: x.sum(), [a, b]))
|
||||||
|
|
||||||
|
# max(map(...)) graph breaks
|
||||||
|
@unittest.expectedFailure
|
||||||
|
@make_test
|
||||||
|
def test_map_max_const(a):
|
||||||
|
return max(map(lambda x: x, [1, 2, 3])), a + 1
|
||||||
|
|
||||||
|
@make_test
|
||||||
|
def test_map_list(a, b):
|
||||||
|
return list(map(lambda x: x + 1, [a, b]))
|
||||||
|
|
||||||
|
@make_test
|
||||||
|
def test_map_tuple(a, b):
|
||||||
|
return tuple(map(lambda x: x + 1, [a, b]))
|
||||||
|
|
||||||
|
@make_test
|
||||||
|
def test_map_iter(a, b):
|
||||||
|
it = iter(map(lambda x: x + 1, [a, b]))
|
||||||
|
return next(it)
|
||||||
|
|
||||||
|
@make_test
|
||||||
|
def test_map_zip_dict(a):
|
||||||
|
d = dict(
|
||||||
|
zip(
|
||||||
|
map(lambda x: x + 1, [0, 1, 2]),
|
||||||
|
[map(lambda x: x - 1, [y]) for y in [3, 4, 5]],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return list(d[3])[0], a + 1 # noqa: RUF015
|
||||||
|
|
||||||
|
@make_test
|
||||||
|
def test_map_dict_fromkeys(a):
|
||||||
|
return dict.fromkeys(map(lambda x: x + 1, [0, 1])), a + 1
|
||||||
|
|
||||||
|
@make_test
|
||||||
|
def test_map_set(a):
|
||||||
|
return set(map(lambda x: x + 1, [0, 1])), a + 1
|
||||||
|
|
||||||
|
# test_map_sum defined earlier
|
||||||
|
|
||||||
|
@make_test
|
||||||
|
def test_map_reduce(a, b):
|
||||||
|
return functools.reduce(lambda x, y: x + y, map(lambda x: x + 1, [a, b]))
|
||||||
|
|
||||||
|
@make_test
|
||||||
|
def test_map_sorted(a):
|
||||||
|
return sorted(map(lambda x: x + 1, [0, 4, 3, 1, 2])), a + 1
|
||||||
|
|
||||||
|
@make_test
|
||||||
|
def test_map_list_extend(a, b, c):
|
||||||
|
l = [a]
|
||||||
|
l.extend(map(lambda x: x + 1, [b, c]))
|
||||||
|
return l
|
||||||
|
|
||||||
|
@make_test
|
||||||
|
def test_map_list_slice_assign(a, b, c, d, e):
|
||||||
|
l = [a, b, c]
|
||||||
|
l[1:2] = map(lambda x: x + 1, [d, e])
|
||||||
|
return l
|
||||||
|
|
||||||
|
@make_test
|
||||||
|
def test_map_deque_extendleft(a, b, c):
|
||||||
|
d = collections.deque([a])
|
||||||
|
d.extendleft(map(lambda x: x + 1, [b, c]))
|
||||||
|
return d
|
||||||
|
|
||||||
|
@make_test
|
||||||
|
def test_map_str_join(a):
|
||||||
|
return "".join(map(lambda x: x, ["a", "b", "c"])), a + 1
|
||||||
|
|
||||||
|
def test_map_with_graph_break(self):
|
||||||
|
def f(a):
|
||||||
|
a += 1
|
||||||
|
|
||||||
|
def g(x):
|
||||||
|
nonlocal a
|
||||||
|
a += 1
|
||||||
|
return x + 1
|
||||||
|
|
||||||
|
m = map(g, [1, 2, 3, 4, 5])
|
||||||
|
a += next(m) # won't graph break
|
||||||
|
torch._dynamo.graph_break()
|
||||||
|
a += next(m) # will graph break
|
||||||
|
return a
|
||||||
|
|
||||||
|
cnts = torch._dynamo.testing.CompileCounter()
|
||||||
|
opt_f = torch.compile(f, backend=cnts)
|
||||||
|
self.assertEqual(f(torch.ones(3, 3)), opt_f(torch.ones(3, 3)))
|
||||||
|
self.assertEqual(cnts.frame_count, 3)
|
||||||
|
|
||||||
|
def test_map_reconstruct(self):
|
||||||
|
def fn(a):
|
||||||
|
return map(lambda x: x[0] + x[1], zip([1, 2, 3], [1, 2, 3])), a + 1
|
||||||
|
|
||||||
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
||||||
|
m = opt_fn(torch.ones(3, 3))[0]
|
||||||
|
self.assertIsInstance(m, map)
|
||||||
|
self.assertEqual(list(m), list(fn(torch.ones(3, 3))[0]))
|
||||||
|
|
||||||
|
def test_zip_reconstruct(self):
|
||||||
|
def fn(a):
|
||||||
|
return zip([1, 2, 3], map(lambda x: x + 1, [1, 2, 3])), a + 1
|
||||||
|
|
||||||
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
||||||
|
m = opt_fn(torch.ones(3, 3))[0]
|
||||||
|
self.assertIsInstance(m, zip)
|
||||||
|
self.assertEqual(list(m), list(fn(torch.ones(3, 3))[0]))
|
||||||
|
|
||||||
|
@make_test
|
||||||
|
def test_map_partial_unpack(a, b):
|
||||||
|
y = 1
|
||||||
|
|
||||||
|
def f(x):
|
||||||
|
nonlocal y
|
||||||
|
y += 1
|
||||||
|
return x
|
||||||
|
|
||||||
|
l = list(zip([a, b], map(f, [1, 2, 3, 4])))
|
||||||
|
return a + y
|
||||||
|
|
||||||
|
@make_test
|
||||||
|
def test_map_call_function_ex(a, b):
|
||||||
|
def f(x, y):
|
||||||
|
return x + y
|
||||||
|
|
||||||
|
return f(*map(lambda x: x + 1, [a, b]))
|
||||||
|
|
||||||
|
@make_test
|
||||||
|
def test_map_unpack_twice(a, b):
|
||||||
|
m = map(lambda x: x + 1, [a, b])
|
||||||
|
l1 = list(m)
|
||||||
|
l2 = list(m)
|
||||||
|
return l1, l2
|
||||||
|
|
||||||
|
@make_test
|
||||||
|
def test_enumerate(a, b):
|
||||||
|
return list(enumerate([a, b], start=1)), a + 1
|
||||||
|
|
||||||
|
@make_test
|
||||||
|
def test_map_enumerate(a, b):
|
||||||
|
return list(enumerate(map(lambda x: x + 1, [a, b]), start=1)), a + 1
|
||||||
|
|
||||||
|
@make_test
|
||||||
|
def test_map_infinite(a, b):
|
||||||
|
return list(map(lambda x, y: x + y, [a, b], itertools.count(3)))
|
||||||
|
|
||||||
|
@make_test
|
||||||
|
def test_map_unpack_vars(a, b):
|
||||||
|
x, y = map(lambda x: x + 1, [a, b])
|
||||||
|
return x + y
|
||||||
|
|
||||||
|
def test_enumerate_custom(self):
|
||||||
|
class MyClass:
|
||||||
|
def __iter__(self):
|
||||||
|
self.a = 1
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
if self.a > 3:
|
||||||
|
raise StopIteration
|
||||||
|
self.a += 1
|
||||||
|
return self.a
|
||||||
|
|
||||||
|
def fn(x):
|
||||||
|
for i, it in enumerate(MyClass()):
|
||||||
|
x += i + it
|
||||||
|
return x
|
||||||
|
|
||||||
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
||||||
|
self.assertEqual(fn(torch.ones(3, 3)), opt_fn(torch.ones(3, 3)))
|
||||||
|
|
||||||
|
def test_enumerate_reconstruct(self):
|
||||||
|
def fn(a, b):
|
||||||
|
return enumerate([a, b], start=1)
|
||||||
|
|
||||||
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
||||||
|
inps = (torch.randn(3, 3), torch.randn(3, 3))
|
||||||
|
it1 = fn(*inps)
|
||||||
|
it2 = opt_fn(*inps)
|
||||||
|
self.assertIsInstance(it2, enumerate)
|
||||||
|
self.assertEqual(list(it1), list(it2))
|
||||||
|
|
||||||
|
|
||||||
def udf_mul(x, y):
|
def udf_mul(x, y):
|
||||||
return x * y
|
return x * y
|
||||||
|
|
@ -3670,10 +3878,16 @@ class DefaultsTests(torch._dynamo.test_case.TestCase):
|
||||||
with self.assertRaisesRegex(torch._dynamo.exc.UserError, "zip()"):
|
with self.assertRaisesRegex(torch._dynamo.exc.UserError, "zip()"):
|
||||||
nopython_fn(x, ys[:1], zs)
|
nopython_fn(x, ys[:1], zs)
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(torch._dynamo.exc.UserError, "zip()"):
|
||||||
|
nopython_fn(x, ys, zs[:1])
|
||||||
|
|
||||||
# Should cause fallback if allow graph break
|
# Should cause fallback if allow graph break
|
||||||
with self.assertRaisesRegex(ValueError, "zip()"):
|
with self.assertRaisesRegex(ValueError, "zip()"):
|
||||||
opt_fn(x, ys[:1], zs)
|
opt_fn(x, ys[:1], zs)
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(ValueError, "zip()"):
|
||||||
|
opt_fn(x, ys, zs[:1])
|
||||||
|
|
||||||
def test_fn_with_attr(self):
|
def test_fn_with_attr(self):
|
||||||
def fn(x):
|
def fn(x):
|
||||||
if fn.pred:
|
if fn.pred:
|
||||||
|
|
|
||||||
|
|
@ -5476,15 +5476,17 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
|
||||||
return x, y
|
return x, y
|
||||||
|
|
||||||
def g(x, y):
|
def g(x, y):
|
||||||
return tuple(map(f, x, y))
|
return map(f, x, y)
|
||||||
|
|
||||||
opt_g = torch.compile(g, fullgraph=True, backend="eager")
|
opt_g = torch.compile(g, fullgraph=True, backend="eager")
|
||||||
|
|
||||||
inps = gen_inps(3, 3)
|
inps = gen_inps(3, 3)
|
||||||
self.assertEqual(g(*inps), opt_g(*inps))
|
self.assertEqual(type(g(*inps)), type(opt_g(*inps)))
|
||||||
|
self.assertEqual(tuple(g(*inps)), tuple(opt_g(*inps)))
|
||||||
|
|
||||||
inps = gen_inps(3, 5)
|
inps = gen_inps(3, 5)
|
||||||
self.assertEqual(g(*inps), opt_g(*inps))
|
self.assertEqual(type(g(*inps)), type(opt_g(*inps)))
|
||||||
|
self.assertEqual(tuple(g(*inps)), tuple(opt_g(*inps)))
|
||||||
|
|
||||||
def test_staticmethod_allow_in_graph(self):
|
def test_staticmethod_allow_in_graph(self):
|
||||||
class MyClass:
|
class MyClass:
|
||||||
|
|
|
||||||
|
|
@ -1663,8 +1663,8 @@ class InstructionTranslatorBase(
|
||||||
|
|
||||||
if not isinstance(
|
if not isinstance(
|
||||||
argsvars, BaseListVariable
|
argsvars, BaseListVariable
|
||||||
) and argsvars.has_unpack_var_sequence(self):
|
) and argsvars.has_force_unpack_var_sequence(self):
|
||||||
argsvars = TupleVariable(argsvars.unpack_var_sequence(self))
|
argsvars = TupleVariable(argsvars.force_unpack_var_sequence(self))
|
||||||
|
|
||||||
# Unpack for cases like fn(**obj) where obj is a map
|
# Unpack for cases like fn(**obj) where obj is a map
|
||||||
if isinstance(kwargsvars, UserDefinedObjectVariable):
|
if isinstance(kwargsvars, UserDefinedObjectVariable):
|
||||||
|
|
@ -1833,7 +1833,7 @@ class InstructionTranslatorBase(
|
||||||
items = []
|
items = []
|
||||||
for seq in seqs:
|
for seq in seqs:
|
||||||
try:
|
try:
|
||||||
items.extend(seq.unpack_var_sequence(self))
|
items.extend(seq.force_unpack_var_sequence(self))
|
||||||
except NotImplementedError:
|
except NotImplementedError:
|
||||||
unimplemented(f"BUILD_LIST_UNPACK {seq}")
|
unimplemented(f"BUILD_LIST_UNPACK {seq}")
|
||||||
self.push(cls(items, mutable_local=MutableLocal()))
|
self.push(cls(items, mutable_local=MutableLocal()))
|
||||||
|
|
@ -1871,7 +1871,7 @@ class InstructionTranslatorBase(
|
||||||
assert isinstance(keys, TupleVariable)
|
assert isinstance(keys, TupleVariable)
|
||||||
assert keys.is_python_constant()
|
assert keys.is_python_constant()
|
||||||
|
|
||||||
keys = keys.unpack_var_sequence(self)
|
keys = keys.force_unpack_var_sequence(self)
|
||||||
assert len(keys) == len(values)
|
assert len(keys) == len(values)
|
||||||
|
|
||||||
self.push(
|
self.push(
|
||||||
|
|
@ -1961,8 +1961,8 @@ class InstructionTranslatorBase(
|
||||||
# x, y = a.shape
|
# x, y = a.shape
|
||||||
proxy = getattr(seq.obj.as_proxy(), seq.name)
|
proxy = getattr(seq.obj.as_proxy(), seq.name)
|
||||||
val = [wrap_fx_proxy(self, proxy[i]) for i in range(inst.argval)]
|
val = [wrap_fx_proxy(self, proxy[i]) for i in range(inst.argval)]
|
||||||
elif seq.has_unpack_var_sequence(self):
|
elif seq.has_force_unpack_var_sequence(self):
|
||||||
val = seq.unpack_var_sequence(self)
|
val = seq.force_unpack_var_sequence(self)
|
||||||
else:
|
else:
|
||||||
unimplemented(f"UNPACK_SEQUENCE {seq}")
|
unimplemented(f"UNPACK_SEQUENCE {seq}")
|
||||||
if len(val) != inst.argval:
|
if len(val) != inst.argval:
|
||||||
|
|
@ -1975,8 +1975,8 @@ class InstructionTranslatorBase(
|
||||||
prefix = inst.argval & 0xFF # low byte
|
prefix = inst.argval & 0xFF # low byte
|
||||||
suffix = inst.argval >> 8 # high byte
|
suffix = inst.argval >> 8 # high byte
|
||||||
seq = self.pop()
|
seq = self.pop()
|
||||||
if seq.has_unpack_var_sequence(self):
|
if seq.has_force_unpack_var_sequence(self):
|
||||||
vals = list(seq.unpack_var_sequence(self))
|
vals = list(seq.force_unpack_var_sequence(self))
|
||||||
assert len(vals) >= prefix + suffix
|
assert len(vals) >= prefix + suffix
|
||||||
vals_prefix = vals[:prefix]
|
vals_prefix = vals[:prefix]
|
||||||
vals_list = vals[prefix : len(vals) - suffix]
|
vals_list = vals[prefix : len(vals) - suffix]
|
||||||
|
|
@ -2400,7 +2400,7 @@ class InstructionTranslatorBase(
|
||||||
self.UNARY_POSITIVE(inst)
|
self.UNARY_POSITIVE(inst)
|
||||||
elif inst.argval == 6:
|
elif inst.argval == 6:
|
||||||
# INTRINSIC_LIST_TO_TUPLE
|
# INTRINSIC_LIST_TO_TUPLE
|
||||||
self.push(TupleVariable(self.pop().unpack_var_sequence(self)))
|
self.push(TupleVariable(self.pop().force_unpack_var_sequence(self)))
|
||||||
else:
|
else:
|
||||||
unimplemented(f"missing CALL_INTRINSIC_1 operand {inst.argval}")
|
unimplemented(f"missing CALL_INTRINSIC_1 operand {inst.argval}")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1608,8 +1608,12 @@ def same(
|
||||||
"""Check correctness to see if ref and res match"""
|
"""Check correctness to see if ref and res match"""
|
||||||
if fp64_ref is None:
|
if fp64_ref is None:
|
||||||
fp64_ref = ref
|
fp64_ref = ref
|
||||||
if isinstance(ref, (list, tuple, torch.nn.ParameterList, torch.Size)):
|
if isinstance(
|
||||||
assert isinstance(res, (list, tuple)), f"type mismatch {type(ref)} {type(res)}"
|
ref, (list, tuple, collections.deque, torch.nn.ParameterList, torch.Size)
|
||||||
|
):
|
||||||
|
assert isinstance(
|
||||||
|
res, (list, tuple, collections.deque)
|
||||||
|
), f"type mismatch {type(ref)} {type(res)}"
|
||||||
if len(ref) != len(res):
|
if len(ref) != len(res):
|
||||||
log_error("Length mismatch")
|
log_error("Length mismatch")
|
||||||
return False
|
return False
|
||||||
|
|
|
||||||
|
|
@ -46,7 +46,9 @@ from .iter import (
|
||||||
CycleIteratorVariable,
|
CycleIteratorVariable,
|
||||||
IteratorVariable,
|
IteratorVariable,
|
||||||
ItertoolsVariable,
|
ItertoolsVariable,
|
||||||
|
MapVariable,
|
||||||
RepeatIteratorVariable,
|
RepeatIteratorVariable,
|
||||||
|
ZipVariable,
|
||||||
)
|
)
|
||||||
from .lazy import LazyVariableTracker
|
from .lazy import LazyVariableTracker
|
||||||
from .lists import (
|
from .lists import (
|
||||||
|
|
|
||||||
|
|
@ -289,6 +289,15 @@ class VariableTracker(metaclass=VariableTrackerMeta):
|
||||||
def unpack_var_sequence(self, tx) -> List["VariableTracker"]:
|
def unpack_var_sequence(self, tx) -> List["VariableTracker"]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def force_unpack_var_sequence(self, tx) -> List["VariableTracker"]:
|
||||||
|
# like unpack_var_sequence, but should only be used when it is
|
||||||
|
# safe to eagerly (vs. lazily) unpack this variable.
|
||||||
|
# e.g. map(f, x) is normally evaluated lazily but sometimes
|
||||||
|
# we want to force eager unpacking, e.g. when converting to a list.
|
||||||
|
# NOTE: this method is allowed to mutate the VariableTracker, so
|
||||||
|
# it should only be called once.
|
||||||
|
return self.unpack_var_sequence(tx)
|
||||||
|
|
||||||
def has_unpack_var_sequence(self, tx) -> bool:
|
def has_unpack_var_sequence(self, tx) -> bool:
|
||||||
try:
|
try:
|
||||||
self.unpack_var_sequence(tx)
|
self.unpack_var_sequence(tx)
|
||||||
|
|
@ -296,6 +305,10 @@ class VariableTracker(metaclass=VariableTrackerMeta):
|
||||||
except NotImplementedError:
|
except NotImplementedError:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
# NB: don't call force_unpack_var_sequence, especially if it mutates!
|
||||||
|
def has_force_unpack_var_sequence(self, tx) -> bool:
|
||||||
|
return self.has_unpack_var_sequence(tx)
|
||||||
|
|
||||||
def inspect_parameter_names(self) -> List[str]:
|
def inspect_parameter_names(self) -> List[str]:
|
||||||
unimplemented(f"inspect_parameter_names: {self}")
|
unimplemented(f"inspect_parameter_names: {self}")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1058,9 +1058,8 @@ class BuiltinVariable(VariableTracker):
|
||||||
return tx.inline_user_function_return(user_func_variable, [arg], {})
|
return tx.inline_user_function_return(user_func_variable, [arg], {})
|
||||||
|
|
||||||
def _call_min_max(self, tx: "InstructionTranslator", *args):
|
def _call_min_max(self, tx: "InstructionTranslator", *args):
|
||||||
if len(args) == 1 and args[0].has_unpack_var_sequence(tx):
|
if len(args) == 1 and args[0].has_force_unpack_var_sequence(tx):
|
||||||
# expand iterable
|
items = args[0].force_unpack_var_sequence(tx)
|
||||||
items = args[0].unpack_var_sequence(tx)
|
|
||||||
return self._call_min_max_seq(tx, items)
|
return self._call_min_max_seq(tx, items)
|
||||||
elif len(args) == 2:
|
elif len(args) == 2:
|
||||||
return self._call_min_max_binary(tx, args[0], args[1])
|
return self._call_min_max_binary(tx, args[0], args[1])
|
||||||
|
|
@ -1075,6 +1074,10 @@ class BuiltinVariable(VariableTracker):
|
||||||
return functools.reduce(functools.partial(self._call_min_max_binary, tx), items)
|
return functools.reduce(functools.partial(self._call_min_max_binary, tx), items)
|
||||||
|
|
||||||
def _call_min_max_binary(self, tx: "InstructionTranslator", a, b):
|
def _call_min_max_binary(self, tx: "InstructionTranslator", a, b):
|
||||||
|
if a is None or b is None:
|
||||||
|
# a or b could be none if we reduce and _call_min_max_binary failed
|
||||||
|
# to return something
|
||||||
|
return
|
||||||
if self.tensor_args(a, b):
|
if self.tensor_args(a, b):
|
||||||
if not isinstance(a, variables.TensorVariable):
|
if not isinstance(a, variables.TensorVariable):
|
||||||
a, b = b, a
|
a, b = b, a
|
||||||
|
|
@ -1223,17 +1226,15 @@ class BuiltinVariable(VariableTracker):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# NOTE must handle IteratorVariable separately!
|
||||||
def _call_iter_tuple_list(
|
def _call_iter_tuple_list(
|
||||||
self, tx: "InstructionTranslator", obj=None, *args, **kwargs
|
self, tx: "InstructionTranslator", obj=None, *args, **kwargs
|
||||||
):
|
):
|
||||||
|
assert not isinstance(obj, variables.IteratorVariable)
|
||||||
|
|
||||||
if self._dynamic_args(*args, **kwargs):
|
if self._dynamic_args(*args, **kwargs):
|
||||||
return self._dyn_proxy(tx, *args, **kwargs)
|
return self._dyn_proxy(tx, *args, **kwargs)
|
||||||
|
|
||||||
if isinstance(obj, variables.IteratorVariable):
|
|
||||||
# For non-list iterators, we will guard on vars that
|
|
||||||
# determine the control flow
|
|
||||||
return obj
|
|
||||||
|
|
||||||
cls = variables.BaseListVariable.cls_for(self.fn)
|
cls = variables.BaseListVariable.cls_for(self.fn)
|
||||||
if obj is None:
|
if obj is None:
|
||||||
return cls(
|
return cls(
|
||||||
|
|
@ -1261,7 +1262,20 @@ class BuiltinVariable(VariableTracker):
|
||||||
mutable_local=MutableLocal(),
|
mutable_local=MutableLocal(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _call_tuple_list(self, tx, obj=None, *args, **kwargs):
|
||||||
|
if isinstance(obj, variables.IteratorVariable):
|
||||||
|
cls = variables.BaseListVariable.cls_for(self.fn)
|
||||||
|
return cls(
|
||||||
|
list(obj.force_unpack_var_sequence(tx)),
|
||||||
|
mutable_local=MutableLocal(),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return self._call_iter_tuple_list(tx, obj, *args, **kwargs)
|
||||||
|
|
||||||
def call_iter(self, tx: "InstructionTranslator", obj, *args, **kwargs):
|
def call_iter(self, tx: "InstructionTranslator", obj, *args, **kwargs):
|
||||||
|
if isinstance(obj, variables.IteratorVariable):
|
||||||
|
ret = obj
|
||||||
|
else:
|
||||||
# Handle the case where we are iterating over a tuple, list or iterator
|
# Handle the case where we are iterating over a tuple, list or iterator
|
||||||
ret = self._call_iter_tuple_list(tx, obj, *args, **kwargs)
|
ret = self._call_iter_tuple_list(tx, obj, *args, **kwargs)
|
||||||
|
|
||||||
|
|
@ -1272,8 +1286,8 @@ class BuiltinVariable(VariableTracker):
|
||||||
return obj.call_method(tx, "__iter__", args, kwargs)
|
return obj.call_method(tx, "__iter__", args, kwargs)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
call_tuple = _call_iter_tuple_list
|
call_tuple = _call_tuple_list
|
||||||
call_list = _call_iter_tuple_list
|
call_list = _call_tuple_list
|
||||||
|
|
||||||
def call_callable(self, tx: "InstructionTranslator", arg):
|
def call_callable(self, tx: "InstructionTranslator", arg):
|
||||||
from .functions import BaseUserFunctionVariable
|
from .functions import BaseUserFunctionVariable
|
||||||
|
|
@ -1331,10 +1345,12 @@ class BuiltinVariable(VariableTracker):
|
||||||
ListVariable,
|
ListVariable,
|
||||||
TupleVariable,
|
TupleVariable,
|
||||||
ListIteratorVariable,
|
ListIteratorVariable,
|
||||||
|
variables.IteratorVariable,
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
items = dict(
|
items = dict(
|
||||||
x.unpack_var_sequence(tx) for x in arg.unpack_var_sequence(tx)
|
x.force_unpack_var_sequence(tx)
|
||||||
|
for x in arg.force_unpack_var_sequence(tx)
|
||||||
)
|
)
|
||||||
return ConstDictVariable(items, user_cls, mutable_local=MutableLocal())
|
return ConstDictVariable(items, user_cls, mutable_local=MutableLocal())
|
||||||
elif isinstance(arg, variables.MutableMappingVariable):
|
elif isinstance(arg, variables.MutableMappingVariable):
|
||||||
|
|
@ -1391,10 +1407,9 @@ class BuiltinVariable(VariableTracker):
|
||||||
return DictVariableType(
|
return DictVariableType(
|
||||||
dict.fromkeys(arg, value), user_cls, mutable_local=MutableLocal()
|
dict.fromkeys(arg, value), user_cls, mutable_local=MutableLocal()
|
||||||
)
|
)
|
||||||
elif arg.has_unpack_var_sequence(tx) and all(
|
elif arg.has_force_unpack_var_sequence(tx):
|
||||||
is_hashable(v) for v in arg.unpack_var_sequence(tx)
|
keys = arg.force_unpack_var_sequence(tx)
|
||||||
):
|
if all(is_hashable(v) for v in keys):
|
||||||
keys = arg.unpack_var_sequence(tx)
|
|
||||||
return DictVariableType(
|
return DictVariableType(
|
||||||
dict.fromkeys(keys, value), user_cls, mutable_local=MutableLocal()
|
dict.fromkeys(keys, value), user_cls, mutable_local=MutableLocal()
|
||||||
)
|
)
|
||||||
|
|
@ -1409,8 +1424,8 @@ class BuiltinVariable(VariableTracker):
|
||||||
arg = args[0]
|
arg = args[0]
|
||||||
if isinstance(arg, variables.SetVariable):
|
if isinstance(arg, variables.SetVariable):
|
||||||
return arg.clone(mutable_local=MutableLocal())
|
return arg.clone(mutable_local=MutableLocal())
|
||||||
elif arg.has_unpack_var_sequence(tx):
|
elif arg.has_force_unpack_var_sequence(tx):
|
||||||
items = arg.unpack_var_sequence(tx)
|
items = arg.force_unpack_var_sequence(tx)
|
||||||
return SetVariable(items, mutable_local=MutableLocal())
|
return SetVariable(items, mutable_local=MutableLocal())
|
||||||
elif isinstance(arg, variables.UserDefinedObjectVariable) and isinstance(
|
elif isinstance(arg, variables.UserDefinedObjectVariable) and isinstance(
|
||||||
arg.value, KeysView
|
arg.value, KeysView
|
||||||
|
|
@ -1443,16 +1458,12 @@ class BuiltinVariable(VariableTracker):
|
||||||
def call_zip(self, tx: "InstructionTranslator", *args, **kwargs):
|
def call_zip(self, tx: "InstructionTranslator", *args, **kwargs):
|
||||||
if kwargs:
|
if kwargs:
|
||||||
assert len(kwargs) == 1 and "strict" in kwargs
|
assert len(kwargs) == 1 and "strict" in kwargs
|
||||||
if all(x.has_unpack_var_sequence(tx) for x in args):
|
strict = kwargs.pop("strict", False)
|
||||||
unpacked = [arg.unpack_var_sequence(tx) for arg in args]
|
args = [
|
||||||
if kwargs.pop("strict", False) and len(unpacked) > 0:
|
arg.unpack_var_sequence(tx) if arg.has_unpack_var_sequence(tx) else arg
|
||||||
if not all(len(u) == len(unpacked[0]) for u in unpacked):
|
for arg in args
|
||||||
raise UserError(
|
]
|
||||||
ValueError,
|
return variables.ZipVariable(args, strict=strict, mutable_local=MutableLocal())
|
||||||
"zip() has one argument of len differing from others",
|
|
||||||
)
|
|
||||||
items = [variables.TupleVariable(list(item)) for item in zip(*unpacked)]
|
|
||||||
return variables.TupleVariable(items)
|
|
||||||
|
|
||||||
def call_len(self, tx: "InstructionTranslator", *args, **kwargs):
|
def call_len(self, tx: "InstructionTranslator", *args, **kwargs):
|
||||||
return args[0].call_method(tx, "__len__", args[1:], kwargs)
|
return args[0].call_method(tx, "__len__", args[1:], kwargs)
|
||||||
|
|
@ -1553,10 +1564,11 @@ class BuiltinVariable(VariableTracker):
|
||||||
return obj.call_hasattr(tx, name)
|
return obj.call_hasattr(tx, name)
|
||||||
|
|
||||||
def call_map(self, tx: "InstructionTranslator", fn, *seqs):
|
def call_map(self, tx: "InstructionTranslator", fn, *seqs):
|
||||||
if all(seq.has_unpack_var_sequence(tx) for seq in seqs):
|
seqs = [
|
||||||
unpacked = [seq.unpack_var_sequence(tx) for seq in seqs]
|
seq.unpack_var_sequence(tx) if seq.has_unpack_var_sequence(tx) else seq
|
||||||
items = [fn.call_function(tx, list(args), {}) for args in zip(*unpacked)]
|
for seq in seqs
|
||||||
return variables.TupleVariable(items)
|
]
|
||||||
|
return variables.MapVariable(fn, seqs, mutable_local=MutableLocal())
|
||||||
|
|
||||||
def call_filter(self, tx: "InstructionTranslator", fn, seq):
|
def call_filter(self, tx: "InstructionTranslator", fn, seq):
|
||||||
if seq.has_unpack_var_sequence(tx):
|
if seq.has_unpack_var_sequence(tx):
|
||||||
|
|
@ -1589,10 +1601,10 @@ class BuiltinVariable(VariableTracker):
|
||||||
return variables.ConstantVariable.create(
|
return variables.ConstantVariable.create(
|
||||||
sum((x.value for x in seq.items), start=start.value),
|
sum((x.value for x in seq.items), start=start.value),
|
||||||
)
|
)
|
||||||
if seq.has_unpack_var_sequence(tx):
|
if seq.has_force_unpack_var_sequence(tx):
|
||||||
if start is self._SENTINEL:
|
if start is self._SENTINEL:
|
||||||
start = variables.ConstantVariable.create(0)
|
start = variables.ConstantVariable.create(0)
|
||||||
items = seq.unpack_var_sequence(tx)
|
items = seq.force_unpack_var_sequence(tx)
|
||||||
return BuiltinVariable(functools.reduce).call_function(
|
return BuiltinVariable(functools.reduce).call_function(
|
||||||
tx,
|
tx,
|
||||||
[
|
[
|
||||||
|
|
@ -1606,8 +1618,8 @@ class BuiltinVariable(VariableTracker):
|
||||||
def call_reduce(
|
def call_reduce(
|
||||||
self, tx: "InstructionTranslator", function, iterable, initial=_SENTINEL
|
self, tx: "InstructionTranslator", function, iterable, initial=_SENTINEL
|
||||||
):
|
):
|
||||||
if iterable.has_unpack_var_sequence(tx):
|
if iterable.has_force_unpack_var_sequence(tx):
|
||||||
items = iterable.unpack_var_sequence(tx)
|
items = iterable.force_unpack_var_sequence(tx)
|
||||||
if initial is self._SENTINEL:
|
if initial is self._SENTINEL:
|
||||||
value, items = items[0], items[1:]
|
value, items = items[0], items[1:]
|
||||||
else:
|
else:
|
||||||
|
|
@ -1903,11 +1915,12 @@ class BuiltinVariable(VariableTracker):
|
||||||
return variables.TupleVariable(items)
|
return variables.TupleVariable(items)
|
||||||
|
|
||||||
def call_sorted(self, tx: "InstructionTranslator", obj: VariableTracker, **kwargs):
|
def call_sorted(self, tx: "InstructionTranslator", obj: VariableTracker, **kwargs):
|
||||||
if (
|
if obj.has_force_unpack_var_sequence(tx) and not isinstance(
|
||||||
obj.has_unpack_var_sequence(tx)
|
obj, variables.TensorVariable
|
||||||
and not isinstance(obj, variables.TensorVariable)
|
|
||||||
and all(x.is_python_constant() for x in obj.unpack_var_sequence(tx))
|
|
||||||
):
|
):
|
||||||
|
unpacked = obj.force_unpack_var_sequence(tx)
|
||||||
|
if not all(x.is_python_constant() for x in unpacked):
|
||||||
|
return
|
||||||
function = kwargs.pop("key", None)
|
function = kwargs.pop("key", None)
|
||||||
reverse = kwargs.pop(
|
reverse = kwargs.pop(
|
||||||
"reverse", ConstantVariable.create(False)
|
"reverse", ConstantVariable.create(False)
|
||||||
|
|
@ -1915,7 +1928,7 @@ class BuiltinVariable(VariableTracker):
|
||||||
assert len(kwargs) == 0
|
assert len(kwargs) == 0
|
||||||
if function:
|
if function:
|
||||||
items = sorted(
|
items = sorted(
|
||||||
obj.unpack_var_sequence(tx),
|
unpacked,
|
||||||
key=lambda x: function.call_function(
|
key=lambda x: function.call_function(
|
||||||
tx, [x], {}
|
tx, [x], {}
|
||||||
).as_python_constant(),
|
).as_python_constant(),
|
||||||
|
|
@ -1923,7 +1936,7 @@ class BuiltinVariable(VariableTracker):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
items = sorted(
|
items = sorted(
|
||||||
obj.unpack_var_sequence(tx),
|
unpacked,
|
||||||
key=lambda x: x.as_python_constant(),
|
key=lambda x: x.as_python_constant(),
|
||||||
reverse=reverse,
|
reverse=reverse,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -145,6 +145,14 @@ class ConstantVariable(VariableTracker):
|
||||||
return variables.BuiltinVariable(str.format).call_function(
|
return variables.BuiltinVariable(str.format).call_function(
|
||||||
tx, [self, *args], kwargs
|
tx, [self, *args], kwargs
|
||||||
)
|
)
|
||||||
|
elif name == "join" and istype(self.value, str):
|
||||||
|
assert len(args) == 1 and len(kwargs) == 0
|
||||||
|
arg_unpacked = args[0].force_unpack_var_sequence(tx)
|
||||||
|
try:
|
||||||
|
arg_const = [x.as_python_constant() for x in arg_unpacked]
|
||||||
|
return ConstantVariable.create(self.value.join(arg_const))
|
||||||
|
except NotImplementedError:
|
||||||
|
return super().call_method(tx, name, args, kwargs)
|
||||||
|
|
||||||
if any(isinstance(x, SymNodeVariable) for x in args):
|
if any(isinstance(x, SymNodeVariable) for x in args):
|
||||||
# Promote to SymNodeVariable for operations involving dynamic shapes.
|
# Promote to SymNodeVariable for operations involving dynamic shapes.
|
||||||
|
|
|
||||||
|
|
@ -314,6 +314,7 @@ class ConstDictVariable(VariableTracker):
|
||||||
ListVariable,
|
ListVariable,
|
||||||
TupleVariable,
|
TupleVariable,
|
||||||
ListIteratorVariable,
|
ListIteratorVariable,
|
||||||
|
variables.IteratorVariable,
|
||||||
UserDefinedObjectVariable,
|
UserDefinedObjectVariable,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -2,14 +2,17 @@
|
||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
import operator
|
import operator
|
||||||
from typing import Dict, List, Optional, TYPE_CHECKING
|
import sys
|
||||||
|
from typing import Dict, List, Optional, TYPE_CHECKING, Union
|
||||||
|
|
||||||
from .. import polyfills, variables
|
from .. import polyfills, variables
|
||||||
|
from ..bytecode_transformation import create_call_function, create_instruction
|
||||||
from ..exc import (
|
from ..exc import (
|
||||||
handle_observed_exception,
|
handle_observed_exception,
|
||||||
ObservedUserStopIteration,
|
ObservedUserStopIteration,
|
||||||
raise_observed_exception,
|
raise_observed_exception,
|
||||||
unimplemented,
|
unimplemented,
|
||||||
|
UserError,
|
||||||
)
|
)
|
||||||
from .base import MutableLocal, VariableTracker
|
from .base import MutableLocal, VariableTracker
|
||||||
from .constant import ConstantVariable
|
from .constant import ConstantVariable
|
||||||
|
|
@ -197,6 +200,25 @@ class IteratorVariable(VariableTracker):
|
||||||
def next_variable(self, tx):
|
def next_variable(self, tx):
|
||||||
unimplemented("abstract method, must implement")
|
unimplemented("abstract method, must implement")
|
||||||
|
|
||||||
|
# NOTE: only call when unpacking this iterator safely done eagerly!
|
||||||
|
# Normally, iterators are accessed lazily.
|
||||||
|
# Example of safe eager unpacking: list(map(f, seq))
|
||||||
|
# Example of unsafe eager unpacking: list(islice(map(f, seq), 5))
|
||||||
|
def force_unpack_var_sequence(self, tx) -> List[VariableTracker]:
|
||||||
|
result = []
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
result.append(self.next_variable(tx))
|
||||||
|
except ObservedUserStopIteration:
|
||||||
|
handle_observed_exception(tx)
|
||||||
|
break
|
||||||
|
return result
|
||||||
|
|
||||||
|
# don't call force_unpack_var_sequence since it can mutate
|
||||||
|
# IteratorVariable state!
|
||||||
|
def has_force_unpack_var_sequence(self, tx) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
class RepeatIteratorVariable(IteratorVariable):
|
class RepeatIteratorVariable(IteratorVariable):
|
||||||
def __init__(self, item: VariableTracker, **kwargs) -> None:
|
def __init__(self, item: VariableTracker, **kwargs) -> None:
|
||||||
|
|
@ -207,6 +229,18 @@ class RepeatIteratorVariable(IteratorVariable):
|
||||||
def next_variable(self, tx):
|
def next_variable(self, tx):
|
||||||
return self.item
|
return self.item
|
||||||
|
|
||||||
|
def reconstruct(self, codegen):
|
||||||
|
codegen.add_push_null(
|
||||||
|
lambda: codegen.extend_output(
|
||||||
|
[
|
||||||
|
codegen.create_load_python_module(itertools),
|
||||||
|
codegen.create_load_attr("repeat"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
codegen(self.item)
|
||||||
|
codegen.extend_output(create_call_function(1, False))
|
||||||
|
|
||||||
|
|
||||||
class CountIteratorVariable(IteratorVariable):
|
class CountIteratorVariable(IteratorVariable):
|
||||||
def __init__(self, item: int = 0, step: int = 1, **kwargs) -> None:
|
def __init__(self, item: int = 0, step: int = 1, **kwargs) -> None:
|
||||||
|
|
@ -220,10 +254,23 @@ class CountIteratorVariable(IteratorVariable):
|
||||||
|
|
||||||
def next_variable(self, tx):
|
def next_variable(self, tx):
|
||||||
assert self.mutable_local
|
assert self.mutable_local
|
||||||
|
old_item = self.item
|
||||||
tx.output.side_effects.mutation(self)
|
tx.output.side_effects.mutation(self)
|
||||||
next_item = self.item.call_method(tx, "__add__", [self.step], {})
|
self.item = self.item.call_method(tx, "__add__", [self.step], {})
|
||||||
self.item = next_item
|
return old_item
|
||||||
return self.item
|
|
||||||
|
def reconstruct(self, codegen):
|
||||||
|
codegen.add_push_null(
|
||||||
|
lambda: codegen.extend_output(
|
||||||
|
[
|
||||||
|
codegen.create_load_python_module(itertools),
|
||||||
|
codegen.create_load_attr("count"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
codegen(self.item)
|
||||||
|
codegen(self.step)
|
||||||
|
codegen.extend_output(create_call_function(2, False))
|
||||||
|
|
||||||
|
|
||||||
class CycleIteratorVariable(IteratorVariable):
|
class CycleIteratorVariable(IteratorVariable):
|
||||||
|
|
@ -269,3 +316,160 @@ class CycleIteratorVariable(IteratorVariable):
|
||||||
return self.item
|
return self.item
|
||||||
else:
|
else:
|
||||||
raise_observed_exception(StopIteration, tx, self)
|
raise_observed_exception(StopIteration, tx, self)
|
||||||
|
|
||||||
|
|
||||||
|
class ZipVariable(IteratorVariable):
|
||||||
|
"""
|
||||||
|
Represents zip(*iterables)
|
||||||
|
"""
|
||||||
|
|
||||||
|
_nonvar_fields = {
|
||||||
|
"index",
|
||||||
|
"strict",
|
||||||
|
*IteratorVariable._nonvar_fields,
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
iterables: List[Union[List[VariableTracker], VariableTracker]],
|
||||||
|
strict: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
assert isinstance(iterables, list)
|
||||||
|
# can be list[Variable] or VariableTracker (with next_variable implemented)
|
||||||
|
self.iterables = iterables
|
||||||
|
self.index = 0
|
||||||
|
self.strict = strict
|
||||||
|
|
||||||
|
def python_type(self):
|
||||||
|
return zip
|
||||||
|
|
||||||
|
def has_unpack_var_sequence(self, tx) -> bool:
|
||||||
|
return all(
|
||||||
|
isinstance(it, list) or it.has_unpack_var_sequence(tx)
|
||||||
|
for it in self.iterables
|
||||||
|
)
|
||||||
|
|
||||||
|
def unpack_var_sequence(self, tx) -> List["VariableTracker"]:
|
||||||
|
assert self.has_unpack_var_sequence(tx)
|
||||||
|
iterables = []
|
||||||
|
for it in self.iterables:
|
||||||
|
if isinstance(it, list):
|
||||||
|
iterables.append(it[self.index :])
|
||||||
|
else:
|
||||||
|
iterables.append(it.unpack_var_sequence(tx))
|
||||||
|
kwargs = {"strict": self.strict} if self.strict else {}
|
||||||
|
zipped = zip(*iterables, **kwargs)
|
||||||
|
return [variables.TupleVariable(list(var)) for var in zipped]
|
||||||
|
|
||||||
|
def next_variable(self, tx):
|
||||||
|
assert self.mutable_local
|
||||||
|
old_index = self.index
|
||||||
|
args = []
|
||||||
|
|
||||||
|
def get_item(it):
|
||||||
|
if isinstance(it, list):
|
||||||
|
if old_index >= len(it):
|
||||||
|
raise_observed_exception(StopIteration, tx, self)
|
||||||
|
return it[old_index]
|
||||||
|
else:
|
||||||
|
return it.next_variable(tx)
|
||||||
|
|
||||||
|
try:
|
||||||
|
for idx, it in enumerate(self.iterables):
|
||||||
|
args.append(get_item(it))
|
||||||
|
except ObservedUserStopIteration:
|
||||||
|
if self.strict:
|
||||||
|
if idx == 0:
|
||||||
|
# all other iterables should be exhausted
|
||||||
|
for it in self.iterables:
|
||||||
|
try:
|
||||||
|
get_item(it)
|
||||||
|
except ObservedUserStopIteration:
|
||||||
|
handle_observed_exception(tx)
|
||||||
|
continue
|
||||||
|
# no ObservedUserStopIteration - fall through to UserError
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# all iterables exhausted, raise original error
|
||||||
|
raise
|
||||||
|
handle_observed_exception(tx)
|
||||||
|
raise UserError(
|
||||||
|
ValueError,
|
||||||
|
"zip() has one argument of len differing from others",
|
||||||
|
) from None
|
||||||
|
raise
|
||||||
|
|
||||||
|
tx.output.side_effects.mutation(self)
|
||||||
|
self.index += 1
|
||||||
|
return variables.TupleVariable(args)
|
||||||
|
|
||||||
|
def reconstruct_items(self, codegen):
|
||||||
|
for it in self.iterables:
|
||||||
|
if isinstance(it, list):
|
||||||
|
remaining_items = it[self.index :]
|
||||||
|
codegen.foreach(remaining_items)
|
||||||
|
codegen.append_output(
|
||||||
|
create_instruction("BUILD_TUPLE", arg=len(remaining_items))
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
codegen(it)
|
||||||
|
|
||||||
|
def reconstruct(self, codegen):
|
||||||
|
codegen.add_push_null(
|
||||||
|
lambda: codegen.load_import_from("builtins", "zip"), call_function_ex=True
|
||||||
|
)
|
||||||
|
self.reconstruct_items(codegen)
|
||||||
|
codegen.append_output(
|
||||||
|
create_instruction("BUILD_TUPLE", arg=len(self.iterables))
|
||||||
|
)
|
||||||
|
if sys.version_info >= (3, 10):
|
||||||
|
codegen.extend_output(
|
||||||
|
[
|
||||||
|
codegen.create_load_const("strict"),
|
||||||
|
codegen.create_load_const(self.strict),
|
||||||
|
create_instruction("BUILD_MAP", arg=1),
|
||||||
|
create_instruction("CALL_FUNCTION_EX", arg=1),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
codegen.append_output(create_instruction("CALL_FUNCTION_EX", arg=0))
|
||||||
|
|
||||||
|
|
||||||
|
class MapVariable(ZipVariable):
|
||||||
|
"""
|
||||||
|
Represents map(fn, *iterables)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
fn: VariableTracker,
|
||||||
|
iterables: List[Union[List[VariableTracker], VariableTracker]],
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(iterables, **kwargs)
|
||||||
|
self.fn = fn
|
||||||
|
|
||||||
|
def python_type(self):
|
||||||
|
return map
|
||||||
|
|
||||||
|
def has_unpack_var_sequence(self, tx) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def next_variable(self, tx):
|
||||||
|
args = super().next_variable(tx)
|
||||||
|
return self.fn.call_function(tx, args.items, {})
|
||||||
|
|
||||||
|
def reconstruct(self, codegen):
|
||||||
|
codegen.add_push_null(
|
||||||
|
lambda: codegen.load_import_from("builtins", "map"), call_function_ex=True
|
||||||
|
)
|
||||||
|
codegen(self.fn)
|
||||||
|
self.reconstruct_items(codegen)
|
||||||
|
codegen.extend_output(
|
||||||
|
[
|
||||||
|
create_instruction("BUILD_TUPLE", arg=len(self.iterables) + 1),
|
||||||
|
create_instruction("CALL_FUNCTION_EX", arg=0),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -29,6 +29,7 @@ from ..utils import (
|
||||||
from .base import MutableLocal, VariableTracker
|
from .base import MutableLocal, VariableTracker
|
||||||
from .constant import ConstantVariable
|
from .constant import ConstantVariable
|
||||||
from .functions import UserFunctionVariable, UserMethodVariable
|
from .functions import UserFunctionVariable, UserMethodVariable
|
||||||
|
from .iter import IteratorVariable
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
|
@ -334,11 +335,11 @@ class CommonListMethodsVariable(BaseListVariable):
|
||||||
name == "extend"
|
name == "extend"
|
||||||
and self.mutable_local
|
and self.mutable_local
|
||||||
and args
|
and args
|
||||||
and args[0].has_unpack_var_sequence(tx)
|
and args[0].has_force_unpack_var_sequence(tx)
|
||||||
):
|
):
|
||||||
assert not kwargs
|
assert not kwargs
|
||||||
(arg,) = args
|
(arg,) = args
|
||||||
seq = arg.unpack_var_sequence(tx)
|
seq = arg.force_unpack_var_sequence(tx)
|
||||||
tx.output.side_effects.mutation(self)
|
tx.output.side_effects.mutation(self)
|
||||||
self.items.extend(seq)
|
self.items.extend(seq)
|
||||||
return ConstantVariable.create(None)
|
return ConstantVariable.create(None)
|
||||||
|
|
@ -422,11 +423,13 @@ class ListVariable(CommonListMethodsVariable):
|
||||||
key, value = args
|
key, value = args
|
||||||
tx.output.side_effects.mutation(self)
|
tx.output.side_effects.mutation(self)
|
||||||
if isinstance(key, SliceVariable):
|
if isinstance(key, SliceVariable):
|
||||||
if not value.has_unpack_var_sequence(tx):
|
if not value.has_force_unpack_var_sequence(tx):
|
||||||
unimplemented(
|
unimplemented(
|
||||||
f"Missing dynamo support for expanding {value} into a list for slice assignment."
|
f"Missing dynamo support for expanding {value} into a list for slice assignment."
|
||||||
)
|
)
|
||||||
self.items[key.as_python_constant()] = value.unpack_var_sequence(tx)
|
self.items[key.as_python_constant()] = value.force_unpack_var_sequence(
|
||||||
|
tx
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.items[key.as_python_constant()] = value
|
self.items[key.as_python_constant()] = value
|
||||||
return ConstantVariable.create(None)
|
return ConstantVariable.create(None)
|
||||||
|
|
@ -464,7 +467,12 @@ class DequeVariable(CommonListMethodsVariable):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
codegen.foreach(self.items)
|
codegen.foreach(self.items)
|
||||||
codegen.extend_output(create_call_function(len(self.items), False))
|
codegen.extend_output(
|
||||||
|
[
|
||||||
|
create_instruction("BUILD_LIST", arg=len(self.items)),
|
||||||
|
*create_call_function(1, False),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
def call_method(
|
def call_method(
|
||||||
self,
|
self,
|
||||||
|
|
@ -487,11 +495,15 @@ class DequeVariable(CommonListMethodsVariable):
|
||||||
tx.output.side_effects.mutation(self)
|
tx.output.side_effects.mutation(self)
|
||||||
self.items[key.as_python_constant()] = value
|
self.items[key.as_python_constant()] = value
|
||||||
return ConstantVariable.create(None)
|
return ConstantVariable.create(None)
|
||||||
elif name == "extendleft" and self.mutable_local:
|
elif (
|
||||||
|
name == "extendleft"
|
||||||
|
and self.mutable_local
|
||||||
|
and args[0].has_force_unpack_var_sequence(tx)
|
||||||
|
):
|
||||||
assert not kwargs
|
assert not kwargs
|
||||||
|
|
||||||
(arg,) = args
|
(arg,) = args
|
||||||
prefix = arg.unpack_var_sequence(tx)
|
prefix = arg.force_unpack_var_sequence(tx)
|
||||||
prefix.reverse()
|
prefix.reverse()
|
||||||
tx.output.side_effects.mutation(self)
|
tx.output.side_effects.mutation(self)
|
||||||
self.items = prefix + list(self.items)
|
self.items = prefix + list(self.items)
|
||||||
|
|
@ -802,10 +814,10 @@ class SliceVariable(BaseListVariable):
|
||||||
return self.items[fields.index(name)]
|
return self.items[fields.index(name)]
|
||||||
|
|
||||||
|
|
||||||
class ListIteratorVariable(VariableTracker):
|
class ListIteratorVariable(IteratorVariable):
|
||||||
_nonvar_fields = {
|
_nonvar_fields = {
|
||||||
"index",
|
"index",
|
||||||
*VariableTracker._nonvar_fields,
|
*IteratorVariable._nonvar_fields,
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, items, index: int = 0, **kwargs) -> None:
|
def __init__(self, items, index: int = 0, **kwargs) -> None:
|
||||||
|
|
@ -856,6 +868,9 @@ class ListIteratorVariable(VariableTracker):
|
||||||
def unpack_var_sequence(self, tx):
|
def unpack_var_sequence(self, tx):
|
||||||
return list(self.items[self.index :])
|
return list(self.items[self.index :])
|
||||||
|
|
||||||
|
def force_unpack_var_sequence(self, tx) -> List[VariableTracker]:
|
||||||
|
return self.unpack_var_sequence(tx)
|
||||||
|
|
||||||
def reconstruct(self, codegen):
|
def reconstruct(self, codegen):
|
||||||
remaining_items = self.items[self.index :]
|
remaining_items = self.items[self.index :]
|
||||||
codegen.foreach(remaining_items)
|
codegen.foreach(remaining_items)
|
||||||
|
|
|
||||||
|
|
@ -379,8 +379,8 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
||||||
elif self.value is collections.deque and not kwargs:
|
elif self.value is collections.deque and not kwargs:
|
||||||
if len(args) == 0:
|
if len(args) == 0:
|
||||||
items = []
|
items = []
|
||||||
elif len(args) == 1 and args[0].has_unpack_var_sequence(tx):
|
elif len(args) == 1 and args[0].has_force_unpack_var_sequence(tx):
|
||||||
items = args[0].unpack_var_sequence(tx)
|
items = args[0].force_unpack_var_sequence(tx)
|
||||||
else:
|
else:
|
||||||
unimplemented("deque() with more than 1 arg not supported")
|
unimplemented("deque() with more than 1 arg not supported")
|
||||||
return variables.lists.DequeVariable(items, mutable_local=MutableLocal())
|
return variables.lists.DequeVariable(items, mutable_local=MutableLocal())
|
||||||
|
|
@ -749,7 +749,7 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
||||||
assert not (args or kwargs)
|
assert not (args or kwargs)
|
||||||
items = []
|
items = []
|
||||||
keys = self.call_method(tx, "keys", [], {})
|
keys = self.call_method(tx, "keys", [], {})
|
||||||
for key in keys.unpack_var_sequence(tx):
|
for key in keys.force_unpack_var_sequence(tx):
|
||||||
items.append(
|
items.append(
|
||||||
TupleVariable(
|
TupleVariable(
|
||||||
[key, self.odict_getitem(tx, key)],
|
[key, self.odict_getitem(tx, key)],
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user