Commit Graph

76 Commits

Author SHA1 Message Date
PyTorch MergeBot
4035a53cca Revert "Recursively print graph module and its submodule (#81080)"
This reverts commit fe7262329c.

Reverted https://github.com/pytorch/pytorch/pull/81080 on behalf of https://github.com/DanilBaibak due to Break internal build
2022-07-18 14:46:26 +00:00
Sherlock Huang
fe7262329c Recursively print graph module and its submodule (#81080)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81080
Approved by: https://github.com/ezyang
2022-07-18 01:19:03 +00:00
Horace He
e7e835e50a Fix to folder by adding custom_builtins to dump (#81433)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81433
Approved by: https://github.com/jamesr66a
2022-07-14 21:39:13 +00:00
PyTorch MergeBot
58532256e9 Revert "Add __all__ for torch.distributed and fx modules (#80460)"
This reverts commit 5d40c3d5c8.

Reverted https://github.com/pytorch/pytorch/pull/80460 on behalf of https://github.com/malfet due to Broke MacOS testing, see https://github.com/pytorch/pytorch/runs/7105579664?check_suite_focus=true
2022-06-29 16:20:55 +00:00
anjali411
5d40c3d5c8 Add __all__ for torch.distributed and fx modules (#80460)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80460
Approved by: https://github.com/albanD, https://github.com/rohan-varma
2022-06-29 02:53:56 +00:00
James Reed
7311390d35 [WIP] Make constructor calls in experimental MetaTracer serializable
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76789

Approved by: https://github.com/pbelevich
2022-05-11 00:19:47 +00:00
Michael Suo
fb0f285638 [lint] upgrade mypy to latest version
Fixes https://github.com/pytorch/pytorch/issues/75927.

Had to fix some bugs and add some ignores.

To check if clean:
```
lintrunner --paths-cmd='git grep -Il .' --take MYPY,MYPYSTRICT
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76753

Approved by: https://github.com/malfet
2022-05-03 20:51:34 +00:00
PyTorch MergeBot
3d7428d9ac Revert "[lint] upgrade mypy to latest version"
This reverts commit 9bf18aab94.

Reverted https://github.com/pytorch/pytorch/pull/76753 on behalf of https://github.com/suo
2022-05-03 20:01:18 +00:00
Michael Suo
9bf18aab94 [lint] upgrade mypy to latest version
Fixes https://github.com/pytorch/pytorch/issues/75927.

Had to fix some bugs and add some ignores.

To check if clean:
```
lintrunner --paths-cmd='git grep -Il .' --take MYPY,MYPYSTRICT
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76753

Approved by: https://github.com/malfet
2022-05-03 19:43:28 +00:00
James Reed
bf730e5039 Fix unnecessary recursion in GraphModule.__call__
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76068

Approved by: https://github.com/Chillee
2022-04-21 03:25:39 +00:00
Pavel Belevich
96c8f64459 Remove with_traceback(None) in wrapped_call to show the root cause error
Before:
```
Traceback (most recent call last):
  File "/Users/pbelevich/PycharmProjects/PiPPy/test/t5_test.py", line 37, in <module>
    t5_pipe_output = t5_pipe(input_ids=t5_input, decoder_attention_mask=None, decoder_input_ids=decoder_input_ids)
  File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/pbelevich/PycharmProjects/PiPPy/pippy/IR.py", line 251, in forward
    return self.executor.run(*executor_args)
  File "/Users/pbelevich/PycharmProjects/PiPPy/pippy/IR.py", line 155, in run
    return super().run(*args, initial_env=initial_env)
  File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/interpreter.py", line 121, in run
    self.env[node] = self.run_node(node)
  File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/interpreter.py", line 148, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "/Users/pbelevich/PycharmProjects/PiPPy/pippy/IR.py", line 170, in call_module
    return super().call_module(target, args, kwargs)
  File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/interpreter.py", line 265, in call_module
    return submod(*args, **kwargs)
  File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/graph_module.py", line 630, in wrapped_call
    raise e.with_traceback(None)
AttributeError: 'NoneType' object has no attribute 'dtype'
```
After:
```
Traceback (most recent call last):
  File "/Users/pbelevich/PycharmProjects/PiPPy/test/t5_test.py", line 37, in <module>
    t5_pipe_output = t5_pipe(input_ids=t5_input, decoder_attention_mask=None, decoder_input_ids=decoder_input_ids)
  File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/pbelevich/PycharmProjects/PiPPy/pippy/IR.py", line 251, in forward
    return self.executor.run(*executor_args)
  File "/Users/pbelevich/PycharmProjects/PiPPy/pippy/IR.py", line 155, in run
    return super().run(*args, initial_env=initial_env)
  File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/interpreter.py", line 121, in run
    self.env[node] = self.run_node(node)
  File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/interpreter.py", line 148, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "/Users/pbelevich/PycharmProjects/PiPPy/pippy/IR.py", line 170, in call_module
    return super().call_module(target, args, kwargs)
  File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/interpreter.py", line 265, in call_module
    return submod(*args, **kwargs)
  File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/graph_module.py", line 630, in wrapped_call
    raise e
  File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/graph_module.py", line 620, in wrapped_call
    return cls_call(self, *args, **kwargs)
  File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/graph_module.py", line 630, in wrapped_call
    raise e
  File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/graph_module.py", line 620, in wrapped_call
    return cls_call(self, *args, **kwargs)
  File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/graph_module.py", line 630, in wrapped_call
    raise e
  File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/graph_module.py", line 620, in wrapped_call
    return cls_call(self, *args, **kwargs)
  File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/graph_module.py", line 630, in wrapped_call
    raise e
  File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/graph_module.py", line 620, in wrapped_call
    return cls_call(self, *args, **kwargs)
  File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/graph_module.py", line 630, in wrapped_call
    raise e
  File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/graph_module.py", line 620, in wrapped_call
    return cls_call(self, *args, **kwargs)
  File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/graph_module.py", line 630, in wrapped_call
    raise e
  File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/graph_module.py", line 622, in wrapped_call
    return super(cls, self).__call__(*args, **kwargs)
  File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "<eval_with_key>.42", line 74, in forward
  File "/Users/pbelevich/PycharmProjects/pbelevich-transformers/src/transformers/utils/fx.py", line 180, in wrapper
    return func(*args, **kwargs)
  File "/Users/pbelevich/PycharmProjects/pbelevich-transformers/src/transformers/modeling_utils.py", line 256, in create_extended_attention_mask_for_decoder
    causal_mask = causal_mask.to(attention_mask.dtype)
AttributeError: 'NoneType' object has no attribute 'dtype'
```

The last lines of stack trace show where the problem is
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74655
Approved by: https://github.com/ansley, https://github.com/rohan-varma
2022-03-25 14:40:45 +00:00
Animesh Jain
7ebab9247d FX graph module - prevent infinite recursion (#73866)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73866

super(type(self), self) in wrapped_call leads to infinite recursion for subclass of Fx graph module. This happens when we call _stateless.functional_call on a fx module.     https://github.com/pytorch/pytorch/blob/master/torch%2Fnn%2Futils%2F_stateless.py

Test Plan:
Tests added in https://github.com/pytorch/pytorch/pull/62436

Imported from OSS

Reviewed By: jansel

Differential Revision: D34737828

fbshipit-source-id: 871b897e1210173ccc83fe34d53fc41af00a39ee
(cherry picked from commit 3d0c5fc71503fa2782b497a9d39ce26288fd219b)
2022-03-09 06:09:57 +00:00
Horace He
d635d0f86e Refactor FX codegen into extensible Codegen object (#72566)
Summary:
The goal of this is to make FX's codegen extensible. I've refactored it into a class with 5 extensibility points on it.

```
class Codegen(object):
    def generate_prologue(self, free_vars: List[str], maybe_return_annotation: str) -> str:
        """
        Given the free variables and a return annotation, generates the beginning of the FX function.
        By default, `generate_prologue(['a', 'b'], '') == 'def forward(a, b):'`
        """
    def generate_output(self, output_args: Argument) -> str:
        """
        Given the output arguments, generates the return statement of the FX function.
        """
    def process_inputs(self, args: Any) -> Any:
        """
        Transforms the inputs so that the graph can take them as arguments, as
        non-default codegen may result in the inputs to the function being
        different from the inputs to the graph.

        If the graph was directly runnable, this invariant should hold true
        `f.process_outputs(f.graph(*f.process_inputs(*inputs))) == f(*inputs)`
        """
    def process_outputs(self, outputs: Any) -> Any:
        """
        Transforms the outputs of the graph to be identical to the codegen.

        See ``process_inputs`` for more details.
        """
    def additional_globals(self) -> List[Tuple[str, Any]]:
        """
        If your codegen uses extra global values, add them here.
        For example, return ['List', typing.List] if you need ``List`` in the global context.
        """
```

So, for example, the `ListCodeGen` we want for AOTAutograd looks like this
```
        class ListCodeGen(CodeGen):
            def generate_prologue(self, free_vars, maybe_return_annotation):
                lst_unpack = f"""
def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}:
    {', '.join(free_vars)} = args_list"""
                return lst_unpack

            def additional_globals(self):
                return [('List', typing.List)]

            def process_inputs(self, *inputs):
                assert(len(inputs) == 1)
                return inputs[0]
```
and
```
        def f(a, b):
            return a + b

        nf = fx.symbolic_trace(f)
        nf.graph.set_codegen(ListCodeGen())
        nf.recompile()
        print(nf.code)
```
would result in
```
def forward(self, args_list: List[torch.Tensor]):
    a, b = args_list
    add = a + b;  a = b = None
    return add
```

Backwards compatibility changes - I added `process_outputs` and `process_inputs` to `fx.Graph`, while removing `flatten_inputs` and `flatten_outputs` - those didn't have `backwards_compatibility` on them, so I *think* it's probably fine?

Pull Request resolved: https://github.com/pytorch/pytorch/pull/72566

Reviewed By: desertfire

Differential Revision: D34160424

Pulled By: Chillee

fbshipit-source-id: ebf6411312b373e3fbcb13288a34befa449a2375
(cherry picked from commit 13cd12eaa1)
2022-02-11 18:13:29 +00:00
Horace He
df6eb9bbab Fixed to_folder not saving dtype (#69983)
Summary:
As above.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/69983

Reviewed By: pbelevich, ngimel

Differential Revision: D33466529

Pulled By: Chillee

fbshipit-source-id: 2d2f0ad5b8e2492aba4c19fa034c8b6c0848a568
2022-01-06 22:15:56 -08:00
James Reed
3eb9443619 [FX] Fix issue where GraphModule.delete_all_unused_submodules deletes submodules from called leaf modules (#66430)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66430

On the whole, I'm not totally satisfied with this approach. I think we should be building a prefix tree data structure during initial iteration over the submodules and querying that when deleting submodules. But I think this approach works and I want to see if we can get it in before 1.10

Test Plan: Imported from OSS

Reviewed By: Chillee

Differential Revision: D31546137

Pulled By: jamesr66a

fbshipit-source-id: f08b8409a3cf511277017ccccb916097b7c4c4fe
2021-10-11 19:37:51 -07:00
James Reed
0c4e4e588e [FX] Rename reduce functions back to their old, public names (#64324)
Summary:
Unfortunately pickle serializes the names of these functions. Also put them under backward-compatibility enforcement.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/64324

Test Plan: Local repro https://fb.workplace.com/groups/3440841732711443/permalink/4018921611570116/

Reviewed By: SplitInfinity, TailofJune

Differential Revision: D30684185

Pulled By: jamesr66a

fbshipit-source-id: 900701220155d15115cd0c07cf7774a2891bd04f
2021-08-31 22:36:11 -07:00
Jay Leverett
44fcb00a56 Fix redundant class definition in GraphModule singleton constructor (#64274)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/63883

Pull Request resolved: https://github.com/pytorch/pytorch/pull/64274

Reviewed By: jamesr66a

Differential Revision: D30675970

Pulled By: jayleverett

fbshipit-source-id: e74ef2a28013f0fa7c58d14f38e66cfe48d26b74
2021-08-31 17:34:14 -07:00
James Reed
538647fe1f [WIP][FX] BC guarantees for 1.10 (#63888)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63888

Test Plan: Imported from OSS

Reviewed By: pbelevich

Differential Revision: D30523133

Pulled By: jamesr66a

fbshipit-source-id: b04cc0d842a74862f42ecba98b757310cd2ec7b0
2021-08-30 19:56:46 -07:00
James Reed
4e37a015c7 [FX] Fix _replicate_for_data_parallel (#63821)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63821

Test Plan: Imported from OSS

Reviewed By: suo

Differential Revision: D30502115

Pulled By: jamesr66a

fbshipit-source-id: 0f004f95def6e1ba21ccbeab40cb0a739a0ad20c
2021-08-24 13:48:15 -07:00
David Esiobu
79693bb86a Use linecache.lazycache to cache generated code. (#63453)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63453

Instead of patching linecache.getlines, use linecache.lazycache and
parts of the loader protocol described in PEP-302

Test Plan:
python3 test/test_fx.py

Imported from OSS

Reviewed By: suo

Differential Revision: D30388176

fbshipit-source-id: 92933711ecf3a21a07e1d6b0d1185ab0efd8341c
2021-08-19 09:17:01 -07:00
James Reed
d661e646ad [FX] Fix GraphModule deepcopy to use deepcopied graph (#63090)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63090

Test Plan: Imported from OSS

Reviewed By: ansley

Differential Revision: D30252471

Pulled By: jamesr66a

fbshipit-source-id: cafd7d7917935a5ea6ffa2a7fe9e9b2a9578b3e3
2021-08-18 13:17:14 -07:00
Bradley Davis
7a1ab9f5d7 [fx] store Tracer class on Graph and GraphModule for package deserialization [v2, the re-do] (#63121)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63121

Re-introducing this diff with a small change to ignore setting Tracer classes on GraphModules when the Tracer class is defined not at module-level (prevents pickling).

Previous, reverted Pull Request: https://github.com/pytorch/pytorch/pull/62497

Reviewed By: houseroad

Differential Revision: D30252776

fbshipit-source-id: 42d2bc846e4b32d00563419c38c02b63cd0986e6
2021-08-12 17:28:50 -07:00
Lu Fang
847a7cfa10 Back out "[fx] store Tracer class on Graph and GraphModule for package deserialization" (#63053)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63053

Original commit changeset: eca09424ad30

The original diff - D30019214 (6286d33878) breaks the publish flow in model saving.

Test Plan: ci

Differential Revision: D30236517

fbshipit-source-id: 3e05db02fc1cbbc2ed262c83bf56d555277abb34
2021-08-10 21:58:08 -07:00
Bradley Davis
6286d33878 [fx] store Tracer class on Graph and GraphModule for package deserialization (#62497)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62497

Previously named: add support for custom tracer in __reduce_package__

Stores a Tracer class on a Graph created by Tracer, and copies the Tracer class into the GraphModule's state so that when a GraphModule is packaged by torch package, it can be reconstructed with the same Tracer and GraphModule class name.

Reviewed By: suo

Differential Revision: D30019214

fbshipit-source-id: eca09424ad30feb93524d481268b066ea55b892a
2021-08-09 13:07:30 -07:00
Bradley Davis
093495d3f0 [fx] prevent implicit submodule inlining when submodule is a GraphModule (#62436)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62436

## Problem

Given two modules and a tracer that indiscriminately marks all modules as a leaf:
```
class InnerModule(torch.nn.Module):

    def forward(self, t):
        return t + t

class MyModule(torch.nn.Module):
    def __init__(self, inner):
        super().__init__()
        self.inner = inner

    def forward(self, t):
        x = self.inner(t)
        y = self.inner(t)
        return x + y

class MyTracer(torch.fx.Tracer):
    def is_leaf_module(self, module, name):
        return True
```

One might generally expect the following behavior (note call_module nodes):
```
print(">> Outer GraphModule (with inner module as nn.Module):")
inner = InnerModule()
m = MyModule(inner)
gm = torch.fx.GraphModule(m, MyTracer().trace(m))
print(gm.graph.print_tabular())

>> Outer GraphModule (with inner module as nn.Module):
opcode         name     target                   args              kwargs
-------------  -------  -----------------------  ----------------  --------
placeholder    t        t                        ()                {}
call_module    inner    inner                    (t,)              {}
call_module    inner_1  inner                    (t,)              {}
call_function  add      <built-in function add>  (inner, inner_1)  {}
output         output   output                   (add,)            {}
None
```

However, when the inner module is first symbolically traced, the symbolic trace of the outer module ignores `is_leaf_node` entirely, and traces through the whole module (note call_function nodes).
```
print(">> Inner module as GraphModule:")
inner = InnerModule()
inner_gm = torch.fx.GraphModule(inner, MyTracer().trace(inner))
print(inner_gm.graph.print_tabular())

print(">> Outer GraphModule (with inner module as GraphModule):")
m = MyModule(inner_gm)
gm = torch.fx.GraphModule(m, MyTracer().trace(m))
print(gm.graph.print_tabular())

>> Inner module as GraphModule:
opcode         name    target                   args    kwargs
-------------  ------  -----------------------  ------  --------
placeholder    t       t                        ()      {}
call_function  add     <built-in function add>  (t, t)  {}
output         output  output                   (add,)  {}
None

>> Outer GraphModule (with inner module as GraphModule):
opcode         name    target                   args          kwargs
-------------  ------  -----------------------  ------------  --------
placeholder    t       t                        ()            {}
call_function  add     <built-in function add>  (t, t)        {}
call_function  add_1   <built-in function add>  (t, t)        {}
call_function  add_2   <built-in function add>  (add, add_1)  {}
output         output  output                   (add_2,)      {}
None
```

This is surprising behavior and at first glance violates the tracer's intent. As I understand it, `torch.fx.symbolic_trace.Tracer.trace` intends to patch `torch.nn.Module.__call__` with a `module_call_wrapper()` that records a `call_module` node if the module is a leaf, else executes `torch.fx._symbbolic_trace._orig_module_call = torch.nn.Module.__call__`, which is set a module loading time.

**Every submodule should be a leaf, but no `call_module` nodes are created when that submodule is a `GraphModule`. Why?**

Upon further inspection, I found:

- The constructor for GraphModule includes a path to `GraphModule.recompile()` via the setter for a `fx.Graph`:
```
inner_gm = torch.fx.GraphModule(inner, MyTracer().trace(inner))

File "/torch/fx/graph_module.py", line 252, in __init__
self.graph = graph

File "/torch/nn/modules/module.py", line 1183, in __setattr__
object.__setattr__(self, name, value)

File "/torch/fx/graph_module.py", line 277, in graph
self.recompile()
```
- `recompile()` wraps the `__call__` method by holding a reference to the `__call__` method at the time of recompilation:
```
cls = type(self)
cls_call = cls.__call__
...
def wrapped_call(self, *args, **kwargs):
    try:
        return cls_call(self, *args, **kwargs)
    except Exception as e:
        ...
cls.__call__ = wrapped_call
```
- Recompilation of the inner GraphModule happens on initialization, before creation or tracing of the outer module. Adding some old-fashioned print debug statements gives:
```
Inner Module:
_orig_module_call: <function Module._call_impl at 0x7faaebfee8b0>
recompile: cls.__call__ now wraps _orig_module_call, <function Module._call_impl at 0x7faaebfee8b0>

Outer Module:
_orig_module_call: <function Module._call_impl at 0x7faaebfee8b0>
tracing: patching method <class 'torch.nn.modules.module.Module'>.__call__ <function Module._call_impl at 0x7faaebfee8b0> with <function Module._call_impl at 0x7fa9d42bce50>

outer module MRO before tracing:
(0) <class '__main__.MyModule'>: <function Module._call_impl at 0x7faaebfee8b0>
(1) <class 'torch.nn.modules.module.Module'>: <function Module._call_impl at 0x7faaebfee8b0>
(2) <class 'object'>: <method-wrapper '__call__' of type object at 0x7fac3cd15f00>

outer module MRO during tracing:
(0) <class '__main__.MyModule'>: <function Module._call_impl at 0x7fa9d42bce50>
(1) <class 'torch.nn.modules.module.Module'>: <function Module._call_impl at 0x7fa9d42bce50>
(2) <class 'object'>: <method-wrapper '__call__' of type object at 0x7fac3cd15f00>

inner module MRO before tracing:
(0) <class 'torch.fx.graph_module.GraphModule.__new__.<locals>.GraphModuleImpl'>: <function x.y.z.wrapped_call at 0x7fa9d42a8670>
(1) <class 'torch.fx.graph_module.GraphModule'>: <function Module._call_impl at 0x7faaebfee8b0>
(2) <class 'torch.nn.modules.module.Module'>: <function Module._call_impl at 0x7faaebfee8b0>
(3) <class 'object'>: <method-wrapper '__call__' of type object at 0x7fac3cd15f00>

inner module MRO during tracing:
(0) <class 'torch.fx.graph_module.GraphModule.__new__.<locals>.GraphModuleImpl'>: <function x.y.z.wrapped_call at 0x7fa9d42a8670>
(1) <class 'torch.fx.graph_module.GraphModule'>: <function Module._call_impl at 0x7fa9d42bce50>
(2) <class 'torch.nn.modules.module.Module'>: <function Module._call_impl at 0x7fa9d42bce50>
(3) <class 'object'>: <method-wrapper '__call__' of type object at 0x7fac3cd15f00>
```

- The outer module is patched correctly, but the inner module's first element in its MRO is the `wrapped_call` from `recompile` that still invokes `<function Module._call_impl at 0x7faaebfee8b0>` directly. Therefore, no call_module nodes are created.

## In Practice

In practice, this behavior affects the ability of `torch.package` to package `GraphModules` whose submodules are `GraphModules`. In our case, the `GraphModule` submodules are not passed through a constructor, but created separately and installed on the root `GraphModule` via `setattr`. This means that prior to packaging, there appear to be no issues with the module, since the root's graph was created before any call_module targets were replaced with `GraphModules`.

When unpackaging such a model with `torch.package`, `torch.fx.graph_module._deserialize_graph_module` uses an inline `KeepModules` tracer that sets all submodules to leaves; the unpackaged module is implicitly and surprisingly inlined in the process.

## Potential Solution

This behavior was previously not understood by us, and so the current workaround is a gnarly process of wrapping all submodules with a `nn.Module` with a manually installed forward method.

Changing `wrapped_call` to return `return super(type(self), self).__call__(*args, **kwargs)` whenever `__call__` is inherited at least appears to solve the issue. Does this seem like an acceptable approach?

## Other Thoughts
- Repeated calls to `recompile` create nested `wrapped_calls`, all for the purpose of error handling. This seems probably unnecessary ¯\\_(ツ)\_/¯
- If a root module with a overriden `__call__` method is symbolically traced, it is ignored

Test Plan:
```
buck test:
    ✓ ListingSuccess: caffe2/test:fx - main (12.570)
    ✓ Pass: caffe2/test:fx - test_tracing_graphmodules_as_leaf_submodules (test_fx.TestFX) (11.982)
```

Reviewed By: ansley

Differential Revision: D29997935

fbshipit-source-id: 1988fbb025b14188da26a3e73e94fb789c3c1f74
2021-08-02 13:37:08 -07:00
Harut Movsisyan
810e19979d Torch deploy for fx.grapm_module with non-torch dependencie (#61680)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/61680

This diff enables torch deploy for fx.graph_module with non-torch dependencies . Here are the issues currently preventing this and are fixed in this change:
-  Pickle is used as an internal format to transmit objects between interpreters. It needs to serialize python code, but to be able to get the source code for imports from python_code.globals it needs access to the PackageImporter. Currently a regular _reduce_ function is used which doesn't have the notion of custom importer.
- When deserializing pickled objects on an interpreter, it is passing empty globals to exec, thus it will not be able to resolve non-torch imports located in the package. We need to be able to point exec to our custom PackageImporter.
- Subclasses extending fx.graph_module should be able to optionally provide their own Tracer (extending fx.Tracer).

As a solution a new reducer is introduced (_reduce_deploy_) for torch deploy workflow. Reducer will be registered in _deploy.py (entry point for C++ torch deploy API) when saving the object transmitting it between interpreters. It allows us to pass a proper PackageImporter for each interpreter for pickling/unpickling fx.graph_module. It also defines an api for passing custom fx.Tracer when needed.

Test Plan:
Added UT to cover changes.
```
buck test //caffe2/torch/csrc/deploy:test_deploy
```
```
buck test caffe2/test:fx
```

Reviewed By: suo

Differential Revision: D29690088

fbshipit-source-id: 3a8dbe02d5d7e085534aa61b7773c86f0f8c19b0
2021-07-21 10:29:48 -07:00
Nikita Shulga
4e94e84f65 Type annotate torch.nn.Module ctor (#61334)
Summary:
Annotate generic types
Fix some type violations
Override `_modules` and `_parameters` in `Sequential`, `ModuleList`, `ModuleDict`, etc

Fixes https://github.com/pytorch/pytorch/issues/45497

Pull Request resolved: https://github.com/pytorch/pytorch/pull/61334

Reviewed By: albanD

Differential Revision: D29579533

Pulled By: malfet

fbshipit-source-id: 5cd8ca918b260ca35cfdd873dee8851d39d17de2
2021-07-16 13:59:06 -07:00
Ansley Ussery
5268b5a29a Add parsing logic for Tuple[()] annotation (#58340)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/58340

Test Plan: Imported from OSS

Reviewed By: jamesr66a

Differential Revision: D28459502

Pulled By: ansley

fbshipit-source-id: 4bb188448d66269b42b068858b895debac86e9ee
2021-05-25 12:12:43 -07:00
James Reed
7b73fdf597 [FX] Fix retracing wrapped functions (#58061)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/58061

Test Plan: Imported from OSS

Reviewed By: yuhc

Differential Revision: D28358801

Pulled By: jamesr66a

fbshipit-source-id: c7c9a8a80e5bfe1eb1f6d2cf858ac7e57153a860
2021-05-17 19:50:16 -07:00
Horace He
8d363d37da [FX] Adds PyTree support to FX through concrete_args (#55888)
Summary:
```
class Foo(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, y, x):
        for k in x:
            for v in x[k]:
                v += y
        return x

example_dict = {'x': {'a': [fx.HOLE], 'z': [fx.HOLE, fx.HOLE]}}
new_f = fx.symbolic_trace(Foo(), concrete_args=example_dict)
print(new_f.code)
new_f(torch.randn(5), {'x': {'a': [torch.randn(5)], 'z': [torch.randn(5), torch.randn(5)]}})

fx.symbolic_trace(new_f, concrete_args=example_dict)
```

prints out
```
def forward(self, y, x):
    y, tree_2, tree_3, tree_4 = pytree.tree_flatten([y, x])[0]
    add = tree_2 + y
    add_1 = tree_3 + y
    add_2 = tree_4 + y;  y = None
    return {'a': [tree_2], 'z': [tree_3, tree_4]}
```

Currently, I store `in_spec` as an extra attribute on `fx.Graph`, and then include it when we do the codegen. I'm not sure if this is the right approach - it introduces a divergence between what's in `fx.Graph` and what's in the python code.

Perhaps the best API is something explicit like `fx.Graph.flatten_args`, but that does make calling things a bit ... more verbose.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/55888

Reviewed By: jamesr66a

Differential Revision: D27884694

Pulled By: Chillee

fbshipit-source-id: f9e8a70c63a8df63c9f9bd0a6459255daa5a8df8
2021-05-07 04:48:35 -07:00
Sam Estep
75024e228c Add lint for unqualified type: ignore (#56290)
Summary:
The other half of https://github.com/pytorch/pytorch/issues/56272.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/56290

Test Plan:
CI should pass on the tip of this PR, and we know that the lint works because the following CI runs (before this PR was finished) failed:

- https://github.com/pytorch/pytorch/runs/2384511062
- https://github.com/pytorch/pytorch/actions/runs/765036024

Reviewed By: seemethere

Differential Revision: D27867219

Pulled By: samestep

fbshipit-source-id: e648f07b6822867e70833e23ddafe7fb7eaca235
2021-04-21 08:07:23 -07:00
Nikita Shulga
add49e7e4e Enforce PEP263 for PyTorch python codebase (#55346)
Summary:
All python files containing non-ASCII characters should be correctly annotated with `# -*- coding: utf-8 -*-` comment

Delete number of superfluous UTF-8 characters, most commonly UTF-8 opening closing quotation mark U+2019 (’) instead of ascii apostrophe ', for example `Module’s`->`Module's`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/55346

Reviewed By: samestep

Differential Revision: D27582044

Pulled By: malfet

fbshipit-source-id: c1cd89655915858ff3a41f675cdfffff795a8e44
2021-04-06 18:31:38 -07:00
Shiyan Deng
15b087cdd2 [fx]Allow rewrite a symbolic traced module (#54011)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/54011

After symbolic tracing, `fn` seems to already have "forward" in its globals. In this case, `new_keys` would have length of 0 and we take "forward" from `global_dict` directly as `fn_compiled`.

Test Plan: Added a new test in test_fx_experimental.

Reviewed By: ansley

Differential Revision: D27049012

fbshipit-source-id: 7fbeb50ebb717900ff5fc0a8a0925d6a97f5a6dd
2021-04-05 18:35:51 -07:00
James Reed
05a03a6c8c [FX][EZ] Fix type correctness on GraphModule.graph (#54305)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/54305

Test Plan: Imported from OSS

Reviewed By: suo

Differential Revision: D27181176

Pulled By: jamesr66a

fbshipit-source-id: ed91cfed193984249c07a5bafc7aa732bfe0194d
2021-03-19 11:48:15 -07:00
Jordan Fix
1053c96693 [GraphModule] Back out changes to module root version of __init__ (#53791)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/53791

Reviewed By: houseroad

Differential Revision: D26970869

fbshipit-source-id: 80684516f57fd2d1aca794f17fe488b2fe2b2f64
2021-03-10 23:18:56 -08:00
Jordan Fix
3b0e4a6ed4 [GraphModule] Improve buffer registration during init (#53444)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/53444

GraphModule construction has two options when constructing the base nn.Module: a dict of names to attrs to assign to the GraphModule, or another nn.Module to copy attrs from.

- For the dict case, add logic to explicitly register `nn.Tensors` that are not `nn.Parameter` as buffers on the GraphModule, else fall back to `__setattr__`.
- For the other `nn.Module` case, update so that it checks in the other module whether the attr to copy in is a buffer, and register it as such, else fall back to `__setattr__`.

Test Plan: Added tests for fetching params and buffers from a GraphModule using both dict and module `__init__`s

Reviewed By: jamesr66a

Differential Revision: D26860055

fbshipit-source-id: 8d9999f91fef20aaa10969558006fc356247591f
2021-03-09 21:05:01 -08:00
Sam Estep
8c798e0622 Forbid trailing whitespace (#53406)
Summary:
Context: https://github.com/pytorch/pytorch/pull/53299#discussion_r587882857

These are the only hand-written parts of this diff:
- the addition to `.github/workflows/lint.yml`
- the file endings changed in these four files (to appease FB-internal land-blocking lints):
  - `GLOSSARY.md`
  - `aten/src/ATen/core/op_registration/README.md`
  - `scripts/README.md`
  - `torch/csrc/jit/codegen/fuser/README.md`

The rest was generated by running this command (on macOS):
```
git grep -I -l ' $' -- . ':(exclude)**/contrib/**' ':(exclude)third_party' | xargs gsed -i 's/ *$//'
```

I looked over the auto-generated changes and didn't see anything that looked problematic.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/53406

Test Plan:
This run (after adding the lint but before removing existing trailing spaces) failed:
- https://github.com/pytorch/pytorch/runs/2043032377

This run (on the tip of this PR) succeeded:
- https://github.com/pytorch/pytorch/runs/2043296348

Reviewed By: walterddr, seemethere

Differential Revision: D26856620

Pulled By: samestep

fbshipit-source-id: 3f0de7f7c2e4b0f1c089eac9b5085a58dd7e0d97
2021-03-05 17:22:55 -08:00
Ansley Ussery
85109ce427 Support submodule manipulation in GraphModule (#52358)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/52358

Test Plan: Imported from OSS

Reviewed By: jamesr66a

Differential Revision: D26759260

Pulled By: ansley

fbshipit-source-id: 25d2b9124a7d957704f1700a45dca143aaed391d
2021-03-04 14:52:35 -08:00
Michael Suo
958d9a8364 [fx/package] make GraphModules packageable (#51976)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/51976

FX serializes things by serializing Python code as a string and exec'ing
it on load. This accomplishes one goal (we don't have to pickle the
graph object directly) but breaks the pickle abstraction in ways that
are not composable with `torch.package`.

In particular:
1. `forward` is serialized by saving Python code. On load, it's
installed
by  `exec`ing that code. This `exec` call needs to have the right
importer installed, otherwise it will not import modules from the
`torch.package` but instead import from the Python environment.
2. Any types/functions used are emitted as `import` statement in the
generated Python code. These are effectively dynamic dependencies of the
`GraphModule` being saved, and need to be registered as such so that the
`PackageImporter` will package them.

To address these, this PR introduces a new protocol for the
importer/exporter: `__reduce_package__`.

A class can implement `__reduce_package__` to customize how it is placed
in the importer/exproter. It functions very similarly to `__reduce__`,
except:
- `__reduce_package__` takes one argument, which is the
`PackageExporter`
instance. Users can use this instance to save stuff to the package to
implement their serialization. `__reduce__` takes no args.
- Only the 2-element tuple version of the return value for `__reduce__`
is supported (this could be extended if necessary).
- When the reduction function is called on load, an additional argument
is added to the beginning of the args tuple. This is the
`PackageImporter`
instance doing the loading.

The `__reduce_package__` protocol is defined using `persistent_id` and
`persistent_load`, which ensures that we can still use the cpickle
implementation of the pickler by default.

Pull Request resolved: #51971

Test Plan: Imported from OSS

Reviewed By: zdevito

Differential Revision: D26340591

Pulled By: suo

fbshipit-source-id: 5872a7d22e832056399a7372bae8a57807717882
2021-02-23 22:43:00 -08:00
Michael Suo
ecf3ca00d8 [fx] Separate globals assignment from code generation (#51974)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/51974

Right now, when an FX `Graph` references an external object, we will emit
code like:

    import foo
    def forward(input: foo.bar.baz):
        ...

This is problematic in a world with `torch.package`, since then name
`foo.bar.baz` may reference a name from any number of packages.

This PR lays the groundwork for FX-package integration by separating the
resolution of external references from the genration of the function
code.

When generating a Graph's Python source, we keep track of all external
references and assign them unique names. At the end, we have a
dictionary mapping names -> actual objects. This becomes the `globals`
namespace we pass to `exec` when installing the forward function in a
`GraphModule`. This is nice because we can always be sure that `exec` is
seeing the same objects that were referenced from the `Graph`, no import
statements needed.

At serialization time, we use a `ModuleEnv` to resolve the globals dict
to a set of import statements that can be run to reprodce the `global`
namespace. This is only used on serialiation/deserialization, and those
functions are expected to check that the import statements are producing
the correct results.

Concretely, the code above will now look like:

    from foo.bar import baz as foo_bar_baz
    def forward(input: foo_bar_baz):
        ...

Test Plan: Imported from OSS

Reviewed By: jamesr66a

Differential Revision: D26340593

Pulled By: suo

fbshipit-source-id: fe247f75205d0a03fd067bdd0f95491e8edf1436
2021-02-23 13:48:03 -08:00
Ansley Ussery
4cc10563e7 Customize traceback for calls to symbolically-traced code (#51648)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/51648

The following code will throw during the call to `traced(5)`:
```python
class M(nn.Module):
    def __init__(self):
        super(M, self).__init__()
        self.W = torch.nn.Parameter(torch.randn(5))

    def forward(self, x):
        return torch.dot(self.W, x)

traced = fx.symbolic_trace(M())
traced(5)
```

Traceback before:
```
Traceback (most recent call last):
  File "test/tinytest.py", line 26, in <module>
    traced(5)
  File "/home/ansley/local/pytorch/torch/fx/graph_module.py", line 338, in wrapped_call
    return self._cls_call(self, *args, **kwargs)
  File "/home/ansley/local/pytorch/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "<eval_with_key_0>", line 4, in forward
TypeError: dot(): argument 'tensor' (position 2) must be Tensor, not int
```

Traceback after:
```
Traceback (most recent call last):
  File "/home/ansley/local/pytorch/torch/fx/graph_module.py", line 338, in wrapped_call
    return torch.nn.Module.__call__(self, *args, **kwargs)
  File "/home/ansley/local/pytorch/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "<eval_with_key_1>", line 4, in forward
    dot_1 = torch.dot(w, x);  w = x = None
TypeError: dot(): argument 'tensor' (position 2) must be Tensor, not int

Call using an FX-traced Module, line 4 of the traced Module’s generated forward function:
    w = self.W
    dot_1 = torch.dot(w, x);  w = x = None

~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
    relu_1 = dot_1.relu();  dot_1 = None

    return relu_1
```

(Note that the same `TypeError` is thrown despite modifying the traceback.)

Test Plan: Imported from OSS

Reviewed By: jamesr66a

Differential Revision: D26424005

Pulled By: ansley

fbshipit-source-id: 368f46ba81fb3111bd09654825bb2ac5595207d1
2021-02-12 18:31:23 -08:00
Ansley Ussery
4ac489091a Improve call provenance during GraphModule scripting (#50538)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/50538

Test Plan: Imported from OSS

Reviewed By: pbelevich, SplitInfinity

Differential Revision: D25935403

Pulled By: ansley

fbshipit-source-id: 2baf5e0ba0fa3918e645fc713a9e80d10bbc84e5
2021-01-21 12:03:19 -08:00
James Reed
5205cc1c62 [FX] Fix NoneType annotation in generated code (#50777)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/50777

Test Plan: Imported from OSS

Reviewed By: Chillee

Differential Revision: D25966026

Pulled By: jamesr66a

fbshipit-source-id: 8e36521eee03eade7e1b602e801229c085b03488
2021-01-19 23:16:58 -08:00
James Reed
ae9f39eb58 [FX][1/2] Make docstrings pretty when rendered (#48738)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/48738

Test Plan: Imported from OSS

Reviewed By: zdevito

Differential Revision: D25280867

Pulled By: jamesr66a

fbshipit-source-id: d08641c19a6c69b4042389c800a48e699f0be628
2020-12-05 17:23:40 -08:00
Horace He
092e52a4da [fx]added prototype of to_folder (#47544)
Summary:
What this does is that given a `FxModule foo`, you can call `foo.to_folder('foo_folder', 'Foo')` and dump the current FX module into runnable Python code.

That is
```
foo = <fxModule>
foo = foo.to_folder('bar', 'Foo')
from bar import Foo
foo2 = Foo()

forall x, foo2(x) == Foo(x)
```

This has several use cases, largely lifted from jamesr66a's doc here: https://fb.quip.com/U6KHAFaP2cWa (FB-internal).

1. As we apply more heavy-weight function transformations with FX, figuring out what's going on can be quite a difficult experience. In particular, things that can typically be used for debugging (like `print` or `import pdb; pdb.set_trace()`) no longer work. This is particularly necessary if you're using a FX transform like `grad` or `vmap. With this, you simply open up the dumped file, and add `print`/`pdb` statements wherever you'd like.

2. This also provides an immense amount of user control. Some potential use-cases:
-  Let's say an existing FX transform has some bug, or generates suboptimal code. Instead of needing to modify that FX transform, writing another FX pass that fixes the suboptimal code, or simply giving up on FX, they can workaround it by simply modifying the resulting code themselves.
- This allows users to check in their FX modules into source control.
- You could even imagine using this as part of some code-gen type workflow, where you write a function, `vmap` it to get the function you actually want, and then simply copy the output of the `vmap` function without needing FX at all in the final code.

An example:
```python
class Test(nn.Module):
    def __init__(self):
        super(Test, self).__init__()
        self.W = torch.nn.Parameter(torch.randn(2))
        self.linear = nn.Linear(2, 2)
        self.attr = torch.randn(2)
        self.attr2 = torch.randn(2)

    def forward(self, x):
        return self.linear(self.W + (self.attr + self.attr2) + x)

mod = fx.symbolic_trace(Test())
mod.to_folder('foo', 'Foo')
```
results in
```python
import torch
class Foo(torch.nn.Module):
    def __init__(self):
        super().__init__()
        state_dict = torch.load('foo/state_dict.pt')
        self.linear = torch.load('foo/linear.pt') # Linear(in_features=2, out_features=2, bias=True)
        self.__tensor_constant0 = state_dict['__tensor_constant0']
        self.W = torch.nn.Parameter(state_dict['W'])

    def forward(self, x):
        w = self.W
        tensor_constant0 = self.__tensor_constant0
        add_1 = w + tensor_constant0
        add_2 = add_1 + x
        linear_1 = self.linear(add_2)
        return linear_1
```
Some current issues:
1. How do you actually ... save things like modules or parameters? I don't think FX is in the business of tracking initializations and such. Thus, the only way I see to do it is to dump the parameters/modules as blobs, and then load them in the generated initialization. This is a somewhat subpar user experience, and perhaps prevents it from being in some use cases (ie: you would need to check in the blobs into source control to save the model).

2. Currently, the only "atomic" modules we have are those in `torch.nn`. However, if we want to allow flexibility in this, and for example, allow "atomic" modules that are user-defined, then it's not clear how to allow those to be dumped in a way that we can then load elsewhere.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/47544

Reviewed By: jamesr66a

Differential Revision: D25232917

Pulled By: Chillee

fbshipit-source-id: fd2b61a5f40e614fc94256a2957ed1d57fcf5492
2020-12-04 18:33:27 -08:00
Mehdi Mirzazadeh
c5834b6a23 Look in named-buffers of module for tensors (#47641)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/47641

ghstack-source-id: 116450114

Test Plan: Presubmit tests

Reviewed By: jamesr66a

Differential Revision: D24848318

fbshipit-source-id: f6ede3def9d6f1357c4fd3406f97721dea06b9f1
2020-11-11 19:08:16 -08:00
James Reed
d1351c66a8 [FX] Add a bunch of docstrings (#47719)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47719

Test Plan: Imported from OSS

Reviewed By: zdevito

Differential Revision: D24875400

Pulled By: jamesr66a

fbshipit-source-id: a1dd43d2eee914a441eff43c4f2efe61a399e8a5
2020-11-11 10:59:57 -08:00
Horace He
373246733d [FX] get the correct error message (#47108)
Summary:
Currently, code like
```
class Test(nn.Module):
    def __init__(self):
        super(Test, self).__init__()
        self.W = torch.nn.Parameter(torch.randn(5))

    def forward(self, x):
        return torch.dot(self.W, x)

mod = Test()
print(fx.symbolic_trace(Test())(5))
```
gives an error like the below, which does not show the actual code that throws the error.
```
Traceback (most recent call last):
  File "t.py", line 20, in <module>
    print(fx.symbolic_trace(Test())(5))
  File "/home/chilli/fb/pytorch/torch/nn/modules/module.py", line 744, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/chilli/fb/pytorch/torch/fx/graph_module.py", line 191, in debug_forward
    return src_forward(self, *args, **kwargs)
  File "<eval_with_key_0>", line 5, in forward
TypeError: dot(): argument 'tensor' (position 2) must be Tensor, not int
```

This is particularly annoying when your function has already been transformed several times.

So, the really annoying thing is that the error clearly has the requisite information in `exception.__traceback__` - it just isn't printing it.

I think the right way of doing this is simply replacing `sys.excepthook`. This appears to be the standard way to modify exception messages.

**Scratch the below**

The 2 methods in the PR right now are:
1. Just prepend the final part of the traceback to the beginning of your error message. Looks like
```
Traceback (most recent call last):
  File "t.py", line 20, in <module>
    print(fx.symbolic_trace(Test())(5))
  File "/home/chilli/fb/pytorch/torch/nn/modules/module.py", line 744, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/chilli/fb/pytorch/torch/fx/graph_module.py", line 197, in debug_forward
    raise e
  File "/home/chilli/fb/pytorch/torch/fx/graph_module.py", line 192, in debug_forward
    return src_forward(self, *args, **kwargs)
  File "<eval_with_key_0>", line 5, in forward
TypeError:   File "<eval_with_key_0>", line 5, in forward
    dot_1 = torch.dot(w, x)
dot(): argument 'tensor' (position 2) must be Tensor, not int
```

2. Use the `from exception` feature in Python. Looks like
```
Traceback (most recent call last):
  File "/home/chilli/fb/pytorch/torch/fx/graph_module.py", line 192, in debug_forward
    return src_forward(self, *args, **kwargs)
  File "<eval_with_key_0>", line 5, in forward
TypeError:   File "<eval_with_key_0>", line 5, in forward
    dot_1 = torch.dot(w, x)
dot(): argument 'tensor' (position 2) must be Tensor, not int

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "t.py", line 20, in <module>
    print(fx.symbolic_trace(Test())(5))
  File "/home/chilli/fb/pytorch/torch/nn/modules/module.py", line 744, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/chilli/fb/pytorch/torch/fx/graph_module.py", line 197, in debug_forward
    raise Exception(last_tb) from e
Exception:   File "<eval_with_key_0>", line 5, in forward
    dot_1 = torch.dot(w, x)
```

I think the first one looks better, but it's pretty hacky since we're shoving the traceback in the message.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/47108

Reviewed By: jamesr66a

Differential Revision: D24751019

Pulled By: Chillee

fbshipit-source-id: 83e6ed0165f98632a77c73de75504fd6263fff40
2020-11-05 10:59:01 -08:00
James Reed
d0df29ac22 [FX] Put inf and nan in globals instead of with an import string (#47035)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/47035

Chillee thought the `from math import inf, nan` string at the top of `.code` was annoying so here's an alternative way to do it by putting those values in `globals` before we `exec`

Test Plan: Imported from OSS

Reviewed By: dzhulgakov

Differential Revision: D24611278

Pulled By: jamesr66a

fbshipit-source-id: c25ef89e649bdd3e79fe91aea945a30fa7106961
2020-10-29 00:35:41 -07:00
James Reed
b04ae953b4 [FX][WIP] Mutable Graph APIs (#45227)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45227

Test Plan: Imported from OSS

Reviewed By: zdevito

Differential Revision: D23880730

Pulled By: jamesr66a

fbshipit-source-id: eb4e8c14d7f6b1deb1ddd6cf38a360413a1705ed
2020-10-05 17:07:08 -07:00