mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "[dynamo] Bug fix for _torchdynamo_inline source handling (#135612)"
This reverts commit5c3d0a2ded. Reverted https://github.com/pytorch/pytorch/pull/135612 on behalf of https://github.com/clee2000 due to broke inductor/test_cpu_select_algorithm.py::TestSelectAlgorithmCPU::test_linear_input_transpose_bias_True_cpu_float32 [GH job link](https://github.com/pytorch/pytorch/actions/runs/10805518363/job/29982386304) [HUD commit link](5c3d0a2ded), bad TD ([comment](https://github.com/pytorch/pytorch/pull/135612#issuecomment-2344039370))
This commit is contained in:
parent
f96e8041b1
commit
596e93b506
|
|
@ -5912,22 +5912,6 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
|
|||
actual[1].untyped_storage().data_ptr(),
|
||||
)
|
||||
|
||||
def test_torch_compile_in_compile_frame(self):
|
||||
# TODO(anijain2305/yanboliang) - Dont graph break on torch.compile.
|
||||
def gn(x, c=None):
|
||||
if c is None:
|
||||
c = 2
|
||||
return c * x
|
||||
|
||||
def outer_func(x):
|
||||
return torch.compile(gn)(x)
|
||||
|
||||
compile_outer = torch.compile(outer_func, backend="eager")
|
||||
x = torch.randn(4)
|
||||
ref = outer_func(x)
|
||||
res = compile_outer(x)
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(ReproTests)
|
||||
|
||||
|
|
|
|||
|
|
@ -547,6 +547,11 @@ class VariableBuilder:
|
|||
if id_dispatch is not None:
|
||||
return id_dispatch(self, value)
|
||||
|
||||
# Note - There are some nested values where types mismatch!
|
||||
# We want to get those out and wrap those.
|
||||
if is_function_or_wrapper(value):
|
||||
value = inspect.getattr_static(value, "_torchdynamo_inline", value)
|
||||
|
||||
# Everything else (NB: order matters!)
|
||||
if is_traceable_wrapper_subclass(value) or istype(
|
||||
value, config.traceable_tensor_subclasses
|
||||
|
|
@ -983,13 +988,6 @@ class VariableBuilder:
|
|||
elif is_lru_cache_wrapped_function(value):
|
||||
self.install_guards(GuardBuilder.TYPE_MATCH)
|
||||
return WrapperUserFunctionVariable(value, "__wrapped__", source=self.source)
|
||||
elif is_function_or_wrapper(value) and inspect.getattr_static(
|
||||
value, "_torchdynamo_inline", False
|
||||
):
|
||||
self.install_guards(GuardBuilder.TYPE_MATCH)
|
||||
return WrapperUserFunctionVariable(
|
||||
value, "_torchdynamo_inline", source=self.source
|
||||
)
|
||||
elif is_function_or_wrapper(value):
|
||||
value, attr_name = unwrap_with_attr_name_if_wrapper(value)
|
||||
# For these wrappers, Dynamo points to the wrapped function,
|
||||
|
|
|
|||
|
|
@ -152,6 +152,8 @@ class UserFunctionVariable(BaseUserFunctionVariable):
|
|||
assert isinstance(
|
||||
fn, (types.FunctionType, torch.jit.ScriptFunction)
|
||||
), f"expected FunctionType found {typestr(fn)} {fn}"
|
||||
# unpack @torch._dynamo.optimize()(fn) wrapped function
|
||||
fn = inspect.getattr_static(fn, "_torchdynamo_inline", fn)
|
||||
self.fn: types.FunctionType = fn
|
||||
|
||||
def as_python_constant(self):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user