[dynamo] Replace unimplemented with unimplemented_v2 in torch/_dynamo/variables/script_object.py (#159343)

Fixes part of #147913

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159343
Approved by: https://github.com/williamwen42

Co-authored-by: William Wen <william.wen42@gmail.com>
This commit is contained in:
zeshengzong 2025-08-01 21:30:37 +00:00 committed by PyTorch MergeBot
parent 8c6c2e40eb
commit 595a65f5c2
3 changed files with 62 additions and 10 deletions

View File

@ -1330,7 +1330,7 @@ class TestCompileTorchbind(TestCase):
return tq
with self.assertRaisesRegex(
RuntimeError, "call method __setattr__ on script object is not safe"
RuntimeError, "Weird method call on TorchScript object"
):
torch.compile(setattr_f, backend=backend)(_empty_tensor_queue())
@ -1343,7 +1343,7 @@ class TestCompileTorchbind(TestCase):
return tq._not_defined_attr
with self.assertRaisesRegex(
RuntimeError, "doesn't define method _not_defined_attr"
RuntimeError, "FakeScriptObject missing method implementation"
):
torch.compile(setattr_f, backend=backend)(_empty_tensor_queue())
@ -1434,7 +1434,7 @@ def forward(self, token, obj, x):
x = torch.randn(2, 3)
with self.assertRaisesRegex(
RuntimeError, "FakeScriptObject doesn't define method"
RuntimeError, "FakeScriptObject missing method implementation"
):
torch.compile(f, backend=backend)(_empty_tensor_queue(), x)

View File

@ -2652,5 +2652,36 @@
"Try to construct nn.Parameter() outside the compiled region. If this is not possible, turn `graph_break_on_nn_param_ctor` off"
]
}
],
"GB0265": [
{
"Gb_type": "FakeScriptObject missing method implementation",
"Context": "value={self.value}, method={name}",
"Explanation": "TorchScript object {self.value} doesn't define the method {name}.",
"Hints": [
"Ensure the method {name} is implemented in {self.value}.",
"Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled."
]
}
],
"GB0266": [
{
"Gb_type": "Weird method call on TorchScript object",
"Context": "value={self.value}, method={name}",
"Explanation": "This particular method call ({name}) is not supported (e.g. calling `__setattr__`). Most method calls to TorchScript objects should be supported.",
"Hints": [
"Avoid calling this method."
]
}
],
"GB0267": [
{
"Gb_type": "Attempted to access non-callable attribute of TorchScript object",
"Context": "value={self.value}, method={name}",
"Explanation": "Attribute accesses of TorchScript objects to non-callable attributes are not supported.",
"Hints": [
"Use method calls instead of attribute access."
]
}
]
}

View File

@ -25,7 +25,8 @@ import functools
import torch
from ..exc import unimplemented, UnsafeScriptObjectError, Unsupported
from .. import graph_break_hints
from ..exc import unimplemented_v2, UnsafeScriptObjectError, Unsupported
from .base import VariableTracker
from .user_defined import UserDefinedObjectVariable
@ -75,14 +76,24 @@ class TorchScriptObjectVariable(UserDefinedObjectVariable):
method = getattr(self.value, name, None)
if method is None:
unimplemented(
f"FakeScriptObject doesn't define method {name}. Did you forget to implement it in the fake class?"
unimplemented_v2(
gb_type="FakeScriptObject missing method implementation",
context=f"value={self.value}, method={name}",
explanation=f"TorchScript object {self.value} doesn't define the method {name}.",
hints=[
f"Ensure the method {name} is implemented in {self.value}.",
*graph_break_hints.USER_ERROR,
],
)
if not callable(method):
unimplemented(
"Only method calls on TorchScript objects can be supported safely."
" Please use method calls instead of attribute access."
unimplemented_v2(
gb_type="Attempted to access non-callable attribute of TorchScript object",
context=f"value={self.value}, method={name}",
explanation="Attribute accesses of TorchScript objects to non-callable attributes are not supported.",
hints=[
"Use method calls instead of attribute access.",
],
)
return TorchHigherOrderOperatorVariable.make(
@ -100,4 +111,14 @@ class TorchScriptObjectVariable(UserDefinedObjectVariable):
"Dynamo cannot safely trace script object due to graph break."
)
def call_method(self, tx, name, args, kwargs):
unimplemented(f"call method {name} on script object is not safe.")
unimplemented_v2(
gb_type="Weird method call on TorchScript object",
context=f"value={self.value}, method={name}",
explanation=(
f"This particular method call ({name}) is not supported (e.g. calling `__setattr__`). "
"Most method calls to TorchScript objects should be supported."
),
hints=[
"Avoid calling this method.",
],
)