import torch import unittest import operator import numbers import pickle import copy import sys import functools import contextlib from pathlib import Path from torch.fx import symbolic_trace, Proxy, Node, GraphModule, Tracer, Graph from torch.fx.experimental import GraphManipulation from torch.fx.experimental import shape_prop from torch.fx.experimental.subgraph_creation_example import split_module from torch.fx.immutable_collections import immutable_dict, immutable_list from copy import deepcopy from torch.fx.proxy import TraceError from fx.quantization import Quantizer from typing import Any, Callable, Dict, NamedTuple, List, Optional, Tuple, Union from torch.testing._internal.common_utils import run_tests, TEST_WITH_ROCM, IS_WINDOWS, IS_SANDCASTLE, IS_MACOS from torch.testing._internal.jit_utils import JitTestCase try: from torchvision.models import resnet18 HAS_TORCHVISION = True except ImportError: HAS_TORCHVISION = False skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision") class SimpleTest(torch.nn.Module): def forward(self, x): return torch.relu(x + 3.0) def a_non_torch_leaf(a, b): return a + b class Pair(NamedTuple): x : torch.Tensor y : torch.Tensor class TestFX(JitTestCase): def checkGraphModule(self, m: torch.nn.Module, args, kwargs=None): """Check that an nn.Module's results match the GraphModule version for a given set of args/kwargs. """ kwargs = kwargs if kwargs else {} ref_outs = m(*args, **kwargs) gm = symbolic_trace(m) gm.graph.lint(gm) test_outs = gm(*args, **kwargs) self.assertEqual(ref_outs, test_outs) def test_graph_module(self): class MySub(torch.nn.Module): def __init__(self): super().__init__() self.w = torch.nn.Parameter(torch.rand(4, 3)) def forward(self, x): return self.w + x class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.lin = torch.nn.Linear(4, 3) self.sub_mod = MySub() self.w = torch.nn.Parameter(torch.rand(3)) def forward(self, A, B, c): t = torch.sigmoid(A) + self.lin(c) return self.sub_mod(t.data + self.w + t + 1 - A + B // A + -A + A.add(B, alpha=3)) m = MyModule() gm = symbolic_trace(m) ms = torch.jit.script(gm) class M2(torch.nn.Module): def forward(self, A): m, idx = torch.max(A, 0) return m + 1, idx + 1 m2 = M2() gm2 = symbolic_trace(m2) class T(torch.nn.Module): def forward(self, A, b=4, *args, c=5, **kwargs): x = A + 1 + args[0] + kwargs['3'] return x t = T() symbolic_trace(t) def test_custom_import(self): graph = torch.fx.Graph() a = graph.placeholder('x') b = graph.placeholder('y') c = graph.call_function(a_non_torch_leaf, (a, b)) d = graph.call_function(torch.sin, (c,)) graph.output(d) gm = GraphModule(torch.nn.Module(), graph) x, y = torch.rand(1), torch.rand(1) self.assertEqual(torch.sin(x + y), gm(x, y)) def test_args_kwargs(self): class T(torch.nn.Module): def forward(self, *args, **kwargs): x = args[0] + kwargs['foo'] return x t = T() self.checkGraphModule(t, (torch.rand(1), torch.rand(1)), {'foo': torch.rand(1)}) def test_args_kwargs_no_self(self): class T(torch.nn.Module): def forward(*args, **kwargs): # noqa: B902 self = args[0] return torch.relu(args[1]) t = T() with self.assertRaisesRegex(RuntimeError, r'cannot be part of \*args expansion'): self.checkGraphModule(t, (torch.rand(1), torch.rand(1)), {'foo': torch.rand(1)}) def test_fx_shifts(self): class MyModule(torch.nn.Module): def forward(self, x): return x << 3, x >> 3 input = torch.LongTensor(10).random_(0, 1024) m = MyModule() self.checkGraphModule(m, (input,)) def test_dict(self): class MyDictMod(torch.nn.Module): def forward(self, d): return d['3'].relu(), {'4' : d['3'].neg()} input_dict = {'3': torch.rand(3, 4)} m = MyDictMod() self.checkGraphModule(m, (input_dict,)) def test_disallow_override(self): # Custom delegate to disallow in-place tensor operations class NoMutableCallTracer(Tracer): def create_node(self, kind : str, target : Union[str, Callable], args : Tuple[Any], kwargs : Dict[str, Any], name : Optional[str] = None, type_expr : Optional[Any] = None) -> Node: name = target if isinstance(target, str) else torch.typename(target) if name[-1] == '_': raise RuntimeError('In-place operations are not supported') return super().create_node(kind, target, args, kwargs, name) # Test method class MyInplaceMod(torch.nn.Module): def forward(self, x): x.add_(3.0) return x m = MyInplaceMod() with self.assertRaisesRegex(RuntimeError, 'In-place operations'): NoMutableCallTracer().trace(m) # Test free function class MyInplaceMod2(torch.nn.Module): def forward(self, x): torch.log_(x) return x m2 = MyInplaceMod2() with self.assertRaisesRegex(RuntimeError, 'In-place operations'): NoMutableCallTracer().trace(m2) # Test symbolic node as an arg class MyInplaceMod3(torch.nn.Module): def forward(self, x): y = torch.ones(3, 4) y.add_(x) return x m3 = MyInplaceMod3() with self.assertRaisesRegex(RuntimeError, 'In-place operations'): NoMutableCallTracer().trace(m3) def test_leaf_module(self): # Custom delegate to make it so that there are no leaf modules, everything # should get traced through class NoLeafModulesTracer(Tracer): def is_leaf_module(self, m, qualname): return False class MyReluMod(torch.nn.Module): def __init__(self): super().__init__() self.relu = torch.nn.ReLU() def forward(self, x): return self.relu(x) mrm = MyReluMod() sym = NoLeafModulesTracer().trace(mrm) for node in sym.nodes: self.assertNotEqual(node.op, 'call_module') sym.lint(sym) def test_graph_edit_with_proxy(self): class M(torch.nn.Module): def forward(self, a, b): return a + b m = M() g = symbolic_trace(m).graph new_g = torch.fx.Graph() val_map : Dict[Node, Node] = {} output_val = new_g.graph_copy(g, val_map) t = Proxy(output_val) # test that we can use proxy objects to generate more graph code later for things that do not need to work with modules. new_g.output((t + t).node) gm = GraphModule(m, new_g) gm.graph.lint(gm) self.assertEqual(gm(3, 4), 14) def test_graph_unique_names(self): class M(torch.nn.Module): def forward(self, a, b): return a + b m = M() g = symbolic_trace(m).graph new_g = torch.fx.Graph() val_map : Dict[Node, Node] = {} output_val = new_g.graph_copy(g, val_map) t = Proxy(output_val) # test that we can use proxy objects to generate more graph code later for things that do not need to work with modules. new_g.output((t + t).node) gm = GraphModule(m, new_g) seen_names : Set[str] = set() for node in gm.graph.nodes: assert node.name not in seen_names seen_names.add(node.name) def test_graph_unique_names_manual(self): graph : torch.fx.Graph = torch.fx.Graph() a : torch.fx.Node = graph.create_node('placeholder', 'x') b : torch.fx.Node = graph.create_node('call_module', 'linear_mod', args=(a,), name='foo_1_1') c : torch.fx.Node = graph.create_node('get_attr', 'y_attr', name='foo_1') d : torch.fx.Node = graph.create_node('call_function', operator.add, args=(b, c)) graph.output(d) graph2 = torch.fx.Graph() val_map : Dict[Node, Node] = {} graph2.graph_copy(graph, val_map) seen_names : Set[str] = set() for node in graph2.nodes: assert node.name not in seen_names seen_names.add(node.name) @skipIfNoTorchVision def test_resnet(self): resnet = resnet18() resnet.train() res_graph = symbolic_trace(resnet) res_script = torch.jit.script(res_graph) ip = torch.rand(1, 3, 224, 224) a = resnet(ip) b = res_graph(ip) c = res_script(ip) self.assertEqual(a, b) self.assertEqual(a, c) quantizer = Quantizer(res_graph) for i in range(10): quantizer.observe((torch.rand(1, 3, 224, 224),)) qgraph = quantizer.quantize() qgraph.graph.lint(qgraph) qgraph_script = torch.jit.script(qgraph) d = qgraph(ip) e = qgraph_script(ip) assert (a - d).abs().max() < 2 self.assertEqual(d, e) def test_unpack(self): class M(torch.nn.Module): def forward(self, a, b): c, d = a return c + d + b a = (torch.rand(1), torch.rand(1)) b = torch.rand(1) m = M() self.checkGraphModule(m, (a, b)) def test_native_callable(self): if TEST_WITH_ROCM or IS_SANDCASTLE or IS_WINDOWS or IS_MACOS: raise unittest.SkipTest("non-portable load_library call used in test") torch_root = Path(__file__).resolve().parent.parent p = torch_root / 'build' / 'lib' / 'libtorchbind_test.so' torch.ops.load_library(str(p)) # This test exercises the case where we use FX to translate from Python # code to some native callable object # # For the purposes of testing, we use ElementwiseInterpreter defined # in test_custom_class.cpp. # # We test that we can # 1) Construct a native callable from FX IR # 2) Construct a drop-in replacement module that delegates to the # native callable rather than the original code # 3) Run both the original code and native callable wrapper with # equivalent results # 4) TorchScript compile the native callable wrapper and confirm # equivalent results with the reference # 5) TorchScript serialize and deserialize the native callable # and confirm equivalent results with the reference # We use this simple Module as a reference computation class MySimpleMod(torch.nn.Module): def forward(self, x): return 3.0 * x + x msm = MySimpleMod() # This is what a lowering pass might look like: a function that takes # a valid nn.Module, symbolically traces it, lowers the Module to some # representation, and wraps that representation up into another # nn.Module instance that handles dispatch to the compiled/lowered code. def lower_to_elementwise_interpreter(orig_mod : torch.nn.Module) -> torch.nn.Module: # ===== Stage 1: Symbolic trace the module ===== mod = symbolic_trace(orig_mod) # ===== Stage 2: Lower GraphModule representation to the C++ # interpreter's instruction format ====== instructions = [] constant_idx = 0 constants = {} fn_input_names = [] target_to_name = { operator.add : "add", operator.mul : "mul" } output_node : Optional[Node] = None # For each instruction, create a triple # (instruction_name : str, inputs : List[str], output : str) # to feed into the C++ interpreter for n in mod.graph.nodes: target, args, out_name = n.target, n.args, n.name assert len(n.kwargs) == 0, "kwargs currently not supported" if n.op == 'placeholder': # Placeholders specify function argument names. Save these # for later when we generate the wrapper GraphModule fn_input_names.append(target) elif n.op == 'call_function': assert target in target_to_name, "Unsupported call target " + target arg_names = [] for arg in args: if not isinstance(arg, Node): # Pull out constants. These constants will later be # fed to the interpreter C++ object via add_constant() arg_name = f'constant_{constant_idx}' constants[arg_name] = torch.Tensor( [arg] if isinstance(arg, numbers.Number) else arg) arg_names.append(arg_name) constant_idx += 1 else: arg_names.append(arg.name) instructions.append((target_to_name[target], arg_names, out_name)) elif n.op == 'output': if output_node is not None: raise RuntimeError('Multiple output nodes!') output_node = n else: raise RuntimeError('Unsupported opcode ' + n.op) interpreter = torch.classes._TorchScriptTesting._ElementwiseInterpreter() # Load constants for k, v in constants.items(): interpreter.add_constant(k, v) # Specify names for positional input arguments interpreter.set_input_names(fn_input_names) # Load instructions interpreter.set_instructions(instructions) # Specify name for single output assert isinstance(output_node.args[0], torch.fx.Node) interpreter.set_output_name(output_node.args[0].name) # ===== Stage 3: Create a wrapper GraphModule around the interpreter ===== class WrapperModule(torch.nn.Module): def __init__(self, interpreter): super().__init__() self.interpreter = interpreter wrapper = WrapperModule(interpreter) # Create a graph that: 1) Takes function arguments 2) Invokes the interpreter # 3) Returns the speficied return value # FIXME: The following code could be greatly simplified by symbolic_trace'ing # the wrapper with a Tracer that considers the Wrapper instance a root # module, however, I can't get `__call__` exposed on TorchBind classes # without it messing up Python `hasattr` for some reason. More digging # into CPython's implementation of hasattr is probably in order... graph = torch.fx.Graph() # Add placeholders for fn inputs placeholder_nodes = [] for name in fn_input_names: placeholder_nodes.append(graph.create_node('placeholder', name)) # Get the interpreter object interpreter_node = graph.create_node('get_attr', 'interpreter') # Add a node to call the interpreter instance output_node = graph.create_node( op='call_method', target='__call__', args=(interpreter_node, placeholder_nodes)) # Register output graph.output(output_node) graph.lint(wrapper) # Return final GraphModule!!! return GraphModule(wrapper, graph) # Lower GraphModule to C++ interpreter lowered = lower_to_elementwise_interpreter(msm) # Compare correctness with original module x = torch.rand(3, 4) ref_out = msm(x) test_out = lowered(x) torch.testing.assert_allclose(test_out, ref_out) # Test TorchScript compilation scripted_lowered = torch.jit.script(lowered) script_out = scripted_lowered(x) torch.testing.assert_allclose(script_out, ref_out) # Test TorchScript ser/de import_copy = self.getExportImportCopy(scripted_lowered) imported_out = import_copy(x) torch.testing.assert_allclose(imported_out, ref_out) def test_reserved_getattr(self): """Ensure that we do not name any nodes with a reserved builtin like `getattr`""" class M(torch.nn.Module): def forward(self, a): return a.foo.bar.baz m = M() m_g = symbolic_trace(m) m_g.graph.lint(m_g) for node in m_g.graph.nodes: self.assertTrue(node.name != "getattr") def test_node_tagging(self): class TaggingTracer(Tracer): def create_node(self, kind : str, target : Union[str, Callable], args : Tuple[Any], kwargs : Dict[str, Any], name : Optional[str] = None, type_expr : Optional[Any] = None) -> Node: n = super().create_node(kind, target, args, kwargs, name) n.tag = 'foo' return n class M(torch.nn.Module): def forward(self, a, b): return a + b m = M() g = TaggingTracer().trace(m) g.lint(m) for n in g.nodes: self.assertTrue(hasattr(n, 'tag')) self.assertEqual(n.tag, 'foo') def test_tensor_attribute(self): class TensorAttribute(torch.nn.Module): def __init__(self): super().__init__() self.tensor = torch.rand(3, 4) def forward(self, x): return torch.nn.functional.linear(x, self.tensor) ta = TensorAttribute() traced = symbolic_trace(ta) traced(torch.rand(4, 4)) class WrapperForQualname(torch.nn.Module): def __init__(self): super().__init__() self.ta = TensorAttribute() def forward(self, x): return torch.nn.functional.linear(x, self.ta.tensor) wfq = WrapperForQualname() traced2 = symbolic_trace(wfq) traced2.graph.lint(traced2) traced2(torch.rand(4, 4)) def test_symbolic_trace_sequential(self): class Simple(torch.nn.Module): def forward(self, x): return torch.neg(x) seq = torch.nn.Sequential( Simple(), Simple(), Simple() ) traced = symbolic_trace(seq) traced.graph.lint(traced) x = torch.rand(3, 4) self.assertEqual(traced(x), seq(x)) def test_tensor_constant(self): class ConstTensor(torch.nn.Module): def forward(self, x): return torch.nn.functional.linear(x, torch.zeros(3, 4)) ct = ConstTensor() traced = symbolic_trace(ct) traced.graph.lint(traced) traced(torch.rand(4, 4)) def test_pickle_graphmodule(self): class Nested(torch.nn.Module): def __init__(self): super().__init__() self.st = torch.nn.Linear(4, 4) def forward(self, x): return self.st(x) n = Nested() traced = symbolic_trace(n) traced.graph.lint(traced) pickled = pickle.dumps(traced) loaded = pickle.loads(pickled) loaded.graph.lint(loaded) x = torch.rand(3, 4) self.assertEqual(loaded(x), traced(x)) def test_deepcopy_graphmodule_with_transform(self): st = SimpleTest() traced = symbolic_trace(st) traced.graph.lint(traced) def transform(traced): new_graph = torch.fx.Graph() val_map : Dict[Node, Node] = {} output_value = new_graph.graph_copy(traced.graph, val_map) relu_out = new_graph.create_node( op='call_method', target='neg', args=(output_value,), kwargs={}) new_graph.output(relu_out) return GraphModule(traced, new_graph) transformed = transform(traced) transformed.graph.lint(transformed) copied = copy.deepcopy(transformed) self.assertNotEqual(id(type(transformed)), id(type(copied))) x = torch.randn(3, 4) self.assertEqual(copied(x), transformed(x)) def test_deepcopy_with_submods_params(self): class Bar(torch.nn.Module): def __init__(self): super().__init__() self.param = torch.nn.Parameter(torch.rand(3, 4)) def forward(self, x): return torch.relu(x) + self.param class Baz(torch.nn.Module): def __init__(self): super().__init__() self.param = torch.nn.Parameter(torch.rand(3, 4)) self.bar = Bar() def forward(self, x): return self.bar(x) - self.param baz = Baz() traced = symbolic_trace(baz) traced.graph.lint(traced) copied = copy.deepcopy(traced) copied.graph.lint(copied) def test_unpack_list_better_error(self): class SomeArgs(torch.nn.Module): def forward(self, a, b): return torch.rand(3, 4) class UnpacksList(torch.nn.Module): def __init__(self): super().__init__() self.sa = SomeArgs() def forward(self, x : list): return self.sa(*x) ul = UnpacksList() with self.assertRaisesRegex(TraceError, 'Proxy object cannot be iterated.'): symbolic_trace(ul) def test_unpack_dict_better_error(self): class SomeKwargs(torch.nn.Module): def forward(self, x=3, y=4): return torch.rand(3, 4) class UnpacksDict(torch.nn.Module): def __init__(self): super().__init__() self.sk = SomeKwargs() def forward(self, x : dict): return self.sk(**x) ud = UnpacksDict() with self.assertRaisesRegex(TraceError, 'Proxy object cannot be iterated.'): symbolic_trace(ud) def test_script_tensor_constant(self): # TorchScript seems to ignore attributes that start with `__`. # We used to call anonymous Tensor values `__tensor_constant*`, but # they were getting ignored by script. Now they're called # `_tensor_constant*` class IHaveATensorConstant(torch.nn.Module): def forward(self, x): return x + torch.rand(3, 4) traced = torch.fx.symbolic_trace(IHaveATensorConstant()) torch.jit.script(traced) def test_torch_custom_ops(self): class M(torch.nn.Module): def forward(self, a): b = torch.ops.aten.sigmoid(a) c = torch.ops.aten.cat([a, b]) return torch.ops.aten.cat((c, c)) m = M() input = torch.randn(3) ref_out = m(input) gm = symbolic_trace(m) gm.graph.lint(gm) out = gm(input) self.assertEqual(out, ref_out) def test_replace_target_nodes_with(self): class testModule(torch.nn.Module): def forward(self, a, b): return a + b m = testModule() traced = symbolic_trace(m) input1 = torch.randn(1) input2 = torch.randn(1) assert (input1 + input2) == traced(input1, input2) GraphManipulation.replace_target_nodes_with( fx_module=traced, old_op="call_function", old_target=operator.add, new_op="call_function", new_target=operator.mul, ) assert (input1 * input2) == traced(input1, input2) def test_pretty_print(self): st = SimpleTest() traced = symbolic_trace(st) traced.graph.lint(traced) printed = str(traced) assert 'GraphModuleImpl()' in printed assert 'torch.relu' in printed def test_pretty_print_graph(self): class KwargPrintTest(torch.nn.Module): def forward(self, x): return torch.squeeze(x + 3.0, dim=2) st = KwargPrintTest() traced = symbolic_trace(st) traced.graph.lint(traced) stringed = str(traced.graph) for s in ['args', 'kwargs', '#users']: assert s in stringed def test_graph_fns(self): g = Graph() a = g.placeholder('a') b = g.call_module('linear', (a,)) c = g.get_attr('bias') d = g.call_method('add', (b, c)) e = g.call_function(torch.sin, (d,)) g.output(e) mod = torch.nn.Module() mod.linear = torch.nn.Linear(3, 4) mod.bias = torch.rand(4) gm = GraphModule(mod, g) gm.graph.lint(gm) input = torch.rand(3) r = gm(input) ref = torch.sin(mod.linear(input) + mod.bias) self.assertEqual(r, ref) def test_remove_uses(self): g : torch.fx.Graph = Graph() x : torch.fx.Node = g.placeholder('x') relu : torch.fx.Node = g.call_function(torch.relu, (x,)) neg : torch.fx.Node = g.call_function(torch.neg, (relu,)) g.output(neg) neg.replace_all_uses_with(relu) g.erase_node(neg) self.assertTrue(neg not in relu.users) def test_construct_root_dict(self): graph : torch.fx.Graph = torch.fx.Graph() a : torch.fx.Node = graph.create_node('placeholder', 'x') b : torch.fx.Node = graph.create_node('call_module', 'foo.bar.baz', args=(a,)) c : torch.fx.Node = graph.create_node('get_attr', 'zip.zap.zam') d : torch.fx.Node = graph.create_node('call_function', operator.add, args=(b, c)) graph.output(d) linear_mod : torch.nn.Module = torch.nn.Linear(3, 4) add_param : torch.Tensor = torch.rand(3, 4) gm : torch.fx.GraphModule = torch.fx.GraphModule( {'foo.bar.baz': linear_mod, 'zip.zap.zam' : add_param}, graph) gm.graph.lint(gm) assert 'self.foo.bar.baz' in gm.code x : torch.Tensor = torch.rand(3, 3) out : torch.Tensor = gm(x) ref_out : torch.Tensor = linear_mod(x) + add_param self.assertEqual(out, ref_out) def test_symbolic_trace_assert(self): message = "assert_foobar" class AssertsTensorShape(torch.nn.Module): def forward(self, x): torch.Assert(x.shape[1] > 4, message) return x m = AssertsTensorShape() # verify traceability traced = symbolic_trace(m) # verify assertion on traced model works correctly at runtime traced(torch.rand(4, 5)) with self.assertRaisesRegex(AssertionError, message): traced(torch.rand(4, 3)) def test_copy_no_remap(self): traced = symbolic_trace(SimpleTest()) g = traced.graph copied = torch.fx.Graph() for node in g.nodes: copied.node_copy(node) with self.assertRaisesRegex(RuntimeError, 'does not belong to this Graph'): copied.lint() def test_wrong_topo(self): graph : torch.fx.Graph = torch.fx.Graph() a : torch.fx.Node = graph.create_node('placeholder', 'x') b : torch.fx.Node = graph.create_node('call_module', 'foo.bar.baz', args=(a,)) c : torch.fx.Node = graph.create_node('get_attr', 'zip.zap.zam') d : torch.fx.Node = graph.create_node('call_function', operator.add, args=(b, c)) graph.output(d) nodes = list(graph.nodes) nodes[3].append(nodes[2]) with self.assertRaisesRegex(RuntimeError, 'was used before it has been defined'): graph.lint() def test_example_shape_prop(self): class TestCase(torch.nn.Module): def __init__(self): super().__init__() self.attr = torch.randn(3, 4) self.submod = torch.nn.Linear(4, 4) def forward(self, x): return torch.neg(self.submod(x.relu() + self.attr)) tc = TestCase() tc_traced = symbolic_trace(tc) ref_out = tc_traced(torch.rand(3, 4)) shape_prop.ShapeProp(tc_traced).propagate(torch.rand(3, 4)) # Make sure we're testing all opcodes opcodes = set() output_shape : Optional[torch.Shape] = None for node in tc_traced.graph.nodes: opcodes.add(node.op) if node.op == 'output': output_shape = node.args[0].shape self.assertEqual(opcodes, set(['placeholder', 'get_attr', 'call_function', 'call_method', 'call_module', 'output'])) # Test shape propogation and make sure results match actual self.assertEqual(output_shape, ref_out.shape) def test_fn_type_annotations(self): class Foo(torch.nn.Module): def forward(self, p : Pair, z : torch.Tensor, i : int) -> Dict[str, torch.Tensor]: return {'a': p.x + p.y + z + i} foo_scripted = torch.jit.script(Foo()) foo_scripted(Pair(torch.rand(5), torch.rand(5)), torch.rand(5), 3) fxed = symbolic_trace(Foo()) fxed_scripted = torch.jit.script(fxed) fxed_scripted(Pair(torch.rand(5), torch.rand(5)), torch.rand(5), 3) def test_fn_type_annotation_empty(self): def forward(a : List[torch.Tensor]): return a[0] torch.jit.script(symbolic_trace(forward)) def test_wrapped_method(self): def wrap_with_relu(fn): @functools.wraps(fn) def wrapper(*args, **kwargs): return torch.relu(fn(*args, **kwargs)) return wrapper class Foo(torch.nn.Module): @wrap_with_relu def forward(self, x, w): return torch.matmul(x, w) f = Foo() traced = symbolic_trace(f) x, w = torch.rand(3, 4), torch.rand(4, 4) self.assertTrue(any(n.target == torch.relu for n in traced.graph.nodes)) def test_sequential(self): m = torch.nn.Sequential(torch.nn.Conv2d(1, 1, 1)) gm = torch.fx.symbolic_trace(m) gm_copy = copy.deepcopy(gm) def test_ctx_mgr(self): @contextlib.contextmanager def do_nothing(): yield class M(torch.nn.Module): def __init__(self): super().__init__() @do_nothing() def forward(self, x): return torch.relu(x) m = M() self.checkGraphModule(m, (torch.rand(3, 4),)) def test_typename_print(self): graph : torch.fx.Graph = torch.fx.Graph() x : torch.fx.Node = graph.create_node('placeholder', 'x') b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,), type_expr=List[float]) output : torch.fx.Node = graph.output(b) self.assertTrue('typing.List[float]' in str(graph)) def test_inf_nan(self): class FooMod(torch.nn.Module): def forward(self, x): return x + float('inf'), x + float('-inf'), x + float('nan') fm = FooMod() self.checkGraphModule(fm, (torch.rand(3, 4),)) def test_inf_nan_kwds(self): graph : torch.fx.Graph = torch.fx.Graph() x : torch.fx.Node = graph.create_node('placeholder', 'x') b : torch.fx.Node = graph.create_node('call_function', operator.add, (x, float('inf')), {}, name='inf') c : torch.fx.Node = graph.create_node('call_function', operator.add, (x, float('nan')), {}, name='nan') graph.output((b, c)) gm = torch.fx.GraphModule(torch.nn.Module(), graph) x = torch.rand(3, 4) self.assertEqual(gm(x), (x + float('inf'), x + float('nan'))) def test_subgraph_creation(self): class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.param = torch.nn.Parameter(torch.rand(3, 4)) self.linear = torch.nn.Linear(4, 5) def forward(self, x, y): z = self.linear(x + self.param).clamp(min=0.0, max=1.0) w = self.linear(y).clamp(min=0.0, max=1.0) return z + w # symbolically trace model my_module = MyModule() my_module_traced = symbolic_trace(my_module) # random mod partitioning partition_counter = 0 NPARTITIONS = 3 def mod_partition(node: Node): nonlocal partition_counter partition = partition_counter % NPARTITIONS partition_counter = (partition_counter + 1) % NPARTITIONS return partition # split module in module with submodules module_with_submodules = split_module(my_module_traced, my_module, mod_partition) x = torch.rand(3, 4) y = torch.rand(3, 4) orig_out = my_module_traced(x, y) submodules_out = module_with_submodules(x, y) self.assertEqual(orig_out, submodules_out) def test_deepcopy_recursion_depth(self): depth = sys.getrecursionlimit() + 20 g = torch.fx.Graph() x = g.placeholder('x') for i in range(depth): x = g.call_function(torch.relu, (x,)) g.output(x) copied_graph = copy.deepcopy(g) val_map = {} for orig_node, new_node in zip(g.nodes, copied_graph.nodes): val_map[orig_node] = new_node for orig_node, new_node in zip(g.nodes, copied_graph.nodes): orig_users = set(orig_node.users.keys()) orig_users_equiv = set(val_map[u] for u in orig_users) new_users = set(new_node.users.keys()) self.assertEqual(orig_users_equiv, new_users) @skipIfNoTorchVision def test_replace_uses(self): rn18 = resnet18() class LowerReluTracer(torch.fx.Tracer): def is_leaf_module(self, m : torch.nn.Module, qualname : str): if isinstance(m, torch.nn.ReLU): return False return super().is_leaf_module(m, qualname) rn18_traced = GraphModule(rn18, LowerReluTracer().trace(rn18)) to_erase = [] for node in rn18_traced.graph.nodes: if node.op == 'call_function' and node.target in [torch.relu, torch.nn.functional.relu]: kwargs = node.kwargs.copy() # Neg doesn't have in-place kwargs.pop('inplace') with rn18_traced.graph.inserting_before(node): new_node = rn18_traced.graph.call_function( the_function=torch.neg, args=node.args, kwargs=node.kwargs) node.replace_all_uses_with(replace_with=new_node) to_erase.append(node) for node in to_erase: rn18_traced.graph.erase_node(node) def test_insertion_point(self): graph : torch.fx.Graph = torch.fx.Graph() x : torch.fx.Node = graph.create_node('placeholder', 'x') b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,)) output : torch.fx.Node = graph.output(b) with graph.inserting_before(b): neg : torch.fx.Node = graph.call_function(the_function=torch.neg, args=(x,)) _, *relu_args = b.args b.args = (neg, *relu_args) gm = torch.fx.GraphModule(torch.nn.Module(), graph) input = torch.randn(33, 44) self.assertEqual(gm(input), torch.relu(torch.neg(input))) def test_move_before(self): graph : torch.fx.Graph = torch.fx.Graph() x : torch.fx.Node = graph.create_node('placeholder', 'x') b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,)) output : torch.fx.Node = graph.output(b) neg : torch.fx.Node = graph.call_function(the_function=torch.neg, args=(x,)) _, *relu_args = b.args b.args = (neg, *relu_args) b.prepend(neg) gm = torch.fx.GraphModule(torch.nn.Module(), graph) input = torch.randn(33, 44) self.assertEqual(gm(input), torch.relu(torch.neg(input))) def test_erase_node_error(self): st = SimpleTest() traced = symbolic_trace(st) for node in traced.graph.nodes: # Test deleting with uses both in another Node and at the output if node.target in [operator.add, torch.relu]: with self.assertRaisesRegex(RuntimeError, 'but it still had .* users in the graph'): traced.graph.erase_node(node) def test_copy_it(self): d = immutable_dict([(3, 4), (5, 6)]) l = immutable_list([(3, 4), (5, 6)]) self.assertEqual(d, deepcopy(d)) self.assertEqual(l, deepcopy(l)) def test_find_uses(self): graph = torch.fx.Graph() x = torch.fx.Proxy(graph.placeholder('x')) y = torch.relu(x) z = x + x u = torch.neg(x) graph.output((y + z + u).node) graph.lint() users_of_x = x.node.users self.assertEqual(len(users_of_x), 3) expected_ops = set(['relu', 'add', 'neg']) for use in users_of_x: assert any(use.name.startswith(prefix) for prefix in expected_ops) def test_inline_graph(self): class InlineInto(torch.nn.Module): def forward(self, x): return torch.relu(x) class ToInline(torch.nn.Module): def forward(self, x): return torch.neg(x) inline_into = symbolic_trace(InlineInto()) to_inline = symbolic_trace(ToInline()) combined_graph = torch.fx.Graph() output_node = combined_graph.graph_copy(inline_into.graph, {}) input_node = list(to_inline.graph.nodes)[0] assert input_node and input_node.op == 'placeholder' val_map = {input_node : output_node} output = combined_graph.graph_copy(to_inline.graph, val_map) combined_graph.output(output) combined_module = torch.fx.GraphModule(torch.nn.Module(), combined_graph) input = torch.rand(3, 4) self.assertEqual(combined_module(input), input.relu().neg()) def test_multi_insert_point(self): graph = torch.fx.Graph() x = torch.fx.Proxy(graph.placeholder('x')) relu = torch.relu(x) with graph.inserting_before(relu.node): y = torch.neg(x) z = torch.tanh(y) graph.output((relu.node, z.node)) graph.lint() expected_ops = ['x', 'neg', 'tanh', 'relu'] for node, expected in zip(graph.nodes, expected_ops): assert expected in node.name def test_reassign_args_kwargs_uses(self): graph = torch.fx.Graph() x, y = Proxy(graph.placeholder('x')), Proxy(graph.placeholder('y')) z = x + y zed = z + z + z graph.output(zed.node) graph.lint() # zed = z + z + z -> zed = z + z + x zed.node.args = (zed.node.args[0], x.node) self.assertEqual(x.node.users.keys(), [z.node, zed.node]) # z = x + y -> z = y + y z.node.args = (y.node, y.node) self.assertEqual(x.node.users.keys(), [zed.node]) def test_trace_function(self): def foo(x, y): return torch.relu(x) + y x, y = torch.randn(3, 4), torch.randn(3, 4) self.checkGraphModule(foo, (x, y)) def test_direct_param_use(self): class TransposeTest(torch.nn.Module): def __init__(self): super().__init__() self.b = torch.nn.Parameter(torch.rand(4, 3)) def forward(self, x): return self.b class Foo(torch.nn.Module): def __init__(self): super().__init__() self.a = TransposeTest() def forward(self, x): return self.a.b, self.a.b.t(), self.a.b.view(12) traced = torch.fx.symbolic_trace(Foo()) assert(all('constant' not in node.target for node in traced.graph.nodes)) def test_single_default_arg(self): class M(torch.nn.Module): def __init__(self): super().__init__() def forward(self, y=1): return y m = M() self.checkGraphModule(m, ()) self.checkGraphModule(m, (3,)) def test_multiple_default_args(self): class M(torch.nn.Module): def __init__(self): super().__init__() def forward(self, y=1, z=2): return y + z m = M() self.checkGraphModule(m, ()) self.checkGraphModule(m, (3,)) self.checkGraphModule(m, (3, 4)) def test_regular_and_default_args(self): class M(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x, y=1): return x + y m = M() self.checkGraphModule(m, (2,)) self.checkGraphModule(m, (2, 3)) def test_string_literal_return(self): class M(torch.nn.Module): def __init__(self): super().__init__() def forward(self): return "foo" m = M() self.checkGraphModule(m, ()) if __name__ == '__main__': run_tests()