mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "[dynamo][api] Better support of torch.nn.Module (#88629)"
This reverts commit c83348597b.
Reverted https://github.com/pytorch/pytorch/pull/88629 on behalf of https://github.com/anijain2305 due to job failing on master https://github.com/pytorch/pytorch/actions/runs/3449914495/jobs/5758267231
This commit is contained in:
parent
6b775c42dd
commit
ae2c668cc0
|
|
@ -904,133 +904,6 @@ class NNModuleTests(torch._dynamo.test_case.TestCase):
|
|||
self.assertTrue(torch._dynamo.testing.same(real, graph(rx)))
|
||||
|
||||
|
||||
class MockModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.relu = torch.nn.ReLU()
|
||||
self.linear = torch.nn.Linear(10, 10)
|
||||
self.register_buffer("buf0", torch.randn(10, 10))
|
||||
|
||||
def forward(self, x):
|
||||
return self.relu(self.linear(x) + self.buf0)
|
||||
|
||||
|
||||
class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
|
||||
def test_nn_module(self):
|
||||
mod = MockModule()
|
||||
cnt = torch._dynamo.testing.CompileCounter()
|
||||
opt_mod = torch._dynamo.optimize(cnt)(mod)
|
||||
self.assertIsInstance(opt_mod, torch._dynamo.OptimizedModule)
|
||||
|
||||
x = torch.randn(10, 10)
|
||||
self.assertTrue(torch._dynamo.testing.same(mod(x), opt_mod(x)))
|
||||
self.assertEqual(cnt.frame_count, 1)
|
||||
|
||||
def test_to(self):
|
||||
mod = MockModule()
|
||||
cnt = torch._dynamo.testing.CompileCounter()
|
||||
opt_mod = torch._dynamo.optimize(cnt)(mod)
|
||||
x = torch.randn(10, 10)
|
||||
self.assertTrue(torch._dynamo.testing.same(mod(x), opt_mod(x)))
|
||||
self.assertEqual(cnt.frame_count, 1)
|
||||
|
||||
# Ensure that there is no recompilation
|
||||
opt_mod(x)
|
||||
self.assertEqual(cnt.frame_count, 1)
|
||||
|
||||
opt_mod = opt_mod.to(device="cpu").to(dtype=torch.float64)
|
||||
self.assertIsInstance(opt_mod, torch._dynamo.OptimizedModule)
|
||||
x = torch.randn(10, 10).to(dtype=torch.float64)
|
||||
opt_mod(x)
|
||||
# Ensure that there is a recompilation
|
||||
self.assertEqual(cnt.frame_count, 2)
|
||||
|
||||
def test_attr(self):
|
||||
class MockModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(10, 10)
|
||||
self.register_buffer("buf0", torch.randn(10, 10))
|
||||
|
||||
def forward(self, x):
|
||||
return self.r(torch.sin(x)) + self.buf0
|
||||
|
||||
mod = MockModule()
|
||||
opt_mod = torch._dynamo.optimize("eager")(mod)
|
||||
|
||||
# Check parameteres and buffers
|
||||
for (p1, p2) in zip(mod.parameters(), opt_mod.parameters()):
|
||||
self.assertTrue(id(p1) == id(p2))
|
||||
|
||||
def test_recursion(self):
|
||||
mod = MockModule()
|
||||
cnt = torch._dynamo.testing.CompileCounter()
|
||||
opt_mod = torch._dynamo.optimize(cnt)(mod)
|
||||
|
||||
for _ in range(5):
|
||||
opt_mod = torch._dynamo.optimize(cnt)(opt_mod)
|
||||
opt_mod(torch.randn(10, 10))
|
||||
self.assertEqual(cnt.frame_count, 1)
|
||||
|
||||
def test_composition(self):
|
||||
class InnerModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.relu = torch.nn.ReLU()
|
||||
|
||||
def forward(self, x):
|
||||
return self.relu(torch.sin(x))
|
||||
|
||||
opt_inner_mod = InnerModule()
|
||||
|
||||
class OuterModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.mod = opt_inner_mod
|
||||
|
||||
def forward(self, x):
|
||||
return self.mod(torch.cos(x))
|
||||
|
||||
outer_mod = OuterModule()
|
||||
cnt = torch._dynamo.testing.CompileCounter()
|
||||
opt_outer_mod = torch._dynamo.optimize(cnt)(outer_mod)
|
||||
|
||||
x = torch.randn(4)
|
||||
self.assertIsInstance(opt_outer_mod, torch._dynamo.OptimizedModule)
|
||||
self.assertTrue(torch._dynamo.testing.same(outer_mod(x), opt_outer_mod(x)))
|
||||
self.assertEqual(cnt.frame_count, 1)
|
||||
|
||||
def test_composition_with_opt_mod(self):
|
||||
class InnerModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.relu = torch.nn.ReLU()
|
||||
|
||||
def forward(self, x):
|
||||
return self.relu(torch.sin(x))
|
||||
|
||||
inner_mod = InnerModule()
|
||||
cnt = torch._dynamo.testing.CompileCounter()
|
||||
opt_inner_mod = torch._dynamo.optimize(cnt)(inner_mod)
|
||||
|
||||
class OuterModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.mod = opt_inner_mod
|
||||
|
||||
def forward(self, x):
|
||||
return self.mod(torch.cos(x))
|
||||
|
||||
outer_mod = OuterModule()
|
||||
opt_outer_mod = torch._dynamo.optimize(cnt)(outer_mod)
|
||||
|
||||
x = torch.randn(4)
|
||||
self.assertIsInstance(opt_outer_mod, torch._dynamo.OptimizedModule)
|
||||
self.assertTrue(torch._dynamo.testing.same(outer_mod(x), opt_outer_mod(x)))
|
||||
# There will be a graph break for the inner mod being OptimizedModule
|
||||
self.assertEqual(cnt.frame_count, 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@ from .eval_frame import (
|
|||
export,
|
||||
optimize,
|
||||
optimize_assert,
|
||||
OptimizedModule,
|
||||
reset_code,
|
||||
run,
|
||||
skip,
|
||||
|
|
@ -26,7 +25,6 @@ __all__ = [
|
|||
"reset",
|
||||
"list_backends",
|
||||
"skip",
|
||||
"OptimizedModule",
|
||||
]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -486,16 +486,8 @@ def same_two_models(gm, opt_gm, example_inputs, only_fwd=False):
|
|||
"""
|
||||
Check two models have same accuracy.
|
||||
"""
|
||||
from .eval_frame import OptimizedModule
|
||||
from .testing import named_parameters_for_optimized_module
|
||||
from .utils import same
|
||||
|
||||
if isinstance(gm, OptimizedModule):
|
||||
gm.named_parameters = named_parameters_for_optimized_module(gm)
|
||||
|
||||
if isinstance(opt_gm, OptimizedModule):
|
||||
opt_gm.named_parameters = named_parameters_for_optimized_module(opt_gm)
|
||||
|
||||
ref = run_fwd_maybe_bwd(gm, example_inputs, only_fwd)
|
||||
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ import inspect
|
|||
import logging
|
||||
import os
|
||||
import sys
|
||||
import textwrap
|
||||
import threading
|
||||
import traceback
|
||||
import types
|
||||
|
|
@ -45,27 +44,6 @@ compile_lock = threading.RLock()
|
|||
most_recent_backend = None
|
||||
|
||||
|
||||
class OptimizedModule(torch.nn.Module):
|
||||
"""
|
||||
Wraps the original nn.Module object and later patches its
|
||||
forward method to optimized self.forward method.
|
||||
"""
|
||||
|
||||
def __init__(self, mod):
|
||||
super().__init__()
|
||||
# Installs the params/buffer
|
||||
self._orig_mod = mod
|
||||
|
||||
def __getattr__(self, name):
|
||||
if name == "_orig_mod":
|
||||
return self._modules["_orig_mod"]
|
||||
return getattr(self._orig_mod, name)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
# This will be monkey patched later
|
||||
raise RuntimeError("Should not be here")
|
||||
|
||||
|
||||
def remove_from_cache(f):
|
||||
"""
|
||||
Make sure f.__code__ is not cached to force a recompile
|
||||
|
|
@ -140,15 +118,31 @@ class _TorchDynamoContext:
|
|||
# Optimize the forward method of torch.nn.Module object
|
||||
if isinstance(fn, torch.nn.Module):
|
||||
mod = fn
|
||||
new_mod = OptimizedModule(mod)
|
||||
new_mod.forward = self(mod.forward)
|
||||
optimized_forward = self(mod.forward)
|
||||
|
||||
class TorchDynamoNNModuleWrapper:
|
||||
"""
|
||||
A wrapper that redirects the forward call to the optimized
|
||||
forward, while for rest it redirects the calls to the original
|
||||
module.
|
||||
"""
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(mod, name)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return optimized_forward(*args, **kwargs)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.forward(*args, **kwargs)
|
||||
|
||||
new_mod = TorchDynamoNNModuleWrapper()
|
||||
# Save the function pointer to find the original callable while nesting
|
||||
# of decorators.
|
||||
new_mod._torchdynamo_orig_callable = mod.forward
|
||||
new_mod._torchdynamo_orig_callable = mod
|
||||
return new_mod
|
||||
|
||||
assert callable(fn)
|
||||
|
||||
callback = self.callback
|
||||
on_enter = self.on_enter
|
||||
backend_ctx_ctor = self.extra_ctx_ctor
|
||||
|
|
@ -190,34 +184,6 @@ class _TorchDynamoContext:
|
|||
# If the function is called using torch._dynamo.optimize decorator, we
|
||||
# should prevent any type of skipping.
|
||||
if callback not in (None, False):
|
||||
if not hasattr(fn, "__code__"):
|
||||
raise RuntimeError(
|
||||
textwrap.dedent(
|
||||
"""
|
||||
|
||||
torch._dynamo.optimize is called on a non function object.
|
||||
If this is a callable class, please optimize the individual methods that you are interested in optimizing.
|
||||
|
||||
>> class CallableClass:
|
||||
>> def __init__(self):
|
||||
>> super().__init__()
|
||||
>> self.relu = torch.nn.ReLU()
|
||||
>>
|
||||
>> def __call__(self, x):
|
||||
>> return self.relu(torch.sin(x))
|
||||
>>
|
||||
>> def print_hello(self):
|
||||
>> print("Hello world")
|
||||
>>
|
||||
>> mod = CallableClass()
|
||||
|
||||
If you want to optimize the __call__ function
|
||||
|
||||
>> mod.__call__ = torch._dynamo.optimize(mod.__call__)
|
||||
|
||||
"""
|
||||
)
|
||||
)
|
||||
always_optimize_code_objects[fn.__code__] = True
|
||||
|
||||
return _fn
|
||||
|
|
|
|||
|
|
@ -32,17 +32,6 @@ def clone_me(x):
|
|||
return x.detach().clone().requires_grad_(x.requires_grad)
|
||||
|
||||
|
||||
def named_parameters_for_optimized_module(mod):
|
||||
assert isinstance(mod, eval_frame.OptimizedModule)
|
||||
return mod._orig_mod.named_parameters
|
||||
|
||||
|
||||
def remove_optimized_module_prefix(name):
|
||||
prefix = "_orig_mod."
|
||||
assert name.startswith(prefix)
|
||||
return name[len(prefix) :]
|
||||
|
||||
|
||||
def collect_results(model, prediction, loss, example_inputs):
|
||||
results = []
|
||||
results.append(prediction)
|
||||
|
|
@ -55,8 +44,6 @@ def collect_results(model, prediction, loss, example_inputs):
|
|||
grads = dict()
|
||||
params = dict()
|
||||
for name, param in model.named_parameters():
|
||||
if isinstance(model, eval_frame.OptimizedModule):
|
||||
name = remove_optimized_module_prefix(name)
|
||||
param_copy = param
|
||||
grad = param.grad
|
||||
# Treat None and zero grad as same
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user