mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Set] Handle exception in ConstantVariable operation (#152987)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152987 Approved by: https://github.com/williamwen42, https://github.com/anijain2305 ghstack dependencies: #150792
This commit is contained in:
parent
477f13c3fb
commit
cf7021a0ee
|
|
@ -1669,6 +1669,18 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
|
||||||
return a + b
|
return a + b
|
||||||
return a - b
|
return a - b
|
||||||
|
|
||||||
|
@make_test
|
||||||
|
def test_set_invalid_ConstantVariable_op(a, b):
|
||||||
|
s = set({"banana", "apple", "orange"})
|
||||||
|
try:
|
||||||
|
s - 1
|
||||||
|
except TypeError:
|
||||||
|
return a + b
|
||||||
|
except Exception:
|
||||||
|
return a - b
|
||||||
|
else:
|
||||||
|
return a * b
|
||||||
|
|
||||||
@make_test
|
@make_test
|
||||||
def test_set_update_bytecode(x):
|
def test_set_update_bytecode(x):
|
||||||
# This produces bytecode SET_UPDATE since python 3.9
|
# This produces bytecode SET_UPDATE since python 3.9
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,7 @@ from typing import Any, Callable, Optional, TYPE_CHECKING
|
||||||
|
|
||||||
from .. import graph_break_hints, variables
|
from .. import graph_break_hints, variables
|
||||||
from ..current_scope_id import current_scope_id
|
from ..current_scope_id import current_scope_id
|
||||||
from ..exc import unimplemented_v2
|
from ..exc import raise_observed_exception, unimplemented_v2
|
||||||
from ..guards import GuardBuilder, install_guard
|
from ..guards import GuardBuilder, install_guard
|
||||||
from ..source import AttrSource, Source
|
from ..source import AttrSource, Source
|
||||||
from ..utils import cmp_name_to_op_mapping, istype
|
from ..utils import cmp_name_to_op_mapping, istype
|
||||||
|
|
@ -515,11 +515,18 @@ class VariableTracker(metaclass=VariableTrackerMeta):
|
||||||
hints=[],
|
hints=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
return variables.ConstantVariable.create(
|
return variables.ConstantVariable.create(
|
||||||
cmp_name_to_op_mapping[name](
|
cmp_name_to_op_mapping[name](
|
||||||
self.as_python_constant(), other.as_python_constant()
|
self.as_python_constant(), other.as_python_constant()
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise_observed_exception(
|
||||||
|
type(e),
|
||||||
|
tx,
|
||||||
|
args=[list(map(variables.ConstantVariable.create, e.args))],
|
||||||
|
)
|
||||||
hints = [
|
hints = [
|
||||||
f"Avoid calling `{self.python_type_name()}.{name}` in your code.",
|
f"Avoid calling `{self.python_type_name()}.{name}` in your code.",
|
||||||
"Please report an issue to PyTorch.",
|
"Please report an issue to PyTorch.",
|
||||||
|
|
|
||||||
|
|
@ -57,7 +57,7 @@ from ..utils import (
|
||||||
str_methods,
|
str_methods,
|
||||||
tensortype_to_dtype,
|
tensortype_to_dtype,
|
||||||
)
|
)
|
||||||
from .base import ValueMutationNew, VariableTracker
|
from .base import AsPythonConstantNotImplementedError, ValueMutationNew, VariableTracker
|
||||||
from .constant import ConstantVariable
|
from .constant import ConstantVariable
|
||||||
from .ctx_manager import EventVariable, StreamVariable
|
from .ctx_manager import EventVariable, StreamVariable
|
||||||
from .dicts import (
|
from .dicts import (
|
||||||
|
|
@ -901,6 +901,12 @@ class BuiltinVariable(VariableTracker):
|
||||||
*[x.as_python_constant() for x in args],
|
*[x.as_python_constant() for x in args],
|
||||||
)
|
)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
|
raise_observed_exception(
|
||||||
|
type(exc),
|
||||||
|
tx,
|
||||||
|
args=list(map(ConstantVariable.create, exc.args)),
|
||||||
|
)
|
||||||
|
except AsPythonConstantNotImplementedError as exc:
|
||||||
unimplemented_v2(
|
unimplemented_v2(
|
||||||
gb_type="constant fold exception",
|
gb_type="constant fold exception",
|
||||||
context=f"attempted to run function {fn} with arguments {args}",
|
context=f"attempted to run function {fn} with arguments {args}",
|
||||||
|
|
@ -922,14 +928,20 @@ class BuiltinVariable(VariableTracker):
|
||||||
k: v.as_python_constant() for k, v in kwargs.items()
|
k: v.as_python_constant() for k, v in kwargs.items()
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
except Exception as exc:
|
except AsPythonConstantNotImplementedError as exc:
|
||||||
unimplemented_v2(
|
unimplemented_v2(
|
||||||
gb_type="constant fold exception",
|
gb_type="constant fold exception",
|
||||||
context=f"attempted to run function {fn} with arguments {args} {kwargs}",
|
context=f"attempted to run function {fn} with arguments {args}",
|
||||||
explanation="Encountered exception when attempting to constant fold.",
|
explanation="Encountered exception when attempting to constant fold.",
|
||||||
hints=[*graph_break_hints.DYNAMO_BUG],
|
hints=[*graph_break_hints.DYNAMO_BUG],
|
||||||
from_exc=exc,
|
from_exc=exc,
|
||||||
)
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
raise_observed_exception(
|
||||||
|
type(exc),
|
||||||
|
tx,
|
||||||
|
args=list(map(ConstantVariable.create, exc.args)),
|
||||||
|
)
|
||||||
return VariableTracker.build(tx, res)
|
return VariableTracker.build(tx, res)
|
||||||
|
|
||||||
handlers.append(constant_fold_handler)
|
handlers.append(constant_fold_handler)
|
||||||
|
|
|
||||||
|
|
@ -190,7 +190,12 @@ its type to `common_constant_types`.
|
||||||
)
|
)
|
||||||
return SymNodeVariable.create(tx, proxy, add_target)
|
return SymNodeVariable.create(tx, proxy, add_target)
|
||||||
else:
|
else:
|
||||||
|
try:
|
||||||
return ConstantVariable.create(op(self.value, add_target))
|
return ConstantVariable.create(op(self.value, add_target))
|
||||||
|
except Exception as e:
|
||||||
|
raise_observed_exception(
|
||||||
|
type(e), tx, args=list(map(ConstantVariable.create, e.args))
|
||||||
|
)
|
||||||
elif isinstance(self.value, bytes) and name == "decode":
|
elif isinstance(self.value, bytes) and name == "decode":
|
||||||
method = getattr(self.value, name)
|
method = getattr(self.value, name)
|
||||||
return ConstantVariable.create(method(*const_args, **const_kwargs))
|
return ConstantVariable.create(method(*const_args, **const_kwargs))
|
||||||
|
|
|
||||||
|
|
@ -46,6 +46,7 @@ from torch.utils import _pytree as pytree
|
||||||
from .. import graph_break_hints, variables
|
from .. import graph_break_hints, variables
|
||||||
from ..exc import (
|
from ..exc import (
|
||||||
IncorrectUsage,
|
IncorrectUsage,
|
||||||
|
ObservedException,
|
||||||
UncapturedHigherOrderOpError,
|
UncapturedHigherOrderOpError,
|
||||||
unimplemented,
|
unimplemented,
|
||||||
unimplemented_v2,
|
unimplemented_v2,
|
||||||
|
|
@ -72,7 +73,7 @@ def raise_hard_error_if_graph_break(reason):
|
||||||
def graph_break_as_hard_error(*args, **kwargs):
|
def graph_break_as_hard_error(*args, **kwargs):
|
||||||
try:
|
try:
|
||||||
return fn(*args, **kwargs)
|
return fn(*args, **kwargs)
|
||||||
except Unsupported as e:
|
except (Unsupported, ObservedException) as e:
|
||||||
msg = " Scroll up to find out what causes the graph break."
|
msg = " Scroll up to find out what causes the graph break."
|
||||||
raise UncapturedHigherOrderOpError(reason + msg) from e
|
raise UncapturedHigherOrderOpError(reason + msg) from e
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user