[dynamo][vllm] Support typing.get_type_hints (#161362)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161362
Approved by: https://github.com/Skylion007, https://github.com/StrongerXi, https://github.com/jansel
This commit is contained in:
Animesh Jain 2025-08-26 22:59:04 -07:00 committed by PyTorch MergeBot
parent 9a12bab0d3
commit 3d406429b0
11 changed files with 125 additions and 1 deletions

View File

@ -931,6 +931,25 @@ class DictTests(torch._dynamo.test_case.TestCase):
self.assertEqual(["b", "c", "a"], list(opt_fn(x).keys()))
self.assertEqual(fn(x), opt_fn(x))
def test_mapping_proxy_ban_muation_on_dict_realization(self):
def fn(x):
class Foo:
b = 4
d = dict(Foo.__dict__)
y = torch.sin(x) * d["b"]
# This should cause a graph break, because otherwise the
# Foo.__dict__ will not be updated.
Foo.bar = 3
return Foo, y * Foo.__dict__["bar"]
opt_fn = torch.compile(fn, backend="eager")
x = torch.randn(4)
foo1, ref = fn(x)
foo2, res = opt_fn(x)
self.assertEqual(ref, res)
self.assertEqual(foo1.bar, foo2.bar)
def test_overridden_get_item(self):
class MyDict(dict):
def __init__(self, *args, **kwargs):

View File

@ -136,6 +136,20 @@ class ExceptionTests(torch._dynamo.test_case.TestCase):
res = opt_fn(x)
self.assertEqual(ref, res)
def test_exception_with_vars(self):
def fn(x):
try:
vars(42)
raise RuntimeError("Should not be raised")
except TypeError:
return x.sin()
x = torch.randn(4)
ref = fn(x)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
res = opt_fn(x)
self.assertEqual(ref, res)
def test_autocast_with_exception(self):
class Optimizer(torch.autograd.Function):
@staticmethod

View File

@ -4186,6 +4186,21 @@ class ReproTests(torch._dynamo.test_case.TestCase):
torch.compile(fn, backend=counter)(torch.randn([2, 2]), [])
self.assertEqual(counter.frame_count, 1)
def test_get_type_hints(self):
class Foo:
pass
def fn(x):
typing.get_type_hints(Foo, include_extras=True)
return torch.sin(x)
x = torch.randn(4)
ref = fn(x)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
res = opt_fn(x)
self.assertEqual(ref, res)
def test_graph_break_on_jit_isinstance(self):
@torch.compile(backend="eager")
def fn(x):

View File

@ -2700,5 +2700,23 @@
"Set torch._dynamo.config.debug_force_graph_break_on_leaf_return = False"
]
}
],
"GB0270": [
{
"Gb_type": "unimplemented builtin op vars() with no arguments",
"Context": "vars: {self} {args}",
"Explanation": "Dynamo does not know how to trace builtin operator {self.fn} with no arguments",
"Hints": [
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
]
}
],
"GB0271": [
{
"Gb_type": "Class attribute mutation when the __dict__ was already materialized",
"Context": "str(self.value)",
"Explanation": "Dyanmo does not support tracing mutations on a class when its __dict__ is materialized",
"Hints": []
}
]
}

View File

@ -1158,6 +1158,21 @@ class BuiltinVariable(VariableTracker):
return builtin_dispatch
def call_vars(self, tx: "InstructionTranslator", *args):
if len(args) == 0:
unimplemented_v2(
gb_type="unimplemented builtin op vars() with no arguments",
context=f"vars: {self} {args}",
explanation=f"Dynamo does not know how to trace builtin operator {self.fn} with no arguments",
hints=[*graph_break_hints.SUPPORTABLE],
)
assert len(args) == 1
# vars(obj) is obj.__dict__ if __dict__ is present else TypeError
try:
return args[0].var_getattr(tx, "__dict__")
except ObservedAttributeError:
raise_observed_exception(TypeError, tx)
def _handle_insert_op_in_graph(self, tx: "InstructionTranslator", args, kwargs):
from .builder import wrap_fx_proxy, wrap_fx_proxy_cls
@ -1881,6 +1896,17 @@ class BuiltinVariable(VariableTracker):
@staticmethod
def call_custom_dict(tx: "InstructionTranslator", user_cls, *args, **kwargs):
args = list(args)
if (
len(args) == 1
and isinstance(args[0], variables.GetAttrVariable)
and isinstance(args[0].obj, variables.UserDefinedClassVariable)
and not tx.output.side_effects.has_pending_mutation(args[0].obj)
):
# Forward the GetAttrVariable(foo, "__dict__") to a realized vt of
# VT(foo.__dict__). This simplifies the construction of the new
# dict.
args[0] = args[0].get_forwarded_dict(tx)
return tx.inline_user_function_return(
VariableTracker.build(tx, polyfills.construct_dict),
[VariableTracker.build(tx, user_cls), *args],
@ -2173,6 +2199,18 @@ class BuiltinVariable(VariableTracker):
seq = seq.unpack_var_sequence(tx) if seq.has_unpack_var_sequence(tx) else seq
return variables.FilterVariable(fn, seq, mutation_type=ValueMutationNew())
def var_getattr(self, tx: "InstructionTranslator", name):
source = self.source and AttrSource(self.source, name)
if self.fn is object:
# for object, we can just directly read the attribute
try:
value = getattr(self.fn, name)
except AttributeError:
raise_observed_exception(AttributeError, tx)
if not callable(value):
return VariableTracker.build(tx, value, source)
return variables.GetAttrVariable(self, name, source=source)
def call_getattr(
self,
tx: "InstructionTranslator",

View File

@ -125,7 +125,7 @@ its type to `common_constant_types`.
def const_getattr(self, tx: "InstructionTranslator", name):
if not hasattr(self.value, name):
raise NotImplementedError
raise_observed_exception(AttributeError, tx, args=[name])
member = getattr(self.value, name)
if callable(member):
raise NotImplementedError

View File

@ -1182,6 +1182,15 @@ class GetAttrVariable(VariableTracker):
return super().call_method(tx, name, args, kwargs)
def get_forwarded_dict(self, tx):
assert (
self.name == "__dict__"
and isinstance(self.obj, variables.UserDefinedClassVariable)
and not tx.output.side_effects.has_pending_mutation(self.obj)
)
self.obj.ban_mutation = True
return VariableTracker.build(tx, self.obj.value.__dict__, self.source)
class MethodWrapperVariable(VariableTracker):
def __init__(self, method_wrapper, **kwargs) -> None:

View File

@ -162,6 +162,10 @@ class UserDefinedClassVariable(UserDefinedVariable):
def __init__(self, value, **kwargs) -> None:
super().__init__(**kwargs)
self.value = value
# Used when we materialize class.__dict__ to a MappingProxyObject. In
# this case, we don't want to allow mutation in the class because there
# is no way to reflect it in the created MappingProxyVariable.
self.ban_mutation = False
def as_python_constant(self):
return self.value
@ -449,6 +453,13 @@ class UserDefinedClassVariable(UserDefinedVariable):
args[0],
args[1:],
)
elif name == "__setattr__" and self.ban_mutation:
unimplemented_v2(
gb_type="Class attribute mutation when the __dict__ was already materialized",
context=str(self.value),
explanation="Dyanmo does not support tracing mutations on a class when its __dict__ is materialized",
hints=graph_break_hints.SUPPORTABLE,
)
return super().call_method(tx, name, args, kwargs)
def call_function(