[Set] Handle exception in ConstantVariable operation (#152987)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152987
Approved by: https://github.com/williamwen42, https://github.com/anijain2305
ghstack dependencies: #150792
This commit is contained in:
Guilherme Leobas 2025-05-15 22:47:14 -03:00 committed by PyTorch MergeBot
parent 477f13c3fb
commit cf7021a0ee
31 changed files with 47 additions and 10 deletions

View File

@ -1669,6 +1669,18 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
return a + b return a + b
return a - b return a - b
@make_test
def test_set_invalid_ConstantVariable_op(a, b):
s = set({"banana", "apple", "orange"})
try:
s - 1
except TypeError:
return a + b
except Exception:
return a - b
else:
return a * b
@make_test @make_test
def test_set_update_bytecode(x): def test_set_update_bytecode(x):
# This produces bytecode SET_UPDATE since python 3.9 # This produces bytecode SET_UPDATE since python 3.9

View File

@ -22,7 +22,7 @@ from typing import Any, Callable, Optional, TYPE_CHECKING
from .. import graph_break_hints, variables from .. import graph_break_hints, variables
from ..current_scope_id import current_scope_id from ..current_scope_id import current_scope_id
from ..exc import unimplemented_v2 from ..exc import raise_observed_exception, unimplemented_v2
from ..guards import GuardBuilder, install_guard from ..guards import GuardBuilder, install_guard
from ..source import AttrSource, Source from ..source import AttrSource, Source
from ..utils import cmp_name_to_op_mapping, istype from ..utils import cmp_name_to_op_mapping, istype
@ -515,11 +515,18 @@ class VariableTracker(metaclass=VariableTrackerMeta):
hints=[], hints=[],
) )
try:
return variables.ConstantVariable.create( return variables.ConstantVariable.create(
cmp_name_to_op_mapping[name]( cmp_name_to_op_mapping[name](
self.as_python_constant(), other.as_python_constant() self.as_python_constant(), other.as_python_constant()
) )
) )
except Exception as e:
raise_observed_exception(
type(e),
tx,
args=[list(map(variables.ConstantVariable.create, e.args))],
)
hints = [ hints = [
f"Avoid calling `{self.python_type_name()}.{name}` in your code.", f"Avoid calling `{self.python_type_name()}.{name}` in your code.",
"Please report an issue to PyTorch.", "Please report an issue to PyTorch.",

View File

@ -57,7 +57,7 @@ from ..utils import (
str_methods, str_methods,
tensortype_to_dtype, tensortype_to_dtype,
) )
from .base import ValueMutationNew, VariableTracker from .base import AsPythonConstantNotImplementedError, ValueMutationNew, VariableTracker
from .constant import ConstantVariable from .constant import ConstantVariable
from .ctx_manager import EventVariable, StreamVariable from .ctx_manager import EventVariable, StreamVariable
from .dicts import ( from .dicts import (
@ -901,6 +901,12 @@ class BuiltinVariable(VariableTracker):
*[x.as_python_constant() for x in args], *[x.as_python_constant() for x in args],
) )
except Exception as exc: except Exception as exc:
raise_observed_exception(
type(exc),
tx,
args=list(map(ConstantVariable.create, exc.args)),
)
except AsPythonConstantNotImplementedError as exc:
unimplemented_v2( unimplemented_v2(
gb_type="constant fold exception", gb_type="constant fold exception",
context=f"attempted to run function {fn} with arguments {args}", context=f"attempted to run function {fn} with arguments {args}",
@ -922,14 +928,20 @@ class BuiltinVariable(VariableTracker):
k: v.as_python_constant() for k, v in kwargs.items() k: v.as_python_constant() for k, v in kwargs.items()
}, },
) )
except Exception as exc: except AsPythonConstantNotImplementedError as exc:
unimplemented_v2( unimplemented_v2(
gb_type="constant fold exception", gb_type="constant fold exception",
context=f"attempted to run function {fn} with arguments {args} {kwargs}", context=f"attempted to run function {fn} with arguments {args}",
explanation="Encountered exception when attempting to constant fold.", explanation="Encountered exception when attempting to constant fold.",
hints=[*graph_break_hints.DYNAMO_BUG], hints=[*graph_break_hints.DYNAMO_BUG],
from_exc=exc, from_exc=exc,
) )
except Exception as exc:
raise_observed_exception(
type(exc),
tx,
args=list(map(ConstantVariable.create, exc.args)),
)
return VariableTracker.build(tx, res) return VariableTracker.build(tx, res)
handlers.append(constant_fold_handler) handlers.append(constant_fold_handler)

View File

@ -190,7 +190,12 @@ its type to `common_constant_types`.
) )
return SymNodeVariable.create(tx, proxy, add_target) return SymNodeVariable.create(tx, proxy, add_target)
else: else:
try:
return ConstantVariable.create(op(self.value, add_target)) return ConstantVariable.create(op(self.value, add_target))
except Exception as e:
raise_observed_exception(
type(e), tx, args=list(map(ConstantVariable.create, e.args))
)
elif isinstance(self.value, bytes) and name == "decode": elif isinstance(self.value, bytes) and name == "decode":
method = getattr(self.value, name) method = getattr(self.value, name)
return ConstantVariable.create(method(*const_args, **const_kwargs)) return ConstantVariable.create(method(*const_args, **const_kwargs))

View File

@ -46,6 +46,7 @@ from torch.utils import _pytree as pytree
from .. import graph_break_hints, variables from .. import graph_break_hints, variables
from ..exc import ( from ..exc import (
IncorrectUsage, IncorrectUsage,
ObservedException,
UncapturedHigherOrderOpError, UncapturedHigherOrderOpError,
unimplemented, unimplemented,
unimplemented_v2, unimplemented_v2,
@ -72,7 +73,7 @@ def raise_hard_error_if_graph_break(reason):
def graph_break_as_hard_error(*args, **kwargs): def graph_break_as_hard_error(*args, **kwargs):
try: try:
return fn(*args, **kwargs) return fn(*args, **kwargs)
except Unsupported as e: except (Unsupported, ObservedException) as e:
msg = " Scroll up to find out what causes the graph break." msg = " Scroll up to find out what causes the graph break."
raise UncapturedHigherOrderOpError(reason + msg) from e raise UncapturedHigherOrderOpError(reason + msg) from e