mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/67139 This diff enables setting breakpoint in the graph module's generated python code. See test plan for usage. In order to support this functionality, and other similar functionalities to customize the generated code, a code transformer functionality is added to `fx.Graph`. This allows flexible customization of `fx.Graph`'s code gen behavior, in composable and functional ways. See test plan for its usage. Test Plan: ### Use of `fx.experimental.debug.set_trace` ``` In [2]: from torch.fx.experimental.debug import set_trace In [3]: set_trace(ttop) Out[3]: top( (a): Sub() ) In [4]: ttop(1) > /data/users/kefeilu/fbsource33/fbcode/buck-out/dev/gen/caffe2/torch/fb/fx2trt/<eval_with_key>.10(6)forward() (Pdb) l 1 2 3 4 def forward(self, x): 5 import pdb; pdb.set_trace() 6 -> a = self.a(x); x = None 7 getitem = a[0] 8 getitem_1 = a[0]; a = None 9 add = getitem + getitem_1; getitem = getitem_1 = None 10 return add 11 (Pdb) ``` ### Use of `on_generate_code` ``` In [1]: def insert_pdb(body): ...: return ['import pdb; pdb.set_trace()\n', *body] ...: In [8]: type(ttop) Out[8]: torch.fx.graph_module.GraphModule.__new__.<locals>.GraphModuleImpl In [10]: with ttop.graph.on_generate_code(lambda _: insert_pdb): ...: ttop.recompile() ...: print(f"== _on_generate_code should not be None: { ttop.graph._on_generate_code }") ...: print(ttop.code) ...: == _on_generate_code should not be None: <function insert_pdb at 0x7fc9895ddd30> def forward(self, x): import pdb; pdb.set_trace() a = self.a(x); x = None getitem = a[0] getitem_1 = a[0]; a = None add = getitem + getitem_1; getitem = getitem_1 = None return add In [11]: ttop.graph._on_generate_code # restored to None In [12]: ttop(1) # this should drop into pdb > /data/users/kefeilu/fbsource33/fbcode/buck-out/dev/gen/caffe2/torch/fb/fx2trt/<eval_with_key>.6(6)forward() (Pdb) l 1 2 3 4 def forward(self, x): 5 import pdb; pdb.set_trace() 6 -> a = self.a(x); x = None 7 getitem = a[0] 8 getitem_1 = a[0]; a = None 9 add = getitem + getitem_1; getitem = getitem_1 = None 10 return add 11 ``` Reviewed By: jamesr66a Differential Revision: D30736160 fbshipit-source-id: 9646867aae0461b5131dfd4ba9ee77a8c2ea9c93
32 lines
805 B
Python
32 lines
805 B
Python
import torch.fx as fx
|
|
|
|
def set_trace(gm: fx.GraphModule) -> fx.GraphModule:
|
|
"""
|
|
Sets a breakpoint in `gm`'s generated python code. It drops into pdb when
|
|
`gm` gets run.
|
|
|
|
Args:
|
|
gm: graph module to insert breakpoint. It is then recompiled for it to
|
|
take effect.
|
|
|
|
Returns:
|
|
the `gm` with breakpoint inserted.
|
|
"""
|
|
def insert_pdb(body):
|
|
return ["import pdb; pdb.set_trace()\n", *body]
|
|
|
|
with gm.graph.on_generate_code(
|
|
make_transformer=lambda cur_transform: (
|
|
# new code transformer to register
|
|
lambda body: (
|
|
insert_pdb(
|
|
cur_transform(body) if cur_transform
|
|
else body
|
|
)
|
|
)
|
|
)
|
|
):
|
|
gm.recompile()
|
|
|
|
return gm
|