# Owner(s): ["module: dynamo"] import types from copy import deepcopy from unittest.mock import patch import torch import torch._dynamo.test_case import torch._dynamo.testing from torch._dynamo.eval_frame import unsupported from torch._dynamo.mutation_guard import GenerationTracker from torch._dynamo.testing import same from torch.nn import functional as F from torch.nn.modules.lazy import LazyModuleMixin from torch.nn.parameter import Parameter, UninitializedParameter try: from . import test_functions except ImportError: import test_functions class BasicModule(torch.nn.Module): def __init__(self): super().__init__() self.linear1 = torch.nn.Linear(10, 10) self.scale = torch.randn(1, 10) def forward(self, x): return F.relu(self.linear1(x)) * self.scale class FnMember(torch.nn.Module): def __init__(self): super().__init__() self.linear1 = torch.nn.Linear(10, 10) self.activation = F.relu def forward(self, x): x = self.linear1(x) if self.activation: x = self.activation(x) return x class FnMemberCmp(torch.nn.Module): def __init__(self, activation): super().__init__() self.linear1 = torch.nn.Linear(10, 10) self.activation = activation def forward(self, x): x = self.linear1(x) if self.activation is not None: x = self.activation(x) if self.activation is None: x = torch.sigmoid(x) return x class SubmoduleExample(torch.nn.Module): def __init__(self): super().__init__() self.layer1 = BasicModule() self.layer2 = BasicModule() self.scale = torch.randn(1, 10) def forward(self, x): x = self.layer1(x) x = self.layer2(x) return x * self.scale class IsTrainingCheck(torch.nn.Module): def __init__(self): super().__init__() self.linear1 = torch.nn.Linear(10, 10) self.linear2 = torch.nn.Linear(10, 10) self.train(True) def forward(self, x): if self.training: mod = self.linear1 else: mod = self.linear2 return F.relu(mod(x)) class IsEvalCheck(IsTrainingCheck): def __init__(self): super().__init__() self.train(False) class ModuleMethodCall(torch.nn.Module): def __init__(self): super().__init__() self.layer1 = BasicModule() self.layer2 = BasicModule() self.scale = torch.randn(1, 10) def call_and_scale(self, mod, x): x = mod(x) return x * self.scale def forward(self, x): x1 = self.call_and_scale(self.layer1, x) x2 = self.call_and_scale(self.layer2, x) return x1 + x2 class UnsupportedMethodCall(torch.nn.Module): def __init__(self): super().__init__() self.layer1 = BasicModule() self.scale = torch.randn(1, 10) def call_and_scale(self, mod, x): x = mod(x) x = x * self.scale return unsupported(x, x) def forward(self, x): x1 = self.call_and_scale(self.layer1, x) return x + x1 class UnsupportedModule(torch.nn.Module): def __init__(self): super().__init__() self.layer1 = BasicModule() self.scale = torch.randn(1, 10) def forward(self, x): x = self.layer1(x) * self.scale return unsupported(x, x) class UnsupportedModuleCall(torch.nn.Module): def __init__(self): super().__init__() self.mod = UnsupportedModule() def forward(self, x): return 1 + self.mod(x * 1.5) class ModuleWithStaticForward(torch.nn.Module): @staticmethod def forward(x): return x * torch.sigmoid(x) class ModuleCallModuleWithStaticForward(torch.nn.Module): def __init__(self): super().__init__() self.mod = ModuleWithStaticForward() def forward(self, x): return self.mod(x) class ModuleStaticMethodCall(torch.nn.Module): def __init__(self): super().__init__() self.layer1 = BasicModule() self.layer2 = BasicModule() self.scale = torch.randn(1, 10) @staticmethod def call_and_scale(scale, mod, x): x = mod(x) return x * scale def forward(self, x): x1 = self.call_and_scale(self.scale, self.layer1, x) x2 = self.call_and_scale(self.scale, self.layer2, x) return x1 + x2 class ModuleClassMethodCall(torch.nn.Module): def __init__(self): super().__init__() self.layer1 = BasicModule() self.layer2 = BasicModule() self.scale = torch.randn(1, 10) @classmethod def call_and_scale(cls, scale, mod, x): x = mod(x) return x * scale def forward(self, x): x1 = self.call_and_scale(self.scale, self.layer1, x) x2 = self.call_and_scale(self.scale, self.layer2, x) return x1 + x2 class ModuleProperty(torch.nn.Module): def __init__(self): super().__init__() self.scale = torch.randn(1, 10) @property def scale_alias(self): return self.scale def forward(self, x): return x * self.scale_alias class ConstLoop(torch.nn.Module): def __init__(self): super().__init__() self.linear1 = torch.nn.Linear(10, 10) self.count = 3 def forward(self, x): for i in range(self.count): x = torch.sigmoid(self.linear1(x)) return x class ViaModuleCall(torch.nn.Module): def __init__(self): super().__init__() self.linear1 = torch.nn.Linear(10, 10) def forward(self, x): return test_functions.constant3(torch.sigmoid(self.linear1(x)), x) class IsNoneLayer(torch.nn.Module): def __init__(self): super().__init__() self.layer1 = torch.nn.Linear(10, 10) self.layer2 = None self.train(True) def forward(self, x): if self.layer1 is not None: x = self.layer1(x) if self.layer2 is not None: x = self.layer2(x) return x class LayerList(torch.nn.Module): def __init__(self): super().__init__() self.layers = [ torch.nn.Linear(10, 10), torch.nn.ReLU(), torch.nn.Linear(10, 10), ] def forward(self, x): for layer in self.layers: x = layer(x) return x class ModuleList(torch.nn.Module): def __init__(self): super().__init__() self.layers = torch.nn.ModuleList( [ torch.nn.Linear(10, 10), torch.nn.ReLU(), torch.nn.Linear(10, 10), torch.nn.ReLU(), ] ) def forward(self, x): for i in range(len(self.layers)): x = self.layers[i](x) for layer in self.layers: x = layer(x) for layer, val in zip(self.layers, (x, x, x, x)): x = layer(x) + val for layer, val in zip(self.layers, (1, 2, 3, 4)): x = layer(x) + val for idx, layer in enumerate(self.layers): x = layer(x) * idx for idx, layer in enumerate(self.layers[::-1]): x = layer(x) * idx return x class ModuleDict(torch.nn.Module): def __init__(self): super().__init__() self.layers = torch.nn.ModuleDict( { "0": torch.nn.Linear(10, 10), } ) def forward(self, x): # TODO(future PR): handle more logic x = self.layers["0"](x) return x class TensorList(torch.nn.Module): def __init__(self): super().__init__() self.layers = ( torch.randn((1, 10)), torch.randn((10, 1)), torch.randn((1, 10)), torch.randn((10, 1)), ) def forward(self, x): for layer in self.layers: x = x * layer return x class Children(torch.nn.Module): def __init__(self): super().__init__() self.l1 = torch.nn.Linear(10, 10) self.l2 = torch.nn.ReLU() self.l3 = torch.nn.Linear(10, 10) self.l4 = torch.nn.ReLU() def forward(self, x): for block in self.children(): x = block(x) return x class IntArg(torch.nn.Module): def __init__(self): super().__init__() self.layer1 = torch.nn.Linear(10, 10) def forward(self, x, offset=1): x = F.relu(self.layer1(x)) + offset return x class Seq(torch.nn.Module): def __init__(self): super().__init__() self.layers = torch.nn.Sequential( torch.nn.Linear(10, 10), torch.nn.ReLU(), torch.nn.Linear(10, 10), torch.nn.ReLU(), ) def forward(self, x): return self.layers(x) class Cfg: def __init__(self): self.val = 0.5 self.count = 3 class CfgModule(torch.nn.Module): def __init__(self): super().__init__() self.cfg = Cfg() self.layer = torch.nn.Linear(10, 10) def forward(self, x): for i in range(self.cfg.count): x = self.layer(x + self.cfg.val) return x class StringMember(torch.nn.Module): def __init__(self): super().__init__() self.linear1 = torch.nn.Linear(10, 10) self.mode = "some_string" def forward(self, x): if self.mode == "some_string": return F.relu(self.linear1(x)) class _Block(torch.nn.Module): def forward(self, x): return 1.5 * torch.cat(x, 1) class _DenseBlock(torch.nn.ModuleDict): _version = 2 def __init__( self, num_layers: int = 3, ) -> None: super().__init__() for i in range(num_layers): self.add_module("denselayer%d" % (i + 1), _Block()) def forward(self, init_features): features = [init_features] for name, layer in self.items(): new_features = layer(features) features.append(new_features) return torch.cat(features, 1) class DenseNetBlocks(torch.nn.Module): def __init__(self): super().__init__() self.layers = _DenseBlock() def forward(self, x): return self.layers(x) class MaterializedModule(torch.nn.Module): """Once the below lazy module is initialized with its first input, it is transformed into this module.""" param: Parameter def __init__(self): super().__init__() self.register_parameter("param", None) def forward(self, x): return x class LazyModule(LazyModuleMixin, MaterializedModule): param: UninitializedParameter cls_to_become = MaterializedModule def __init__(self): super().__init__() self.param = UninitializedParameter() def initialize_parameters(self, x): self.param.materialize(x.shape) def requires_grad1(module: torch.nn.Module, recurse: bool = False) -> bool: requires_grad = any([p.requires_grad for p in module.parameters(recurse)]) return requires_grad def requires_grad2(module: torch.nn.Module, recurse: bool = False) -> bool: requires_grad = any(p.requires_grad for p in module.parameters(recurse)) return requires_grad class ParametersModule1(torch.nn.Module): def __init__(self): super().__init__() self.linear1 = torch.nn.Linear(10, 10) self.scale = torch.nn.Parameter(torch.randn(1, 10)) def forward(self, x): if not requires_grad1(self): return F.relu(self.linear1(x)) * self.scale else: return x + 1 class ParametersModule2(ParametersModule1): def forward(self, x): if not requires_grad2(self): return F.relu(self.linear1(x)) * self.scale else: return x + 1 class ParametersModule3(ParametersModule1): def forward(self, x): ones = torch.ones(10, dtype=next(self.parameters()).dtype) return F.relu(self.linear1(x)) * self.scale + ones class SuperModule(BasicModule): def forward(self, x): x = super().forward(x) return x + 10.0 class SuperModule2(BasicModule): def forward(self, x): return BasicModule.forward(self, x) class ComplicatedSuperParent(torch.nn.Module): @classmethod def custom_add(cls, x): x = x + x return x class SuperChildCallsClassMethod(ComplicatedSuperParent): @classmethod def child_func(cls, x): x = super().custom_add(x) return x def forward(self, x): x = self.child_func(x) return x class HasAttrModule(torch.nn.Module): def __init__(self): super().__init__() self.scale = torch.nn.Parameter(torch.randn(1, 10)) def forward(self, x): x = F.relu(x) if hasattr(self, "scale"): x *= self.scale if hasattr(self, "scale2"): x *= self.scale2 return x class EnumValues(torch.nn.ModuleDict): def __init__( self, num_layers: int = 3, ) -> None: super().__init__() for i in range(num_layers): self.add_module("denselayer%d" % (i + 1), _Block()) def forward(self, init_features): features = [init_features] for idx, layer in enumerate(self.values()): new_features = layer(features) features.append(new_features) return torch.cat(features, 1) class AccessByKeys(torch.nn.ModuleDict): def __init__( self, num_layers: int = 3, ) -> None: super().__init__() for i in range(num_layers): self.add_module("denselayer%d" % (i + 1), _Block()) def forward(self, init_features): features = [init_features] for k in self.keys(): new_features = self[k](features) features.append(new_features) return torch.cat(features, 1) class CallForwardDirectly(torch.nn.Module): def __init__(self): super().__init__() self.layer1 = BasicModule() self.layer2 = torch.nn.Linear(10, 10) def forward(self, x): x = self.layer1.forward(x) x = self.layer2.forward(x) return x class ModuleNameString(torch.nn.Module): def __init__(self): super().__init__() self.linear1 = torch.nn.Linear(10, 10) def forward(self, x): if self.__class__.__name__ == "ABC": return 10 if self.linear1.__class__.__name__ == "Linear": return F.relu(self.linear1(x) + 10) return 11 class SelfMutatingModule(torch.nn.Module): def __init__(self, layer): super().__init__() self.layer = layer self.counter = 0 def forward(self, x): result = self.layer(x) + self.counter self.counter += 1 return F.relu(result) class ModuleAttributePrecedenceBase(torch.nn.Module): def __init__(self): super().__init__() def linear(self, x): return x * 2.0 class ModuleAttributePrecedence(ModuleAttributePrecedenceBase): def __init__(self): super().__init__() self.activation = torch.nn.ReLU() self.linear = torch.nn.Linear(10, 10) self.initializer = torch.ones([10, 10]) self.scale = 0.5 def activation(self, x): return x * 1.2 def initializer(self): return torch.zeros([10, 10]) def scale(self): return 2.0 def forward(self, x): # object attribute takes precedence unless it's a nn.Module return self.activation(self.linear(self.initializer + x)) * self.scale class ModuleForwardHasGraphBreak(torch.nn.Module): def __init__(self): super().__init__() self.layer1 = BasicModule() self.layer2 = BasicModule() self.layer3 = torch.nn.Sequential(BasicModule(), BasicModule()) self.layer4 = torch.nn.ModuleList( [ torch.nn.Linear(10, 10), torch.nn.ReLU(), torch.nn.Linear(10, 10), torch.nn.ReLU(), ] ) self.layer5 = torch.nn.ModuleDict( { "0": torch.nn.Linear(10, 10), } ) self.scale = torch.randn(1, 10) def forward(self, x): """ This is used to test if the results of functions like `named_parameters` can be reconstructed correctly after graph break. https://github.com/pytorch/torchdynamo/issues/1931 """ x = self.layer1(x) params1 = dict(self.named_parameters()) params2 = list(self.parameters()) buffers1 = dict(self.named_buffers()) buffers2 = list(self.buffers()) modules1 = dict(self.named_modules()) modules2 = list(self.modules()) torch._dynamo.graph_break() y = modules2 y = modules1 y = buffers2 y = buffers1 y = params2 y = params1 x = ( self.layer2(x) + y["layer3.1.linear1.weight"] + y["layer4.2.weight"] + y["layer5.0.weight"] ) return x * self.scale class ModulePatch1(torch.nn.Module): pass class ModulePatch2(torch.nn.Module): def forward(self, x): return x - 1 def make_test(fn, expected_ops=None): def test_fn(self): return torch._dynamo.testing.standard_test( self, fn=fn, nargs=1, expected_ops=expected_ops ) fn.eval() return test_fn class NNModuleTests(torch._dynamo.test_case.TestCase): test_seq = make_test(Seq()) test_basicmodule1 = make_test(BasicModule()) test_basicmodule2 = make_test(BasicModule()) test_submodules1 = make_test(SubmoduleExample()) test_submodules2 = make_test(SubmoduleExample()) test_modulemethod1 = make_test(ModuleMethodCall()) test_modulemethod2 = make_test(ModuleMethodCall()) test_module_call_module_with_static_forward = make_test( ModuleCallModuleWithStaticForward() ) test_module_static_method = make_test(ModuleStaticMethodCall()) test_fnmember = make_test(FnMember()) test_fnmembercmp1 = make_test(FnMemberCmp(F.relu)) test_fnmembercmp2 = make_test(FnMemberCmp(None)) test_constloop = make_test(ConstLoop()) test_istraining1 = make_test(IsTrainingCheck()) test_istraining2 = make_test(IsTrainingCheck()) test_iseval1 = make_test(IsEvalCheck()) test_iseval2 = make_test(IsEvalCheck()) test_viamodulecall = make_test(ViaModuleCall()) test_isnonelayer = make_test(IsNoneLayer()) test_layerlist = make_test(LayerList()) test_tensorlist = make_test(TensorList()) test_intarg = make_test(IntArg()) test_cfgmod = make_test(CfgModule()) test_stringmember = make_test(StringMember()) test_modulelist = make_test(ModuleList()) test_moduledict = make_test(ModuleDict()) test_super1 = make_test(SuperModule()) test_super2 = make_test(SuperModule2()) test_super_class_method = make_test(SuperChildCallsClassMethod()) test_children = make_test(Children()) test_densenet = make_test(DenseNetBlocks()) test_parameters1 = make_test(ParametersModule1()) test_parameters2 = make_test(ParametersModule2()) test_parameters3 = make_test(ParametersModule3(), expected_ops=5) test_hasattr = make_test(HasAttrModule()) test_enumvalues = make_test(EnumValues()) test_access_by_keys = make_test(AccessByKeys()) test_module_class_method = make_test(ModuleClassMethodCall()) test_module_property = make_test(ModuleProperty()) test_forward_directly = make_test(CallForwardDirectly()) test_module_name_string = make_test(ModuleNameString()) test_module_attribute_precedence = make_test(ModuleAttributePrecedence()) def test_module_forward_has_graph_break(self): m = ModuleForwardHasGraphBreak() x = torch.rand([10, 10]) ref = m(x) opt_m = torch._dynamo.optimize("eager")(m) res = opt_m(x) self.assertTrue(torch.allclose(ref, res)) def test_unsupportedmethod(self): m = UnsupportedMethodCall() i = torch.randn(10) cnt = torch._dynamo.testing.CompileCounter() opt_m = torch._dynamo.optimize(cnt)(m) r = opt_m(i) self.assertTrue(torch._dynamo.testing.same(r, m(i))) self.assertEqual(cnt.op_count, 5) def test_unsupportedmodule(self): m = UnsupportedModuleCall() i = torch.randn(10) cnt = torch._dynamo.testing.CompileCounter() opt_m = torch._dynamo.optimize(cnt)(m) r = opt_m(i) self.assertTrue(torch._dynamo.testing.same(r, m(i))) self.assertEqual(cnt.op_count, 6) def test_self_mutating1(self): m1 = torch.nn.Linear(10, 10) m2 = SelfMutatingModule(m1) m3 = SelfMutatingModule(m1) m4 = SelfMutatingModule(m1) i = torch.randn(10) out2 = [m2(i), m2(i), m2(i)] cnt = torch._dynamo.testing.CompileCounter() opt_m3 = torch._dynamo.optimize_assert(cnt)(m3) opt_m4 = torch._dynamo.optimize_assert(cnt)(m4) out3 = [opt_m3(i), opt_m3(i), opt_m3(i)] out4 = [opt_m4(i), opt_m4(i), opt_m4(i)] self.assertTrue(torch._dynamo.testing.same(out2, out3)) self.assertTrue(torch._dynamo.testing.same(out2, out4)) self.assertEqual(cnt.frame_count, 3) @patch.object(torch._dynamo.config, "raise_on_ctx_manager_usage", False) def test_generation_tag(self): cnt = torch._dynamo.testing.CompileCounter() # guarantee that we have installed # the generation tagging function with torch._dynamo.optimize_assert(cnt): pass m1 = torch.nn.Linear(10, 10) prev_generation = GenerationTracker.get_generation_value(m1) cur_generation = prev_generation + 1 with torch._dynamo.optimize_assert(cnt): m2 = torch.nn.Linear(10, 10) self.assertEqual(GenerationTracker.get_generation_value(m1), prev_generation) self.assertEqual(GenerationTracker.get_generation_value(m2), cur_generation) # check that newly constructed instances # also have the same generation (even if copied from an old instance) m3 = deepcopy(m1) self.assertEqual(GenerationTracker.get_generation_value(m3), cur_generation) def test_simple_torch_function(self): def foo(x): # function call, twice to test wrapping x = F.sigmoid(x) x = F.sigmoid(x) # method call, twice to test wrapping x = x.sigmoid() x = x.sigmoid() return x class TensorProxy(torch.Tensor): @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): return super().__torch_function__(func, types, args, kwargs) torch._dynamo.config.traceable_tensor_subclasses.add(TensorProxy) try: x = torch.randn(1).as_subclass(TensorProxy) cnt = torch._dynamo.testing.CompileCounter() out1 = foo(x) opt_foo = torch._dynamo.optimize(cnt, nopython=True)(foo) out2 = opt_foo(x) self.assertEqual(cnt.op_count, 4) self.assertTrue(torch._dynamo.testing.same(out1, out2)) finally: torch._dynamo.config.traceable_tensor_subclasses.remove(TensorProxy) def test_torch_function_with_closure(self): def run(): counter = 0 def foo(x): # function call, twice to test wrapping x = F.sigmoid(x) x = F.sigmoid(x) # method call, twice to test wrapping x = x.sigmoid() x = x.sigmoid() return x class TensorProxy(torch.Tensor): @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): nonlocal counter # for now, only support reads from closure cells # TODO(future PR): support writes as well counter + 1 return super().__torch_function__(func, types, args, kwargs) torch._dynamo.config.traceable_tensor_subclasses.add(TensorProxy) try: x = torch.randn(1).as_subclass(TensorProxy) x = torch.randn(1) cnt = torch._dynamo.testing.CompileCounter() out1 = foo(x) opt_foo = torch._dynamo.optimize(cnt, nopython=True)(foo) out2 = opt_foo(x) self.assertEqual(cnt.op_count, 4) self.assertTrue(torch._dynamo.testing.same(out1, out2)) finally: torch._dynamo.config.traceable_tensor_subclasses.remove(TensorProxy) run() @patch.object(torch._dynamo.config, "raise_on_ctx_manager_usage", False) def test_nn_moduledict_contains(self): class M(torch.nn.Module): def __init__(self, module_dict): super().__init__() self.module_dict = module_dict def forward(self, x): if "foo" in self.module_dict: x = torch.mul(x, 1.0) x = torch.add(x, 1.0) return x module_dict = torch.nn.ModuleDict({"foo": torch.nn.Conv2d(1, 1, 1)}) m = M(module_dict) data = torch.randn(1) out1 = m(data) cnt = torch._dynamo.testing.CompileCounter() opt_m = torch._dynamo.optimize(cnt, nopython=True)(m) out2 = opt_m(data) self.assertEqual(cnt.op_count, 2) self.assertTrue(torch._dynamo.testing.same(out1, out2)) module_dict = torch.nn.ModuleDict({"bar": torch.nn.Conv2d(1, 1, 1)}) m = M(module_dict) data = torch.randn(1) out1 = m(data) cnt = torch._dynamo.testing.CompileCounter() torch._dynamo.reset() opt_m = torch._dynamo.optimize(cnt, nopython=True)(m) out2 = opt_m(data) self.assertEqual(cnt.op_count, 1) self.assertTrue(torch._dynamo.testing.same(out1, out2)) module_dict = torch.nn.ModuleDict({"cat": torch.nn.Conv2d(1, 1, 1)}) pre = m(data) cnt.clear() with torch._dynamo.optimize(cnt, nopython=False): opt_pre = m(data) m = M(module_dict) data = torch.randn(1) out1 = m(data) out_post = m(data) self.assertEqual(cnt.frame_count, 1) self.assertEqual(cnt.op_count, 1) self.assertTrue(torch._dynamo.testing.same(pre, opt_pre)) self.assertTrue(torch._dynamo.testing.same(out1, out_post)) def test_lazy_module(self): input_shape = (16, 3, 6, 7, 8) cnt = torch._dynamo.testing.CompileCounter() module = LazyModule() def test_static_module(): input = torch.ones(*input_shape) module(input) opt_test_static_module = torch._dynamo.optimize(cnt)(test_static_module) opt_test_static_module() self.assertTrue( isinstance(module, MaterializedModule), "Module should be transformed to an instance of MaterializedModule.", ) self.assertEqual(module.param.shape, input_shape) # test when mapped to UnspecializedNNModule module = LazyModule() def test_unspecialized(): nonlocal module module = LazyModule() input = torch.ones(*input_shape) module(input) opt_test_unspecialized = torch._dynamo.optimize(cnt)(test_unspecialized) opt_test_unspecialized() self.assertTrue( isinstance(module, MaterializedModule), "Module should be transformed to an instance of MaterializedModule.", ) self.assertEqual(module.param.shape, input_shape) # test with a static module in torch.* module = torch.nn.modules.LazyBatchNorm3d( affine=False, track_running_stats=False ) cnt = torch._dynamo.testing.CompileCounter() torch._dynamo.reset() def test_torch_static(): input = torch.ones(*input_shape) return module(input) # fully materialized opt_test_torch_static = torch._dynamo.optimize(cnt)(test_torch_static) opt_test_torch_static() out = opt_test_torch_static() self.assertTrue(same(out, module(torch.ones(*input_shape)))) self.assertTrue( isinstance(module, torch.nn.modules.batchnorm.BatchNorm3d), "Module should be transformed to an instance of BatchNorm3d.", ) self.assertEqual(cnt.frame_count, 1, "No guards should have triggered.") def test_call_fn_with_non_const_inputs_safe(self): class ModuleSpecialFwd(torch.nn.Module): def __init__(self): super(ModuleSpecialFwd, self).__init__() self.conv = torch.nn.Conv2d( in_channels=3, out_channels=20, kernel_size=(5, 5) ) def _conv_forward(self, x): return self.conv._conv_forward(x, self.conv.weight, self.conv.bias) def forward(self, x): return self._conv_forward(x) mod = ModuleSpecialFwd() rx = torch.randn([3, 10, 10]) real = mod(rx) graph, _ = torch._dynamo.export(mod, rx) self.assertTrue(torch._dynamo.testing.same(real, graph(rx))) class MockModule(torch.nn.Module): def __init__(self): super().__init__() self.relu = torch.nn.ReLU() self.linear = torch.nn.Linear(10, 10) self.register_buffer("buf0", torch.randn(10, 10)) def forward(self, x): return self.relu(self.linear(x) + self.buf0) class OptimizedModuleTest(torch._dynamo.test_case.TestCase): def test_nn_module(self): mod = MockModule() cnt = torch._dynamo.testing.CompileCounter() opt_mod = torch._dynamo.optimize(cnt)(mod) self.assertIsInstance(opt_mod, torch._dynamo.OptimizedModule) x = torch.randn(10, 10) self.assertTrue(torch._dynamo.testing.same(mod(x), opt_mod(x))) self.assertEqual(cnt.frame_count, 1) def test_to(self): mod = MockModule() cnt = torch._dynamo.testing.CompileCounter() opt_mod = torch._dynamo.optimize(cnt)(mod) x = torch.randn(10, 10) self.assertTrue(torch._dynamo.testing.same(mod(x), opt_mod(x))) self.assertEqual(cnt.frame_count, 1) # Ensure that there is no recompilation opt_mod(x) self.assertEqual(cnt.frame_count, 1) opt_mod = opt_mod.to(device="cpu").to(dtype=torch.float64) self.assertIsInstance(opt_mod, torch._dynamo.OptimizedModule) x = torch.randn(10, 10).to(dtype=torch.float64) opt_mod(x) # Ensure that there is a recompilation self.assertEqual(cnt.frame_count, 2) # Ensure that there is no recompilation opt_mod(x) self.assertEqual(cnt.frame_count, 2) torch._dynamo.reset() opt_mod(x) self.assertEqual(cnt.frame_count, 3) def test_attr(self): class MockModule(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(10, 10) self.register_buffer("buf0", torch.randn(10, 10)) def forward(self, x): return self.r(torch.sin(x)) + self.buf0 mod = MockModule() opt_mod = torch._dynamo.optimize("eager")(mod) # Check parameteres and buffers for (p1, p2) in zip(mod.parameters(), opt_mod.parameters()): self.assertTrue(id(p1) == id(p2)) def test_recursion(self): mod = MockModule() cnt = torch._dynamo.testing.CompileCounter() opt_mod = torch._dynamo.optimize(cnt)(mod) for _ in range(5): opt_mod = torch._dynamo.optimize(cnt)(opt_mod) opt_mod(torch.randn(10, 10)) self.assertEqual(cnt.frame_count, 1) def test_composition(self): class InnerModule(torch.nn.Module): def __init__(self): super().__init__() self.relu = torch.nn.ReLU() def forward(self, x): return self.relu(torch.sin(x)) opt_inner_mod = InnerModule() class OuterModule(torch.nn.Module): def __init__(self): super().__init__() self.mod = opt_inner_mod def forward(self, x): return self.mod(torch.cos(x)) outer_mod = OuterModule() cnt = torch._dynamo.testing.CompileCounter() opt_outer_mod = torch._dynamo.optimize(cnt)(outer_mod) x = torch.randn(4) self.assertIsInstance(opt_outer_mod, torch._dynamo.OptimizedModule) self.assertTrue(torch._dynamo.testing.same(outer_mod(x), opt_outer_mod(x))) self.assertEqual(cnt.frame_count, 1) def test_composition_with_opt_mod(self): class InnerModule(torch.nn.Module): def __init__(self): super().__init__() self.relu = torch.nn.ReLU() def forward(self, x): return self.relu(torch.sin(x)) inner_mod = InnerModule() cnt = torch._dynamo.testing.CompileCounter() opt_inner_mod = torch._dynamo.optimize(cnt)(inner_mod) class OuterModule(torch.nn.Module): def __init__(self): super().__init__() self.mod = opt_inner_mod def forward(self, x): return self.mod(torch.cos(x)) outer_mod = OuterModule() opt_outer_mod = torch._dynamo.optimize(cnt)(outer_mod) x = torch.randn(4) self.assertIsInstance(opt_outer_mod, torch._dynamo.OptimizedModule) self.assertTrue(torch._dynamo.testing.same(outer_mod(x), opt_outer_mod(x))) # There will be a graph break for the inner mod being OptimizedModule self.assertEqual(cnt.frame_count, 2) def test_module_patch(self): mod = ModulePatch1() mod.forward = types.MethodType(ModulePatch2.forward, mod) def fn(x): return mod(x) self.assertTrue( torch.allclose( torch._dynamo.optimize("eager", nopython=True)(fn)(torch.ones(10)), torch.zeros(1), ) ) if __name__ == "__main__": from torch._dynamo.test_case import run_tests run_tests()