mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Offload set method execution to CPython when possible (#160763)
Reduces CPython `test_set.py` runtime from 63.477s to 40.298s Pull Request resolved: https://github.com/pytorch/pytorch/pull/160763 Approved by: https://github.com/anijain2305
This commit is contained in:
parent
f00445b43e
commit
8076a185c8
|
|
@ -946,6 +946,18 @@ class SetVariable(ConstDictVariable):
|
|||
codegen.foreach([x.vt for x in self.set_items])
|
||||
codegen.append_output(create_instruction("BUILD_SET", arg=len(self.set_items)))
|
||||
|
||||
def _fast_set_method(self, tx, fn, args, kwargs):
|
||||
try:
|
||||
res = fn(
|
||||
*[x.as_python_constant() for x in [self, *args]],
|
||||
**{k: v.as_python_constant() for k, v in kwargs.items()},
|
||||
)
|
||||
except Exception as exc:
|
||||
raise_observed_exception(
|
||||
type(exc), tx, args=list(map(ConstantVariable.create, exc.args))
|
||||
)
|
||||
return VariableTracker.build(tx, res)
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx,
|
||||
|
|
@ -954,6 +966,23 @@ class SetVariable(ConstDictVariable):
|
|||
kwargs: dict[str, VariableTracker],
|
||||
) -> "VariableTracker":
|
||||
# We forward the calls to the dictionary model
|
||||
from ..utils import check_constant_args
|
||||
|
||||
if (
|
||||
name
|
||||
in (
|
||||
"isdisjoint",
|
||||
"union",
|
||||
"intersection",
|
||||
"difference",
|
||||
"symmetric_difference",
|
||||
)
|
||||
and check_constant_args(args, kwargs)
|
||||
and self.python_type() is set
|
||||
):
|
||||
py_type = self.python_type()
|
||||
return self._fast_set_method(tx, getattr(py_type, name), args, kwargs)
|
||||
|
||||
if name == "__init__":
|
||||
temp_set_vt = variables.BuiltinVariable(set).call_set(tx, *args, *kwargs)
|
||||
tx.output.side_effects.mutation(self)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user