pytorch/torch/fx/experimental/debug.py
Kefei Lu 76e9dbb0f4 [torch.fx] add code-gen customizability and support for setting breakpoint in code-gen'd forward() call (#67139)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67139

This diff enables setting breakpoint in the graph module's generated python code. See test plan for usage.

In order to support this functionality, and other similar functionalities to customize the generated code, a code transformer functionality is added to `fx.Graph`. This allows flexible customization of `fx.Graph`'s code gen behavior, in composable and functional ways. See test plan for its usage.

Test Plan:
### Use of `fx.experimental.debug.set_trace`

```
In [2]: from torch.fx.experimental.debug import set_trace

In [3]: set_trace(ttop)
Out[3]:
top(
  (a): Sub()
)

In [4]: ttop(1)
> /data/users/kefeilu/fbsource33/fbcode/buck-out/dev/gen/caffe2/torch/fb/fx2trt/<eval_with_key>.10(6)forward()
(Pdb) l
  1
  2
  3
  4     def forward(self, x):
  5         import pdb; pdb.set_trace()
  6  ->     a = self.a(x);  x = None
  7         getitem = a[0]
  8         getitem_1 = a[0];  a = None
  9         add = getitem + getitem_1;  getitem = getitem_1 = None
 10         return add
 11
(Pdb)
```

### Use of `on_generate_code`

```
In [1]: def insert_pdb(body):
   ...:     return ['import pdb; pdb.set_trace()\n', *body]
   ...:

In [8]: type(ttop)
Out[8]: torch.fx.graph_module.GraphModule.__new__.<locals>.GraphModuleImpl

In [10]: with ttop.graph.on_generate_code(lambda _: insert_pdb):
    ...:     ttop.recompile()
    ...:     print(f"== _on_generate_code should not be None: { ttop.graph._on_generate_code }")
    ...:     print(ttop.code)
    ...:

== _on_generate_code should not be None: <function insert_pdb at 0x7fc9895ddd30>

def forward(self, x):
    import pdb; pdb.set_trace()
    a = self.a(x);  x = None
    getitem = a[0]
    getitem_1 = a[0];  a = None
    add = getitem + getitem_1;  getitem = getitem_1 = None
    return add

In [11]: ttop.graph._on_generate_code  # restored to None

In [12]: ttop(1) # this should drop into pdb
> /data/users/kefeilu/fbsource33/fbcode/buck-out/dev/gen/caffe2/torch/fb/fx2trt/<eval_with_key>.6(6)forward()
(Pdb) l
  1
  2
  3
  4     def forward(self, x):
  5         import pdb; pdb.set_trace()
  6  ->     a = self.a(x);  x = None
  7         getitem = a[0]
  8         getitem_1 = a[0];  a = None
  9         add = getitem + getitem_1;  getitem = getitem_1 = None
 10         return add
 11
```

Reviewed By: jamesr66a

Differential Revision: D30736160

fbshipit-source-id: 9646867aae0461b5131dfd4ba9ee77a8c2ea9c93
2021-11-16 13:28:11 -08:00

32 lines
805 B
Python

import torch.fx as fx
def set_trace(gm: fx.GraphModule) -> fx.GraphModule:
"""
Sets a breakpoint in `gm`'s generated python code. It drops into pdb when
`gm` gets run.
Args:
gm: graph module to insert breakpoint. It is then recompiled for it to
take effect.
Returns:
the `gm` with breakpoint inserted.
"""
def insert_pdb(body):
return ["import pdb; pdb.set_trace()\n", *body]
with gm.graph.on_generate_code(
make_transformer=lambda cur_transform: (
# new code transformer to register
lambda body: (
insert_pdb(
cur_transform(body) if cur_transform
else body
)
)
)
):
gm.recompile()
return gm