mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Offload set method execution to CPython when possible (#160763)
Reduces CPython `test_set.py` runtime from 63.477s to 40.298s Pull Request resolved: https://github.com/pytorch/pytorch/pull/160763 Approved by: https://github.com/anijain2305
This commit is contained in:
parent
f00445b43e
commit
8076a185c8
|
|
@ -946,6 +946,18 @@ class SetVariable(ConstDictVariable):
|
||||||
codegen.foreach([x.vt for x in self.set_items])
|
codegen.foreach([x.vt for x in self.set_items])
|
||||||
codegen.append_output(create_instruction("BUILD_SET", arg=len(self.set_items)))
|
codegen.append_output(create_instruction("BUILD_SET", arg=len(self.set_items)))
|
||||||
|
|
||||||
|
def _fast_set_method(self, tx, fn, args, kwargs):
|
||||||
|
try:
|
||||||
|
res = fn(
|
||||||
|
*[x.as_python_constant() for x in [self, *args]],
|
||||||
|
**{k: v.as_python_constant() for k, v in kwargs.items()},
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
raise_observed_exception(
|
||||||
|
type(exc), tx, args=list(map(ConstantVariable.create, exc.args))
|
||||||
|
)
|
||||||
|
return VariableTracker.build(tx, res)
|
||||||
|
|
||||||
def call_method(
|
def call_method(
|
||||||
self,
|
self,
|
||||||
tx,
|
tx,
|
||||||
|
|
@ -954,6 +966,23 @@ class SetVariable(ConstDictVariable):
|
||||||
kwargs: dict[str, VariableTracker],
|
kwargs: dict[str, VariableTracker],
|
||||||
) -> "VariableTracker":
|
) -> "VariableTracker":
|
||||||
# We forward the calls to the dictionary model
|
# We forward the calls to the dictionary model
|
||||||
|
from ..utils import check_constant_args
|
||||||
|
|
||||||
|
if (
|
||||||
|
name
|
||||||
|
in (
|
||||||
|
"isdisjoint",
|
||||||
|
"union",
|
||||||
|
"intersection",
|
||||||
|
"difference",
|
||||||
|
"symmetric_difference",
|
||||||
|
)
|
||||||
|
and check_constant_args(args, kwargs)
|
||||||
|
and self.python_type() is set
|
||||||
|
):
|
||||||
|
py_type = self.python_type()
|
||||||
|
return self._fast_set_method(tx, getattr(py_type, name), args, kwargs)
|
||||||
|
|
||||||
if name == "__init__":
|
if name == "__init__":
|
||||||
temp_set_vt = variables.BuiltinVariable(set).call_set(tx, *args, *kwargs)
|
temp_set_vt = variables.BuiltinVariable(set).call_set(tx, *args, *kwargs)
|
||||||
tx.output.side_effects.mutation(self)
|
tx.output.side_effects.mutation(self)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user