Commit Graph

374 Commits

Author SHA1 Message Date
Yukio Siraichi
ffb526a2e4 Value range refinement using uni-variate expressions. (#97963)
This PR introduces value range refinement of shape symbols by symbolically evaluating the
value range of the involved guards. This should help `_maybe_evaluate_static` to eliminate
more guards.

This is a stack of PRs created from the discussion on: #96616.

In summary, this PR:
- simplifies `FloorDiv` nodes on the left-hand side of an expression so as to isolate a
symbol in the numerator
- tries to match the expression against the form: `<symbol> <relop> <expr>`
- uses the matched expression for refining the value range of `<symbol>` using the range
of `<expr>`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97963
Approved by: https://github.com/ezyang
2023-06-30 01:32:22 +00:00
Yukio Siraichi
e311bed2a8 Turn translation validation on for tests and accuracy runs by default. (#103611)
This PR turns translation validation on by default for tests and accuracy benchmark
runs. It also installs Z3 on CI.

The main changes are:

- Add `--no-translation-validation` as an option in _test/run_tests.py_
    - Set `PYTORCH_TEST_WITH_TV` environment variable
- Add `TEST_WITH_TV` variable in _torch/testing/_internal/common_utils.py_
- Turn translation validation on for accuracy benchmarks in _benchmarks/dynamo/common.py_
- Add Z3 installation on CI scripts

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103611
Approved by: https://github.com/ezyang
2023-06-30 01:32:21 +00:00
Brian Hirsh
875f60399e pre_dispatch tracing: support autocast and no_grad/enable_grad ctx managers, add a pre_dispatch_eager dynamo backend (#103024)
This PR adds support for `enable_grad`/`no_grad`/`autocast` context managers getting properly traced in `pre_dispatch` tracing. The stuff in this PR includes:
- I added a torch function mode that runs during make_fx pre_dispatch tracing, `ProxyTorchFunctionMode`. It directly intercepts the torch ops that run during the above context managers, and adds them to the current graph instead of executing them
- `enable_grad` and `no_grad` currently desugar into `torch._C.set_grad_enabled(bool)`, but this API isn't currently overrideable by torch function so I added the ability to interpose there
- the `torch.amp` context managers don't currently have a nice equivalent, like `set_autocast_enabled(state)`, so I ended up adding two new API's: `torch.amp._set_autocast_enabled` and `torch.amp._set_autocast_disabled`. If you look at how the context manager is implemented, it ends up calling several different state-changing functions, some of which depend on the backend - so I figured that it would be cleaner just to add a new API (that should probably only be used by tracing) - but open to feedback
- I added a new dynamo backend, `compile(backend="pre_dispatch_eager")`. When pre_dispatch tracing becomes always-on in inductor, it will be another potential surface for bugs. I also added a test file for it (`test/dynamo/test_pre_dispatch.py`).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103024
Approved by: https://github.com/ezyang
2023-06-29 14:17:42 +00:00
Nikita Karetnikov
e9705c52ac [pt2] add metas for _pdist_forward and _pdist_backward (#103817)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103817
Approved by: https://github.com/ezyang
2023-06-22 11:18:05 +00:00
Nikita Karetnikov
e48851033a [pt2] add metas for pad ops (#103815)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103815
Approved by: https://github.com/ezyang
2023-06-22 11:18:05 +00:00
Brian Hirsh
c3c03e7cb8 Reland of https://github.com/pytorch/pytorch/pull/101818 (#103888)
Original PR broke internal

This reverts commit 5ed618132f.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103888
Approved by: https://github.com/albanD
2023-06-21 21:00:56 +00:00
Peter Bell
8b418f197c [decomp] Add decomposition for torch.renorm (#103858)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103858
Approved by: https://github.com/ezyang, https://github.com/nkaretnikov
2023-06-21 20:57:43 +00:00
Peter Bell
a61096fb94 [decomp] Decompose logaddexp2 (#103765)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103765
Approved by: https://github.com/Chillee
2023-06-21 20:16:24 +00:00
PyTorch MergeBot
7b6dc72ffa Revert "[decomp] Decompose logaddexp2 (#103765)"
This reverts commit bab21d20eb.

Reverted https://github.com/pytorch/pytorch/pull/103765 on behalf of https://github.com/ezyang due to looks like land race ([comment](https://github.com/pytorch/pytorch/pull/103765#issuecomment-1599030496))
2023-06-20 15:35:02 +00:00
Peter Bell
bab21d20eb [decomp] Decompose logaddexp2 (#103765)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103765
Approved by: https://github.com/Chillee
2023-06-20 09:24:21 +00:00
Richard Zou
27a67d8699 Refactor and improve make_fx testing (#103196)
This is in preparation for the custom_op_compile_check utility, which
will call the newly refactored function.

This PR:
- splits off code into helper functions
- adds clearer error messages
- stops updating the inputs destructively (leading to slightly slower
tests)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103196
Approved by: https://github.com/bdhirsh, https://github.com/soulitzer
2023-06-14 14:00:12 +00:00
Nikita Karetnikov
d38b651d51 [pt2] add SymInt support for cosine_similarity (#103400)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103400
Approved by: https://github.com/ezyang, https://github.com/Skylion007
2023-06-13 21:23:48 +00:00
Nikita Karetnikov
c07634436e [pt2] add SymInt support for bilinear (#103396)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103396
Approved by: https://github.com/ezyang
2023-06-13 21:23:48 +00:00
Nikita Karetnikov
4a76fb49f3 [pt2] add metas for avg_pool3d and avg_pool3d_backward (#103392)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103392
Approved by: https://github.com/ezyang
2023-06-13 21:23:46 +00:00
PyTorch MergeBot
5ed618132f Revert "change pre_autograd to pre_dispatch tracing (#101818)"
This reverts commit b0392de2c3.

Reverted https://github.com/pytorch/pytorch/pull/101818 on behalf of https://github.com/izaitsevfb due to Breaks internal builds see D46629736 TypeError: wrap_key() got an unexpected keyword argument pre_autograd ([comment](https://github.com/pytorch/pytorch/pull/101818#issuecomment-1587837667))
2023-06-12 18:16:37 +00:00
Nikita Karetnikov
2b3d955ffd [pt2] add meta and SymInt support for linalg_matrix_exp (#102945)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102945
Approved by: https://github.com/lezcano
2023-06-09 22:45:16 +00:00
Brian Hirsh
b0392de2c3 change pre_autograd to pre_dispatch tracing (#101818)
We discussed in a composability meeting a few weeks ago that `pre_autograd` should probably be renamed to `pre_dispatch`.

One question in this PR was: should I re-use a dispatch key? Or should I create a new dispatch key (that yet again corresponds to "top of the dispatcher")?

~~For now, I ended up sticking our proxy mode on the mode stack corresponding to `PythonTLSSnapshot`, because it was simple and it works. It looks like one of the functorch dispatch keys has higher priority though, so it's possible that functorch will end up running first. Open to options, but we can consider adding a new dispatch key later if that becomes a problem~~

Update: I added a dedicated dispatch key, `PreDispatch`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101818
Approved by: https://github.com/ezyang, https://github.com/Neilblaze, https://github.com/albanD, https://github.com/zou3519
2023-06-09 17:30:15 +00:00
Nikita Karetnikov
1fcc67fd8c [pt2] add SymInt support for linalg.tensorsolve (#102466)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102466
Approved by: https://github.com/Skylion007, https://github.com/lezcano
2023-06-06 08:06:55 +00:00
Nikita Karetnikov
ec0aa965da [pt2] add meta for _linalg_solve_ex (#102454)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102454
Approved by: https://github.com/lezcano
2023-06-06 08:06:55 +00:00
Nikita Karetnikov
4bda4a7e4d [pt2] add meta for lu_unpack (#102937)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102937
Approved by: https://github.com/lezcano
2023-06-06 08:06:53 +00:00
Nikita Karetnikov
6ac3352a37 [pt2] add meta for _linalg_slogdet (#102464)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102464
Approved by: https://github.com/ezyang
2023-06-05 03:17:08 +00:00
Nikita Karetnikov
757791d1e3 [pt2] add SymInt support for linalg.vander (#102469)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102469
Approved by: https://github.com/Skylion007, https://github.com/lezcano
2023-06-04 09:58:02 +00:00
Edward Z. Yang
8bbef821c3 Add some unit tests from cm3leon involving repeat_interleave (#102733)
These actually were fixed by https://github.com/pytorch/pytorch/pull/102570
but that PR doesn't test guard-freeness, so here you go.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102733
Approved by: https://github.com/zou3519
2023-06-02 15:35:35 +00:00
PyTorch MergeBot
95cdd58c8f Revert "[pt2] add SymInt support for linalg.tensorsolve (#102466)"
This reverts commit b1b76f614d.

Reverted https://github.com/pytorch/pytorch/pull/102466 on behalf of https://github.com/clee2000 due to reverting b/c stack https://github.com/pytorch/pytorch/pull/102469#issuecomment-1569041604, i think this is the one that actually causes the test to fail ([comment](https://github.com/pytorch/pytorch/pull/102466#issuecomment-1569045123))
2023-05-30 20:26:46 +00:00
PyTorch MergeBot
463df86ce8 Revert "[pt2] add SymInt support for linalg.vander (#102469)"
This reverts commit 05717895aa.

Reverted https://github.com/pytorch/pytorch/pull/102469 on behalf of https://github.com/clee2000 due to broke test_aotdispatch on linux ex 05717895aa https://github.com/pytorch/pytorch/actions/runs/5125654882/jobs/9219389448, shows up as green on pr due to bug with keep-going flag and reruns ([comment](https://github.com/pytorch/pytorch/pull/102469#issuecomment-1569041604))
2023-05-30 20:24:26 +00:00
Nikita Karetnikov
05717895aa [pt2] add SymInt support for linalg.vander (#102469)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102469
Approved by: https://github.com/Skylion007, https://github.com/lezcano
2023-05-30 19:50:16 +00:00
Nikita Karetnikov
b1b76f614d [pt2] add SymInt support for linalg.tensorsolve (#102466)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102466
Approved by: https://github.com/Skylion007, https://github.com/lezcano
2023-05-30 19:50:15 +00:00
Nikita Karetnikov
0ba81ce8fe [pt2] add SymInt support for linalg.tensorinv (#102465)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102465
Approved by: https://github.com/Skylion007, https://github.com/lezcano
2023-05-30 19:50:14 +00:00
Nikita Karetnikov
995ac703cd [pt2] add SymInt support for linalg.pinv (#102367)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102367
Approved by: https://github.com/lezcano
2023-05-27 11:10:47 +00:00
PyTorch MergeBot
da3aba1e46 Revert "[pt2] add SymInt support for linalg.pinv (#102367)"
This reverts commit 0d5b74da0c.

Reverted https://github.com/pytorch/pytorch/pull/102367 on behalf of https://github.com/kit1980 due to Broke slow tests https://github.com/pytorch/pytorch/actions/runs/5095190248/jobs/9160028124 ([comment](https://github.com/pytorch/pytorch/pull/102367#issuecomment-1565104562))
2023-05-27 00:33:42 +00:00
Nikita Karetnikov
0d5b74da0c [pt2] add SymInt support for linalg.pinv (#102367)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102367
Approved by: https://github.com/lezcano
2023-05-26 15:20:34 +00:00
vfdev-5
e3d97b6213 [inductor] Added smooth_l1_loss refs (#102077)
Added `smooth_l1_loss` to refs + tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102077
Approved by: https://github.com/lezcano, https://github.com/ngimel
2023-05-24 15:07:08 +00:00
Edward Z. Yang
3318a832b3 Tighten FakeTensor reentrancy asserts, add debugging (#102091)
When investigating failures in https://github.com/pytorch/pytorch/pull/100017 I realized that we were reentering FakeTensorMode even though there was already one on the stack. Although we have attempted assert for these cases in the past, e.g., as in https://github.com/pytorch/pytorch/pull/97186 it seems that the existing protections were insufficient.

In this particular case, the reapplication of FakeTensorMode was due to an interaction with NotImplemented multiple dispatch handling. If proxy tensor mode detects an unrecognized tensor type (this includes FakeTensor, if it is not tracked with a proxy), it will return NotImplemented to give this tensor a chance to unpack itself into proxyable operation. However, this is never the right thing for FakeTensor, where no unpacking is possible. However, today, FakeTensor attempts to reapply the FakeTensorMode, resulting in FakeTensorMode being twice on the stack.

This PR does a number of things:

* It adds an assert in `FakeTensorMode.__torch_dispatch__` that you must not already have this mode on the stack, this is ALWAYS an error
* It modifies `FakeTensor.__torch_dispatch__` to return `NotImplemented` if the mode is already active. This prevents us from readding the mode on the stack
* It adds a new logging artifact `not_implemented` which you can use to get debug logs about all of the times a `__torch_dispatch__` handler returned NotImplemented and why it did so. Your subclass has to manually opt into this logging, but I inserted the necessary logs for ProxyTensorMode and FakeTensor(Mode)
* `with fake_mode` now no-ops if the fake mode is already on the stack, which is what users want anyway
* I am BREAKING pre-autograd tracing, because it is currently doing something weird with the original C++ mode stack. Brian is going to follow up with a fix next week.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102091
Approved by: https://github.com/thiagocrepaldi, https://github.com/eellison, https://github.com/wanchaol, https://github.com/bdhirsh
2023-05-24 05:37:51 +00:00
Nikita Karetnikov
e79d9b9938 [pt2] add SymInt support for linalg.matrix_power (#101940)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101940
Approved by: https://github.com/lezcano, https://github.com/ezyang
2023-05-24 00:21:52 +00:00
Nikita Karetnikov
42b974e8f7 [pt2] add meta for linalg_lu_solve (#101836)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101836
Approved by: https://github.com/lezcano
2023-05-24 00:21:50 +00:00
Khushi
1aaf0396eb [reland][opinfo] empty_strided (#101782)
Follows #100223

Previous PR: #100890

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101782
Approved by: https://github.com/ezyang
2023-05-19 03:06:29 +00:00
drisspg
6f13d6892a Add meta support for multinomial (#101324)
# Summary
Found this when trying to compile the text gen loop of nanogpt here: b33289942b/torchbenchmark/models/nanogpt_generate/model.py (L322)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101324
Approved by: https://github.com/ngimel
2023-05-19 00:04:26 +00:00
Angela Yi
72a73ef67b Add aten.searchsorted.Tensor meta kernel (#101637)
Test Plan: CI

Differential Revision: D45933187

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101637
Approved by: https://github.com/ezyang
2023-05-18 06:55:11 +00:00
PyTorch MergeBot
dfac4364c4 Revert "[opinfo] empty_strided (#100890)"
This reverts commit 01c7106580.

Reverted https://github.com/pytorch/pytorch/pull/100890 on behalf of https://github.com/PaliC due to broke test_ops.py slow test ([comment](https://github.com/pytorch/pytorch/pull/100890#issuecomment-1551903975))
2023-05-17 19:00:15 +00:00
ydwu4
326a4cc815 Support map autograd and pytree in/out. (#101633)
Rebased https://github.com/pytorch/pytorch/pull/100494 and added dummy AOTConfig.

This PR adds autograd and pytree support for map operator.

Implementation-wise:

1. We temporarily make two HigherOrderOperators, "map" and "map_impl":
- "map" is user-facing. Currently, it unwraps the pytrees in inputs and create a flat_fn for it. Dynamo currently cannot deal with pytree.tree_flatten and pytree.tree_unflatten, we therefore make it a HigherOrderOperator to trigger dynamo logic of handling HigherOrderOperators.
- "map_impl" is the actual operator that works with the rest of torch subsystems such as functionalization, make_fx. It accepts flattend arguments, and a num_mapped_args integer denoting how many of the flattend arguments need to mapped i.e. their first dimension will be unstacked.

2. We create the forward and backward graph in autograd key and call torch.autograd.Function. Currently, the backward graph is recomputation-based and we need to partition the joint graph in the future to be more efficient.

Example traced graphs for map operators:
### Case 1: simple f and autograd
```python
def f(x, y):
    return x + y

def g(xs, y):
    out = control_flow.map(f, xs, y)
    return torch.autograd.grad(out, (xs, y), torch.ones_like(out))

gm = make_fx(g, tracing_mode="symbolic")(torch.ones(3, 4, 5, requires_grad=True), torch.ones(5, requires_grad=True))
# gm.print_readable() produces following:
class g(torch.nn.Module):
    def forward(self, xs_1: f32[3, s1, s2], y_1: f32[s2]):
        # No stacktrace found for following nodes
        body_graph_0 = self.body_graph_0
        map_impl = torch.ops.map_impl(body_graph_0, 1, xs_1, y_1);  body_graph_0 = None
        getitem: f32[3, s1, s2] = map_impl[0];  map_impl = None
        ones_like: f32[3, s1, s2] = torch.ops.aten.ones_like.default(getitem, pin_memory = False)
        is_same_size = torch.ops.aten.is_same_size.default(getitem, ones_like);  getitem = None
        body_graph_1 = self.body_graph_1
        map_impl_1 = torch.ops.map_impl(body_graph_1, 2, xs_1, ones_like, y_1);  body_graph_1 = xs_1 = ones_like = None
        getitem_1 = map_impl_1[0]
        getitem_2: f32[3, s1, s2] = map_impl_1[1]
        getitem_3: f32[3, s2] = map_impl_1[2];  map_impl_1 = None
        sum_1: f32[1, s2] = torch.ops.aten.sum.dim_IntList(getitem_3, [0], True);  getitem_3 = None
        sym_size: Sym(s2) = torch.ops.aten.sym_size(y_1, 0);  y_1 = None
        view: f32[s2] = torch.ops.aten.view.default(sum_1, [sym_size]);  sum_1 = sym_size = None
        return (getitem_2, view)

    class <lambda>(torch.nn.Module):
        def forward(self, arg0_1, arg1_1: f32[s1, s2], arg2_1: f32[s2]):
            # No stacktrace found for following nodes
            add: f32[s1, s2] = torch.ops.aten.add.Tensor(arg1_1, arg2_1);  arg1_1 = arg2_1 = None
            return [add]

    class <lambda>(torch.nn.Module):
        def forward(self, arg0_1, arg1_1: f32[s1, s2], arg2_1: f32[s1, s2], arg3_1: f32[s2]):
            # No stacktrace found for following nodes
            add: f32[s1, s2] = torch.ops.aten.add.Tensor(arg1_1, arg3_1);  arg1_1 = None
            is_same_size = torch.ops.aten.is_same_size.default(add, arg2_1);  add = None
            sum_1: f32[1, s2] = torch.ops.aten.sum.dim_IntList(arg2_1, [0], True)
            sym_size: Sym(s2) = torch.ops.aten.sym_size(arg3_1, 0);  arg3_1 = None
            view: f32[s2] = torch.ops.aten.view.default(sum_1, [sym_size]);  sum_1 = sym_size = None
            return [None, arg2_1, view]
```
### Case 2: list input/output f and autograd
```python
def f(x, y):
    return [x[0].cos() + y.sin(), x[1].sin() * y.cos()]

def g(xs, y):
    out = control_flow.map(f, xs, y)
    flat_out, _ = pytree.tree_flatten(out)
    flat_inp, _ = pytree.tree_flatten((xs, y))
    requires_grad_inp = [inp for inp in flat_inp if inp.requires_grad]
    return torch.autograd.grad(flat_out, requires_grad_inp, [torch.ones_like(out) for out in flat_out])

gm = make_fx(g, tracing_mode="symbolic")(
    [torch.ones(3, 4, 5), torch.ones(3, 4, 5, requires_grad=True)],
    torch.ones(5, requires_grad=True))

# gm.print_readable() produces following:
class g(torch.nn.Module):
    def forward(self, xs, y):
        xs_1: f32[3, s1, s2], xs_2: f32[3, s1, s2], y_1: f32[s2], = fx_pytree.tree_flatten_spec([xs, y], self._in_spec)
        # No stacktrace found for following nodes
        body_graph_0 = self.body_graph_0
        map_impl = torch.ops.map_impl(body_graph_0, 2, xs_1, xs_2, y_1);  body_graph_0 = None
        getitem: f32[3, s1, s2] = map_impl[0]
        getitem_1: f32[3, s1, s2] = map_impl[1];  map_impl = None
        ones_like: f32[3, s1, s2] = torch.ops.aten.ones_like.default(getitem, pin_memory = False)
        ones_like_1: f32[3, s1, s2] = torch.ops.aten.ones_like.default(getitem_1, pin_memory = False)
        is_same_size = torch.ops.aten.is_same_size.default(getitem, ones_like);  getitem = None
        is_same_size_1 = torch.ops.aten.is_same_size.default(getitem_1, ones_like_1);  getitem_1 = None
        body_graph_1 = self.body_graph_1
        map_impl_1 = torch.ops.map_impl(body_graph_1, 4, xs_1, xs_2, ones_like, ones_like_1, y_1);  body_graph_1 = xs_1 = xs_2 = ones_like = ones_like_1 = None
        getitem_2 = map_impl_1[0]
        getitem_3 = map_impl_1[1]
        getitem_4: f32[3, s1, s2] = map_impl_1[2]
        getitem_5: f32[3, s2] = map_impl_1[3];  map_impl_1 = None
        sum_1: f32[1, s2] = torch.ops.aten.sum.dim_IntList(getitem_5, [0], True);  getitem_5 = None
        sym_size: Sym(s2) = torch.ops.aten.sym_size(y_1, 0);  y_1 = None
        view: f32[s2] = torch.ops.aten.view.default(sum_1, [sym_size]);  sum_1 = sym_size = None
        return pytree.tree_unflatten([getitem_4, view], self._out_spec)

    class <lambda>(torch.nn.Module):
        def forward(self, arg0_1, arg1_1: f32[s1, s2], arg2_1: f32[s1, s2], arg3_1: f32[s2]):
            # No stacktrace found for following nodes
            cos: f32[s1, s2] = torch.ops.aten.cos.default(arg1_1);  arg1_1 = None
            sin: f32[s2] = torch.ops.aten.sin.default(arg3_1)
            add: f32[s1, s2] = torch.ops.aten.add.Tensor(cos, sin);  cos = sin = None
            sin_1: f32[s1, s2] = torch.ops.aten.sin.default(arg2_1);  arg2_1 = None
            cos_1: f32[s2] = torch.ops.aten.cos.default(arg3_1);  arg3_1 = None
            mul: f32[s1, s2] = torch.ops.aten.mul.Tensor(sin_1, cos_1);  sin_1 = cos_1 = None
            return [add, mul]

    class <lambda>(torch.nn.Module):
        def forward(self, arg0_1, arg1_1: f32[s1, s2], arg2_1: f32[s1, s2], arg3_1: f32[s1, s2], arg4_1: f32[s1, s2], arg5_1: f32[s2]):
            # No stacktrace found for following nodes
            cos: f32[s1, s2] = torch.ops.aten.cos.default(arg1_1);  arg1_1 = None
            sin: f32[s2] = torch.ops.aten.sin.default(arg5_1)
            add: f32[s1, s2] = torch.ops.aten.add.Tensor(cos, sin);  cos = sin = None
            sin_1: f32[s1, s2] = torch.ops.aten.sin.default(arg2_1)
            cos_1: f32[s2] = torch.ops.aten.cos.default(arg5_1)
            mul: f32[s1, s2] = torch.ops.aten.mul.Tensor(sin_1, cos_1)
            is_same_size = torch.ops.aten.is_same_size.default(add, arg3_1);  add = None
            is_same_size_1 = torch.ops.aten.is_same_size.default(mul, arg4_1);  mul = None
            mul_1: f32[s1, s2] = torch.ops.aten.mul.Tensor(arg4_1, sin_1);  sin_1 = None
            mul_2: f32[s1, s2] = torch.ops.aten.mul.Tensor(arg4_1, cos_1);  arg4_1 = cos_1 = None
            sum_1: f32[1, s2] = torch.ops.aten.sum.dim_IntList(mul_1, [0], True);  mul_1 = None
            sym_size: Sym(s2) = torch.ops.aten.sym_size(arg5_1, 0)
            view: f32[s2] = torch.ops.aten.view.default(sum_1, [sym_size]);  sum_1 = None

            #
            sin_2: f32[s2] = torch.ops.aten.sin.default(arg5_1)
            neg: f32[s2] = torch.ops.aten.neg.default(sin_2);  sin_2 = None
            mul_3: f32[s2] = torch.ops.aten.mul.Tensor(view, neg);  view = neg = None
            cos_2: f32[s1, s2] = torch.ops.aten.cos.default(arg2_1);  arg2_1 = None
            mul_4: f32[s1, s2] = torch.ops.aten.mul.Tensor(mul_2, cos_2);  mul_2 = cos_2 = None
            sum_2: f32[1, s2] = torch.ops.aten.sum.dim_IntList(arg3_1, [0], True);  arg3_1 = None
            view_1: f32[s2] = torch.ops.aten.view.default(sum_2, [sym_size]);  sum_2 = sym_size = None
            cos_3: f32[s2] = torch.ops.aten.cos.default(arg5_1);  arg5_1 = None
            mul_5: f32[s2] = torch.ops.aten.mul.Tensor(view_1, cos_3);  view_1 = cos_3 = None
            add_1: f32[s2] = torch.ops.aten.add.Tensor(mul_3, mul_5);  mul_3 = mul_5 = None
            return [None, None, mul_4, add_1]
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101633
Approved by: https://github.com/zou3519
2023-05-17 16:52:26 +00:00
PyTorch MergeBot
e69198b043 Revert "Support map autograd and pytree in/out (#100494)"
This reverts commit b8fa41be9d.

Reverted https://github.com/pytorch/pytorch/pull/100494 on behalf of https://github.com/PaliC due to breaking tests on trunk, please check hud.pytorch.org for the broken tests ([comment](https://github.com/pytorch/pytorch/pull/100494#issuecomment-1550454835))
2023-05-16 22:50:18 +00:00
ydwu4
b8fa41be9d Support map autograd and pytree in/out (#100494)
This PR adds autograd and pytree support for map operator.

Implementation-wise:

1. We temporarily make two HigherOrderOperators, "map" and "map_impl":
- "map" is user-facing. Currently, it unwraps the pytrees in inputs and create a flat_fn for it. Dynamo currently cannot deal with pytree.tree_flatten and pytree.tree_unflatten, we therefore make it a HigherOrderOperator to trigger dynamo logic of handling HigherOrderOperators.
- "map_impl" is the actual operator that works with the rest of torch subsystems such as functionalization, make_fx. It accepts flattend arguments, and a num_mapped_args integer denoting how many of the flattend arguments need to mapped i.e. their first dimension will be unstacked.

2. We create the forward and backward graph in autograd key and call torch.autograd.Function. Currently, the backward graph is recomputation-based and we need to partition the joint graph in the future to be more efficient.

Example traced graphs for map operators:
### Case 1: simple f and autograd
```python
def f(x, y):
    return x + y

def g(xs, y):
    out = control_flow.map(f, xs, y)
    return torch.autograd.grad(out, (xs, y), torch.ones_like(out))

gm = make_fx(g, tracing_mode="symbolic")(torch.ones(3, 4, 5, requires_grad=True), torch.ones(5, requires_grad=True))
# gm.print_readable() produces following:
class g(torch.nn.Module):
    def forward(self, xs_1: f32[3, s1, s2], y_1: f32[s2]):
        # No stacktrace found for following nodes
        body_graph_0 = self.body_graph_0
        map_impl = torch.ops.map_impl(body_graph_0, 1, xs_1, y_1);  body_graph_0 = None
        getitem: f32[3, s1, s2] = map_impl[0];  map_impl = None
        ones_like: f32[3, s1, s2] = torch.ops.aten.ones_like.default(getitem, pin_memory = False)
        is_same_size = torch.ops.aten.is_same_size.default(getitem, ones_like);  getitem = None
        body_graph_1 = self.body_graph_1
        map_impl_1 = torch.ops.map_impl(body_graph_1, 2, xs_1, ones_like, y_1);  body_graph_1 = xs_1 = ones_like = None
        getitem_1 = map_impl_1[0]
        getitem_2: f32[3, s1, s2] = map_impl_1[1]
        getitem_3: f32[3, s2] = map_impl_1[2];  map_impl_1 = None
        sum_1: f32[1, s2] = torch.ops.aten.sum.dim_IntList(getitem_3, [0], True);  getitem_3 = None
        sym_size: Sym(s2) = torch.ops.aten.sym_size(y_1, 0);  y_1 = None
        view: f32[s2] = torch.ops.aten.view.default(sum_1, [sym_size]);  sum_1 = sym_size = None
        return (getitem_2, view)

    class <lambda>(torch.nn.Module):
        def forward(self, arg0_1, arg1_1: f32[s1, s2], arg2_1: f32[s2]):
            # No stacktrace found for following nodes
            add: f32[s1, s2] = torch.ops.aten.add.Tensor(arg1_1, arg2_1);  arg1_1 = arg2_1 = None
            return [add]

    class <lambda>(torch.nn.Module):
        def forward(self, arg0_1, arg1_1: f32[s1, s2], arg2_1: f32[s1, s2], arg3_1: f32[s2]):
            # No stacktrace found for following nodes
            add: f32[s1, s2] = torch.ops.aten.add.Tensor(arg1_1, arg3_1);  arg1_1 = None
            is_same_size = torch.ops.aten.is_same_size.default(add, arg2_1);  add = None
            sum_1: f32[1, s2] = torch.ops.aten.sum.dim_IntList(arg2_1, [0], True)
            sym_size: Sym(s2) = torch.ops.aten.sym_size(arg3_1, 0);  arg3_1 = None
            view: f32[s2] = torch.ops.aten.view.default(sum_1, [sym_size]);  sum_1 = sym_size = None
            return [None, arg2_1, view]
```
### Case 2: list input/output f and autograd
```python
def f(x, y):
    return [x[0].cos() + y.sin(), x[1].sin() * y.cos()]

def g(xs, y):
    out = control_flow.map(f, xs, y)
    flat_out, _ = pytree.tree_flatten(out)
    flat_inp, _ = pytree.tree_flatten((xs, y))
    requires_grad_inp = [inp for inp in flat_inp if inp.requires_grad]
    return torch.autograd.grad(flat_out, requires_grad_inp, [torch.ones_like(out) for out in flat_out])

gm = make_fx(g, tracing_mode="symbolic")(
    [torch.ones(3, 4, 5), torch.ones(3, 4, 5, requires_grad=True)],
    torch.ones(5, requires_grad=True))

# gm.print_readable() produces following:
class g(torch.nn.Module):
    def forward(self, xs, y):
        xs_1: f32[3, s1, s2], xs_2: f32[3, s1, s2], y_1: f32[s2], = fx_pytree.tree_flatten_spec([xs, y], self._in_spec)
        # No stacktrace found for following nodes
        body_graph_0 = self.body_graph_0
        map_impl = torch.ops.map_impl(body_graph_0, 2, xs_1, xs_2, y_1);  body_graph_0 = None
        getitem: f32[3, s1, s2] = map_impl[0]
        getitem_1: f32[3, s1, s2] = map_impl[1];  map_impl = None
        ones_like: f32[3, s1, s2] = torch.ops.aten.ones_like.default(getitem, pin_memory = False)
        ones_like_1: f32[3, s1, s2] = torch.ops.aten.ones_like.default(getitem_1, pin_memory = False)
        is_same_size = torch.ops.aten.is_same_size.default(getitem, ones_like);  getitem = None
        is_same_size_1 = torch.ops.aten.is_same_size.default(getitem_1, ones_like_1);  getitem_1 = None
        body_graph_1 = self.body_graph_1
        map_impl_1 = torch.ops.map_impl(body_graph_1, 4, xs_1, xs_2, ones_like, ones_like_1, y_1);  body_graph_1 = xs_1 = xs_2 = ones_like = ones_like_1 = None
        getitem_2 = map_impl_1[0]
        getitem_3 = map_impl_1[1]
        getitem_4: f32[3, s1, s2] = map_impl_1[2]
        getitem_5: f32[3, s2] = map_impl_1[3];  map_impl_1 = None
        sum_1: f32[1, s2] = torch.ops.aten.sum.dim_IntList(getitem_5, [0], True);  getitem_5 = None
        sym_size: Sym(s2) = torch.ops.aten.sym_size(y_1, 0);  y_1 = None
        view: f32[s2] = torch.ops.aten.view.default(sum_1, [sym_size]);  sum_1 = sym_size = None
        return pytree.tree_unflatten([getitem_4, view], self._out_spec)

    class <lambda>(torch.nn.Module):
        def forward(self, arg0_1, arg1_1: f32[s1, s2], arg2_1: f32[s1, s2], arg3_1: f32[s2]):
            # No stacktrace found for following nodes
            cos: f32[s1, s2] = torch.ops.aten.cos.default(arg1_1);  arg1_1 = None
            sin: f32[s2] = torch.ops.aten.sin.default(arg3_1)
            add: f32[s1, s2] = torch.ops.aten.add.Tensor(cos, sin);  cos = sin = None
            sin_1: f32[s1, s2] = torch.ops.aten.sin.default(arg2_1);  arg2_1 = None
            cos_1: f32[s2] = torch.ops.aten.cos.default(arg3_1);  arg3_1 = None
            mul: f32[s1, s2] = torch.ops.aten.mul.Tensor(sin_1, cos_1);  sin_1 = cos_1 = None
            return [add, mul]

    class <lambda>(torch.nn.Module):
        def forward(self, arg0_1, arg1_1: f32[s1, s2], arg2_1: f32[s1, s2], arg3_1: f32[s1, s2], arg4_1: f32[s1, s2], arg5_1: f32[s2]):
            # No stacktrace found for following nodes
            cos: f32[s1, s2] = torch.ops.aten.cos.default(arg1_1);  arg1_1 = None
            sin: f32[s2] = torch.ops.aten.sin.default(arg5_1)
            add: f32[s1, s2] = torch.ops.aten.add.Tensor(cos, sin);  cos = sin = None
            sin_1: f32[s1, s2] = torch.ops.aten.sin.default(arg2_1)
            cos_1: f32[s2] = torch.ops.aten.cos.default(arg5_1)
            mul: f32[s1, s2] = torch.ops.aten.mul.Tensor(sin_1, cos_1)
            is_same_size = torch.ops.aten.is_same_size.default(add, arg3_1);  add = None
            is_same_size_1 = torch.ops.aten.is_same_size.default(mul, arg4_1);  mul = None
            mul_1: f32[s1, s2] = torch.ops.aten.mul.Tensor(arg4_1, sin_1);  sin_1 = None
            mul_2: f32[s1, s2] = torch.ops.aten.mul.Tensor(arg4_1, cos_1);  arg4_1 = cos_1 = None
            sum_1: f32[1, s2] = torch.ops.aten.sum.dim_IntList(mul_1, [0], True);  mul_1 = None
            sym_size: Sym(s2) = torch.ops.aten.sym_size(arg5_1, 0)
            view: f32[s2] = torch.ops.aten.view.default(sum_1, [sym_size]);  sum_1 = None

            #
            sin_2: f32[s2] = torch.ops.aten.sin.default(arg5_1)
            neg: f32[s2] = torch.ops.aten.neg.default(sin_2);  sin_2 = None
            mul_3: f32[s2] = torch.ops.aten.mul.Tensor(view, neg);  view = neg = None
            cos_2: f32[s1, s2] = torch.ops.aten.cos.default(arg2_1);  arg2_1 = None
            mul_4: f32[s1, s2] = torch.ops.aten.mul.Tensor(mul_2, cos_2);  mul_2 = cos_2 = None
            sum_2: f32[1, s2] = torch.ops.aten.sum.dim_IntList(arg3_1, [0], True);  arg3_1 = None
            view_1: f32[s2] = torch.ops.aten.view.default(sum_2, [sym_size]);  sum_2 = sym_size = None
            cos_3: f32[s2] = torch.ops.aten.cos.default(arg5_1);  arg5_1 = None
            mul_5: f32[s2] = torch.ops.aten.mul.Tensor(view_1, cos_3);  view_1 = cos_3 = None
            add_1: f32[s2] = torch.ops.aten.add.Tensor(mul_3, mul_5);  mul_3 = mul_5 = None
            return [None, None, mul_4, add_1]
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100494
Approved by: https://github.com/zou3519
2023-05-16 22:05:11 +00:00
Nikita Karetnikov
42e65a2587 [pt2] add meta for linalg_lu_factor_ex (#101375)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101375
Approved by: https://github.com/lezcano
2023-05-16 20:56:54 +00:00
Khushi
01c7106580 [opinfo] empty_strided (#100890)
Follows: #100223

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100890
Approved by: https://github.com/ezyang
2023-05-15 23:39:39 +00:00
Nikita Karetnikov
9eb1748b2b [pt2] add meta and SymInt support for linalg_lu (#101372)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101372
Approved by: https://github.com/lezcano, https://github.com/albanD
2023-05-15 20:25:00 +00:00
Nikita Karetnikov
ac4cc63ae2 [pt2] add meta for linalg_ldl_solve (#101367)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101367
Approved by: https://github.com/lezcano
2023-05-15 20:25:00 +00:00
Nikita Karetnikov
7dd8e08817 [pt2] add meta for linalg_ldl_factor_ex (#101362)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101362
Approved by: https://github.com/lezcano
2023-05-15 02:56:49 +00:00
Nikita Karetnikov
a8964d6377 [pt2] add meta and SymInt support for linalg_householder_product (#101315)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101315
Approved by: https://github.com/lezcano
2023-05-15 02:56:49 +00:00
Nikita Karetnikov
6abde61f8e [pt2] add meta function for _linalg_eigh (#100964)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100964
Approved by: https://github.com/ezyang
2023-05-10 15:45:15 +00:00
Khushi
51fe53e619 [opinfo] item (#100313)
Follows #100223

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100313
Approved by: https://github.com/ezyang
2023-05-10 11:32:45 +00:00