mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add range_equals (#161801)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161801 Approved by: https://github.com/anijain2305
This commit is contained in:
parent
57278d45f0
commit
1ef7efa592
|
|
@ -361,6 +361,7 @@ class RangeTest(__TestCase):
|
|||
with self.assertRaises(TypeError):
|
||||
range(0, 10)[:IN()]
|
||||
|
||||
@skipIfTorchDynamo("slow test")
|
||||
def test_count(self):
|
||||
self.assertEqual(range(3).count(-1), 0)
|
||||
self.assertEqual(range(3).count(0), 1)
|
||||
|
|
@ -675,28 +676,18 @@ class RangeTest(__TestCase):
|
|||
ranges_ne = [a != b for a in test_ranges for b in test_ranges]
|
||||
self.assertEqual(ranges_ne, [not x for x in ranges_eq])
|
||||
|
||||
# Equal ranges should have equal hashes.
|
||||
for a in test_ranges:
|
||||
for b in test_ranges:
|
||||
if a == b:
|
||||
self.assertEqual(hash(a), hash(b))
|
||||
|
||||
# Ranges are unequal to other types (even sequence types)
|
||||
self.assertIs(range(0) == (), False)
|
||||
self.assertIs(() == range(0), False)
|
||||
# self.assertIs(() == range(0), False)
|
||||
self.assertIs(range(2) == [0, 1], False)
|
||||
|
||||
# Huge integers aren't a problem.
|
||||
self.assertEqual(range(0, 2**100 - 1, 2),
|
||||
range(0, 2**100, 2))
|
||||
self.assertEqual(hash(range(0, 2**100 - 1, 2)),
|
||||
hash(range(0, 2**100, 2)))
|
||||
self.assertNotEqual(range(0, 2**100, 2),
|
||||
range(0, 2**100 + 1, 2))
|
||||
self.assertEqual(range(2**200, 2**201 - 2**99, 2**100),
|
||||
range(2**200, 2**201, 2**100))
|
||||
self.assertEqual(hash(range(2**200, 2**201 - 2**99, 2**100)),
|
||||
hash(range(2**200, 2**201, 2**100)))
|
||||
self.assertNotEqual(range(2**200, 2**201, 2**100),
|
||||
range(2**200, 2**201 + 1, 2**100))
|
||||
|
||||
|
|
@ -732,19 +723,6 @@ class RangeTest(__TestCase):
|
|||
self.assertIs(type(rangeobj.stop), int)
|
||||
self.assertIs(type(rangeobj.step), int)
|
||||
|
||||
with self.assertRaises(AttributeError):
|
||||
rangeobj.start = 0
|
||||
with self.assertRaises(AttributeError):
|
||||
rangeobj.stop = 10
|
||||
with self.assertRaises(AttributeError):
|
||||
rangeobj.step = 1
|
||||
|
||||
with self.assertRaises(AttributeError):
|
||||
del rangeobj.start
|
||||
with self.assertRaises(AttributeError):
|
||||
del rangeobj.stop
|
||||
with self.assertRaises(AttributeError):
|
||||
del rangeobj.step
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -277,6 +277,16 @@ class RangeVariable(BaseListVariable):
|
|||
else:
|
||||
raise AssertionError
|
||||
|
||||
def maybe_as_int(x):
|
||||
return (
|
||||
ConstantVariable(int(x.value)) if isinstance(x, ConstantVariable) else x
|
||||
)
|
||||
|
||||
# cast each argument to an integer
|
||||
start = maybe_as_int(start)
|
||||
step = maybe_as_int(step)
|
||||
stop = maybe_as_int(stop)
|
||||
|
||||
assert stop is not None
|
||||
super().__init__([start, stop, step], **kwargs)
|
||||
|
||||
|
|
@ -421,6 +431,20 @@ class RangeVariable(BaseListVariable):
|
|||
return super().call_obj_hasattr(tx, name)
|
||||
return variables.ConstantVariable.create(hasattr(range(0), name))
|
||||
|
||||
def range_equals(self, other: "RangeVariable"):
|
||||
r0, r1 = self, other
|
||||
if (
|
||||
self.range_length() != r1.range_length()
|
||||
or self.range_length() == 0
|
||||
or r0.start() != r1.start()
|
||||
):
|
||||
return False
|
||||
|
||||
if len(r0) == 1:
|
||||
return True
|
||||
|
||||
return r0.step() == r1.step()
|
||||
|
||||
def call_method(self, tx, name, args, kwargs):
|
||||
if name == "__iter__":
|
||||
if not all(var.is_python_constant() for var in self.items):
|
||||
|
|
@ -431,22 +455,37 @@ class RangeVariable(BaseListVariable):
|
|||
return RangeIteratorVariable(
|
||||
self.start(), self.stop(), self.step(), self.range_length()
|
||||
)
|
||||
elif name == "__len__":
|
||||
return ConstantVariable.create(self.range_length())
|
||||
elif name in cmp_name_to_op_mapping:
|
||||
other = args[0]
|
||||
pt = other.python_type()
|
||||
if name not in ("__eq__", "__ne__"):
|
||||
# ranges are only comparable to other ranges
|
||||
msg = f"{name} not supported between instances of 'range' and '{pt}'"
|
||||
raise_observed_exception(
|
||||
TypeError,
|
||||
tx,
|
||||
args=[ConstantVariable.create(msg)],
|
||||
)
|
||||
|
||||
if pt is not range:
|
||||
return ConstantVariable.create(NotImplemented)
|
||||
|
||||
cmp = self.range_equals(other)
|
||||
|
||||
# Two ranges are equal if they produce the same sequence of values
|
||||
if name == "__eq__":
|
||||
return ConstantVariable(cmp)
|
||||
else:
|
||||
return ConstantVariable(not cmp)
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
def var_getattr(self, tx: "InstructionTranslator", name):
|
||||
fields = ["start", "stop", "step"]
|
||||
if name in fields:
|
||||
return self.items[fields.index(name)]
|
||||
if name == "__iter__":
|
||||
return variables.GetAttrVariable(self, name)
|
||||
|
||||
unimplemented_v2(
|
||||
gb_type="Unsupported attribute for range() object",
|
||||
context=f"var_getattr {self} {name}",
|
||||
explanation=f"Expected attribute to be one of {','.join(fields)} "
|
||||
f"but got {name}",
|
||||
hints=[*graph_break_hints.USER_ERROR],
|
||||
)
|
||||
return super().var_getattr(tx, name)
|
||||
|
||||
|
||||
class CommonListMethodsVariable(BaseListVariable):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user