mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "[dynamo] Use polyfill to implement comparison operators (#144485)"
This reverts commit d1f82de2bf.
Reverted https://github.com/pytorch/pytorch/pull/144485 on behalf of https://github.com/huydhn due to This seems to break dynamo tests in trunk after landing ([comment](https://github.com/pytorch/pytorch/pull/144485#issuecomment-2622893294))
This commit is contained in:
parent
953e80936e
commit
1185b81c51
|
|
@ -74,7 +74,7 @@ detectron2_fasterrcnn_r_50_fpn,fail_accuracy,46
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
detectron2_fcos_r_50_fpn,pass,22
|
detectron2_fcos_r_50_fpn,pass,24
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
|
@ -74,7 +74,7 @@ detectron2_fasterrcnn_r_50_fpn,pass,46
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
detectron2_fcos_r_50_fpn,pass,22
|
detectron2_fcos_r_50_fpn,pass,24
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
|
@ -74,7 +74,7 @@ detectron2_fasterrcnn_r_50_fpn,pass,46
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
detectron2_fcos_r_50_fpn,pass,22
|
detectron2_fcos_r_50_fpn,pass,24
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
|
@ -74,7 +74,7 @@ detectron2_fasterrcnn_r_50_fpn,pass,46
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
detectron2_fcos_r_50_fpn,pass,22
|
detectron2_fcos_r_50_fpn,pass,24
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
|
@ -82,7 +82,7 @@ detectron2_fasterrcnn_r_50_fpn,eager_fail_to_run,0
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
detectron2_fcos_r_50_fpn,pass,20
|
detectron2_fcos_r_50_fpn,pass,22
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
|
@ -11782,33 +11782,6 @@ fn
|
||||||
_, ne = run(torch.ones(1))
|
_, ne = run(torch.ones(1))
|
||||||
self.assertFalse(ne)
|
self.assertFalse(ne)
|
||||||
|
|
||||||
def test_ne_operator_with_custom_ne(self):
|
|
||||||
class Foo:
|
|
||||||
def __init__(self, x):
|
|
||||||
self.x = x
|
|
||||||
self.ne_called = False
|
|
||||||
|
|
||||||
def __ne__(self, other):
|
|
||||||
# ne_called attr is later checked to ensure that overrideen
|
|
||||||
# `__ne__` is traced
|
|
||||||
self.ne_called = True
|
|
||||||
return not self.__eq__(other)
|
|
||||||
|
|
||||||
def __eq__(self, other):
|
|
||||||
return self.x == other.x
|
|
||||||
|
|
||||||
f1 = Foo(0)
|
|
||||||
f2 = Foo(0)
|
|
||||||
|
|
||||||
@torch.compile(fullgraph=True, backend="eager")
|
|
||||||
def run(x):
|
|
||||||
# `x + 1` prevents Dynamo from skipping this frame.
|
|
||||||
return x + 1, f1 != f2
|
|
||||||
|
|
||||||
_, ne = run(torch.ones(1))
|
|
||||||
self.assertFalse(ne)
|
|
||||||
self.assertTrue(f1.ne_called)
|
|
||||||
|
|
||||||
def test_ne_operator_with_custom_graphbreak_eq(self):
|
def test_ne_operator_with_custom_graphbreak_eq(self):
|
||||||
counters.clear()
|
counters.clear()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6007,19 +6007,6 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
|
||||||
self.assertEqual(fn(config, x), opt_fn(config, x))
|
self.assertEqual(fn(config, x), opt_fn(config, x))
|
||||||
self.assertEqual(cloned_config.baz, 4)
|
self.assertEqual(cloned_config.baz, 4)
|
||||||
|
|
||||||
@unittest.skipIf(not HAS_OMEGACONG, "missing omegaconf package")
|
|
||||||
def test_omegaconf_listconfig_contains(self):
|
|
||||||
def fn(cfg, x):
|
|
||||||
if 1 in cfg:
|
|
||||||
return torch.sin(x)
|
|
||||||
return torch.cos(x)
|
|
||||||
|
|
||||||
config = OmegaConf.create([1, 2, 3, {"key": "value"}])
|
|
||||||
|
|
||||||
x = torch.randn(4)
|
|
||||||
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
||||||
self.assertEqual(fn(config, x), opt_fn(config, x))
|
|
||||||
|
|
||||||
# https://github.com/pytorch/pytorch/issues/136257
|
# https://github.com/pytorch/pytorch/issues/136257
|
||||||
def test_overwriting_params(self):
|
def test_overwriting_params(self):
|
||||||
class M(torch.nn.Module):
|
class M(torch.nn.Module):
|
||||||
|
|
|
||||||
|
|
@ -8,15 +8,12 @@ Python polyfills for common builtins.
|
||||||
|
|
||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
|
|
||||||
import types
|
|
||||||
from collections.abc import MutableMapping, Sequence
|
from collections.abc import MutableMapping, Sequence
|
||||||
from itertools import repeat as _repeat
|
from itertools import repeat as _repeat
|
||||||
from typing import Any, Callable, List, TYPE_CHECKING
|
from typing import Any, Callable, List, TYPE_CHECKING
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ..utils import dict_keys
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
# Load by torch._dynamo.polyfills.loader
|
# Load by torch._dynamo.polyfills.loader
|
||||||
|
|
@ -222,52 +219,14 @@ def predicate(obj: Any) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def cmp_eq(a, b):
|
def object_eq(self, other):
|
||||||
# Note that the commented `is` check should ideally be removed. This is a
|
# Mirrors CPython implementation:
|
||||||
# CPython optimization that skips the __eq__ checks it the obj id's are
|
# https://github.com/python/cpython/blob/a1c52d1265c65bcf0d9edf87e143843ad54f9b8f/Objects/typeobject.c#L6228-L6233
|
||||||
# same. But, these lines adds many `is` nodes in the Fx graph for
|
return self is other
|
||||||
# SymNodeVariable. For now, we can just skip this check. This is STILL
|
|
||||||
# correct because one of the __eq__ checks will pass later, just could be
|
|
||||||
# slow in some corner cases.
|
|
||||||
# if a is b:
|
|
||||||
# return True
|
|
||||||
result = a.__eq__(b)
|
|
||||||
if result is NotImplemented:
|
|
||||||
result = b.__eq__(a)
|
|
||||||
return result is not NotImplemented and result
|
|
||||||
|
|
||||||
|
|
||||||
def cmp_ne(a, b):
|
def object_ne(self, other):
|
||||||
# Check if __ne__ is overridden
|
# Mirrors CPython implementation:
|
||||||
if isinstance(type(a).__ne__, types.FunctionType):
|
# https://github.com/python/cpython/blob/a1c52d1265c65bcf0d9edf87e143843ad54f9b8f/Objects/typeobject.c#L6235-L6255
|
||||||
return a.__ne__(b)
|
# Using `==` is important because `self` might have a user-defined `__eq__`.
|
||||||
return not cmp_eq(a, b)
|
return not (self == other)
|
||||||
|
|
||||||
|
|
||||||
def cmp_lt(a, b):
|
|
||||||
result = a.__lt__(b)
|
|
||||||
if result is NotImplemented:
|
|
||||||
raise TypeError(f"{type(a)} does not support the < operator")
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def cmp_le(a, b):
|
|
||||||
# Check if __le__ is overridden
|
|
||||||
if isinstance(type(a).__le__, types.FunctionType):
|
|
||||||
return a.__le__(b)
|
|
||||||
return cmp_eq(a, b) or cmp_lt(a, b)
|
|
||||||
|
|
||||||
|
|
||||||
def cmp_gt(a, b):
|
|
||||||
# Check if __gt__ is overridden
|
|
||||||
if isinstance(type(a).__gt__, types.FunctionType):
|
|
||||||
return a.__gt__(b)
|
|
||||||
# a > b is equivalent to b < a
|
|
||||||
return cmp_lt(b, a)
|
|
||||||
|
|
||||||
|
|
||||||
def cmp_ge(a, b):
|
|
||||||
# Check if __ge__ is overridden
|
|
||||||
if isinstance(type(a).__ge__, types.FunctionType):
|
|
||||||
return a.__ge__(b)
|
|
||||||
return cmp_eq(a, b) or cmp_gt(a, b)
|
|
||||||
|
|
|
||||||
|
|
@ -1008,16 +1008,6 @@ def is_function(value):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
cmp_name_to_op_mapping = {
|
|
||||||
"__eq__": operator.eq,
|
|
||||||
"__ne__": operator.ne,
|
|
||||||
"__lt__": operator.lt,
|
|
||||||
"__le__": operator.le,
|
|
||||||
"__gt__": operator.gt,
|
|
||||||
"__ge__": operator.ge,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def is_wrapper_or_member_descriptor(value):
|
def is_wrapper_or_member_descriptor(value):
|
||||||
return isinstance(
|
return isinstance(
|
||||||
value,
|
value,
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ from ..current_scope_id import current_scope_id
|
||||||
from ..exc import unimplemented
|
from ..exc import unimplemented
|
||||||
from ..guards import GuardBuilder, install_guard
|
from ..guards import GuardBuilder, install_guard
|
||||||
from ..source import AttrSource, Source
|
from ..source import AttrSource, Source
|
||||||
from ..utils import cmp_name_to_op_mapping, istype
|
from ..utils import istype
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
|
@ -410,29 +410,6 @@ class VariableTracker(metaclass=VariableTrackerMeta):
|
||||||
and not kwargs
|
and not kwargs
|
||||||
):
|
):
|
||||||
return self.var_getattr(tx, args[0].as_python_constant())
|
return self.var_getattr(tx, args[0].as_python_constant())
|
||||||
elif (
|
|
||||||
name in cmp_name_to_op_mapping
|
|
||||||
and len(args) == 1
|
|
||||||
and self.is_python_constant()
|
|
||||||
and not tx.output.side_effects.has_pending_mutation(self)
|
|
||||||
and not kwargs
|
|
||||||
):
|
|
||||||
# NB : Checking for mutation is necessary because we compare
|
|
||||||
# constant values
|
|
||||||
other = args[0]
|
|
||||||
if not isinstance(self, type(other)):
|
|
||||||
return variables.ConstantVariable.create(NotImplemented)
|
|
||||||
if (
|
|
||||||
not other.is_python_constant()
|
|
||||||
or tx.output.side_effects.has_pending_mutation(other)
|
|
||||||
):
|
|
||||||
unimplemented(f"call_method {self} {name} {args} {kwargs}")
|
|
||||||
|
|
||||||
return variables.ConstantVariable.create(
|
|
||||||
cmp_name_to_op_mapping[name](
|
|
||||||
self.as_python_constant(), other.as_python_constant()
|
|
||||||
)
|
|
||||||
)
|
|
||||||
unimplemented(f"call_method {self} {name} {args} {kwargs}")
|
unimplemented(f"call_method {self} {name} {args} {kwargs}")
|
||||||
|
|
||||||
def set_name_hint(self, name):
|
def set_name_hint(self, name):
|
||||||
|
|
|
||||||
|
|
@ -38,7 +38,6 @@ from ..utils import (
|
||||||
check_numpy_ndarray_args,
|
check_numpy_ndarray_args,
|
||||||
check_unspec_or_constant_args,
|
check_unspec_or_constant_args,
|
||||||
check_unspec_python_args,
|
check_unspec_python_args,
|
||||||
cmp_name_to_op_mapping,
|
|
||||||
dict_methods,
|
dict_methods,
|
||||||
extract_fake_example_value,
|
extract_fake_example_value,
|
||||||
get_fake_value,
|
get_fake_value,
|
||||||
|
|
@ -101,16 +100,6 @@ IN_PLACE_DESUGARING_MAP = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
polyfill_fn_mapping = {
|
|
||||||
operator.eq: polyfills.cmp_eq,
|
|
||||||
operator.ne: polyfills.cmp_ne,
|
|
||||||
operator.lt: polyfills.cmp_lt,
|
|
||||||
operator.le: polyfills.cmp_le,
|
|
||||||
operator.gt: polyfills.cmp_gt,
|
|
||||||
operator.ge: polyfills.cmp_ge,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class BuiltinVariable(VariableTracker):
|
class BuiltinVariable(VariableTracker):
|
||||||
_SENTINEL = object()
|
_SENTINEL = object()
|
||||||
_nonvar_fields = {
|
_nonvar_fields = {
|
||||||
|
|
@ -283,6 +272,7 @@ class BuiltinVariable(VariableTracker):
|
||||||
# combinations. Handlers are attempted in order, and will be used if the type checks
|
# combinations. Handlers are attempted in order, and will be used if the type checks
|
||||||
# match. They are expected to have the signature:
|
# match. They are expected to have the signature:
|
||||||
# fn(tx, arg0: VariableTracker, arg1: VariableTracker) -> VariableTracker
|
# fn(tx, arg0: VariableTracker, arg1: VariableTracker) -> VariableTracker
|
||||||
|
from .dicts import DictKeysVariable, SetVariable
|
||||||
from .functions import BaseUserFunctionVariable, UserFunctionVariable
|
from .functions import BaseUserFunctionVariable, UserFunctionVariable
|
||||||
from .nn_module import NNModuleVariable
|
from .nn_module import NNModuleVariable
|
||||||
from .tensor import supported_const_comparison_ops
|
from .tensor import supported_const_comparison_ops
|
||||||
|
|
@ -468,52 +458,38 @@ class BuiltinVariable(VariableTracker):
|
||||||
]
|
]
|
||||||
op_handlers[operator.mul].extend(list_like_expansion_handlers)
|
op_handlers[operator.mul].extend(list_like_expansion_handlers)
|
||||||
|
|
||||||
|
size_or_tuple = (SizeVariable, TupleVariable)
|
||||||
|
has_set_items = (SetVariable, DictKeysVariable)
|
||||||
|
|
||||||
def create_cmp_op_handlers(op):
|
def create_cmp_op_handlers(op):
|
||||||
def compare_by_value(tx: "InstructionTranslator", a, b):
|
def compare_by_value(tx: "InstructionTranslator", a, b):
|
||||||
return ConstantVariable(op(a.value, b.value))
|
return ConstantVariable(op(a.value, b.value))
|
||||||
|
|
||||||
if op in polyfill_fn_mapping:
|
|
||||||
# For constants, speedup the comparison instead of using
|
|
||||||
# polyfill. Removing this line causes major regression for pr
|
|
||||||
# time benchmark - add_loop_eager.
|
|
||||||
result = [((ConstantVariable, ConstantVariable), compare_by_value)]
|
result = [((ConstantVariable, ConstantVariable), compare_by_value)]
|
||||||
|
|
||||||
op_var = BuiltinVariable(op)
|
if op in supported_const_comparison_ops.values():
|
||||||
# Special handling of SymNode variable
|
|
||||||
result.extend(
|
|
||||||
[
|
|
||||||
(
|
|
||||||
(SymNodeVariable, VariableTracker),
|
|
||||||
op_var._comparison_with_symnode,
|
|
||||||
),
|
|
||||||
(
|
|
||||||
(VariableTracker, SymNodeVariable),
|
|
||||||
op_var._comparison_with_symnode,
|
|
||||||
),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
def handler(tx, a, b):
|
|
||||||
return tx.inline_user_function_return(
|
|
||||||
VariableTracker.build(tx, polyfill_fn_mapping[op]), [a, b], {}
|
|
||||||
)
|
|
||||||
|
|
||||||
result.append(((VariableTracker, VariableTracker), handler))
|
|
||||||
return result
|
|
||||||
|
|
||||||
result = [((ConstantVariable, ConstantVariable), compare_by_value)]
|
|
||||||
|
|
||||||
if op in supported_const_comparison_ops.values() and op.__name__.startswith(
|
|
||||||
"is_"
|
|
||||||
):
|
|
||||||
# Tensor is None, List is not None, etc
|
# Tensor is None, List is not None, etc
|
||||||
none_result = op(object(), None)
|
none_result = op(object(), None)
|
||||||
|
if op.__name__.startswith("is_"):
|
||||||
|
|
||||||
def never(tx: "InstructionTranslator", a, b):
|
def never(tx: "InstructionTranslator", a, b):
|
||||||
return ConstantVariable(none_result)
|
return ConstantVariable(none_result)
|
||||||
|
|
||||||
obj_op_none = never
|
obj_op_none = never
|
||||||
none_op_obj = never
|
none_op_obj = never
|
||||||
|
else:
|
||||||
|
|
||||||
|
def obj_op_none(
|
||||||
|
tx: "InstructionTranslator", a, b: ConstantVariable
|
||||||
|
):
|
||||||
|
if b.value is None or b.value is True or b.value is False:
|
||||||
|
return ConstantVariable(none_result)
|
||||||
|
|
||||||
|
def none_op_obj(
|
||||||
|
tx: "InstructionTranslator", a: ConstantVariable, b
|
||||||
|
):
|
||||||
|
if a.value is None or a.value is True or a.value is False:
|
||||||
|
return ConstantVariable(none_result)
|
||||||
|
|
||||||
types_that_are_never_none = (
|
types_that_are_never_none = (
|
||||||
TensorVariable,
|
TensorVariable,
|
||||||
|
|
@ -538,6 +514,27 @@ class BuiltinVariable(VariableTracker):
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def list_compare_nocheck(tx: "InstructionTranslator", left, right):
|
||||||
|
return BaseListVariable.list_compare(tx, op, left, right)
|
||||||
|
|
||||||
|
def list_compare_check(tx: "InstructionTranslator", left, right):
|
||||||
|
if type(left) is not type(
|
||||||
|
right
|
||||||
|
): # Mismatch in BaseListVariable subclasses
|
||||||
|
unimplemented(f"{op.__name__}({left}, {right})")
|
||||||
|
return BaseListVariable.list_compare(tx, op, left, right)
|
||||||
|
|
||||||
|
def compare_set_items(tx: "InstructionTranslator", left, right):
|
||||||
|
return ConstantVariable(op(left.set_items, right.set_items))
|
||||||
|
|
||||||
|
def compare_via_method(tx: "InstructionTranslator", left, right):
|
||||||
|
return left.call_method(tx, f"__{op.__name__}__", [right], {})
|
||||||
|
|
||||||
|
if op.__name__.startswith("is_"):
|
||||||
|
compare_user_defined = compare_by_value
|
||||||
|
else:
|
||||||
|
compare_user_defined = compare_via_method
|
||||||
|
|
||||||
op_var = BuiltinVariable(op)
|
op_var = BuiltinVariable(op)
|
||||||
result.extend(
|
result.extend(
|
||||||
[
|
[
|
||||||
|
|
@ -560,13 +557,19 @@ class BuiltinVariable(VariableTracker):
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
((size_or_tuple, size_or_tuple), list_compare_nocheck),
|
||||||
|
(
|
||||||
|
(variables.BaseListVariable, variables.BaseListVariable),
|
||||||
|
list_compare_check,
|
||||||
|
),
|
||||||
|
((has_set_items, has_set_items), compare_set_items),
|
||||||
(
|
(
|
||||||
(UserDefinedObjectVariable, UserDefinedObjectVariable),
|
(UserDefinedObjectVariable, UserDefinedObjectVariable),
|
||||||
compare_by_value,
|
compare_user_defined,
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
(UserDefinedClassVariable, UserDefinedClassVariable),
|
(UserDefinedClassVariable, UserDefinedClassVariable),
|
||||||
compare_by_value,
|
compare_user_defined,
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
(
|
(
|
||||||
|
|
@ -594,6 +597,8 @@ class BuiltinVariable(VariableTracker):
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if op.__name__.startswith("is_"):
|
||||||
|
|
||||||
def handle_is(tx: "InstructionTranslator", left, right):
|
def handle_is(tx: "InstructionTranslator", left, right):
|
||||||
# If the two objects are of different type, we can safely return False
|
# If the two objects are of different type, we can safely return False
|
||||||
# and True for `is` and `is not`, respectively
|
# and True for `is` and `is not`, respectively
|
||||||
|
|
@ -1701,8 +1706,6 @@ class BuiltinVariable(VariableTracker):
|
||||||
member, (torch._ops.OpOverloadPacket, torch._ops.OpOverload)
|
member, (torch._ops.OpOverloadPacket, torch._ops.OpOverload)
|
||||||
) and torch._dynamo.trace_rules.is_aten_op_or_tensor_method(member):
|
) and torch._dynamo.trace_rules.is_aten_op_or_tensor_method(member):
|
||||||
return variables.TorchInGraphFunctionVariable(member, source=source)
|
return variables.TorchInGraphFunctionVariable(member, source=source)
|
||||||
elif name in cmp_name_to_op_mapping:
|
|
||||||
return variables.GetAttrVariable(obj, name, source=source)
|
|
||||||
elif isinstance(obj, DummyModule):
|
elif isinstance(obj, DummyModule):
|
||||||
# TODO(mlazos) - Do we need this?
|
# TODO(mlazos) - Do we need this?
|
||||||
if obj.is_torch or name not in obj.value.__dict__:
|
if obj.is_torch or name not in obj.value.__dict__:
|
||||||
|
|
|
||||||
|
|
@ -8,8 +8,8 @@ from torch._dynamo.source import AttrSource, GetItemSource
|
||||||
|
|
||||||
from .. import variables
|
from .. import variables
|
||||||
from ..exc import raise_observed_exception, unimplemented
|
from ..exc import raise_observed_exception, unimplemented
|
||||||
from ..utils import cmp_name_to_op_mapping, common_constant_types, istype, np
|
from ..utils import common_constant_types, istype, np
|
||||||
from .base import VariableTracker
|
from .base import typestr, VariableTracker
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
|
@ -192,7 +192,8 @@ its type to `common_constant_types`.
|
||||||
search = args[0].as_python_constant()
|
search = args[0].as_python_constant()
|
||||||
result = search in self.value
|
result = search in self.value
|
||||||
return ConstantVariable.create(result)
|
return ConstantVariable.create(result)
|
||||||
return super().call_method(tx, name, args, kwargs)
|
|
||||||
|
unimplemented(f"const method call {typestr(self.value)}.{name}")
|
||||||
|
|
||||||
def call_hasattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
|
def call_hasattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
|
||||||
result = hasattr(self.value, name)
|
result = hasattr(self.value, name)
|
||||||
|
|
@ -226,8 +227,6 @@ class EnumVariable(VariableTracker):
|
||||||
def var_getattr(self, tx: "InstructionTranslator", name):
|
def var_getattr(self, tx: "InstructionTranslator", name):
|
||||||
if not hasattr(self.value, name):
|
if not hasattr(self.value, name):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
if name in cmp_name_to_op_mapping:
|
|
||||||
return variables.GetAttrVariable(self, name)
|
|
||||||
member = getattr(self.value, name)
|
member = getattr(self.value, name)
|
||||||
source = self.source and AttrSource(self.source, name)
|
source = self.source and AttrSource(self.source, name)
|
||||||
return VariableTracker.build(tx, member, source=source)
|
return VariableTracker.build(tx, member, source=source)
|
||||||
|
|
|
||||||
|
|
@ -1116,9 +1116,6 @@ class StreamVariable(VariableTracker):
|
||||||
self.value = value
|
self.value = value
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
def python_type(self):
|
|
||||||
return torch.Stream
|
|
||||||
|
|
||||||
def call_method(
|
def call_method(
|
||||||
self,
|
self,
|
||||||
tx,
|
tx,
|
||||||
|
|
@ -1127,8 +1124,15 @@ class StreamVariable(VariableTracker):
|
||||||
kwargs: "dict[str, VariableTracker]",
|
kwargs: "dict[str, VariableTracker]",
|
||||||
) -> "VariableTracker":
|
) -> "VariableTracker":
|
||||||
assert hasattr(self.value, name), f"no stream method found named {name}"
|
assert hasattr(self.value, name), f"no stream method found named {name}"
|
||||||
|
assert name in [
|
||||||
|
"wait_stream",
|
||||||
|
"synchronize",
|
||||||
|
"query",
|
||||||
|
"record_event",
|
||||||
|
"wait_event",
|
||||||
|
], f" unsupported stream method {name}"
|
||||||
|
|
||||||
from ..utils import cmp_name_to_op_mapping, proxy_args_kwargs
|
from ..utils import proxy_args_kwargs
|
||||||
from .builder import wrap_fx_proxy_cls
|
from .builder import wrap_fx_proxy_cls
|
||||||
|
|
||||||
if name in ("wait_stream", "synchronize", "wait_event"):
|
if name in ("wait_stream", "synchronize", "wait_event"):
|
||||||
|
|
@ -1152,17 +1156,8 @@ class StreamVariable(VariableTracker):
|
||||||
"call_method", name, *proxy_args_kwargs([self] + args, kwargs)
|
"call_method", name, *proxy_args_kwargs([self] + args, kwargs)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
elif name in cmp_name_to_op_mapping and len(args) == 1 and not kwargs:
|
else:
|
||||||
# NB : Checking for mutation is necessary because we compare
|
unimplemented(self.device + " stream method " + name + " unsupported")
|
||||||
# constant values
|
|
||||||
other = args[0]
|
|
||||||
if not isinstance(other, StreamVariable):
|
|
||||||
return variables.ConstantVariable.create(NotImplemented)
|
|
||||||
return variables.ConstantVariable.create(
|
|
||||||
cmp_name_to_op_mapping[name](self.value, other.value)
|
|
||||||
)
|
|
||||||
|
|
||||||
return super().call_method(tx, name, args, kwargs)
|
|
||||||
|
|
||||||
def as_proxy(self):
|
def as_proxy(self):
|
||||||
return self.proxy
|
return self.proxy
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ from ..bytecode_transformation import create_call_function, create_instruction
|
||||||
from ..exc import raise_observed_exception, unimplemented
|
from ..exc import raise_observed_exception, unimplemented
|
||||||
from ..guards import GuardBuilder, install_guard
|
from ..guards import GuardBuilder, install_guard
|
||||||
from ..source import is_from_local_source
|
from ..source import is_from_local_source
|
||||||
from ..utils import cmp_name_to_op_mapping, dict_keys, dict_values, specialize_symnode
|
from ..utils import dict_keys, dict_values, specialize_symnode
|
||||||
from .base import ValueMutationNew, VariableTracker
|
from .base import ValueMutationNew, VariableTracker
|
||||||
from .constant import ConstantVariable
|
from .constant import ConstantVariable
|
||||||
|
|
||||||
|
|
@ -751,9 +751,7 @@ class DictKeySetVariable(SetVariable):
|
||||||
return dict_keys
|
return dict_keys
|
||||||
|
|
||||||
def as_python_constant(self):
|
def as_python_constant(self):
|
||||||
return dict.fromkeys(
|
unimplemented("DictKeySetVariable.as_python_constant")
|
||||||
{k.vt.as_python_constant() for k in self.set_items}, None
|
|
||||||
).keys()
|
|
||||||
|
|
||||||
def call_method(
|
def call_method(
|
||||||
self,
|
self,
|
||||||
|
|
@ -839,12 +837,6 @@ class DictKeysVariable(DictViewVariable):
|
||||||
) -> "VariableTracker":
|
) -> "VariableTracker":
|
||||||
if name == "__contains__":
|
if name == "__contains__":
|
||||||
return self.dv_dict.call_method(tx, name, args, kwargs)
|
return self.dv_dict.call_method(tx, name, args, kwargs)
|
||||||
if name in cmp_name_to_op_mapping:
|
|
||||||
if not isinstance(args[0], (SetVariable, DictKeysVariable)):
|
|
||||||
return ConstantVariable.create(NotImplemented)
|
|
||||||
return ConstantVariable.create(
|
|
||||||
cmp_name_to_op_mapping[name](self.set_items, args[0].set_items)
|
|
||||||
)
|
|
||||||
return super().call_method(tx, name, args, kwargs)
|
return super().call_method(tx, name, args, kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,6 @@ from ..source import AttrSource, ConstantSource, DefaultsSource, GetItemSource
|
||||||
from ..utils import (
|
from ..utils import (
|
||||||
check_constant_args,
|
check_constant_args,
|
||||||
check_unspec_or_constant_args,
|
check_unspec_or_constant_args,
|
||||||
cmp_name_to_op_mapping,
|
|
||||||
counters,
|
counters,
|
||||||
identity,
|
identity,
|
||||||
is_function,
|
is_function,
|
||||||
|
|
@ -280,8 +279,6 @@ class UserFunctionVariable(BaseUserFunctionVariable):
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def var_getattr(self, tx: "InstructionTranslator", name: str):
|
def var_getattr(self, tx: "InstructionTranslator", name: str):
|
||||||
if name in cmp_name_to_op_mapping:
|
|
||||||
return variables.GetAttrVariable(self, name)
|
|
||||||
source = self.source and AttrSource(self.source, name)
|
source = self.source and AttrSource(self.source, name)
|
||||||
try:
|
try:
|
||||||
subobj = inspect.getattr_static(self.fn, name)
|
subobj = inspect.getattr_static(self.fn, name)
|
||||||
|
|
@ -898,9 +895,6 @@ class FunctoolsWrapsVariable(UserFunctionVariable):
|
||||||
|
|
||||||
|
|
||||||
class CollectionsNamedTupleFunction(UserFunctionVariable):
|
class CollectionsNamedTupleFunction(UserFunctionVariable):
|
||||||
def as_python_constant(self):
|
|
||||||
return self.fn
|
|
||||||
|
|
||||||
def call_function(
|
def call_function(
|
||||||
self,
|
self,
|
||||||
tx: "InstructionTranslator",
|
tx: "InstructionTranslator",
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,6 @@ from ..bytecode_transformation import create_call_function, create_instruction
|
||||||
from ..exc import raise_observed_exception, unimplemented
|
from ..exc import raise_observed_exception, unimplemented
|
||||||
from ..source import AttrSource
|
from ..source import AttrSource
|
||||||
from ..utils import (
|
from ..utils import (
|
||||||
cmp_name_to_op_mapping,
|
|
||||||
get_fake_value,
|
get_fake_value,
|
||||||
guard_if_dyn,
|
guard_if_dyn,
|
||||||
istype,
|
istype,
|
||||||
|
|
@ -137,18 +136,6 @@ class BaseListVariable(VariableTracker):
|
||||||
[self] + list(args),
|
[self] + list(args),
|
||||||
kwargs,
|
kwargs,
|
||||||
)
|
)
|
||||||
elif name in cmp_name_to_op_mapping:
|
|
||||||
left = self
|
|
||||||
right = args[0]
|
|
||||||
if not isinstance(left, BaseListVariable) and not isinstance(
|
|
||||||
right, BaseListVariable
|
|
||||||
):
|
|
||||||
return variables.ConstantVariable.create(NotImplemented)
|
|
||||||
return variables.UserFunctionVariable(polyfills.list_cmp).call_function(
|
|
||||||
tx,
|
|
||||||
[variables.BuiltinVariable(cmp_name_to_op_mapping[name]), left, right],
|
|
||||||
{},
|
|
||||||
)
|
|
||||||
|
|
||||||
return super().call_method(tx, name, args, kwargs)
|
return super().call_method(tx, name, args, kwargs)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -41,7 +41,6 @@ from ..utils import (
|
||||||
build_checkpoint_variable,
|
build_checkpoint_variable,
|
||||||
build_invoke_subgraph_variable,
|
build_invoke_subgraph_variable,
|
||||||
check_constant_args,
|
check_constant_args,
|
||||||
cmp_name_to_op_mapping,
|
|
||||||
dict_methods,
|
dict_methods,
|
||||||
get_custom_getattr,
|
get_custom_getattr,
|
||||||
has_torch_function,
|
has_torch_function,
|
||||||
|
|
@ -184,9 +183,6 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
obj = None
|
obj = None
|
||||||
|
|
||||||
if name in cmp_name_to_op_mapping and not isinstance(obj, types.FunctionType):
|
|
||||||
return variables.GetAttrVariable(self, name, source=source)
|
|
||||||
|
|
||||||
if isinstance(obj, staticmethod):
|
if isinstance(obj, staticmethod):
|
||||||
return VariableTracker.build(tx, obj.__get__(self.value), source)
|
return VariableTracker.build(tx, obj.__get__(self.value), source)
|
||||||
elif isinstance(obj, classmethod):
|
elif isinstance(obj, classmethod):
|
||||||
|
|
@ -791,14 +787,14 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
||||||
if is_standard_setattr(method) or isinstance(self.value, threading.local):
|
if is_standard_setattr(method) or isinstance(self.value, threading.local):
|
||||||
return self.method_setattr_standard(tx, *args, **kwargs)
|
return self.method_setattr_standard(tx, *args, **kwargs)
|
||||||
|
|
||||||
if method is object.__eq__ and len(args) == 1 and not kwargs:
|
if len(args) == 1 and not kwargs:
|
||||||
other = args[0]
|
if method is object.__eq__:
|
||||||
if not isinstance(other, UserDefinedObjectVariable):
|
func_var = VariableTracker.build(tx, polyfills.object_eq)
|
||||||
return variables.ConstantVariable.create(NotImplemented)
|
return func_var.call_function(tx, [self, *args], kwargs)
|
||||||
|
|
||||||
# TODO(anijain2305) - Identity checking should already be a part
|
if method is object.__ne__:
|
||||||
# of the cmp_eq polyfill function.
|
func_var = VariableTracker.build(tx, polyfills.object_ne)
|
||||||
return ConstantVariable.create(self.value is other.value)
|
return func_var.call_function(tx, [self, *args], kwargs)
|
||||||
|
|
||||||
# check for methods implemented in C++
|
# check for methods implemented in C++
|
||||||
if isinstance(method, types.FunctionType):
|
if isinstance(method, types.FunctionType):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user