mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo][vllm] Support typing.get_type_hints (#161362)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161362 Approved by: https://github.com/Skylion007, https://github.com/StrongerXi, https://github.com/jansel
This commit is contained in:
parent
9a12bab0d3
commit
3d406429b0
|
|
@ -931,6 +931,25 @@ class DictTests(torch._dynamo.test_case.TestCase):
|
|||
self.assertEqual(["b", "c", "a"], list(opt_fn(x).keys()))
|
||||
self.assertEqual(fn(x), opt_fn(x))
|
||||
|
||||
def test_mapping_proxy_ban_muation_on_dict_realization(self):
|
||||
def fn(x):
|
||||
class Foo:
|
||||
b = 4
|
||||
|
||||
d = dict(Foo.__dict__)
|
||||
y = torch.sin(x) * d["b"]
|
||||
# This should cause a graph break, because otherwise the
|
||||
# Foo.__dict__ will not be updated.
|
||||
Foo.bar = 3
|
||||
return Foo, y * Foo.__dict__["bar"]
|
||||
|
||||
opt_fn = torch.compile(fn, backend="eager")
|
||||
x = torch.randn(4)
|
||||
foo1, ref = fn(x)
|
||||
foo2, res = opt_fn(x)
|
||||
self.assertEqual(ref, res)
|
||||
self.assertEqual(foo1.bar, foo2.bar)
|
||||
|
||||
def test_overridden_get_item(self):
|
||||
class MyDict(dict):
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
|
|
|||
|
|
@ -136,6 +136,20 @@ class ExceptionTests(torch._dynamo.test_case.TestCase):
|
|||
res = opt_fn(x)
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
def test_exception_with_vars(self):
|
||||
def fn(x):
|
||||
try:
|
||||
vars(42)
|
||||
raise RuntimeError("Should not be raised")
|
||||
except TypeError:
|
||||
return x.sin()
|
||||
|
||||
x = torch.randn(4)
|
||||
ref = fn(x)
|
||||
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
||||
res = opt_fn(x)
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
def test_autocast_with_exception(self):
|
||||
class Optimizer(torch.autograd.Function):
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -4186,6 +4186,21 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
|||
torch.compile(fn, backend=counter)(torch.randn([2, 2]), [])
|
||||
self.assertEqual(counter.frame_count, 1)
|
||||
|
||||
def test_get_type_hints(self):
|
||||
class Foo:
|
||||
pass
|
||||
|
||||
def fn(x):
|
||||
typing.get_type_hints(Foo, include_extras=True)
|
||||
return torch.sin(x)
|
||||
|
||||
x = torch.randn(4)
|
||||
ref = fn(x)
|
||||
|
||||
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
||||
res = opt_fn(x)
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
def test_graph_break_on_jit_isinstance(self):
|
||||
@torch.compile(backend="eager")
|
||||
def fn(x):
|
||||
|
|
|
|||
|
|
@ -2700,5 +2700,23 @@
|
|||
"Set torch._dynamo.config.debug_force_graph_break_on_leaf_return = False"
|
||||
]
|
||||
}
|
||||
],
|
||||
"GB0270": [
|
||||
{
|
||||
"Gb_type": "unimplemented builtin op vars() with no arguments",
|
||||
"Context": "vars: {self} {args}",
|
||||
"Explanation": "Dynamo does not know how to trace builtin operator {self.fn} with no arguments",
|
||||
"Hints": [
|
||||
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
|
||||
]
|
||||
}
|
||||
],
|
||||
"GB0271": [
|
||||
{
|
||||
"Gb_type": "Class attribute mutation when the __dict__ was already materialized",
|
||||
"Context": "str(self.value)",
|
||||
"Explanation": "Dyanmo does not support tracing mutations on a class when its __dict__ is materialized",
|
||||
"Hints": []
|
||||
}
|
||||
]
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1158,6 +1158,21 @@ class BuiltinVariable(VariableTracker):
|
|||
|
||||
return builtin_dispatch
|
||||
|
||||
def call_vars(self, tx: "InstructionTranslator", *args):
|
||||
if len(args) == 0:
|
||||
unimplemented_v2(
|
||||
gb_type="unimplemented builtin op vars() with no arguments",
|
||||
context=f"vars: {self} {args}",
|
||||
explanation=f"Dynamo does not know how to trace builtin operator {self.fn} with no arguments",
|
||||
hints=[*graph_break_hints.SUPPORTABLE],
|
||||
)
|
||||
assert len(args) == 1
|
||||
# vars(obj) is obj.__dict__ if __dict__ is present else TypeError
|
||||
try:
|
||||
return args[0].var_getattr(tx, "__dict__")
|
||||
except ObservedAttributeError:
|
||||
raise_observed_exception(TypeError, tx)
|
||||
|
||||
def _handle_insert_op_in_graph(self, tx: "InstructionTranslator", args, kwargs):
|
||||
from .builder import wrap_fx_proxy, wrap_fx_proxy_cls
|
||||
|
||||
|
|
@ -1881,6 +1896,17 @@ class BuiltinVariable(VariableTracker):
|
|||
|
||||
@staticmethod
|
||||
def call_custom_dict(tx: "InstructionTranslator", user_cls, *args, **kwargs):
|
||||
args = list(args)
|
||||
if (
|
||||
len(args) == 1
|
||||
and isinstance(args[0], variables.GetAttrVariable)
|
||||
and isinstance(args[0].obj, variables.UserDefinedClassVariable)
|
||||
and not tx.output.side_effects.has_pending_mutation(args[0].obj)
|
||||
):
|
||||
# Forward the GetAttrVariable(foo, "__dict__") to a realized vt of
|
||||
# VT(foo.__dict__). This simplifies the construction of the new
|
||||
# dict.
|
||||
args[0] = args[0].get_forwarded_dict(tx)
|
||||
return tx.inline_user_function_return(
|
||||
VariableTracker.build(tx, polyfills.construct_dict),
|
||||
[VariableTracker.build(tx, user_cls), *args],
|
||||
|
|
@ -2173,6 +2199,18 @@ class BuiltinVariable(VariableTracker):
|
|||
seq = seq.unpack_var_sequence(tx) if seq.has_unpack_var_sequence(tx) else seq
|
||||
return variables.FilterVariable(fn, seq, mutation_type=ValueMutationNew())
|
||||
|
||||
def var_getattr(self, tx: "InstructionTranslator", name):
|
||||
source = self.source and AttrSource(self.source, name)
|
||||
if self.fn is object:
|
||||
# for object, we can just directly read the attribute
|
||||
try:
|
||||
value = getattr(self.fn, name)
|
||||
except AttributeError:
|
||||
raise_observed_exception(AttributeError, tx)
|
||||
if not callable(value):
|
||||
return VariableTracker.build(tx, value, source)
|
||||
return variables.GetAttrVariable(self, name, source=source)
|
||||
|
||||
def call_getattr(
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
|
|
|
|||
|
|
@ -125,7 +125,7 @@ its type to `common_constant_types`.
|
|||
|
||||
def const_getattr(self, tx: "InstructionTranslator", name):
|
||||
if not hasattr(self.value, name):
|
||||
raise NotImplementedError
|
||||
raise_observed_exception(AttributeError, tx, args=[name])
|
||||
member = getattr(self.value, name)
|
||||
if callable(member):
|
||||
raise NotImplementedError
|
||||
|
|
|
|||
|
|
@ -1182,6 +1182,15 @@ class GetAttrVariable(VariableTracker):
|
|||
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
def get_forwarded_dict(self, tx):
|
||||
assert (
|
||||
self.name == "__dict__"
|
||||
and isinstance(self.obj, variables.UserDefinedClassVariable)
|
||||
and not tx.output.side_effects.has_pending_mutation(self.obj)
|
||||
)
|
||||
self.obj.ban_mutation = True
|
||||
return VariableTracker.build(tx, self.obj.value.__dict__, self.source)
|
||||
|
||||
|
||||
class MethodWrapperVariable(VariableTracker):
|
||||
def __init__(self, method_wrapper, **kwargs) -> None:
|
||||
|
|
|
|||
|
|
@ -162,6 +162,10 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
|||
def __init__(self, value, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.value = value
|
||||
# Used when we materialize class.__dict__ to a MappingProxyObject. In
|
||||
# this case, we don't want to allow mutation in the class because there
|
||||
# is no way to reflect it in the created MappingProxyVariable.
|
||||
self.ban_mutation = False
|
||||
|
||||
def as_python_constant(self):
|
||||
return self.value
|
||||
|
|
@ -449,6 +453,13 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
|||
args[0],
|
||||
args[1:],
|
||||
)
|
||||
elif name == "__setattr__" and self.ban_mutation:
|
||||
unimplemented_v2(
|
||||
gb_type="Class attribute mutation when the __dict__ was already materialized",
|
||||
context=str(self.value),
|
||||
explanation="Dyanmo does not support tracing mutations on a class when its __dict__ is materialized",
|
||||
hints=graph_break_hints.SUPPORTABLE,
|
||||
)
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
def call_function(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user