Commit Graph

5394 Commits

Author SHA1 Message Date
Zhengxu Chen
12daa4f663 [jit][edge] Enable CALL instruction in lite interpreter. (#65964)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65964

ghstack-source-id: 141425519

Test Plan: buck run xplat/caffe2:test_lite_interpreter

Reviewed By: cccclai

Differential Revision: D31326149

fbshipit-source-id: 8a599d92f3fa4e6c125100adb36d89592e71e547
2021-10-25 14:44:33 -07:00
Jacob Szwejbka
a6d0339492 [Pytorch Edge] Extend runtime compatibility to custom classes (#66972)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66972

Add api to view how many custom classes we have and what their names are

Test Plan: unit test

Reviewed By: cccclai

Differential Revision: D31811337

fbshipit-source-id: 9f8ca1fc578a0a5360c9cd8f95475acc33f250e4
2021-10-25 13:42:26 -07:00
Zhengxu Chen
4dce051cb0 [jit][edge] Add control stack frame to lite interpreter (#65963)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65963

ghstack-source-id: 141425517

Test Plan: In next diff.

Reviewed By: qihqi, cccclai

Differential Revision: D31326150

fbshipit-source-id: dbbf65f2bf14846c45d0add71edc7d4dbfc6b92c
2021-10-25 12:15:16 -07:00
Mike Iovine
5d9ff8f30e [Static Runtime] Add static_runtime::fused_sigrid_transforms (#66659)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66659

Original message: We added and registered a new operator, static_runtime::fused_sigrid_transforms, and modified the original sigrid_transforms to handle non-fused case only

Note: this diff was commandeered from a bootcamper. Some final touches were needed.

Test Plan: `buck test caffe2/benchmarks/static_runtime/...`

Reviewed By: swolchok

Differential Revision: D31550307

fbshipit-source-id: 287380be0cca20ee6e145bcc7217547bd58cf6d0
2021-10-25 10:44:46 -07:00
Pavithran Ramachandran
8d164a36fb Use at::native::is_nonzero in promoted ops to improve portability (#67097)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67097

all delegated models have `is_nonzero` ops by default, by making the op native and consumable without dispatch eases the portability of such models
ghstack-source-id: 141375082

Test Plan:
`buck test caffe2/test/cpp/jit:jit -- BackendTest.TestComposite`

```
~/fbsource/fbcode] cd ~/fbsource/fbcode/ && buck test caffe2/test:jit -- test_trace_arange
Parsing buck files: finished in 0.5 sec
Building: finished in 9.4 sec (100%) 16035/16035 jobs, 0/16035 updated
  Total time: 10.0 sec
More details at https://www.internalfb.com/intern/buck/build/1e55eea5-2adb-41d1-96ae-cbf4b446d6c6
BUILD SUCCEEDED
Tpx test run coordinator for Facebook. See https://fburl.com/tpx for details.
Running with tpx session id: 46eedba2-ae17-4e88-b205-93bd1332665d
Trace available for this run at /tmp/tpx-20211015-113905.235421/trace.log
Started reporting to test run: https://www.internalfb.com/intern/testinfra/testrun/1970324912349177
    ✓ ListingSuccess: caffe2/test:jit - main (12.372)
    ✓ Pass: caffe2/test:jit - test_trace_arange (jit.test_tracer.TestTracer) (13.748)
    ✓ Pass: caffe2/test:jit - test_trace_arange_with_grad (jit.test_tracer.TestTracer) (13.892)
Summary
  Pass: 2
  ListingSuccess: 1
If you need help understanding your runs, please follow the wiki: https://fburl.com/posting_in_tpx_users
Finished test run: https://www.internalfb.com/intern/testinfra/testrun/1970324912349177
```

Reviewed By: iseeyuan

Differential Revision: D31656842

fbshipit-source-id: c0e6c798478a2783c0e17e6e9100ba5ce044da78
2021-10-25 10:18:31 -07:00
Mike Iovine
a0495b3cdb [SR] Remove unused operator() overload (#67001)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67001

The overload of `operator()` taking `std::vector<at::Tensor>` was only used for testing. In a diff following this one, I will add a new overload that takes `std::vector<c10::IValue> args` and no `kwargs` so we can avoid default-constructing `kwargs` everywhere.

This new overload will probably take a forwarding reference, so to avoid problems with overloading on forwarding reference and simplify the interface, it's best to remove this unused one.

Test Plan:
`buck test caffe2/benchmarks/static_runtime/...`

`buck test caffe2/test:static_runtime`

Reviewed By: hlu1

Differential Revision: D31821990

fbshipit-source-id: 6d2e4a75ca4abe6e262651532eb96c3b274c6f4a
2021-10-25 08:18:58 -07:00
Mike Iovine
364645cd9d [SR] Factor operator() implementation into separate function (#67125)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67125

Using explicit template instantiations in D31659973 (f2582a59d0) was a bad idea. The problem is that the lvalue instantiation was for a `const` vector of `IValue`, meaning that if you tried to pass SR a non-const vector of arguments, the linker would fail to find the symbol.

The reason we didn't catch this in D31659973 (f2582a59d0) was because predictor always passes a `const` reference anyways. But we should fix this to prevent unexpected problems in the future.

Test Plan: `buck test caffe2/benchmarks/static_runtime/...`

Reviewed By: hlu1

Differential Revision: D31873406

fbshipit-source-id: 5ab5a03334bed925cec11facadcedf9bec9b90ad
2021-10-25 08:17:40 -07:00
Mike Iovine
dd81fa9027 [JIT] Freeze allows preservation of submodule attributes (#66102)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66102

This changes allows the `preserved_attributes` parameter of `torch.jit.freeze` to accept attributes of submodules. Previously, only root-level attributes were able to be preserved. Example:

```
class SubModule(nn.Module):
    def __init__(self):
        super(SubModule, self).__init__()
        self.a = 1
        self.b = 2

    def forward(self):
        return self.a + self.b

class Module(nn.Module):
    def __init__(self):
        super(Module, self).__init__()
        self.sub = SubModule()

    def forward(self):
        return self.sub()

mod = torch.jit.script(Module())
mod.eval()
frozen_mod = torch.jit.freeze(mod, preserved_attrs = ['sub.a'])

mod.sub   # OK
mod.sub.a # OK
mod.sub.b # Error, not preserved
mod()     # = 3
mod.sub.a = 0
mod()     # = 2
```

Test Plan: `buck test caffe2/test:jit -- TestFreezing`

Reviewed By: eellison

Differential Revision: D31383868

fbshipit-source-id: 34a05ca9528d4e5f04f71ac2a339fd584a8fa305
2021-10-25 07:56:20 -07:00
Nikolay Korovaiko
a7ebf76a15 jit trace (#59949)
Summary:
Fixes #{issue number}

Pull Request resolved: https://github.com/pytorch/pytorch/pull/59949

Reviewed By: ZolotukhinM

Differential Revision: D31366787

Pulled By: Krovatkin

fbshipit-source-id: 798cbcd97e8ecfba984f98cd70214954be9309af
2021-10-24 18:04:22 -07:00
Natalia Gimelshein
b6fa998892 Revert D31514095: Use kernel_func_name from aotCompiler
Test Plan: revert-hammer

Differential Revision:
D31514095 (7b55dc8340)

Original commit changeset: b70c8e2c7336

fbshipit-source-id: ad4d828f33506e612b51c276149fa0e12b0565d5
2021-10-23 17:17:53 -07:00
Priya Ramani
7b55dc8340 Use kernel_func_name from aotCompiler (#66337)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66337

Right now, assembly code generated for the a given method from the model is named wrapper or func by default. The function name is then replaced with a proper kernel_func_name after target specific assembly is generated.
This PR propagates a desired kernel_func_name right from aotCompiler API so that the generated function has the needed name that doesn't need to be replaced later.

Test Plan: Imported from OSS

Reviewed By: navahgar

Differential Revision: D31514095

Pulled By: priyaramani

fbshipit-source-id: b70c8e2c733600a435cd4e8b32092d37b7bf7de5
2021-10-23 02:20:45 -07:00
BowenBao
1da628bdb7 [ONNX] Update slice process shape to support rank only inference (#65782) (#66149)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66149

Updated logic will be able to infer rank of slice output, when only rank is known for slice input. Enables cases where `ConstantValueMap::HasRank(input)` is `True`, while `ConstantValueMap::HasShape(input)` is `False`.

Test Plan: Imported from OSS

Reviewed By: jansel

Differential Revision: D31423840

Pulled By: malfet

fbshipit-source-id: 17b2b24aa63435d5212ebe6bdf66ae3c348c4e3b

Co-authored-by: BowenBao <bowbao@microsoft.com>
2021-10-22 13:46:26 -07:00
Nikita Shulga
6f3f302d9f [ONNX] Deprecate fold_if pass (#65697) (#66145)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66145

Deprecate fold_if pass

Test Plan: Imported from OSS

Reviewed By: jansel

Differential Revision: D31424097

fbshipit-source-id: 25b89679c756393a1065ca6aaa24d29db960cbd4

Co-authored-by: jiafatom <jiafa@microsoft.com>
2021-10-22 13:46:20 -07:00
Nikita Shulga
7a78f715a6 [ONNX] Add warning for inplace updates on tensor.shape in tracing mode (#63170) (#66142)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66142

* Add warning

* Lint and clang fixes

* Remove duplicate comments

* Added pitfalls section

* Modify sections

* Minor modifications

* Add underline to avoid doc build failures

Test Plan: Imported from OSS

Reviewed By: jansel

Differential Revision: D31424092

fbshipit-source-id: c83195f3c66885ad1aecde13b3029c45dd171dbd
2021-10-22 13:46:14 -07:00
Nikita Shulga
53a163a015 [ONNX] Export nn.Module call as ONNX local function (#63589) (#66140)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66140

* Add new argument to export api to enable users specifying `nn.Module` classes that they wish to be exported as local function in ONNX model.
* Refactor `torch/csrc/jit/serialization/export.cpp`, and remove redundant `EncoderBase` class.
* ~~Contains changes from #63268~~
* Depends on #63716 to update onnx submodule.

Test Plan: Imported from OSS

Reviewed By: jansel

Differential Revision: D31424098

fbshipit-source-id: c949d0b01c206c30b4182c2dd1a5b90e32b7a0d3

Co-authored-by: BowenBao <bowbao@microsoft.com>
2021-10-22 13:44:56 -07:00
Mike Iovine
f2582a59d0 [SR] Add rvalue overload for operator() (#66648)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66648

Currently, SR shallow-copies its `IValue` inputs when running inferences. We can avoid refcount bumps by `std::move`-ing the inputs into their slots. To achieve this, I've made the following changes:

1. Add an overload for `set_inputs` that takes a `std::vector<IValue>&&`.
2. Change the signatures of `StaticModule::operator()` and `StaticRuntime::operator()`.
Old:
```
operator()(const std::vector<IValue>& args, const std::unordered_map<std::string, IValue>& kwargs)
```
New:
```
template <class IValueList>
operator()(IValueList&& args, const std::unordered_map<std::string, IValue>& kwargs)
```

The implementations use perfect forwarding to invoke the correct overload of `set_inputs`.

Test Plan: Added a short new unit test to exercise the new code path. All other unit tests still pass.

Reviewed By: hlu1

Differential Revision: D31659973

fbshipit-source-id: b8c194405b54a5af1b418f8edaa1dd29a061deed
2021-10-22 10:51:47 -07:00
Aditya Pillai
40a8a50913 Add static_runtime::fused_equally_split (#2)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch-canary/pull/2

Pull Request resolved: https://github.com/pytorch/pytorch/pull/66881

Adds `static_runtime::fused_equally_split` operator and removes `is_fused` logic from original operator. Modifies `FuseUnpackListV2` to map `fb::equally_split` to this new operator.

Test Plan:
```
adityapillai@5960 /data/sandcastle/boxes/fbsource/fbcode 1m 13s
❯ buck test //caffe2/benchmarks/static_runtime/fb:test_fb_operators
```
and sandcastle
strange_what_could_go_wrong

Reviewed By: mikeiovine

Differential Revision: D31742293

fbshipit-source-id: 60b35589c8817719b005d49811f575b6590d1c39
2021-10-22 10:26:49 -07:00
Mike Iovine
391eb1dbe3 [JIT] UseVariadicOp handles multiple lists (#66288)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66288

This change makes it so `UseVariadicOp` can transform ops with many Tensor list inputs.

Input pattern:
```
%output : Type = op(%list_1, %arg_1, %list_2, %list_3)
```
Output pattern:
```
%output : Type = variadic_op(%list_11, ..., %list_1N, %arg_1, %list_21, ..., %list_2M, %list_31, ..., %list_3K, N, M, K)
```
The length of each list is passed at the end of the variadic op so that the op implementation can process the inputs appropriately. This also frees us from needing to update `hasVarArgs` in static runtime each time we add a variadic op.

This diff also makes `UseVariadicOp` more robust. Before, `list_idx` was passed as an argument. Now, `VariadicUpdater` determines `list_idx` from the node's schema.

Test Plan:
Existing variadic ops do not break:
`buck test caffe2/benchmarks/static_runtime:static_runtime_cpptest`

Reviewed By: d1jang

Differential Revision: D31450811

fbshipit-source-id: 808fcc3ae8940b9e602586f38f8cf9154c9a6462
2021-10-22 10:22:33 -07:00
Elias Ellison
6e6ede2e70 [JIT] Re-enable alias sensitive peepholes (#65860)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65860

Re-enable peepholes like `x + 0 == x`. These were at one point enabled, and then disabled because they did not properly account for aliasing, and then re-enabled with reconstructing the alias db everytime which is slow  - O(n^2). I've added correctness conditions, and I've also made it so that we avoid using stale aliasing properties for either the input or output of nodes we optimize.
Some of the other code that we have written to avoid re-instantiating the alias db involves internally mutating it, however this is tricky to reason about and we probably have to add some extra invariants...

cc navahgar relevant to graph opts and d1jang alias analysis relevant here

Test Plan: Imported from OSS

Reviewed By: ZolotukhinM

Differential Revision: D31352382

Pulled By: eellison

fbshipit-source-id: 441a27f17dc623d6c24538d1d43cba0412c3c482
2021-10-22 09:45:57 -07:00
Don Jang
051ea5ccbf [Static Runtime] Bundle function & function_kind to carry them together (#66974)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66974

`D31591785 (67e003f09b)` started carrying a function object to be executed and `FunctionKind` for the type of the function *separately*, and this caused a bug fixed by D31783028 (79803b199f).

This change bundles them as it was before done by swolchok to reduce the chances of such a mistake in the future. They need to be carried altogether always since `FunctionKind` identifies the type of the function object.

Note that `struct Function` is a POD type, so accessing its field (first, second) shouldn't cause an extra overhead in `ProcessedNode::run()`.

Test Plan:
Confirmed that the managed memory metics remain the same before/after this diff on inline_cvr:

```
#AFTER
# inline_cvr/local
Total number of managed tensors: 2660
Total number of managed output tensors: 0
Total number of unmanaged values: 3041
Total memory managed: 1496896 bytes
Total number of reused tensors: 1183
Total number of 'out' variant nodes/total number of nodes: 2452/2469 (99.3115%)
# inline_cvr/local_ro
Total number of managed tensors: 1412
Total number of managed output tensors: 0
Total number of unmanaged values: 2679
Total memory managed: 39040 bytes
Total number of reused tensors: 959
Total number of 'out' variant nodes/total number of nodes: 1928/1939 (99.4327%)
# inline_cvr/remote_ro
First iter time: 12.0344 ms
Total number of managed tensors: 1293
Total number of managed output tensors: 0
Total number of unmanaged values: 14
Total memory managed: 5293824 bytes
Total number of reused tensors: 771
Total number of 'out' variant nodes/total number of nodes: 1298/1298 (100%)
```

```
#BEFORE
#  inline_cvr/local
Total number of managed tensors: 2660
Total number of managed output tensors: 0
Total number of unmanaged values: 3041
Total memory managed: 1496896 bytes
Total number of reused tensors: 1183
Total number of 'out' variant nodes/total number of nodes: 2452/2469 (99.3115%)

#inline_cvr/local_ro
Total number of managed tensors: 1412
Total number of managed output tensors: 0
Total number of unmanaged values: 2679
Total memory managed: 39040 bytes
Total number of reused tensors: 959
Total number of 'out' variant nodes/total number of nodes: 1928/1939 (99.4327%)

#inline_cvr_remote_ro
Total number of managed tensors: 1293
Total number of managed output tensors: 0
Total number of unmanaged values: 14
Total memory managed: 5293824 bytes
Total number of reused tensors: 771
Total number of 'out' variant nodes/total number of nodes: 1298/1298 (100%)
```

Reviewed By: mikeiovine

Differential Revision: D31798419

fbshipit-source-id: fd4301b6731e402be0820729654735c791511aba
2021-10-22 08:57:49 -07:00
Chen Lai
5f58764d1d [PyTorch Edge][type] Add type support for NamedTuple custom class (import) (#63130)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63130

Extend `type_parser` to handle `NamedTuple` type. It can be extended to handle other types when needed. The custom type will follow the following format:
```
"qualified_named[
    NamedTuple, [
        [filed_name_1, field_type_1],
        [filed_name_2, field_type_2]
    ]
]"
```
For example:
```
"__torch__.base_models.sparse_nn.pytorch_preproc_types.PreprocOutputType[
    NamedTuple, [
        [float_features, Tensor],
        [id_list_features, List[Tensor]],
        [label,  Tensor],
        [weight, Tensor],
        ]
    ]"
```

For nested types, the order of type lists from type table should be:
```
std::string type_1 = “__torch__.C [
    NamedTuple, [
        [field_name_c_1, Tensor],
        [field_name_c_2, Tuple[Tensor, Tensor]],
    ]
]”

std::string type_2 = “__torch__.B [
   NamedTuple, [
       [field_name_b, __torch__.C ]
   ]
]”

std::string type_3 = “__torch__.A[
   NamedTuple, [
       [field_name_a, __torch__.B]
   ]
]”
std::vector<std::string> type_strs = {type_str_1, type_str_2, type_3};
std::vector<TypePtr> type_ptrs =  c10::parseType(type_strs);
```

namedtuple from both `collection` and `typing` are supported
```

from typing import NamedTuple
from collections import namedtuple
```

This change only adds the parser and now new runtime can read the above format.
ghstack-source-id: 141293658

Test Plan:
```
buck test mode/dev //caffe2/test/cpp/jit:jit -- --exact 'caffe2/test/cpp/jit:jit - LiteInterpreterTest.CompatiblePrimitiveType'
buck test mode/dev //caffe2/test/cpp/jit:jit -- --exact 'caffe2/test/cpp/jit:jit - LiteInterpreterTest.CompatibleCustomType'

buck test mode/dev //caffe2/test/cpp/jit:jit -- --exact 'caffe2/test/cpp/jit:jit - LiteInterpreterTest.InCompatiblePrimitiveType'
buck test mode/dev //caffe2/test/cpp/jit:jit -- --exact 'caffe2/test/cpp/jit:jit - LiteInterpreterTest.InCompatibleCustomType'
```

Reviewed By: iseeyuan

Differential Revision: D30261547

fbshipit-source-id: 68a9974338464e320b39a5c613dc048f6c5adeb5
2021-10-22 00:40:57 -07:00
Han Qi
fe102b9888 diff tool (#66854)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66854

diff tool and script to test correctness of flatbuffer format

Test Plan:
`./verify_flatbuffer.sh | pastry`
P463163180

Reviewed By: zhxchen17

Differential Revision: D31752696

fbshipit-source-id: bea00102b21e62c02367853c8bec2742b483fbda
2021-10-21 22:53:51 -07:00
Tal Ben-Nun
9d4549295d ONNX export: propagate node metadata across passes (#45256)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/45255

Mostly straightforward. Only downside in this PR is the lack of more scalable way to check for all newly-created nodes in `callPySymbolicFunction`. The other options were:
* Create a scope within the node's scope and loop through all nodes that correspond to the scope. The code would still need to loop through all nodes.
* Add extra state to the graph (no good reason to do so).
* Add extra state to the ONNX exporter, since python calls go back to `g.op(...)` (no good reason to do so, also not very pythonic).

cc BowenBao neginraoof

Pull Request resolved: https://github.com/pytorch/pytorch/pull/45256

Reviewed By: malfet, houseroad

Differential Revision: D31744281

Pulled By: msaroufim

fbshipit-source-id: 1b63f6e7f02ed61b3a9b7ac3d0be0a3a203c8ff6
2021-10-21 11:49:05 -07:00
Mike Iovine
53cf7e844f [SR] Fix bug in FuseListUnpackV2 (#67021)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67021

When applying the equally split optimization, we still need to delete the list unpack node.

I did an accuracy test yesterday but didn't catch this issue because my diffs were not properly synced between devservers (I use hlu1's devbig for testing and it had an old version of "Add FuseListUnpackV2"). But I did another test this morning and realized that there was an issue.

This is not affecting anything in prod right now since D31742293 has not landed.

Reviewed By: hlu1

Differential Revision: D31827278

fbshipit-source-id: c7b05e3d8ec942632adcff4bdfebb8c27c1a7a39
2021-10-21 11:08:04 -07:00
Bert Maher
bdb889aca1 [nnc] Use a descriptive name for fused kernels when profiling (#66990)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66990

NNC fusion groups currently show up as "TensorExpr" in the profiler,
which is true but not super useful since it obscures what's actually happening
in the fusion group.  This change will log them as `fused_XXX` where XXX is a
(length-limited) series of ops describing the subgraph, for instance
`fused_mul_add` to represent a group containing `aten::mul`, `aten::add`.

Test Plan: New unit test to check the output of autograd profiler.

Reviewed By: dzhulgakov

Differential Revision: D31762087

fbshipit-source-id: 3fadbdc67b054faa01aa42e5b6ea2c4a6bc3481f
2021-10-21 00:06:23 -07:00
Mike Iovine
ab1e4eac42 [Static Runtime] Add FuseListUnpackV2 (#66509)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66509

Like `FuseListUnpack`, but instead of adding arguments to the fused node's outputs, inserts a new fused op.

By using a new fused op, we can avoid runtime `is_fused` checks. This will make the op implementations significantly cleaner. Eventually, we will migrate all ops to `V2` and delete to old pass.

`FuseListUnpackV2` also fixes the bug described in T103159043.

Test Plan: I've made some changes to D31550307 locally and verified that everything works.

Reviewed By: hlu1

Differential Revision: D31492017

fbshipit-source-id: 4f90fcbc17e4c70a3d65985bee836fabf868a22c
2021-10-20 16:39:32 -07:00
Elias Ellison
17889ad26e Add support for cat in output stitching (#66098)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66098

`cat` is somewhat special-cased right now because currently we only have list of Tensor inputs where the list is constructed in the JIT IR graph. While that is generally true for Fusion (e.g. why we have ConstantChunk) that may not be true for shape analysis generally, so I'm waiting a bit to generalize.

Test Plan: Imported from OSS

Reviewed By: navahgar, anjali411

Differential Revision: D31797467

Pulled By: eellison

fbshipit-source-id: ca761e214dfd7f3bba8d189f3b3f42ffec064f63
2021-10-20 16:13:09 -07:00
Elias Ellison
2dd23ebfdb Add support for multi output nodes in partial eval graph stitching (#66097)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66097

Adding logic to generate runtime shapes for nodes with multi-outputs. It is generalizing existing flow of looking at a node, getting its shape graph, inlining it, and adding a mapping from the output to the new value in the stitched shape compute graph to loop over multiple outputs.

Test Plan: Imported from OSS

Reviewed By: navahgar

Differential Revision: D31797468

Pulled By: eellison

fbshipit-source-id: 2c182b71a46b36d33f23ad35b89790a4a5d4471c
2021-10-20 16:13:07 -07:00
Elias Ellison
0196b984f3 Add Handling of Cat in Shape Analysis (#65575)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65575

This is needed for lowering an NNC model to mobile. It is also the last class of unhandled ops which NNC fuses, and we need integration this for computing output symbolic shapes.

The graph of with two dynamic shape inputs produces:
```
graph(%x.1 : Tensor(SS(-2), 2, 3),
      %y.1 : Tensor(SS(-3), 2, 3)):
  %5 : int = prim::Constant[value=0]()
  %4 : Tensor[] = prim::ListConstruct(%x.1, %y.1)
  %6 : Tensor(SS(-4), 2, 3) = aten::cat(%4, %5) # /private/home/eellison/pytorch/test/jit/test_symbolic_shape_analysis.py:290:19
  return (%6)
```
With a partial eval graph of
```
Done with partial evaluation
graph(%129 : int[],
      %130 : int[],
      %dim.14 : int):
  %738 : int = prim::Constant[value=3]()
  %737 : int = prim::Constant[value=2]()
  %132 : int = prim::Constant[value=0]()
  %392 : int = aten::__getitem__(%129, %132) # <string>:339:44
  %417 : int = aten::__getitem__(%130, %132) # <string>:339:44
  %cat_dim_size.48 : int = aten::add(%392, %417) # <string>:339:29
  %result_size.5 : int[] = prim::ListConstruct(%cat_dim_size.48, %737, %738)
  return (%result_size.5)
```

To handle cat, I essentially make the cat shape op variadic,
replacing
```
torch.cat([x, y]
...
def cat_shape_op(tensors: List[List[int]], dim: int):
    ...
    op(tensors)
```
with
```
def cat_shape_op(x: List[int], y: List[int], dim: int):
    tensors = [x, y]
    op(tensors)
```
This reuses the existing input Tensor properties partial evaluation path and avoids having to add special handling to optimize out `len(tensors)` calls in the IR.

Test Plan: Imported from OSS

Reviewed By: navahgar

Differential Revision: D31797471

Pulled By: eellison

fbshipit-source-id: 62c794533d5fabfd3fad056d7e5fe3e8781b22c5
2021-10-20 16:13:05 -07:00
Elias Ellison
eaba976d49 Add x + 0 optimization (#65574)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/65574

Test Plan: Imported from OSS

Reviewed By: navahgar

Differential Revision: D31797470

Pulled By: eellison

fbshipit-source-id: bf9309fb43f164665335fed0d09697b0e2f67261
2021-10-20 16:13:03 -07:00
Elias Ellison
b059f035be Fix bug preventing optimization from firing (#65573)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65573

When we remove mutation on
```
x = [0, 1, 3, 4]
x[-2] = 4
```
we have a safety check that the new index will be in bounds of the old index. in practice, this should always be the case otherwise you would have a runtime error. Within that check (not within the actual adjustment) we were using the wrong length of inputs preventing the optimization from firing.

Test Plan: Imported from OSS

Reviewed By: navahgar

Differential Revision: D31797469

Pulled By: eellison

fbshipit-source-id: 02a1686b9f6016eb5aeb87ed342c043c203dcd0e
2021-10-20 16:13:01 -07:00
Elias Ellison
63b41e1f4d [JIT] Add partial evaluation graph stitching logic (#65377)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65377

When we run symbolic shape analysis on
```
conv = torch.nn.Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
max_pool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
mod = nn.Sequential(conv1, max_pool)
...
graph(%self : __torch__.torch.nn.modules.container.___torch_mangle_0.Sequential,
      %input.1 : Tensor):
  %18 : bool = prim::Constant[value=0]()
  %30 : int[] = prim::Constant[value=[1, 1]]()
  %29 : int[] = prim::Constant[value=[3, 3]]()
  %28 : int[] = prim::Constant[value=[2, 2]]()
  %6 : int = prim::Constant[value=1]()
  %self.0.bias : NoneType = prim::Constant()
  %self.0.weight : Double(64, 3, 7, 7, strides=[147, 49, 7, 1], requires_grad=0, device=cpu) = prim::Constant[value=<Tensor>]()
  %input.5 : Tensor(SS(-2), 64, SS(-3), SS(-4)) = aten::conv2d(%input.1, %self.0.weight, %self.0.bias, %28, %29, %30, %6)
  %input.9 : Tensor(SS(-2), 64, SS(-5), SS(-6)) = aten::max_pool2d(%input.5, %29, %28, %30, %30, %18)
  return (%input.9)
```
we partially evaluate the shape compute graph of `conv2d`, whose output gets passed in and used to partially evaluate the shape compute graph of `max_pool2d`.

The conv2d remaining partially eval'd graph is [here](https://gist.github.com/eellison/0598bd224a422211efa1a45d2b7560b7), and the maxpool2d eval'd graph is [here](https://gist.github.com/eellison/625540b84f650ddbefd3ae5511ab8814). We can take the partially eval'd graphs of a series of operators and stitch them together, which allows us to
a) recover symbolic equivalences by CSE'ing & other optimizations
b) calculate shapes for a whole block of operators just on the input, such as for fusing the whole model to nnc with dynamic shapes and then passing along the computed symbolic shapes. the calculation will also handle error handling.
c) (future-looking) generate inputs on demand for straight-line networks that are composed just of aten operators

The combined graph of the two gives us compute for the unknown symbolic dimensions - `SS(-2), SS(-3), SS(-4), SS(-5), and SS(-6)`.
```
graph(%input.1 : int[]):
  %42 : bool = prim::Constant[value=0]() # <string>:152:17
  %15 : int = prim::Constant[value=3]()
  %input_batch_size_dim.1 : int = prim::Constant[value=0]() # <string>:417:41
  %13 : int = prim::Constant[value=1]() # <string>:426:61
  %12 : int = prim::Constant[value=4]() # <string>:437:32
  %11 : str = prim::Constant[value="AssertionError: "]()
  %9 : int = prim::Constant[value=2]()
  %8 : int = prim::Constant[value=6]()
  %7 : int = prim::Constant[value=7]()
  %16 : int = aten::len(%input.1) # <string>:438:17
  %17 : bool = aten::eq(%16, %12) # <string>:438:17
   = prim::If(%17) # <string>:438:10
    block0():
      -> ()
    block1():
       = prim::RaiseException(%11) # <string>:438:10
      -> ()
  %18 : int = aten::__getitem__(%input.1, %13) # <string>:407:17
  %19 : bool = aten::eq(%18, %15) # <string>:407:17
   = prim::If(%19) # <string>:407:10
    block0():
      -> ()
    block1():
       = prim::RaiseException(%11) # <string>:407:10
      -> ()
  %20 : int = aten::__getitem__(%input.1, %9) # <string>:411:20
  %21 : int = aten::add(%20, %8) # <string>:411:20
  %22 : bool = aten::ge(%21, %7) # <string>:411:20
   = prim::If(%22) # <string>:411:12
    block0():
      -> ()
    block1():
       = prim::RaiseException(%11) # <string>:411:12
      -> ()
  %23 : int = aten::__getitem__(%input.1, %15) # <string>:411:20
  %24 : int = aten::add(%23, %8) # <string>:411:20
  %25 : bool = aten::ge(%24, %7) # <string>:411:20
   = prim::If(%25) # <string>:411:12
    block0():
      -> ()
    block1():
       = prim::RaiseException(%11) # <string>:411:12
      -> ()
  %26 : int = aten::__getitem__(%input.1, %input_batch_size_dim.1) # <string>:422:29
  %27 : int = aten::sub(%20, %13) # <string>:428:32
  %28 : int = aten::floordiv(%27, %9) # <string>:428:32
  %29 : int = aten::add(%28, %13) # <string>:428:32
  %30 : int = aten::sub(%23, %13) # <string>:428:32
  %31 : int = aten::floordiv(%30, %9) # <string>:428:32
  %32 : int = aten::add(%31, %13) # <string>:428:32
  %48 : int = aten::floordiv(%28, %9) # <string>:133:17
  %outputSize.2 : int = aten::add(%48, %13) # <string>:136:23
  %51 : int = aten::floordiv(%31, %9) # <string>:133:17
  %outputSize.1 : int = aten::add(%51, %13) # <string>:136:23
  %53 : bool = aten::ne(%29, %input_batch_size_dim.1) # <string>:156:41
  %54 : bool = prim::If(%53) # <string>:157:64
    block0():
      %55 : bool = aten::ne(%32, %input_batch_size_dim.1) # <string>:157:93
      -> (%55)
    block1():
      -> (%42)
   = prim::If(%54) # <string>:157:10
    block0():
      -> ()
    block1():
       = prim::RaiseException(%11) # <string>:157:10
      -> ()
  %56 : bool = aten::ge(%outputSize.1, %13) # <string>:160:17
  %57 : bool = prim::If(%56) # <string>:160:17
    block0():
      %58 : bool = aten::ge(%outputSize.2, %13) # <string>:160:38
      -> (%58)
    block1():
      -> (%42)
   = prim::If(%57) # <string>:160:10
    block0():
      -> ()
    block1():
       = prim::RaiseException(%11) # <string>:160:10
      -> ()
  return (%26, %29, %32, %outputSize.2, %outputSize.1)
  ```

This PR runs shape analysis, retains the partially evaluated graphs, and then stitches them together, keeping track of what inputs in the partial eval graph correspond to what inputs in the encompassing graph IR and what outputs correspond to what symbolic shape. Adding NNC ppl as reviewers because it is relevant to dynamic shape fusion.

Question for reviewers  : should I make this a separate file ?

Test Plan: Imported from OSS

Reviewed By: navahgar

Differential Revision: D31797472

Pulled By: eellison

fbshipit-source-id: a41ed31fad085d3563e71c815f49af0cd18aaeed
2021-10-20 16:12:58 -07:00
Elias Ellison
4ad6c144f6 [JIT][Easy] Shape cleanups (#65148)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65148

No functional changes, factoring out optimizations and renaming the `graph` in symbolic shape analysis to `shape_compute_graph` as ZolotukhinM suggested

Test Plan: Imported from OSS

Reviewed By: navahgar

Differential Revision: D31797447

Pulled By: eellison

fbshipit-source-id: 60d322da040245dd7b47ee7c8996239572fd11c2
2021-10-20 16:11:24 -07:00
David Berard
e86d8323cb [JIT] Add special cases for batch_norm, instance_norm in alias_analysis (#66554)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66554

In native_functions.yaml, the schemas for batch_norm and instance_norm
are incorrect: the inputs `running_mean` and `running_var` are mutated,
but are not marked as such in the function schema. Since `(a!)?`
annotations are currently not working (see #65760), this instead adds a
special case to `alias_anaysis.cpp`. If the value of `training` or
`use_input_stats` is known to be `false`, then `alias_analysis` will
mark the input as _not_ being written to.

Test Plan:
Removed the `skip` annotation on the following test, and added a special
exception in `check_alias_annotations`:
```
python test/test_ops.py -k test_variant_consistency_jit_nn_functional_batch_norm
```

Also:
```
./build/bin/test_jit --gtest_filter="*BatchAndInstanceNormFixture*"
```

Imported from OSS

Reviewed By: eellison

Differential Revision: D31612339

fbshipit-source-id: 12ca61b782b9e41e06883ba080a276209dc435bb
2021-10-20 10:22:10 -07:00
Hao Lu
79803b199f [Static Runtime] Make sure ProcessedNode::function_kind_ is copied over (#66917)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66917

The total number of 'out' variant nodes/total number of nodes is now 100% for all the models, which isn't true obviously.

Reviewed By: swolchok, mikeiovine

Differential Revision: D31783028

fbshipit-source-id: e0bc2c6614aa3c3a235283c9125de1b339f42585
2021-10-20 00:21:35 -07:00
Michael Suo
ef15691a1e Revert D31732421: [JIT][Easy] Shape cleanups
Test Plan: revert-hammer

Differential Revision:
D31732421 (16d0896b69)

Original commit changeset: e934507d1795

fbshipit-source-id: 6b34815c556de64ee5c7ef8d41e4cb434ccd7098
2021-10-19 20:07:06 -07:00
Michael Suo
70c9eb130d Revert D31732419: [JIT] Add partial evaluation graph stitching logic
Test Plan: revert-hammer

Differential Revision:
D31732419 (5db7db667f)

Original commit changeset: 883a55cbeef0

fbshipit-source-id: f5faba69dfb6b54aeb29d1beaeec8c5b0373830f
2021-10-19 20:07:04 -07:00
Michael Suo
90b42452e2 Revert D31732417: Fix bug preventing optimization from firing
Test Plan: revert-hammer

Differential Revision:
D31732417 (853fc25fb0)

Original commit changeset: dd734254c021

fbshipit-source-id: 3da0663dac5b5d2117b3d7abdbcd45d96f98de33
2021-10-19 20:07:02 -07:00
Michael Suo
b8d58129bb Revert D31732420: Add x + 0 optimization
Test Plan: revert-hammer

Differential Revision:
D31732420 (66543f88de)

Original commit changeset: 0271e0dc0dda

fbshipit-source-id: c2beea1661e10c2f1a982b5d4a34b1041dcb1204
2021-10-19 20:07:00 -07:00
Michael Suo
e730752610 Revert D31732416: Add Handling of Cat in Shape Analysis
Test Plan: revert-hammer

Differential Revision:
D31732416 (cc7de1df3b)

Original commit changeset: 6d93ddf62c34

fbshipit-source-id: e2c9713177a7f783897e99dd71e631fb275c37da
2021-10-19 20:06:57 -07:00
Michael Suo
57fcea9e88 Revert D31732418: Add support for multi output nodes in partial eval graph stitching
Test Plan: revert-hammer

Differential Revision:
D31732418 (0fdc9b77a3)

Original commit changeset: 767698d031b1

fbshipit-source-id: f899eb155dcec67d57f53a658a71169d37b63b42
2021-10-19 20:06:55 -07:00
Michael Suo
4187d870df Revert D31732415: Add support for cat in output stitching
Test Plan: revert-hammer

Differential Revision:
D31732415 (b4db5174fe)

Original commit changeset: 7f513cea355f

fbshipit-source-id: a0d8f1512b13d51f6e50b5da58084effbaf0a0dc
2021-10-19 20:06:53 -07:00
Michael Suo
1bf0e1acb4 Revert D31732414: Add Initial NNC Dynamic Shapes Flow
Test Plan: revert-hammer

Differential Revision:
D31732414 (de4fe7a38c)

Original commit changeset: 290a94a667c2

fbshipit-source-id: 3021a1d7a8661967e37d4f9cfc86ed47cc4a7f3d
2021-10-19 20:05:29 -07:00
Elias Ellison
de4fe7a38c Add Initial NNC Dynamic Shapes Flow (#66136)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66136

FOR REVIEWERS: this is ready to review, test failures comes from somewhere else in stack..

Takes in a TensorExprGraph of static shapes and generalizes the input shapes
to symbolic dimensions. Dimensions of value 1 will be preserved, otherwise
dimensions with the same value will be bucketed to the same symbolic shape.

E.g. `Tensor(5, 3), Tensor(3, 1) -> Tensor(SS(-1), SS(-2)), Tensor(SS(-2), 1)`

From there, runs symbolic shape inference on the graph, and creates a
versioning if in the graph with prim::TensorExprDynamicGuard checking if
the inputs at runtime match the Generalized Symbolic Shapes that are inputs
to the TE Kernel. The computate to calculate all symbolic dimensions is
inlined in to the if block with the TE Kernel. All Sym Dim Value* are
appended to the end of the TE Kernel Graph/Node inputs, and the Node is
augmented with a integer list attr `symbolic_shape_inputs` that gives the
mapping from Value * -> Symbolic Shape int64_t value. For more lengthy IR
examples and walkthrough look at ShapeAnalysisTest.DynamicShapesFusion in
`test_shape_analysis` Returns True on Success, False on Failure, can fail if
shape propagation fails to propagate # of dims or if complete shapes on
inputs not set.

Example transformation
```
graph(%x_inp : Float(10, 5, strides=[5, 1], requires_grad=0, device=cpu),
      %y_inp : Float(4, 5, strides=[5, 1], requires_grad=0, device=cpu),
      %z_inp : Float(1, 1, strides=[1, 1], requires_grad=0, device=cpu)):
  %3 : Tensor = prim::TensorExprGroup_0(%x_inp, %y_inp, %z_inp)
  return ()
with prim::TensorExprGroup_0 = graph(%x.1 : Float(10, 5, strides=[5, 1], requires_grad=0, device=cpu),
      %y.1 : Float(4, 5, strides=[5, 1], requires_grad=0, device=cpu),
      %z : Float(1, 1, strides=[1, 1], requires_grad=0, device=cpu)):
  %3 : int = prim::Constant[value=0]()
  %4 : Tensor = aten::tanh(%x.1)
  %5 : Tensor = aten::erf(%4)
  %6 : Tensor = aten::relu(%y.1)
  %7 : Tensor[] = prim::ListConstruct(%5, %6)
  %8 : Tensor = aten::cat(%7, %3)
  %9 : Tensor = aten::hardswish(%8)
  %10 : Tensor = aten::mul(%9, %z)
  return (%9)
```
->

```
  graph(%x_inp : Float(10, 5, strides=[5, 1], requires_grad=0, device=cpu),
      %y_inp : Float(4, 5, strides=[5, 1], requires_grad=0, device=cpu),
      %z_inp : Float(1, 1, strides=[1, 1], requires_grad=0, device=cpu)):
  %4 : bool = prim::TensorExprDynamicGuard[types=[Float(SS(-2), SS(-3), strides=[5, 1], requires_grad=0, device=cpu), Float(SS(-4), SS(-3), strides=[5, 1], requires_grad=0, device=cpu), Float(1, 1, strides=[1, 1], requires_grad=0, device=cpu)]](%x_inp, %y_inp, %z_inp)
  %5 : Tensor = prim::If(%4)
    block0():
      %15 : int[] = aten::size(%x_inp)
      %16 : int[] = aten::size(%y_inp)
      %17 : int = prim::Constant[value=1]()
      %18 : int = prim::Constant[value=0]()
      %elem.3 : int = aten::__getitem__(%15, %18) # <string>:40:10
      %elem.5 : int = aten::__getitem__(%15, %17) # <string>:40:10
      %elem.11 : int = aten::__getitem__(%16, %18) # <string>:40:10
      %cat_dim_size.48 : int = aten::add(%elem.3, %elem.11) # <string>:321:29
      %3 : Tensor = prim::TensorExprGroup_0[symbolic_shape_inputs=[-5, -4, -3, -2]](%x_inp, %y_inp, %z_inp, %cat_dim_size.48, %elem.11, %elem.5, %elem.3)
      -> (%3)
    block1():
      %14 : Tensor = prim::FallbackGraph_1(%x_inp, %y_inp, %z_inp)
      -> (%14)
  return ()
  with prim::TensorExprGroup_0 = graph(%x.1 : Float(SS(-2), SS(-3), strides=[5, 1], requires_grad=0, device=cpu),
        %y.1 : Float(SS(-4), SS(-3), strides=[5, 1], requires_grad=0, device=cpu),
        %z : Float(1, 1, strides=[1, 1], requires_grad=0, device=cpu),
        %SS_5 : int,
        %SS_4 : int,
        %SS_3 : int,
        %SS_2 : int):
    %3 : int = prim::Constant[value=0]()
    %4 : Tensor(SS(-2), SS(-3)) = aten::tanh(%x.1)
    %5 : Tensor(SS(-2), SS(-3)) = aten::erf(%4)
    %6 : Tensor(SS(-4), SS(-3)) = aten::relu(%y.1)
    %7 : Tensor[] = prim::ListConstruct(%5, %6)
    %8 : Tensor(SS(-5), SS(-3)) = aten::cat(%7, %3)
    %9 : Tensor(SS(-5), SS(-3)) = aten::hardswish(%8)
    %10 : Tensor(SS(-5), SS(-3)) = aten::mul(%9, %z)
    return (%9)
```

Test Plan: Imported from OSS

Reviewed By: navahgar

Differential Revision: D31732414

Pulled By: eellison

fbshipit-source-id: 290a94a667c20467717202a43c60e4f9ca4c00e2
2021-10-19 16:41:49 -07:00
Elias Ellison
b4db5174fe Add support for cat in output stitching (#66098)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66098

`cat` is somewhat special-cased right now because currently we only have list of Tensor inputs where the list is constructed in the JIT IR graph. While that is generally true for Fusion (e.g. why we have ConstantChunk) that may not be true for shape analysis generally, so I'm waiting a bit to generalize.

Test Plan: Imported from OSS

Reviewed By: navahgar

Differential Revision: D31732415

Pulled By: eellison

fbshipit-source-id: 7f513cea355f1e4c1d2ca7c32c06690a9bdcb050
2021-10-19 16:41:44 -07:00
Elias Ellison
0fdc9b77a3 Add support for multi output nodes in partial eval graph stitching (#66097)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66097

Adding logic to generate runtime shapes for nodes with multi-outputs. It is generalizing existing flow of looking at a node, getting its shape graph, inlining it, and adding a mapping from the output to the new value in the stitched shape compute graph to loop over multiple outputs.

Test Plan: Imported from OSS

Reviewed By: navahgar

Differential Revision: D31732418

Pulled By: eellison

fbshipit-source-id: 767698d031b1daf002678a025b270e0ede429061
2021-10-19 16:41:39 -07:00
Elias Ellison
cc7de1df3b Add Handling of Cat in Shape Analysis (#65575)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65575

This is needed for lowering an NNC model to mobile. It is also the last class of unhandled ops which NNC fuses, and we need integration this for computing output symbolic shapes.

The graph of with two dynamic shape inputs produces:
```
graph(%x.1 : Tensor(SS(-2), 2, 3),
      %y.1 : Tensor(SS(-3), 2, 3)):
  %5 : int = prim::Constant[value=0]()
  %4 : Tensor[] = prim::ListConstruct(%x.1, %y.1)
  %6 : Tensor(SS(-4), 2, 3) = aten::cat(%4, %5) # /private/home/eellison/pytorch/test/jit/test_symbolic_shape_analysis.py:290:19
  return (%6)
```
With a partial eval graph of
```
Done with partial evaluation
graph(%129 : int[],
      %130 : int[],
      %dim.14 : int):
  %738 : int = prim::Constant[value=3]()
  %737 : int = prim::Constant[value=2]()
  %132 : int = prim::Constant[value=0]()
  %392 : int = aten::__getitem__(%129, %132) # <string>:339:44
  %417 : int = aten::__getitem__(%130, %132) # <string>:339:44
  %cat_dim_size.48 : int = aten::add(%392, %417) # <string>:339:29
  %result_size.5 : int[] = prim::ListConstruct(%cat_dim_size.48, %737, %738)
  return (%result_size.5)
```

To handle cat, I essentially make the cat shape op variadic,
replacing
```
torch.cat([x, y]
...
def cat_shape_op(tensors: List[List[int]], dim: int):
    ...
    op(tensors)
```
with
```
def cat_shape_op(x: List[int], y: List[int], dim: int):
    tensors = [x, y]
    op(tensors)
```
This reuses the existing input Tensor properties partial evaluation path and avoids having to add special handling to optimize out `len(tensors)` calls in the IR.

Test Plan: Imported from OSS

Reviewed By: navahgar

Differential Revision: D31732416

Pulled By: eellison

fbshipit-source-id: 6d93ddf62c34846ec238159f75229632515530b7
2021-10-19 16:41:34 -07:00
Elias Ellison
66543f88de Add x + 0 optimization (#65574)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/65574

Test Plan: Imported from OSS

Reviewed By: navahgar

Differential Revision: D31732420

Pulled By: eellison

fbshipit-source-id: 0271e0dc0ddab06220048ed5bf4236fc85f3318c
2021-10-19 16:41:29 -07:00
Elias Ellison
853fc25fb0 Fix bug preventing optimization from firing (#65573)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65573

When we remove mutation on
```
x = [0, 1, 3, 4]
x[-2] = 4
```
we have a safety check that the new index will be in bounds of the old index. in practice, this should always be the case otherwise you would have a runtime error. Within that check (not within the actual adjustment) we were using the wrong length of inputs preventing the optimization from firing.

Test Plan: Imported from OSS

Reviewed By: navahgar

Differential Revision: D31732417

Pulled By: eellison

fbshipit-source-id: dd734254c0212ca459c1c135da262974de5299be
2021-10-19 16:41:24 -07:00
Elias Ellison
5db7db667f [JIT] Add partial evaluation graph stitching logic (#65377)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65377

When we run symbolic shape analysis on
```
conv = torch.nn.Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
max_pool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
mod = nn.Sequential(conv1, max_pool)
...
graph(%self : __torch__.torch.nn.modules.container.___torch_mangle_0.Sequential,
      %input.1 : Tensor):
  %18 : bool = prim::Constant[value=0]()
  %30 : int[] = prim::Constant[value=[1, 1]]()
  %29 : int[] = prim::Constant[value=[3, 3]]()
  %28 : int[] = prim::Constant[value=[2, 2]]()
  %6 : int = prim::Constant[value=1]()
  %self.0.bias : NoneType = prim::Constant()
  %self.0.weight : Double(64, 3, 7, 7, strides=[147, 49, 7, 1], requires_grad=0, device=cpu) = prim::Constant[value=<Tensor>]()
  %input.5 : Tensor(SS(-2), 64, SS(-3), SS(-4)) = aten::conv2d(%input.1, %self.0.weight, %self.0.bias, %28, %29, %30, %6)
  %input.9 : Tensor(SS(-2), 64, SS(-5), SS(-6)) = aten::max_pool2d(%input.5, %29, %28, %30, %30, %18)
  return (%input.9)
```
we partially evaluate the shape compute graph of `conv2d`, whose output gets passed in and used to partially evaluate the shape compute graph of `max_pool2d`.

The conv2d remaining partially eval'd graph is [here](https://gist.github.com/eellison/0598bd224a422211efa1a45d2b7560b7), and the maxpool2d eval'd graph is [here](https://gist.github.com/eellison/625540b84f650ddbefd3ae5511ab8814). We can take the partially eval'd graphs of a series of operators and stitch them together, which allows us to
a) recover symbolic equivalences by CSE'ing & other optimizations
b) calculate shapes for a whole block of operators just on the input, such as for fusing the whole model to nnc with dynamic shapes and then passing along the computed symbolic shapes. the calculation will also handle error handling.
c) (future-looking) generate inputs on demand for straight-line networks that are composed just of aten operators

The combined graph of the two gives us compute for the unknown symbolic dimensions - `SS(-2), SS(-3), SS(-4), SS(-5), and SS(-6)`.
```
graph(%input.1 : int[]):
  %42 : bool = prim::Constant[value=0]() # <string>:152:17
  %15 : int = prim::Constant[value=3]()
  %input_batch_size_dim.1 : int = prim::Constant[value=0]() # <string>:417:41
  %13 : int = prim::Constant[value=1]() # <string>:426:61
  %12 : int = prim::Constant[value=4]() # <string>:437:32
  %11 : str = prim::Constant[value="AssertionError: "]()
  %9 : int = prim::Constant[value=2]()
  %8 : int = prim::Constant[value=6]()
  %7 : int = prim::Constant[value=7]()
  %16 : int = aten::len(%input.1) # <string>:438:17
  %17 : bool = aten::eq(%16, %12) # <string>:438:17
   = prim::If(%17) # <string>:438:10
    block0():
      -> ()
    block1():
       = prim::RaiseException(%11) # <string>:438:10
      -> ()
  %18 : int = aten::__getitem__(%input.1, %13) # <string>:407:17
  %19 : bool = aten::eq(%18, %15) # <string>:407:17
   = prim::If(%19) # <string>:407:10
    block0():
      -> ()
    block1():
       = prim::RaiseException(%11) # <string>:407:10
      -> ()
  %20 : int = aten::__getitem__(%input.1, %9) # <string>:411:20
  %21 : int = aten::add(%20, %8) # <string>:411:20
  %22 : bool = aten::ge(%21, %7) # <string>:411:20
   = prim::If(%22) # <string>:411:12
    block0():
      -> ()
    block1():
       = prim::RaiseException(%11) # <string>:411:12
      -> ()
  %23 : int = aten::__getitem__(%input.1, %15) # <string>:411:20
  %24 : int = aten::add(%23, %8) # <string>:411:20
  %25 : bool = aten::ge(%24, %7) # <string>:411:20
   = prim::If(%25) # <string>:411:12
    block0():
      -> ()
    block1():
       = prim::RaiseException(%11) # <string>:411:12
      -> ()
  %26 : int = aten::__getitem__(%input.1, %input_batch_size_dim.1) # <string>:422:29
  %27 : int = aten::sub(%20, %13) # <string>:428:32
  %28 : int = aten::floordiv(%27, %9) # <string>:428:32
  %29 : int = aten::add(%28, %13) # <string>:428:32
  %30 : int = aten::sub(%23, %13) # <string>:428:32
  %31 : int = aten::floordiv(%30, %9) # <string>:428:32
  %32 : int = aten::add(%31, %13) # <string>:428:32
  %48 : int = aten::floordiv(%28, %9) # <string>:133:17
  %outputSize.2 : int = aten::add(%48, %13) # <string>:136:23
  %51 : int = aten::floordiv(%31, %9) # <string>:133:17
  %outputSize.1 : int = aten::add(%51, %13) # <string>:136:23
  %53 : bool = aten::ne(%29, %input_batch_size_dim.1) # <string>:156:41
  %54 : bool = prim::If(%53) # <string>:157:64
    block0():
      %55 : bool = aten::ne(%32, %input_batch_size_dim.1) # <string>:157:93
      -> (%55)
    block1():
      -> (%42)
   = prim::If(%54) # <string>:157:10
    block0():
      -> ()
    block1():
       = prim::RaiseException(%11) # <string>:157:10
      -> ()
  %56 : bool = aten::ge(%outputSize.1, %13) # <string>:160:17
  %57 : bool = prim::If(%56) # <string>:160:17
    block0():
      %58 : bool = aten::ge(%outputSize.2, %13) # <string>:160:38
      -> (%58)
    block1():
      -> (%42)
   = prim::If(%57) # <string>:160:10
    block0():
      -> ()
    block1():
       = prim::RaiseException(%11) # <string>:160:10
      -> ()
  return (%26, %29, %32, %outputSize.2, %outputSize.1)
  ```

This PR runs shape analysis, retains the partially evaluated graphs, and then stitches them together, keeping track of what inputs in the partial eval graph correspond to what inputs in the encompassing graph IR and what outputs correspond to what symbolic shape. Adding NNC ppl as reviewers because it is relevant to dynamic shape fusion.

Question for reviewers  : should I make this a separate file ?

Test Plan: Imported from OSS

Reviewed By: navahgar

Differential Revision: D31732419

Pulled By: eellison

fbshipit-source-id: 883a55cbeef0fd5a6068a779ffa89b6f537245b3
2021-10-19 16:41:19 -07:00