mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Set] Handle exception in ConstantVariable operation (#152987)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152987 Approved by: https://github.com/williamwen42, https://github.com/anijain2305 ghstack dependencies: #150792
This commit is contained in:
parent
477f13c3fb
commit
cf7021a0ee
|
|
@ -1669,6 +1669,18 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
|
|||
return a + b
|
||||
return a - b
|
||||
|
||||
@make_test
|
||||
def test_set_invalid_ConstantVariable_op(a, b):
|
||||
s = set({"banana", "apple", "orange"})
|
||||
try:
|
||||
s - 1
|
||||
except TypeError:
|
||||
return a + b
|
||||
except Exception:
|
||||
return a - b
|
||||
else:
|
||||
return a * b
|
||||
|
||||
@make_test
|
||||
def test_set_update_bytecode(x):
|
||||
# This produces bytecode SET_UPDATE since python 3.9
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ from typing import Any, Callable, Optional, TYPE_CHECKING
|
|||
|
||||
from .. import graph_break_hints, variables
|
||||
from ..current_scope_id import current_scope_id
|
||||
from ..exc import unimplemented_v2
|
||||
from ..exc import raise_observed_exception, unimplemented_v2
|
||||
from ..guards import GuardBuilder, install_guard
|
||||
from ..source import AttrSource, Source
|
||||
from ..utils import cmp_name_to_op_mapping, istype
|
||||
|
|
@ -515,11 +515,18 @@ class VariableTracker(metaclass=VariableTrackerMeta):
|
|||
hints=[],
|
||||
)
|
||||
|
||||
try:
|
||||
return variables.ConstantVariable.create(
|
||||
cmp_name_to_op_mapping[name](
|
||||
self.as_python_constant(), other.as_python_constant()
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
raise_observed_exception(
|
||||
type(e),
|
||||
tx,
|
||||
args=[list(map(variables.ConstantVariable.create, e.args))],
|
||||
)
|
||||
hints = [
|
||||
f"Avoid calling `{self.python_type_name()}.{name}` in your code.",
|
||||
"Please report an issue to PyTorch.",
|
||||
|
|
|
|||
|
|
@ -57,7 +57,7 @@ from ..utils import (
|
|||
str_methods,
|
||||
tensortype_to_dtype,
|
||||
)
|
||||
from .base import ValueMutationNew, VariableTracker
|
||||
from .base import AsPythonConstantNotImplementedError, ValueMutationNew, VariableTracker
|
||||
from .constant import ConstantVariable
|
||||
from .ctx_manager import EventVariable, StreamVariable
|
||||
from .dicts import (
|
||||
|
|
@ -901,6 +901,12 @@ class BuiltinVariable(VariableTracker):
|
|||
*[x.as_python_constant() for x in args],
|
||||
)
|
||||
except Exception as exc:
|
||||
raise_observed_exception(
|
||||
type(exc),
|
||||
tx,
|
||||
args=list(map(ConstantVariable.create, exc.args)),
|
||||
)
|
||||
except AsPythonConstantNotImplementedError as exc:
|
||||
unimplemented_v2(
|
||||
gb_type="constant fold exception",
|
||||
context=f"attempted to run function {fn} with arguments {args}",
|
||||
|
|
@ -922,14 +928,20 @@ class BuiltinVariable(VariableTracker):
|
|||
k: v.as_python_constant() for k, v in kwargs.items()
|
||||
},
|
||||
)
|
||||
except Exception as exc:
|
||||
except AsPythonConstantNotImplementedError as exc:
|
||||
unimplemented_v2(
|
||||
gb_type="constant fold exception",
|
||||
context=f"attempted to run function {fn} with arguments {args} {kwargs}",
|
||||
context=f"attempted to run function {fn} with arguments {args}",
|
||||
explanation="Encountered exception when attempting to constant fold.",
|
||||
hints=[*graph_break_hints.DYNAMO_BUG],
|
||||
from_exc=exc,
|
||||
)
|
||||
except Exception as exc:
|
||||
raise_observed_exception(
|
||||
type(exc),
|
||||
tx,
|
||||
args=list(map(ConstantVariable.create, exc.args)),
|
||||
)
|
||||
return VariableTracker.build(tx, res)
|
||||
|
||||
handlers.append(constant_fold_handler)
|
||||
|
|
|
|||
|
|
@ -190,7 +190,12 @@ its type to `common_constant_types`.
|
|||
)
|
||||
return SymNodeVariable.create(tx, proxy, add_target)
|
||||
else:
|
||||
try:
|
||||
return ConstantVariable.create(op(self.value, add_target))
|
||||
except Exception as e:
|
||||
raise_observed_exception(
|
||||
type(e), tx, args=list(map(ConstantVariable.create, e.args))
|
||||
)
|
||||
elif isinstance(self.value, bytes) and name == "decode":
|
||||
method = getattr(self.value, name)
|
||||
return ConstantVariable.create(method(*const_args, **const_kwargs))
|
||||
|
|
|
|||
|
|
@ -46,6 +46,7 @@ from torch.utils import _pytree as pytree
|
|||
from .. import graph_break_hints, variables
|
||||
from ..exc import (
|
||||
IncorrectUsage,
|
||||
ObservedException,
|
||||
UncapturedHigherOrderOpError,
|
||||
unimplemented,
|
||||
unimplemented_v2,
|
||||
|
|
@ -72,7 +73,7 @@ def raise_hard_error_if_graph_break(reason):
|
|||
def graph_break_as_hard_error(*args, **kwargs):
|
||||
try:
|
||||
return fn(*args, **kwargs)
|
||||
except Unsupported as e:
|
||||
except (Unsupported, ObservedException) as e:
|
||||
msg = " Scroll up to find out what causes the graph break."
|
||||
raise UncapturedHigherOrderOpError(reason + msg) from e
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user