mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Dynamo] Support typing.Mapping & Support function as argument (#88963)
These missing features come from https://github.com/pytorch/benchmark/pull/1302, where we'd like to enable E2E hf_bert dynamo train/eval. The dependent [HuggingFace accelerate library](https://huggingface.co/docs/accelerate/index) requires these improvements. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88963 Approved by: https://github.com/jansel
This commit is contained in:
parent
126e44173d
commit
b72f5b9ae3
|
|
@ -2792,6 +2792,42 @@ class MiscTests(torch._dynamo.test_case.TestCase):
|
|||
res = opt_fn(x)
|
||||
self.assertTrue(torch.allclose(ref, res))
|
||||
|
||||
def test_user_function_variable_supports_function_argument(self):
|
||||
def add1(x):
|
||||
return x + 1
|
||||
|
||||
def add2(x):
|
||||
return x + 2
|
||||
|
||||
def gn(x, f=add1):
|
||||
if f is add1:
|
||||
return x + 1
|
||||
else:
|
||||
return x + 2
|
||||
|
||||
def fn(x, f):
|
||||
return gn(x, f)
|
||||
|
||||
x = torch.randn(2, 3)
|
||||
ref = fn(x, add2)
|
||||
opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
|
||||
res = opt_fn(x, add2)
|
||||
self.assertTrue(torch.allclose(ref, res))
|
||||
|
||||
def test_typing_variable_isinstance(self):
|
||||
def fn(x, m):
|
||||
if isinstance(m, typing.Mapping):
|
||||
return x + 1
|
||||
else:
|
||||
return x - 1
|
||||
|
||||
x = torch.randn(2, 3)
|
||||
m = {"x": torch.randn(3)}
|
||||
ref = fn(x, m)
|
||||
opt_fn = torch._dynamo.optimize("eager")(fn)
|
||||
res = opt_fn(x, m)
|
||||
self.assertTrue(torch.allclose(ref, res))
|
||||
|
||||
def test_repro_graph_breaks_in__get_item_by_idx(self):
|
||||
class Mod(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ import re
|
|||
import sys
|
||||
import time
|
||||
import types
|
||||
import typing
|
||||
import weakref
|
||||
from contextlib import contextmanager
|
||||
from functools import lru_cache
|
||||
|
|
@ -275,6 +276,13 @@ def istype(obj, allowed_types):
|
|||
return type(obj) is allowed_types
|
||||
|
||||
|
||||
def is_typing(value):
|
||||
if sys.version_info < (3, 9):
|
||||
return isinstance(value, typing._GenericAlias)
|
||||
else:
|
||||
return isinstance(value, typing._SpecialGenericAlias)
|
||||
|
||||
|
||||
def is_numpy_int_type(value):
|
||||
return istype(
|
||||
value,
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import operator
|
|||
import re
|
||||
import types
|
||||
from abc import ABCMeta
|
||||
from typing import Any, List, Union
|
||||
from typing import Any, Union
|
||||
|
||||
import numpy as np
|
||||
from functorch.experimental.ops import PyOperator
|
||||
|
|
@ -43,6 +43,7 @@ from ..utils import (
|
|||
global_key_name,
|
||||
is_namedtuple,
|
||||
is_numpy_int_type,
|
||||
is_typing,
|
||||
istensor,
|
||||
istype,
|
||||
odict_values,
|
||||
|
|
@ -360,7 +361,8 @@ class VariableBuilder:
|
|||
value,
|
||||
guards=make_guards(GuardBuilder.FUNCTION_MATCH),
|
||||
)
|
||||
elif value is List:
|
||||
elif is_typing(value):
|
||||
# typing.List, typing.Mapping, etc.
|
||||
return TypingVariable(
|
||||
value,
|
||||
guards=make_guards(GuardBuilder.ID_MATCH),
|
||||
|
|
|
|||
|
|
@ -24,6 +24,8 @@ def wrap_bound_arg(val, options):
|
|||
return cls([wrap_bound_arg(x, options) for x in val], **options)
|
||||
elif variables.ConstantVariable.is_literal(val):
|
||||
return variables.ConstantVariable(val, **options)
|
||||
elif isinstance(val, types.FunctionType):
|
||||
return variables.UserFunctionVariable(val, **options)
|
||||
elif isinstance(val, enum.Enum):
|
||||
return variables.EnumVariable(val, **options)
|
||||
elif isinstance(val, (type, abc.ABCMeta)):
|
||||
|
|
|
|||
|
|
@ -654,6 +654,12 @@ class TypingVariable(VariableTracker):
|
|||
)
|
||||
unimplemented("typing")
|
||||
|
||||
def python_type(self):
|
||||
return type(self.value)
|
||||
|
||||
def as_python_constant(self):
|
||||
return self.value
|
||||
|
||||
|
||||
class NumpyVariable(VariableTracker):
|
||||
"""
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user