Offload set method execution to CPython when possible (#160763)

Reduces CPython `test_set.py` runtime from 63.477s to 40.298s

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160763
Approved by: https://github.com/anijain2305
This commit is contained in:
Guilherme Leobas 2025-09-03 13:56:47 +00:00 committed by PyTorch MergeBot
parent f00445b43e
commit 8076a185c8
5 changed files with 29 additions and 0 deletions

View File

@ -946,6 +946,18 @@ class SetVariable(ConstDictVariable):
codegen.foreach([x.vt for x in self.set_items]) codegen.foreach([x.vt for x in self.set_items])
codegen.append_output(create_instruction("BUILD_SET", arg=len(self.set_items))) codegen.append_output(create_instruction("BUILD_SET", arg=len(self.set_items)))
def _fast_set_method(self, tx, fn, args, kwargs):
try:
res = fn(
*[x.as_python_constant() for x in [self, *args]],
**{k: v.as_python_constant() for k, v in kwargs.items()},
)
except Exception as exc:
raise_observed_exception(
type(exc), tx, args=list(map(ConstantVariable.create, exc.args))
)
return VariableTracker.build(tx, res)
def call_method( def call_method(
self, self,
tx, tx,
@ -954,6 +966,23 @@ class SetVariable(ConstDictVariable):
kwargs: dict[str, VariableTracker], kwargs: dict[str, VariableTracker],
) -> "VariableTracker": ) -> "VariableTracker":
# We forward the calls to the dictionary model # We forward the calls to the dictionary model
from ..utils import check_constant_args
if (
name
in (
"isdisjoint",
"union",
"intersection",
"difference",
"symmetric_difference",
)
and check_constant_args(args, kwargs)
and self.python_type() is set
):
py_type = self.python_type()
return self._fast_set_method(tx, getattr(py_type, name), args, kwargs)
if name == "__init__": if name == "__init__":
temp_set_vt = variables.BuiltinVariable(set).call_set(tx, *args, *kwargs) temp_set_vt = variables.BuiltinVariable(set).call_set(tx, *args, *kwargs)
tx.output.side_effects.mutation(self) tx.output.side_effects.mutation(self)