mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[Set] Raise TypeError if set is called with the wrong number of arguments (#152990)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152990 Approved by: https://github.com/anijain2305 ghstack dependencies: #150792, #152987, #152988, #152904, #152901, #152902, #152903, #152905, #152906, #152989, #152907, #152908
This commit is contained in:
parent
5a0ca65555
commit
f66a159db5
|
|
@ -1687,7 +1687,16 @@ class BuiltinVariable(VariableTracker):
|
||||||
assert not kwargs
|
assert not kwargs
|
||||||
if not args:
|
if not args:
|
||||||
return SetVariable([], mutation_type=ValueMutationNew())
|
return SetVariable([], mutation_type=ValueMutationNew())
|
||||||
assert len(args) == 1
|
if len(args) != 1:
|
||||||
|
raise_observed_exception(
|
||||||
|
TypeError,
|
||||||
|
tx,
|
||||||
|
args=[
|
||||||
|
ConstantVariable.create(
|
||||||
|
f"set() takes 1 positional argument but {len(args)} were given"
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
arg = args[0]
|
arg = args[0]
|
||||||
if isinstance(arg, variables.SetVariable):
|
if isinstance(arg, variables.SetVariable):
|
||||||
return arg.clone(mutation_type=ValueMutationNew())
|
return arg.clone(mutation_type=ValueMutationNew())
|
||||||
|
|
@ -1703,35 +1712,36 @@ class BuiltinVariable(VariableTracker):
|
||||||
if isinstance(out, SetVariable):
|
if isinstance(out, SetVariable):
|
||||||
return out
|
return out
|
||||||
return BuiltinVariable(set).call_set(tx, out)
|
return BuiltinVariable(set).call_set(tx, out)
|
||||||
unimplemented_v2(
|
raise_observed_exception(
|
||||||
gb_type="failed to construct builtin set()",
|
TypeError,
|
||||||
context=f"set(): {args} {kwargs}",
|
tx,
|
||||||
explanation="Unable to call builtin set() with provided arguments.",
|
args=[ConstantVariable.create("failed to construct builtin set()")],
|
||||||
hints=[
|
|
||||||
*graph_break_hints.USER_ERROR,
|
|
||||||
*graph_break_hints.SUPPORTABLE,
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def call_frozenset(self, tx: "InstructionTranslator", *args, **kwargs):
|
def call_frozenset(self, tx: "InstructionTranslator", *args, **kwargs):
|
||||||
assert not kwargs
|
assert not kwargs
|
||||||
if not args:
|
if not args:
|
||||||
return FrozensetVariable([])
|
return FrozensetVariable([])
|
||||||
assert len(args) == 1
|
if len(args) != 1:
|
||||||
|
raise_observed_exception(
|
||||||
|
TypeError,
|
||||||
|
tx,
|
||||||
|
args=[
|
||||||
|
ConstantVariable.create(
|
||||||
|
f"frozenset() takes 1 positional argument but {len(args)} were given"
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
arg = args[0]
|
arg = args[0]
|
||||||
if isinstance(arg, variables.FrozensetVariable):
|
if isinstance(arg, variables.FrozensetVariable):
|
||||||
return FrozensetVariable([x.vt for x in arg.set_items])
|
return FrozensetVariable([x.vt for x in arg.set_items])
|
||||||
elif arg.has_unpack_var_sequence(tx):
|
elif arg.has_unpack_var_sequence(tx):
|
||||||
items = arg.unpack_var_sequence(tx)
|
items = arg.unpack_var_sequence(tx)
|
||||||
return FrozensetVariable(items)
|
return FrozensetVariable(items)
|
||||||
unimplemented_v2(
|
raise_observed_exception(
|
||||||
gb_type="failed to construct builtin frozenset()",
|
TypeError,
|
||||||
context=f"frozenset(): {args} {kwargs}",
|
tx,
|
||||||
explanation="Unable to call builtin frozenset() with provided arguments.",
|
args=[ConstantVariable.create("failed to construct builtin frozenset()")],
|
||||||
hints=[
|
|
||||||
*graph_break_hints.USER_ERROR,
|
|
||||||
*graph_break_hints.SUPPORTABLE,
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def call_zip(self, tx: "InstructionTranslator", *args, **kwargs):
|
def call_zip(self, tx: "InstructionTranslator", *args, **kwargs):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user