Commit Graph

10 Commits

Author SHA1 Message Date
XiaobingSuper
1682722152 keep output type after calling SubgraphRewriter (#65453)
Summary:
For jit **SubgraphRewriter**, it doesn't keep output type after overwriting the old graph, for example, in profiling mode, the old graph has the old operator's shapes, but after replacing the old operator with a newer operator by applying **SubgraphRewriter**, the tensor shape info was eliminated.

The activation is that I want to replace pytorch convolution with a customer's convolution, I first register **aten::_convolution** as a profiler node that can reorder the input and output's shapes, and then using graph rewrite to replace it as **aten::conv2d**, which tensors' shapes info are eliminated. I hope using input size do some pre-progress before replacing **aten::conv2d** with the customer's convolution.

Before rewrite:
```
graph(%self.1 : __torch__.MyModule,
      %x.1 : Float(2, 3, 20, 20, strides=[1200, 400, 20, 1], requires_grad=0, device=cpu)):
  %7 : int = prim::Constant[value=1](), scope: __module.conv # /home/xiaobinz/miniconda3/envs/pytorch-master/lib/python3.6/                      site-packages/torch/nn/modules/conv.py:443:0
  %6 : bool = prim::Constant[value=0](), scope: __module.conv # /home/xiaobinz/miniconda3/envs/pytorch-master/lib/python3.6                      /site-packages/torch/nn/modules/conv.py:443:0
  %5 : bool = prim::Constant[value=1](), scope: __module.conv # /home/xiaobinz/miniconda3/envs/pytorch-master/lib/python3.6                      /site-packages/torch/nn/modules/conv.py:443:0
  %4 : NoneType = prim::Constant()
  %3 : int[] = prim::Constant[value=[1, 1]]()
  %2 : int[] = prim::Constant[value=[0, 0]]()
  %conv : __torch__.torch.nn.modules.conv.Conv2d = prim::GetAttr[name="conv"](%self.1)
  %z : Float(2, 3, 20, 20, strides=[1200, 400, 20, 1], requires_grad=0, device=cpu) = aten::clone(%x.1, %4) # jit_test.py:2                      2:0
  %weight : Float(3, 3, 1, 1, strides=[3, 1, 1, 1], requires_grad=0, device=cpu) = prim::GetAttr[name="weight"](%conv)
  %x : Float(2, 3, 20, 20, strides=[1200, 400, 20, 1], requires_grad=0, device=cpu) = aten::_convolution(%x.1, %weight, %4,                       %3, %2, %3, %6, %2, %7, %6, %6, %5, %5), scope: __module.conv # /home/xiaobinz/miniconda3/envs/pytorch-master/lib/python3.                      6/site-packages/torch/nn/modules/conv.py:443:0
  %16 : Float(2, 3, 20, 20, strides=[1200, 400, 20, 1], requires_grad=0, device=cpu) = aten::add(%x, %z, %7) # jit_test.py:                      24:0
  return (%16)
```
 after rewrite by using **aten::conv2d**
```
graph(%self.1 : __torch__.MyModule,
      %x.1 : Float(2, 3, 20, 20, strides=[1200, 400, 20, 1], requires_grad=0, device=cpu)):
  %7 : int = prim::Constant[value=1](), scope: __module.conv # /home/xiaobinz/miniconda3/envs/pytorch-master/lib/python3.6/site-packages/torch/nn/modules/conv.py:443:0
  %6 : bool = prim::Constant[value=0](), scope: __module.conv # /home/xiaobinz/miniconda3/envs/pytorch-master/lib/python3.6/site-packages/torch/nn/modules/conv.py:443:0
  %5 : bool = prim::Constant[value=1](), scope: __module.conv # /home/xiaobinz/miniconda3/envs/pytorch-master/lib/python3.6/site-packages/torch/nn/modules/conv.py:443:0
  %4 : NoneType = prim::Constant()
  %3 : int[] = prim::Constant[value=[1, 1]]()
  %2 : int[] = prim::Constant[value=[0, 0]]()
  %conv : __torch__.torch.nn.modules.conv.Conv2d = prim::GetAttr[name="conv"](%self.1)
  %z : Float(2, 3, 20, 20, strides=[1200, 400, 20, 1], requires_grad=0, device=cpu) = aten::clone(%x.1, %4) # jit_test.py:22:0
  %weight : Float(3, 3, 1, 1, strides=[3, 1, 1, 1], requires_grad=0, device=cpu) = prim::GetAttr[name="weight"](%conv)
  %18 : Tensor = aten::conv2d(%x.1, %weight, %4, %3, %2, %3, %7)
  %16 : Float(2, 3, 20, 20, strides=[1200, 400, 20, 1], requires_grad=0, device=cpu) = aten::add(%18, %z, %7) # jit_test.py:24:0
  return (%16)
```

expected result after replace **aten::_convolution** with  **aten::conv2d**:

```
graph(%self.1 : __torch__.MyModule,
      %x.1 : Float(2, 3, 20, 20, strides=[1200, 400, 20, 1], requires_grad=0, device=cpu)):
  %7 : int = prim::Constant[value=1](), scope: __module.conv # /home/xiaobinz/miniconda3/envs/pytorch-master/lib/python3.6/                      site-packages/torch/nn/modules/conv.py:443:0
  %6 : bool = prim::Constant[value=0](), scope: __module.conv # /home/xiaobinz/miniconda3/envs/pytorch-master/lib/python3.6                      /site-packages/torch/nn/modules/conv.py:443:0
  %5 : bool = prim::Constant[value=1](), scope: __module.conv # /home/xiaobinz/miniconda3/envs/pytorch-master/lib/python3.6                      /site-packages/torch/nn/modules/conv.py:443:0
  %4 : NoneType = prim::Constant()
  %3 : int[] = prim::Constant[value=[1, 1]]()
  %2 : int[] = prim::Constant[value=[0, 0]]()
  %conv : __torch__.torch.nn.modules.conv.Conv2d = prim::GetAttr[name="conv"](%self.1)
  %z : Float(2, 3, 20, 20, strides=[1200, 400, 20, 1], requires_grad=0, device=cpu) = aten::clone(%x.1, %4) # jit_test.py:2                      2:0
  %weight : Float(3, 3, 1, 1, strides=[3, 1, 1, 1], requires_grad=0, device=cpu) = prim::GetAttr[name="weight"](%conv)
  %18 : Float(2, 3, 20, 20, strides=[1200, 400, 20, 1], requires_grad=0, device=cpu) = aten::conv2d(%x.1, %weight, %4, %3,                       %2, %3, %7)
  %16 : Float(2, 3, 20, 20, strides=[1200, 400, 20, 1], requires_grad=0, device=cpu) = aten::add(%18, %z, %7) # jit_test.py                      :24:0
  return (%16)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/65453

Reviewed By: zdevito

Differential Revision: D31162489

Pulled By: ZolotukhinM

fbshipit-source-id: 0d1c1d607cb612df47c64f173d9f4c9e8b1d6c49
2021-09-24 11:07:40 -07:00
Nikita Shulga
a9b0a921d5 Disable avoid-non-const-global-variables lint check (#62008)
Summary:
As GoogleTest `TEST` macro is non-compliant with it as well as `DEFINE_DISPATCH`

All changes but the ones to `.clang-tidy` are generated using following script:
```
for i in `find . -type f -iname "*.c*" -or -iname "*.h"|xargs grep cppcoreguidelines-avoid-non-const-global-variables|cut -f1 -d:|sort|uniq`;  do sed -i "/\/\/ NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)/d" $i; done
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/62008

Reviewed By: driazati, r-barnes

Differential Revision: D29838584

Pulled By: malfet

fbshipit-source-id: 1b2f8602c945bd4ce50a9bfdd204755556e31d13
2021-07-22 18:04:40 -07:00
Nikita Shulga
4cb534f92e Make PyTorch code-base clang-tidy compliant (#56892)
Summary:
This is an automatic change generated by the following script:
```
#!/usr/bin/env python3
from subprocess import check_output, check_call
import os

def get_compiled_files_list():
    import json
    with open("build/compile_commands.json") as f:
        data = json.load(f)
    files = [os.path.relpath(node['file']) for node in data]
    for idx, fname in enumerate(files):
        if fname.startswith('build/') and fname.endswith('.DEFAULT.cpp'):
            files[idx] = fname[len('build/'):-len('.DEFAULT.cpp')]
    return files

def run_clang_tidy(fname):
    check_call(["python3", "tools/clang_tidy.py", "-c", "build", "-x", fname,"-s"])
    changes = check_output(["git", "ls-files", "-m"])
    if len(changes) == 0:
        return
    check_call(["git", "commit","--all", "-m", f"NOLINT stubs for {fname}"])

def main():
    git_files = check_output(["git", "ls-files"]).decode("ascii").split("\n")
    compiled_files = get_compiled_files_list()
    for idx, fname in enumerate(git_files):
        if fname not in compiled_files:
            continue
        if fname.startswith("caffe2/contrib/aten/"):
            continue
        print(f"[{idx}/{len(git_files)}] Processing {fname}")
        run_clang_tidy(fname)

if __name__ == "__main__":
    main()
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/56892

Reviewed By: H-Huang

Differential Revision: D27991944

Pulled By: malfet

fbshipit-source-id: 5415e1eb2c1b34319a4f03024bfaa087007d7179
2021-04-28 14:10:25 -07:00
Mikhail Zolotukhin
38a59a67f3 [JIT] Support multiple outputs in subgraph matcher. (#48992)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/48992

Differential Revision: D25388100

Test Plan: Imported from OSS

Reviewed By: heitorschueroff

Pulled By: ZolotukhinM

fbshipit-source-id: d95713af2220cf4f99ac92f59f8e5b902f2f3822
2020-12-15 13:09:24 -08:00
Michael Suo
22401b850b port all JIT tests to gtest (#45264)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/45264

Context for why we are porting to gtest in: https://github.com/pytorch/pytorch/pull/45018.

This PR completes the process of porting and removes unused files/macros.

Test Plan: Imported from OSS

Reviewed By: ZolotukhinM

Differential Revision: D23901392

Pulled By: suo

fbshipit-source-id: 89526890e1a49462f3f77718f4ee273c5bc578ba
2020-09-25 11:37:43 -07:00
Jerry Zhang
004aa089a6 [jit][subgraph_rewriter] Support list of filters (#39867)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/39867

Support list of filters in subgraph rewriter, the rewrite will execute only
when the match passes all filter check. this is useful for different matches
to share the same filter.

Test Plan: Imported from OSS

Differential Revision: D22009855

fbshipit-source-id: 67aab8d6326b2011a9061397699dc62ee9ad4e2d
2020-06-12 08:24:49 -07:00
Meghan Lele
6384c2d81b [JIT] clang-format JIT code (#35115)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/35115

This commit runs the newly added tools/clang_format.py on the JIT
codebase and includes all of the formatting changes thus produced.

Testing:
Ran the script, CI.

Test Plan: Imported from OSS

Reviewed By: eellison

Differential Revision: D20568523

Pulled By: SplitInfinity

fbshipit-source-id: e09bdb982ccf090eecfb7c7b461b8d0681eef82b
2020-03-26 11:24:51 -07:00
Michael Suo
c235be42dd [jit] kill script namespace (#34515)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34515

Once upon a time we thought this was necessary. In reality it is not, so
removing it.

For backcompat, our public interface (defined in `api/`) still has
typedefs to the old `script::` names.

There was only one collision: `Pass` as a `Stmt` and `Pass` as a graph
transform. I renamed one of them.

Test Plan: Imported from OSS

Differential Revision: D20353503

Pulled By: suo

fbshipit-source-id: 48bb911ce75120a8c9e0c6fb65262ef775dfba93
2020-03-11 23:32:48 -07:00
Michael Suo
dbe850af5b [jit] do the code reorg (#33851)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/33851

Rationale and context described in #33828.

Script to reproduce the move:
https://gist.github.com/suo/16cbefaaeb67ca5a7c6caffd49b7f6e9
ghstack-source-id: 99079645

Test Plan: Make sure CI passes

Reviewed By: jamesr66a

Differential Revision: D20133869

fbshipit-source-id: 390e9241a9c85366d9005c492ac31f10aa96488e
2020-02-27 13:02:51 -08:00
Jerry Zhang
f29e0d70cb Add filter function to subgraph rewriter runGraph (#26223)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26223

add filter function to runGraph, if the function returns false for given `Match`,
the we'll skip the rewrite.

Test Plan:
will test in later PR that adds extra filtering on Constant nodes

Imported from OSS

Differential Revision: D17462529

fbshipit-source-id: 52abe52cb3e729a3871f7a60eddd5275060af36a
2019-09-18 16:34:50 -07:00