Revert "[dynamo][api] Better support of torch.nn.Module (#88629)"

This reverts commit c83348597b.

Reverted https://github.com/pytorch/pytorch/pull/88629 on behalf of https://github.com/anijain2305 due to job failing on master https://github.com/pytorch/pytorch/actions/runs/3449914495/jobs/5758267231
This commit is contained in:
PyTorch MergeBot 2022-11-12 07:52:53 +00:00
parent 6b775c42dd
commit ae2c668cc0
5 changed files with 20 additions and 204 deletions

View File

@ -904,133 +904,6 @@ class NNModuleTests(torch._dynamo.test_case.TestCase):
self.assertTrue(torch._dynamo.testing.same(real, graph(rx)))
class MockModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.relu = torch.nn.ReLU()
self.linear = torch.nn.Linear(10, 10)
self.register_buffer("buf0", torch.randn(10, 10))
def forward(self, x):
return self.relu(self.linear(x) + self.buf0)
class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
def test_nn_module(self):
mod = MockModule()
cnt = torch._dynamo.testing.CompileCounter()
opt_mod = torch._dynamo.optimize(cnt)(mod)
self.assertIsInstance(opt_mod, torch._dynamo.OptimizedModule)
x = torch.randn(10, 10)
self.assertTrue(torch._dynamo.testing.same(mod(x), opt_mod(x)))
self.assertEqual(cnt.frame_count, 1)
def test_to(self):
mod = MockModule()
cnt = torch._dynamo.testing.CompileCounter()
opt_mod = torch._dynamo.optimize(cnt)(mod)
x = torch.randn(10, 10)
self.assertTrue(torch._dynamo.testing.same(mod(x), opt_mod(x)))
self.assertEqual(cnt.frame_count, 1)
# Ensure that there is no recompilation
opt_mod(x)
self.assertEqual(cnt.frame_count, 1)
opt_mod = opt_mod.to(device="cpu").to(dtype=torch.float64)
self.assertIsInstance(opt_mod, torch._dynamo.OptimizedModule)
x = torch.randn(10, 10).to(dtype=torch.float64)
opt_mod(x)
# Ensure that there is a recompilation
self.assertEqual(cnt.frame_count, 2)
def test_attr(self):
class MockModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(10, 10)
self.register_buffer("buf0", torch.randn(10, 10))
def forward(self, x):
return self.r(torch.sin(x)) + self.buf0
mod = MockModule()
opt_mod = torch._dynamo.optimize("eager")(mod)
# Check parameteres and buffers
for (p1, p2) in zip(mod.parameters(), opt_mod.parameters()):
self.assertTrue(id(p1) == id(p2))
def test_recursion(self):
mod = MockModule()
cnt = torch._dynamo.testing.CompileCounter()
opt_mod = torch._dynamo.optimize(cnt)(mod)
for _ in range(5):
opt_mod = torch._dynamo.optimize(cnt)(opt_mod)
opt_mod(torch.randn(10, 10))
self.assertEqual(cnt.frame_count, 1)
def test_composition(self):
class InnerModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.relu = torch.nn.ReLU()
def forward(self, x):
return self.relu(torch.sin(x))
opt_inner_mod = InnerModule()
class OuterModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.mod = opt_inner_mod
def forward(self, x):
return self.mod(torch.cos(x))
outer_mod = OuterModule()
cnt = torch._dynamo.testing.CompileCounter()
opt_outer_mod = torch._dynamo.optimize(cnt)(outer_mod)
x = torch.randn(4)
self.assertIsInstance(opt_outer_mod, torch._dynamo.OptimizedModule)
self.assertTrue(torch._dynamo.testing.same(outer_mod(x), opt_outer_mod(x)))
self.assertEqual(cnt.frame_count, 1)
def test_composition_with_opt_mod(self):
class InnerModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.relu = torch.nn.ReLU()
def forward(self, x):
return self.relu(torch.sin(x))
inner_mod = InnerModule()
cnt = torch._dynamo.testing.CompileCounter()
opt_inner_mod = torch._dynamo.optimize(cnt)(inner_mod)
class OuterModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.mod = opt_inner_mod
def forward(self, x):
return self.mod(torch.cos(x))
outer_mod = OuterModule()
opt_outer_mod = torch._dynamo.optimize(cnt)(outer_mod)
x = torch.randn(4)
self.assertIsInstance(opt_outer_mod, torch._dynamo.OptimizedModule)
self.assertTrue(torch._dynamo.testing.same(outer_mod(x), opt_outer_mod(x)))
# There will be a graph break for the inner mod being OptimizedModule
self.assertEqual(cnt.frame_count, 2)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -7,7 +7,6 @@ from .eval_frame import (
export,
optimize,
optimize_assert,
OptimizedModule,
reset_code,
run,
skip,
@ -26,7 +25,6 @@ __all__ = [
"reset",
"list_backends",
"skip",
"OptimizedModule",
]

View File

@ -486,16 +486,8 @@ def same_two_models(gm, opt_gm, example_inputs, only_fwd=False):
"""
Check two models have same accuracy.
"""
from .eval_frame import OptimizedModule
from .testing import named_parameters_for_optimized_module
from .utils import same
if isinstance(gm, OptimizedModule):
gm.named_parameters = named_parameters_for_optimized_module(gm)
if isinstance(opt_gm, OptimizedModule):
opt_gm.named_parameters = named_parameters_for_optimized_module(opt_gm)
ref = run_fwd_maybe_bwd(gm, example_inputs, only_fwd)
try:

View File

@ -5,7 +5,6 @@ import inspect
import logging
import os
import sys
import textwrap
import threading
import traceback
import types
@ -45,27 +44,6 @@ compile_lock = threading.RLock()
most_recent_backend = None
class OptimizedModule(torch.nn.Module):
"""
Wraps the original nn.Module object and later patches its
forward method to optimized self.forward method.
"""
def __init__(self, mod):
super().__init__()
# Installs the params/buffer
self._orig_mod = mod
def __getattr__(self, name):
if name == "_orig_mod":
return self._modules["_orig_mod"]
return getattr(self._orig_mod, name)
def forward(self, *args, **kwargs):
# This will be monkey patched later
raise RuntimeError("Should not be here")
def remove_from_cache(f):
"""
Make sure f.__code__ is not cached to force a recompile
@ -140,15 +118,31 @@ class _TorchDynamoContext:
# Optimize the forward method of torch.nn.Module object
if isinstance(fn, torch.nn.Module):
mod = fn
new_mod = OptimizedModule(mod)
new_mod.forward = self(mod.forward)
optimized_forward = self(mod.forward)
class TorchDynamoNNModuleWrapper:
"""
A wrapper that redirects the forward call to the optimized
forward, while for rest it redirects the calls to the original
module.
"""
def __getattr__(self, name):
return getattr(mod, name)
def forward(self, *args, **kwargs):
return optimized_forward(*args, **kwargs)
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
new_mod = TorchDynamoNNModuleWrapper()
# Save the function pointer to find the original callable while nesting
# of decorators.
new_mod._torchdynamo_orig_callable = mod.forward
new_mod._torchdynamo_orig_callable = mod
return new_mod
assert callable(fn)
callback = self.callback
on_enter = self.on_enter
backend_ctx_ctor = self.extra_ctx_ctor
@ -190,34 +184,6 @@ class _TorchDynamoContext:
# If the function is called using torch._dynamo.optimize decorator, we
# should prevent any type of skipping.
if callback not in (None, False):
if not hasattr(fn, "__code__"):
raise RuntimeError(
textwrap.dedent(
"""
torch._dynamo.optimize is called on a non function object.
If this is a callable class, please optimize the individual methods that you are interested in optimizing.
>> class CallableClass:
>> def __init__(self):
>> super().__init__()
>> self.relu = torch.nn.ReLU()
>>
>> def __call__(self, x):
>> return self.relu(torch.sin(x))
>>
>> def print_hello(self):
>> print("Hello world")
>>
>> mod = CallableClass()
If you want to optimize the __call__ function
>> mod.__call__ = torch._dynamo.optimize(mod.__call__)
"""
)
)
always_optimize_code_objects[fn.__code__] = True
return _fn

View File

@ -32,17 +32,6 @@ def clone_me(x):
return x.detach().clone().requires_grad_(x.requires_grad)
def named_parameters_for_optimized_module(mod):
assert isinstance(mod, eval_frame.OptimizedModule)
return mod._orig_mod.named_parameters
def remove_optimized_module_prefix(name):
prefix = "_orig_mod."
assert name.startswith(prefix)
return name[len(prefix) :]
def collect_results(model, prediction, loss, example_inputs):
results = []
results.append(prediction)
@ -55,8 +44,6 @@ def collect_results(model, prediction, loss, example_inputs):
grads = dict()
params = dict()
for name, param in model.named_parameters():
if isinstance(model, eval_frame.OptimizedModule):
name = remove_optimized_module_prefix(name)
param_copy = param
grad = param.grad
# Treat None and zero grad as same