Commit Graph

547 Commits

Author SHA1 Message Date
ydwu4
3062e267b1 [cond] Add more tests for valid inputs of cond (#110727)
This PR adds a parametrized test for cond. It tests cond can be traced with valid inputs. Specifically valid inputs is combination of:
- pred (python boolean, boolean tensor, int tensor, scalar tensor)
- true_fn/false_fn (func, obj, nn_module)
- Operands (0 or more tensor inputs), tested with 0  and 2
- closures (0 or more tensor closures), tested with 0 and 2
- nested_level (no nesting or level-2 nested cond)

What this test doesn't cover:
- pred: symbolic boolean expression as predicate
- true_fn/false_fn: that mutates indermediate tensors
- operands: non-tensor operands such as float, int
- closures: nn_module attribute closures, python constant closures
- nested_level: 3+

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110727
Approved by: https://github.com/zou3519
2023-10-11 15:56:13 +00:00
Brian Hirsh
ba86dfcd83 AOTDispatch subclass (#104483)
This is a PoC of AOTDispatch support. This PR actually works on basic examples, and I'm working on testing it out on `DTensor` (with @wanchaol), `SemiStructuredSparsityTensor` (with @jcaip), and `FP8Tensor`.

There are some design decisions baked into the PR that I think we need consensus on though - so I'm planning on writing a larger design doc to go over the changes.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104483
Approved by: https://github.com/ezyang
2023-10-10 16:13:16 +00:00
Guilherme Leobas
0a580da582 Add batch decomposition for torch.linalg.eigh (#110640)
Closes https://github.com/pytorch/pytorch/issues/108481

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110640
Approved by: https://github.com/kshitij12345, https://github.com/zou3519
2023-10-09 21:36:49 +00:00
chilli
c596db762f refactor aotautograd to set requires_grad on info rather than a separate array (#110720)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110720
Approved by: https://github.com/bdhirsh
2023-10-09 20:18:19 +00:00
vfdev-5
d2a2a67fa4 Added new test sample to interpolate op in OpInfo (#104181)
Description:
- Added new test sample to interpolate op in OpInfo
- Fixed silent issue with zero tensor test sample for uint8 dtype

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104181
Approved by: https://github.com/pmeier, https://github.com/lezcano
2023-10-09 10:55:56 +00:00
Kazuaki Ishizaki
a603dcc307 Fix typo under test directory (#110826)
This PR fixes typo `the the` of comments in files under `test` directory.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110826
Approved by: https://github.com/Skylion007
2023-10-08 20:52:38 +00:00
kshitij12345
b8a3998c23 add batch rule for missing inplace ops (#110692)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110692
Approved by: https://github.com/ezyang
2023-10-06 20:53:28 +00:00
kshitij12345
371d8ba599 vmap: decompose real and imag instead of registering batch rule (#110508)
Clean-up

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110508
Approved by: https://github.com/zou3519
2023-10-06 06:01:12 +00:00
Brian Hirsh
b457e3f79a Reland attempt 2 of "Update AOTAutograd to use FunctionalTensorMode instead of C++ functionalization (#106406)" (#109906)" (#110079)
The first reland broke internal (failing diff: D49617462).

The major error looks like it's because there's an internal-only higher order op that needs a new functionalization rule. I'm going to land an internal diff for that and confirm tests pass before relanding this PR.

Also confirmed that the issue from https://github.com/pytorch/pytorch/issues/110121 is fixed, and added a test.

This reverts commit 1b90f07f5a.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110079
Approved by: https://github.com/ezyang
2023-10-03 18:50:25 +00:00
PyTorch MergeBot
df3ab70dde Revert "Added new test sample to interpolate op in OpInfo (#104181)"
This reverts commit 87f8bc65f8.

Reverted https://github.com/pytorch/pytorch/pull/104181 on behalf of https://github.com/peterbell10 due to Causing OOM in slow-gradcheck ([comment](https://github.com/pytorch/pytorch/pull/104181#issuecomment-1745472323))
2023-10-03 18:07:02 +00:00
vfdev-5
d9fe1713c3 Enabled batch rule decompositions for upsample*.vec ops (#110333)
Follow-up PR to https://github.com/pytorch/pytorch/pull/110172
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110333
Approved by: https://github.com/zou3519
2023-10-03 06:58:18 +00:00
vfdev-5
87f8bc65f8 Added new test sample to interpolate op in OpInfo (#104181)
Description:
- Added new test sample to interpolate op in OpInfo
- Fixed silent issue with zero tensor test sample for uint8 dtype

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104181
Approved by: https://github.com/pmeier, https://github.com/lezcano
2023-10-02 15:35:48 +00:00
ydwu4
5f7eff0adb Replace node.meta source_fn with source_fn_stack (#108595)
A resubmit of https://github.com/pytorch/pytorch/pull/108447. Copy over the descriptions:

This is a follow-up of the discussion in https://github.com/pytorch/pytorch/pull/108356, where we want to repalce source_fn with source_fn_stack

Before this PR, for the following example:
```python
backend = EagerAndRecordGraphs()

@torch.compile(backend=backend, fullgraph=True)
def cond_f(pred, pred2, x, y):
    def true_fn(pred2, x, y):
        return x + y

    def false_fn(pred2, x, y):
        def true_fn2(x, y):
            return x.sin() - y.cos()

        def false_fn2(x, y):
            return x.cos() - y.sin()

        return control_flow.cond(pred2, true_fn2, false_fn2, (x, y))

    return control_flow.cond(pred, true_fn, false_fn, (pred2, x, y))
```
The graph captured is shown below:
```python
class GraphModule(torch.nn.Module):
    def forward(self, L_pred_ : torch.Tensor, L_pred2_ : torch.Tensor, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
        l_pred_ = L_pred_
        l_pred2_ = L_pred2_
        l_x_ = L_x_
        l_y_ = L_y_

        cond_true_1 = self.cond_true_1
        cond_false_1 = self.cond_false_1
        cond = torch.ops.higher_order.cond(l_pred_, cond_true_1, cond_false_1, [l_pred2_, l_x_, l_y_]);  l_pred_ = cond_true_1 = cond_false_1 = l_pred2_ = l_x_ = l_y_ = None
        return (cond,)

    class GraphModule(torch.nn.Module):
        def forward(self, l_pred2_, l_x_, l_y_):
            add = l_x_ + l_y_;  l_x_ = l_y_ = None
            return add

    class GraphModule(torch.nn.Module):
        def forward(self, l_pred2_, l_x_, l_y_):
            cond_true_0 = self.cond_true_0
            cond_false_0 = self.cond_false_0
            cond = torch.ops.higher_order.cond(l_pred2_, cond_true_0, cond_false_0, [l_x_, l_y_]);  l_pred2_ = cond_true_0 = cond_false_0 = l_x_ = l_y_ = None
            return cond

        class GraphModule(torch.nn.Module):
            def forward(self, l_x_, l_y_):
                sin = l_x_.sin();  l_x_ = None
                cos = l_y_.cos();  l_y_ = None
                sub = sin - cos;  sin = cos = None
                return sub

        class GraphModule(torch.nn.Module):
            def forward(self, l_x_, l_y_):
                cos = l_x_.cos();  l_x_ = None
                sin = l_y_.sin();  l_y_ = None
                sub = cos - sin;  cos = sin = None
                return sub
```
the source_fn for inner cond, sin, cos will be a (name, target) tuple:
```
('cond', <torch._ops.HigherOrderOperator object at xxx>)
('sin', 'sin')
('cos', 'cos')
('sub'. <built-in function sub>)
```

After this pr, the source_fn_stack will be a list of (name, target) tuple. The bottom of stack is the end of the list.
```
[('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>)],
[('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('sin', 'sin')],
[('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cos', 'cos')]
[('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('sub', <built-in function sub>)]
```

Test Plan:
See added tests in test_higher_order_ops.py and modify existing test.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108595
Approved by: https://github.com/angelayi, https://github.com/zou3519
2023-09-28 18:18:36 +00:00
vfdev-5
c62be12061 Added batch rules for _upsample_bi*2d_aa and _upsample_bi*2d_aa_backward (#110172)
Description:
- Added batch rules for `_upsample_bi*2d_aa` and `_upsample_bi*2d_aa_backward`
- Added few more test cases into `sample_inputs_upsample_aten`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110172
Approved by: https://github.com/kshitij12345, https://github.com/zou3519
2023-09-28 17:42:48 +00:00
Edward Z. Yang
f7c9ef88f5 Add masked_select abstract impl (#110103)
Fixes https://github.com/pytorch/pytorch/issues/109871

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110103
Approved by: https://github.com/bdhirsh
2023-09-27 04:07:58 +00:00
PyTorch MergeBot
194d9aa0f2 Revert "[Dynamo] Match closures by code ID (#109427)"
This reverts commit 3de0857503.

Reverted https://github.com/pytorch/pytorch/pull/109427 on behalf of https://github.com/voznesenskym due to Fails test `PYTORCH_TEST_WITH_DYNAMO=1 python test_ops.py -k test_out_warning__refs_cat_cpu ([comment](https://github.com/pytorch/pytorch/pull/109427#issuecomment-1736101561))
2023-09-26 18:54:36 +00:00
PyTorch MergeBot
1b90f07f5a Revert "Reland "Update AOTAutograd to use FunctionalTensorMode instead of C++ functionalization (#106406)" (#109906)"
This reverts commit d0fe8fa5db.

Reverted https://github.com/pytorch/pytorch/pull/109906 on behalf of https://github.com/atalman due to Breaks internal tests ([comment](https://github.com/pytorch/pytorch/pull/109906#issuecomment-1735416852))
2023-09-26 12:10:25 +00:00
Li-Huai (Allan) Lin
129f535778 [VMAP] Add linspace and logspace batch rules (#105451)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105451
Approved by: https://github.com/zou3519
ghstack dependencies: #107958, #104889
2023-09-26 04:08:24 +00:00
Ken Jin
3de0857503 [Dynamo] Match closures by code ID (#109427)
Closes https://github.com/pytorch/pytorch/issues/107866

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109427
Approved by: https://github.com/ezyang, https://github.com/jansel
2023-09-25 19:10:35 +00:00
Brian Hirsh
d0fe8fa5db Reland "Update AOTAutograd to use FunctionalTensorMode instead of C++ functionalization (#106406)" (#109906)
I'm pretty sure this is fixed but I'll run inductor and trunk CI. The failing test in trunk previously was that the selective activation checkpointing code that landed recently assumes that it can detect whether or not AOTAutograd is running by seeing if the inputs to SAC are C++ `FunctionalTensorWrapper`s

previous land broke some inductor trunk tests

This reverts commit 629a628cc8.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109906
Approved by: https://github.com/ezyang
2023-09-25 14:53:54 +00:00
jjsjann123
0d3db1048a remove nvfuser test in upstream pytorch (#109918)
Removing nvfuser related tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109918
Approved by: https://github.com/msaroufim
2023-09-24 13:49:37 +00:00
PyTorch MergeBot
629a628cc8 Revert "Update AOTAutograd to use FunctionalTensorMode instead of C++ functionalization (#106406)"
This reverts commit b5d6e831a9.

Reverted https://github.com/pytorch/pytorch/pull/106406 on behalf of https://github.com/malfet due to Broke lots of tests on trunk ([comment](https://github.com/pytorch/pytorch/pull/106406#issuecomment-1731524917))
2023-09-22 14:32:34 +00:00
Brian Hirsh
b5d6e831a9 Update AOTAutograd to use FunctionalTensorMode instead of C++ functionalization (#106406)
Now that FunctionalTensor and `FunctionalTensorMode` are lower down in this stack, the changes in this PR are more mechanical: Everywhere in AOTAutograd that I used to use the C++ functionalization API, I now use the python functionalization API.

Note that this doesn't actually cause functionalization to run underneath torch_dispatch. I'm saving that re-ordering for later in the stack.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106406
Approved by: https://github.com/ezyang
ghstack dependencies: #108654, #109662, #109632, #109023
2023-09-22 07:09:04 +00:00
Brian Hirsh
238fb66085 python functionalization: support higher order ops (#108656)
We now have two types of functionalization, C++ Functionalization (through the `Functionalize` dispatch key), and python functionalization (through the `FunctionalTensorMode` torch_dispatch mode).

This means that all higher order ops need custom functionalization rules for the python variant too. I added them here, as well as a helper function `dispatch_functionalize()` - equivalent to `torch.func.functionalize()`, except that it uses `FunctionalTensorMode`.

In theory we could have secretly switched `torch.func.functionalize` to use `FunctionalTensorMode`. This would be BC-breaking, though, since `FunctionalTensorMode` isn't composable with the other functorch transforms (the functorch layer-mode stack doesn't know how to re-order torch_dispatch modes arbitrarily).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108656
Approved by: https://github.com/zou3519
ghstack dependencies: #109024, #109248
2023-09-20 04:37:31 +00:00
FFFrog
70f2adaec3 Setup_context does not contain default values of forward() (#108561)
Fixes #108529

As the title shown.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108561
Approved by: https://github.com/soulitzer
2023-09-19 16:23:52 +00:00
Jez Ng
7f3885137f Add meta function for _segment_reduce (#109359)
This fixes numerous tests which were xfailing. For instance, the
`_segment_reduce.lengths` OpInfo test, which was previously relying on
the fallback kernel to determine the shape of the meta tensor. The
fallback kernel would fail with

    segment_reduce(): Expected all rows of lengths along axis to sum to data.size(lengths.dim()-1) when !unsafe.

as it was trying to read the values of a meta tensor.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109359
Approved by: https://github.com/ezyang
2023-09-16 13:31:03 +00:00
ydwu4
706d8e2230 [dynamo] Respect shape dynamism of SymInt sized tensor (#109331)
Before this PR, if we run the following code:
```python
def true_fn(x):
    return x - x.cos()

def false_fn(x):
    return x + x.sin()

def foo(x):
    return cond(x.shape[0] == 4, true_fn, false_fn, [x])
gm = make_fx(foo, tracing_mode='symbolic')(torch.ones(3, 4))
gm = make_fx(foo, tracing_mode='symbolic')(torch.ones(4, 5))
```
we'll have the following error:
```python
Traceback (most recent call last):
  File "/home/yidi/local/pytorch/make_fx.py", line 16, in <module>
    gm = make_fx(foo, tracing_mode='symbolic')(torch.ones(4, 5))
  File "/home/yidi/local/pytorch/torch/fx/experimental/proxy_tensor.py", line 841, in wrapped
    t = dispatch_trace(wrap_key(func, args, fx_tracer, pre_dispatch), tracer=fx_tracer, concrete_args=tuple(phs))
  File "/home/yidi/local/pytorch/torch/_compile.py", line 24, in inner
    return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 397, in _fn
    return fn(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/fx/experimental/proxy_tensor.py", line 461, in dispatch_trace
    graph = tracer.trace(root, concrete_args)
  File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 397, in _fn
    return fn(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/fx/_symbolic_trace.py", line 817, in trace
    (self.create_arg(fn(*args)),),
  File "/home/yidi/local/pytorch/torch/fx/experimental/proxy_tensor.py", line 497, in wrapped
    out = f(*tensors)
  File "/home/yidi/local/pytorch/make_fx.py", line 13, in foo
    return control_flow.cond(x.shape[0] == 4, true_fn, false_fn, [x])
  File "/home/yidi/local/pytorch/torch/_higher_order_ops/cond.py", line 151, in cond
    return torch.compile(cond_op, backend="eager", fullgraph=True)(
  File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 397, in _fn
    return fn(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 545, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 140, in _fn
    return fn(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 380, in _convert_frame_assert
    return _compile(
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 561, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/home/yidi/local/pytorch/torch/_dynamo/utils.py", line 189, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 483, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/home/yidi/local/pytorch/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
    transformations(instructions, code_options)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 432, in transform
    tracer = InstructionTranslator(
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2032, in __init__
    self.symbolic_locals = collections.OrderedDict(
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2035, in <genexpr>
    VariableBuilder(
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 229, in __call__
    vt = self._wrap(value).clone(**self.options())
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 374, in _wrap
    return type_dispatch(self, value)
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 808, in wrap_listlike
    output = [
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 809, in <listcomp>
    VariableBuilder(self.tx, GetItemSource(self.get_source(), i))(
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 229, in __call__
    vt = self._wrap(value).clone(**self.options())
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 374, in _wrap
    return type_dispatch(self, value)
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 808, in wrap_listlike
    output = [
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 809, in <listcomp>
    VariableBuilder(self.tx, GetItemSource(self.get_source(), i))(
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 229, in __call__
    vt = self._wrap(value).clone(**self.options())
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 374, in _wrap
    return type_dispatch(self, value)
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 1040, in wrap_tensor
    tensor_variable = wrap_fx_proxy(
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 1267, in wrap_fx_proxy
    return wrap_fx_proxy_cls(
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 1382, in wrap_fx_proxy_cls
    example_value = wrap_to_fake_tensor_and_record(
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 1652, in wrap_to_fake_tensor_and_record
    dynamic_dims, constraint_dims = _automatic_dynamic(
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 1550, in _automatic_dynamic
    if dim is not None and e.size()[i] != dim:
  File "/home/yidi/local/pytorch/torch/__init__.py", line 352, in __bool__
    return self.node.bool_()
  File "/home/yidi/local/pytorch/torch/fx/experimental/symbolic_shapes.py", line 1019, in bool_
    return self.guard_bool("", 0)
  File "/home/yidi/local/pytorch/torch/fx/experimental/symbolic_shapes.py", line 1001, in guard_bool
    r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node)
  File "/home/yidi/local/pytorch/torch/fx/experimental/recording.py", line 227, in wrapper
    return fn(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/fx/experimental/symbolic_shapes.py", line 3793, in evaluate_expr
    assert orig_expr == hint, f"{orig_expr} != {hint}"
AssertionError: False != True

from user code:

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True
```

It's because we record the SymInt in the frame state in _automatic_dynamic the first time we compile the function. Then In the second time, when we are given a symint sized input with different hints, the comparison fails.

Implementation:
This PR returns shape dynamism according to the dynamism of inputs: if a diemsion is SymInt, return DYNAMIC else return static.

Test Plan:
Add a test.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109331
Approved by: https://github.com/ezyang
2023-09-16 02:56:53 +00:00
ydwu4
1aba61e977 Allow cond to have more dynamo cache beyond limit (#109318)
This is short term workaround for https://github.com/pytorch/pytorch/issues/108500. In the long term, we should have separate caches if cond appears at different places in user code or per true_fn/false_fn cache.

Test Plan:
see added test. It tests cond can go beyond cache limit.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109318
Approved by: https://github.com/ezyang
2023-09-15 15:33:36 +00:00
ydwu4
2bf7a283cb Remove expected test failures for cond (#108709)
Remove the expected failure in def test_control_flow_tracing(self) by chaning the error message to `Expected pred to be bool or tensor, but got Proxy\(eq\)`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108709
Approved by: https://github.com/ezyang, https://github.com/zou3519
ghstack dependencies: #107662, #107850
2023-09-14 21:34:31 +00:00
ydwu4
6140facf00 Support SymBool input to torch.compile (#107850)
We could have SymBool inputs for torch.compile, e.g. in the following situation:
```
def f(x:torch.Tensor):
  pred = x.size(0) == 3
  torch.compile(f)(pred, x)

make_fx(f, tracing_mode="symbolic")(x)
```

The idea of this PR (credit to @ezyang) is to support SymBool by re-using the infra we've already had for SymInt so that we don't need to replicate a lot of stuff.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107850
Approved by: https://github.com/ezyang
ghstack dependencies: #107662
2023-09-14 21:34:31 +00:00
PyTorch MergeBot
47f79e9a2b Revert "Support SymBool input to torch.compile (#107850)"
This reverts commit 9f6d70b2fd.

Reverted https://github.com/pytorch/pytorch/pull/107850 on behalf of https://github.com/huydhn due to Sorry for reverting this, but test_export_with_symbool_inputs is failing in trunk a08e1370ef ([comment](https://github.com/pytorch/pytorch/pull/107850#issuecomment-1718675877))
2023-09-14 02:53:36 +00:00
PyTorch MergeBot
de76c88d90 Revert "Remove expected test failures for cond (#108709)"
This reverts commit a08e1370ef.

Reverted https://github.com/pytorch/pytorch/pull/108709 on behalf of https://github.com/huydhn due to Sorry for reverting this, but test_export_with_symbool_inputs is failing in trunk a08e1370ef ([comment](https://github.com/pytorch/pytorch/pull/108709#issuecomment-1718669964))
2023-09-14 02:47:28 +00:00
ydwu4
a08e1370ef Remove expected test failures for cond (#108709)
Remove the expected failure in def test_control_flow_tracing(self) by chaning the error message to `Expected pred to be bool or tensor, but got Proxy\(eq\)`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108709
Approved by: https://github.com/ezyang, https://github.com/zou3519
ghstack dependencies: #107662, #107850
2023-09-14 01:16:29 +00:00
ydwu4
9f6d70b2fd Support SymBool input to torch.compile (#107850)
We could have SymBool inputs for torch.compile, e.g. in the following situation:
```
def f(x:torch.Tensor):
  pred = x.size(0) == 3
  torch.compile(f)(pred, x)

make_fx(f, tracing_mode="symbolic")(x)
```

The idea of this PR (credit to @ezyang) is to support SymBool by re-using the infra we've already had for SymInt so that we don't need to replicate a lot of stuff.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107850
Approved by: https://github.com/ezyang
ghstack dependencies: #107662
2023-09-14 01:16:29 +00:00
Guilherme Leobas
49e3d76684 Add SymInt support to torch.take_along_dim (#108879)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108879
Approved by: https://github.com/Skylion007, https://github.com/lezcano, https://github.com/Chillee
2023-09-13 23:13:09 +00:00
ydwu4
33c94b8b16 Better error handling for cond (#108817)
## Exception in cond:
For code below:
```python
import torch
import functorch.experimental.control_flow as control_flow
def true_fn(x):
    return x.sin()

def false_fn(x):
    return x, x

def f(x, y):
    return control_flow.cond(y, true_fn, false_fn, [x])

f(torch.ones(3, 4), torch.tensor(False))
```
The original exception stack trace is:
```python
Traceback (most recent call last):
  File "/home/yidi/local/pytorch/test_exc.py", line 33, in <module>
    f(torch.ones(3, 4), torch.tensor(False))
  File "/home/yidi/local/pytorch/test_exc.py", line 31, in f
    return control_flow.cond(y, true_fn, false_fn, [x])
  File "/home/yidi/local/pytorch/torch/_higher_order_ops/cond.py", line 154, in cond
    return torch.compile(cond_op, backend="eager", fullgraph=True)(
  File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 365, in _fn
    return fn(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 513, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 140, in _fn
    return fn(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 380, in _convert_frame_assert
    return _compile(
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 560, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/home/yidi/local/pytorch/torch/_dynamo/utils.py", line 197, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 482, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/home/yidi/local/pytorch/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
    transformations(instructions, code_options)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 449, in transform
    tracer.run()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2083, in run
    super().run()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 733, in run
    and self.step()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 696, in step
    getattr(self, inst.opname)(inst)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 397, in wrapper
    return inner_fn(self, inst)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1164, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars.items)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 570, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 418, in call_function
    (false_r, false_graph, false_lifted_freevars) = speculate_branch(False)
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 410, in speculate_branch
    raise UncapturedHigherOrderOpError(
torch._dynamo.exc.UncapturedHigherOrderOpError: Expected branch to return a single tensor

from user code:
   File "/home/yidi/local/pytorch/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True
```
After this PR we get:
```python
Traceback (most recent call last):
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 50, in graph_break_as_hard_error
    return fn(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 429, in call_function
    (false_r, false_graph, false_lifted_freevars) = speculate_branch(False)
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 421, in speculate_branch
    unimplemented(
  File "/home/yidi/local/pytorch/torch/_dynamo/exc.py", line 187, in unimplemented
    raise Unsupported(msg)
torch._dynamo.exc.Unsupported: Expected branch to return a single tensor

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/yidi/local/pytorch/test_exc.py", line 33, in <module>
    f(torch.ones(3, 4), torch.tensor(False))
  File "/home/yidi/local/pytorch/test_exc.py", line 31, in f
    return control_flow.cond(y, true_fn, false_fn, [x])
  File "/home/yidi/local/pytorch/torch/_higher_order_ops/cond.py", line 154, in cond
    return torch.compile(cond_op, backend="eager", fullgraph=True)(
  File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 338, in _fn
    return fn(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 500, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 140, in _fn
    return fn(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 382, in _convert_frame_assert
    return _compile(
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 562, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/home/yidi/local/pytorch/torch/_dynamo/utils.py", line 189, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 484, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/home/yidi/local/pytorch/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
    transformations(instructions, code_options)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 451, in transform
    tracer.run()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2088, in run
    super().run()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 728, in run
    and self.step()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 691, in step
    getattr(self, inst.opname)(inst)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 392, in wrapper
    return inner_fn(self, inst)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1159, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars.items)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 565, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 53, in graph_break_as_hard_error
    raise UncapturedHigherOrderOpError(reason + msg) from e
torch._dynamo.exc.UncapturedHigherOrderOpError: Cond doesn't work unless it is captured completely with torch.compile. Scroll up to find out what causes the graph break.

from user code:
   File "/home/yidi/local/pytorch/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True
```
## Exception during speculating branches
The example code below has a inplace-buffer mutation error,
```python
import torch
import functorch.experimental.control_flow as control_flow

class Foo(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.register_buffer("buffer", torch.ones(6, 4))

    def forward(self, x):
        def true_fn(x):
            self.buffer += 1
            return self.buffer.sum() + x.sum()

        def false_fn(x):
            return (x - 1).sum()

        return control_flow.cond(x.shape[0] > 4, true_fn, false_fn, [x])

mod_for_compile = torch.compile(Foo(), backend="eager", dynamic=True)
mod_for_compile(torch.ones(3, 4))
```

Before this PR the exception looks like:
```python
[2023-09-08 15:20:03,332] [0/0] torch._dynamo.variables.higher_order_ops: [WARNING] speculate_subgraph: while introspecting cond, we were unable to trace function `true_fn` into a single graph. This means that Dynamo was unable to prove safety for this API and will fall back to eager-mode PyTorch, which could lead to a slowdown.
[2023-09-08 15:20:03,332] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] Can't inplace modify module params/buffers inside HigherOrderOp
Traceback (most recent call last):
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 163, in speculate_subgraph
    output = f.call_function(tx, args, sub_kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 90, in call_function
    return tx.inline_user_function_return(
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 606, in inline_user_function_return
    result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2200, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2316, in inline_call_
    tracer.run()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 733, in run
    and self.step()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 696, in step
    getattr(self, inst.opname)(inst)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1219, in STORE_ATTR
    .call_function(self, [obj, ConstantVariable(inst.argval), val], {})
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/builtin.py", line 618, in call_function
    result = handler(tx, *args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/builtin.py", line 1169, in call_setattr
    raise AttributeMutationError(
torch._dynamo.exc.AttributeMutationError: Can't inplace modify module params/buffers inside HigherOrderOp

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 394, in speculate_branch
    ret_val, ret_graph, ret_lifted_freevars = speculate_subgraph(
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 222, in speculate_subgraph
    raise Unsupported(
torch._dynamo.exc.Unsupported: speculate_subgraph: while introspecting cond, we were unable to trace function `true_fn` into a single graph. This means that Dynamo was unable to prove safety for this API and will fall back to eager-mode PyTorch, which could lead to a slowdown. Scroll up for the stack trace of the initial exception. The reason was: Can't inplace modify module params/buffers inside HigherOrderOp

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/yidi/local/pytorch/test_exc.py", line 20, in <module>
    mod_for_compile(torch.ones(3, 4))
  File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1519, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1528, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 365, in _fn
    return fn(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1519, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1528, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 513, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 632, in _convert_frame
    result = inner_convert(frame, cache_entry, hooks, frame_state)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 140, in _fn
    return fn(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 380, in _convert_frame_assert
    return _compile(
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 560, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/home/yidi/local/pytorch/torch/_dynamo/utils.py", line 197, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 482, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/home/yidi/local/pytorch/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
    transformations(instructions, code_options)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 449, in transform
    tracer.run()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2083, in run
    super().run()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 733, in run
    and self.step()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 696, in step
    getattr(self, inst.opname)(inst)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 397, in wrapper
    return inner_fn(self, inst)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1124, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 570, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 261, in call_function
    return super().call_function(tx, args, kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 90, in call_function
    return tx.inline_user_function_return(
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 606, in inline_user_function_return
    result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2200, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2316, in inline_call_
    tracer.run()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 733, in run
    and self.step()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 696, in step
    getattr(self, inst.opname)(inst)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 397, in wrapper
    return inner_fn(self, inst)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1124, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 570, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 415, in call_function
    (true_r, true_graph, true_lifted_freevars) = speculate_branch(True)
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 405, in speculate_branch
    raise UncapturedHigherOrderOpError(
torch._dynamo.exc.UncapturedHigherOrderOpError: Cond doesn't work unless it is captured completely with torch.compile

from user code:
   File "/home/yidi/local/pytorch/test_exc.py", line 16, in forward
    return control_flow.cond(x.shape[0] > 4, true_fn, false_fn, [x])
  File "/home/yidi/local/pytorch/torch/_higher_order_ops/cond.py", line 127, in cond
    return cond_op(pred, true_fn, false_fn, operands)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True
```

after this PR, the only difference is the error message of UncapturedHigherOrderOpError changes from `Cond doesn't work unless it is captured completely with torch.compile` to `Cond doesn't work unless it is captured completely with torch.compile. Scroll up to find out what causes the graph break`.

```python
[2023-09-08 15:17:02,052] [0/0] torch._dynamo.variables.higher_order_ops: [WARNING] speculate_subgraph: while introspecting cond, we were unable to trace function `true_fn` into a single graph. This means that Dynamo was unable to prove safety for this API and will fall back to eager-mode PyTorch, which could lead to a slowdown.
[2023-09-08 15:17:02,052] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] Can't inplace modify module params/buffers inside HigherOrderOp
Traceback (most recent call last):
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 177, in speculate_subgraph
    output = f.call_function(tx, args, sub_kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 90, in call_function
    return tx.inline_user_function_return(
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 601, in inline_user_function_return
    result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2193, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2300, in inline_call_
    tracer.run()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 728, in run
    and self.step()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 691, in step
    getattr(self, inst.opname)(inst)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1214, in STORE_ATTR
    .call_function(self, [obj, ConstantVariable(inst.argval), val], {})
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/builtin.py", line 618, in call_function
    result = handler(tx, *args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/builtin.py", line 1169, in call_setattr
    raise AttributeMutationError(
torch._dynamo.exc.AttributeMutationError: Can't inplace modify module params/buffers inside HigherOrderOp

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 50, in graph_break_as_hard_error
    return fn(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 426, in call_function
    (true_r, true_graph, true_lifted_freevars) = speculate_branch(True)
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 410, in speculate_branch
    ret_val, ret_graph, ret_lifted_freevars = speculate_subgraph(
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 236, in speculate_subgraph
    raise Unsupported(
torch._dynamo.exc.Unsupported: speculate_subgraph: while introspecting cond, we were unable to trace function `true_fn` into a single graph. This means that Dynamo was unable to prove safety for this API and will fall back to eager-mode PyTorch, which could lead to a slowdown. Scroll up for the stack trace of the initial exception. The reason was: Can't inplace modify module params/buffers inside HigherOrderOp

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/yidi/local/pytorch/test_exc.py", line 20, in <module>
    mod_for_compile(torch.ones(3, 4))
  File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 338, in _fn
    return fn(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 500, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 634, in _convert_frame
    result = inner_convert(frame, cache_entry, hooks, frame_state)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 140, in _fn
    return fn(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 382, in _convert_frame_assert
    return _compile(
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 562, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/home/yidi/local/pytorch/torch/_dynamo/utils.py", line 189, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 484, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/home/yidi/local/pytorch/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
    transformations(instructions, code_options)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 451, in transform
    tracer.run()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2088, in run
    super().run()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 728, in run
    and self.step()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 691, in step
    getattr(self, inst.opname)(inst)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 392, in wrapper
    return inner_fn(self, inst)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1119, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 565, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 261, in call_function
    return super().call_function(tx, args, kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 90, in call_function
    return tx.inline_user_function_return(
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 601, in inline_user_function_return
    result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2193, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2300, in inline_call_
    tracer.run()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 728, in run
    and self.step()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 691, in step
    getattr(self, inst.opname)(inst)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 392, in wrapper
    return inner_fn(self, inst)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1119, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 565, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 53, in graph_break_as_hard_error
    raise UncapturedHigherOrderOpError(reason + msg) from e
torch._dynamo.exc.UncapturedHigherOrderOpError: Cond doesn't work unless it is captured completely with torch.compile. Scroll up to find out what causes the graph break.

from user code:
   File "/home/yidi/local/pytorch/test_exc.py", line 16, in forward
    return control_flow.cond(x.shape[0] > 4, true_fn, false_fn, [x])
  File "/home/yidi/local/pytorch/torch/_higher_order_ops/cond.py", line 127, in cond
    return cond_op(pred, true_fn, false_fn, operands)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108817
Approved by: https://github.com/zou3519
2023-09-13 23:03:59 +00:00
Li-Huai (Allan) Lin
b2cba439b4 Introduce Tensor overload to linspace and logspace (#104889)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104889
Approved by: https://github.com/zou3519
ghstack dependencies: #107958
2023-09-11 23:30:40 +00:00
PyTorch MergeBot
a7f5abeade Revert "Introduce Tensor overload to linspace and logspace (#104889)"
This reverts commit 57e5239321.

Reverted https://github.com/pytorch/pytorch/pull/104889 on behalf of https://github.com/clee2000 due to sorry have to revert this to revert https://github.com/pytorch/pytorch/pull/107958 ([comment](https://github.com/pytorch/pytorch/pull/104889#issuecomment-1714305768))
2023-09-11 17:33:48 +00:00
Li-Huai (Allan) Lin
57e5239321 Introduce Tensor overload to linspace and logspace (#104889)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104889
Approved by: https://github.com/zou3519
ghstack dependencies: #107958
2023-09-11 15:29:39 +00:00
Joel Schlosser
b928e08f3d Initial vmap + NT support with unbind fallback (#106786)
PoC demonstrating vmap + NT based on the [design doc](https://docs.google.com/document/d/1dVVk6TOqz93PLTIneU2T3xaxCs9qZ0MaJyCvOAp_bC0). This PR:
* Allows `BatchedTensorImpl`s to contain NTs
* Introduces a `BatchedNestedTensor` dispatch key for NT-specific batching rules
* Provides a batching rule fallback that unbinds the NTs -> performs computation on constituent -> rebinds results into NT

Restrictions:
* Only supports one level of vmap
* Only supports vmapping over dim=0 for NTs
    * For operations with mixed NT / dense inputs, support is also limited to dim=0 for the dense inputs
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106786
Approved by: https://github.com/zou3519
2023-09-07 13:53:20 +00:00
Brian Hirsh
fac7a1f730 fix issue with lift_fresh_copy when using export + compile (#108243)
Fixes https://github.com/pytorch/pytorch/issues/105327. The problem is that `lift_fresh_copy()`'s functionalization implementation currently assumes that the input is always not functional. This is apparently too limiting: when you have "user" code like this (which can potentially come from exporting a model and then running compile on the resulting graph):
```
tensor_constant0 = torch.tensor(2)
lift_fresh = torch.ops.aten.lift_fresh_copy.default(tensor_constant0)
```

When we run this through AOTAutograd, the first call (torch.tensor(2)) will **already** be lifted into a functional tensor wrapper - so the `lift_fresh_copy` call doesn't need to do any "lifting" anymore - it just needs to do a clone.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108243
Approved by: https://github.com/albanD
ghstack dependencies: #108081, #108235
2023-09-05 20:02:35 +00:00
ydwu4
e3933609d4 Make make_fx cond preserve node meta (#108356)
**Motivation:**
Currently, for the following code that exports cond operator:
```python
import torch
from functorch.experimental.control_flow import cond

class MySubModule(torch.nn.Module):
    def foo(self, x):
        return x.cos()

    def forward(self, x):
        return self.foo(x)

class CondBranchClassMethod(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.subm = MySubModule()

    def bar(self, x):
        return x.sin()

    def forward(self, x):
        return cond(x.shape[0] <= 2, self.subm.forward, self.bar, [x])

from torch._export import capture_pre_autograd_graph

example_inputs = (torch.randn(1, 3, 3, 3),)
m = CondBranchClassMethod()
m.eval()
gm = capture_pre_autograd_graph(m, example_inputs)
print(gm)

# source_fn for original cond op, getattr submodule op are all cond op
for n in gm.graph.nodes:
    print("n:", n.format_node(), n.meta)

print("\n\n\n")
# source_fn for submodule nodes are all cond op
# Expected: ideally this should be the real ops, e.g. torch.sin, aten.cos, etc
for n in gm.submodule_0.graph.nodes:
    print("n:", n.format_node(), n.meta)
```

Output is like below:
```
GraphModule(
  (submodule_0): GraphModule()
  (submodule_1): GraphModule()
)

def forward(self, arg_0):
    arg0_1, = fx_pytree.tree_flatten_spec([arg_0], self._in_spec)
    submodule_0 = self.submodule_0
    submodule_1 = self.submodule_1
    cond = torch.ops.higher_order.cond(True, submodule_0, submodule_1, [arg0_1]);  submodule_0 = submodule_1 = arg0_1 = None
    return pytree.tree_unflatten((cond,), self._out_spec)

# To see more debug info, please use `graph_module.print_readable()`
n: %arg0_1 : [num_users=1] = placeholder[target=arg0_1] {'val': FakeTensor(..., size=(1, 3, 3, 3)), 'tensor_meta': None, 'is_torch_exported': True, 'stack_trace': 'NoneType: None\n'}
n: %submodule_0 : [num_users=1] = get_attr[target=submodule_0] {'stack_trace': 'NoneType: None\n', 'source_fn': ('cond', <torch._ops.HigherOrderOperator object at 0x7f68ae93efd0>), 'original_aten': None, 'from_node': [('cond', <torch._ops.HigherOrderOperator object at 0x7f68ae93efd0>), ('conditional', <torch._ops.HigherOrderOperator object at 0x7f68ae93efd0>), ('cond', <torch._ops.HigherOrderOperator object at 0x7f68ae93efd0>)], 'seq_nr': -1}
n: %submodule_1 : [num_users=1] = get_attr[target=submodule_1] {'stack_trace': 'NoneType: None\n', 'source_fn': ('cond', <torch._ops.HigherOrderOperator object at 0x7f68ae93efd0>), 'original_aten': None, 'from_node': [('cond', <torch._ops.HigherOrderOperator object at 0x7f68ae93efd0>), ('conditional', <torch._ops.HigherOrderOperator object at 0x7f68ae93efd0>), ('cond', <torch._ops.HigherOrderOperator object at 0x7f68ae93efd0>)], 'seq_nr': -1}
n: %cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (True, %submodule_0, %submodule_1, [%arg0_1]), kwargs = {}) {'stack_trace': 'NoneType: None\n', 'source_fn': ('cond', <torch._ops.HigherOrderOperator object at 0x7f68ae93efd0>), 'original_aten': None, 'from_node': [('cond', <torch._ops.HigherOrderOperator object at 0x7f68ae93efd0>), ('conditional', <torch._ops.HigherOrderOperator object at 0x7f68ae93efd0>), ('cond', <torch._ops.HigherOrderOperator object at 0x7f68ae93efd0>)], 'seq_nr': -1, 'val': FakeTensor(..., size=(1, 3, 3, 3)), 'tensor_meta': None, 'is_torch_exported': True}
n: return (cond,) {'stack_trace': 'NoneType: None\n', 'from_node': [('output', 'output')], 'seq_nr': -1, 'is_torch_exported': True, 'val': (FakeTensor(..., size=(1, 3, 3, 3)),), 'tensor_meta': (None,)}

n: %arg0_1 : [num_users=1] = placeholder[target=arg0_1] {'stack_trace': '  File "<ipython-input-9-2a8c7c0498ed>", line 36, in forward\n    return cond(x.shape[0] <= 2, self.subm.forward, self.bar, [x])\n', 'source_fn': ('cond', <torch._ops.HigherOrderOperator object at 0x7f68ae93efd0>), 'original_aten': None, 'from_node': [('cond', <torch._ops.HigherOrderOperator object at 0x7f68ae93efd0>), ('arg0_1', 'arg0_1')], 'seq_nr': -1, 'val': FakeTensor(..., size=(1, 3, 3, 3)), 'tensor_meta': None}
n: %cos_default : [num_users=1] = call_function[target=torch.ops.aten.cos.default](args = (%arg0_1,), kwargs = {}) {'stack_trace': '  File "<ipython-input-9-2a8c7c0498ed>", line 36, in forward\n    return cond(x.shape[0] <= 2, self.subm.forward, self.bar, [x])\n', 'source_fn': ('cond', <torch._ops.HigherOrderOperator object at 0x7f68ae93efd0>), 'original_aten': <OpOverload(op='aten.cos', overload='default')>, 'from_node': [('cond', <torch._ops.HigherOrderOperator object at 0x7f68ae93efd0>), ('cos', <OpOverload(op='aten.cos', overload='default')>), ('cos_default', <OpOverload(op='aten.cos', overload='default')>)], 'seq_nr': -1, 'val': FakeTensor(..., size=(1, 3, 3, 3)), 'tensor_meta': None}
n: return cos_default {'stack_trace': '  File "<ipython-input-9-2a8c7c0498ed>", line 36, in forward\n    return cond(x.shape[0] <= 2, self.subm.forward, self.bar, [x])\n', 'source_fn': ('cond', <torch._ops.HigherOrderOperator object at 0x7f68ae93efd0>), 'original_aten': None, 'from_node': [('cond', <torch._ops.HigherOrderOperator object at 0x7f68ae93efd0>), ('output', 'output')], 'seq_nr': -1, 'val': FakeTensor(..., size=(1, 3, 3, 3)), 'tensor_meta': None}
```

As we can see, the meta of nodes in subgrarphs are overriden with the cond's metat data. This is because the function _set_current_meta is only invoked at the top-level graph module in interpreter. When we're calling into cond and dealing with the submodules here, we didn't set the current_meta to the meta of nodes of subgraph properly.

**Implementation:**
This pr fixes it by: in trace_cond, we optionally use an fx.interpreter to interpret the subgraphs so that the meta data is preserved only when the following conditions are satisfied:
- The subgraphs are graph_module: this is necessary that we use the fx.Interpreter
- The current make_fx has turned preserve_node_meta on (as is the case for capture_pre_autograd_graph).

**Test Plan**
See added tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108356
Approved by: https://github.com/SherlockNoMad
2023-09-01 22:43:55 +00:00
ydwu4
f8c93df2d1 Fix boolean tensor for map (#108289)
torch.empty_strided is able to create a new tensor based on the meta data. For boolean tensor, we call a clone directly, however, we'll get a functional tensor if input is a functional tensor and that functional tensor won't be tracked by tracer's tensor_tracker after dispatching so it become a tensor\_constant in the graph if create_arg. So we manually unwrap the functional tensor before calling clone.

Test Plan:
See added test.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108289
Approved by: https://github.com/angelayi
2023-08-31 19:17:28 +00:00
ydwu4
49e964cad6 Automatically turn on dynamo in cond (#108028)
A replacement of https://github.com/pytorch/pytorch/pull/107932.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108028
Approved by: https://github.com/zou3519
ghstack dependencies: #108025, #108026, #108027
2023-08-28 10:16:41 +00:00
David Watson
598babf017 Added normal op decomposition for specializations of the normal op (#106792)
This fixes running normal with the meta key.

```
import torch

t = torch.tensor(4.0, device='meta')
torch.normal(0.5, t)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106792
Approved by: https://github.com/lezcano
2023-08-25 16:18:28 +00:00
ydwu4
a408920817 Reland fakify FunctionalTensor (#107569)
Try to rebase and reland https://github.com/pytorch/pytorch/pull/107062 . One difference compared with previous is to make the DTensor logic same as previously in _clone_input.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107569
Approved by: https://github.com/zou3519
2023-08-22 15:46:25 +00:00
Kshiteej K
977a77ca2c Manually enable capture_func_transforms for testing (#107122)
Manually enable `capture_func_transforms` for testing as plan is to default `capture_func_transforms` to False in 2.1. (enable it so that we still test the support on release branch).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107122
Approved by: https://github.com/zou3519
2023-08-21 20:38:33 +00:00
PyTorch MergeBot
96c5be8bc4 Revert "Fakify leaf of FunctionalTensor (#107062)"
This reverts commit 3349725766.

Reverted https://github.com/pytorch/pytorch/pull/107062 on behalf of https://github.com/ydwu4 due to This appears to have broken the test TestDTensorCompile.test_dtensor_fullgraph. Probably a land race ([comment](https://github.com/pytorch/pytorch/pull/107062#issuecomment-1685447747))
2023-08-21 00:30:16 +00:00
Liao, Xuan
71632d4d24 [cpu] add sdpa choice and UT (#105131)
Feature RFC: https://github.com/pytorch/rfcs/pull/56.

Write an SDPA selecting function for CPU to automatically choose one SDPA implementation among several ones. There are two CPU implementations which could be chosen: the unfused SDPA and flash attention. In general, flash attention has a higher priority than the unfused SDPA. For cases where flash attention is not applicable, such as manually disabling flash attention or the inputs not 4 dimensional, the unfused SDPA is chosen.

## Performance of the stack

### NanoGPT's SDPA kernel
Using benchmark [repo](https://github.com/mingfeima/bench_sdpa/blob/main/README.md), with one socket.
Shape: Batch size 1, Sequence length 1024, Head number 25, Head size 64.
Machine: SPR.

| Dtype    | Causal   | Mode      | SDPA            | Time (ms per iter) | Speedup |
| -------- | -------- | -------   | -------         | -------            | ------- |
| float32  | FALSE    | Inference | Unfused         | 3.081              |         |
|          |          |           | Flash attention | 1.665              | **1.85045** |
| float32  | TRUE     | Inference | Unfused         | 3.463              |         |
|          |          |           | Flash attention | 1.662              | **2.083634**|
| bfloat16 | FALSE    | Inference | Unfused         | 1.203              |         |
|          |          |           | Flash attention | 1.154              | **1.042461**|
| bfloat16 | TRUE     | Inference | Unfused         | 1.543              |         |
|          |          |           | Flash attention | 1.154              | **1.337088**|
| float32  | FALSE    | Training  | Unfused         | 54.938             |         |
|          |          |           | Flash attention | 23.029             | **2.385601**|
| float32  | TRUE     | Training  | Unfused         | 58.266             |         |
|          |          |           | Flash attention | 17.835             | **3.266947**|
| bfloat16 | FALSE    | Training  | Unfused         | 18.924             |         |
|          |          |           | Flash attention | 18.886             | **1.002012**|
| bfloat16 | TRUE     | Training  | Unfused         | 21.08              |         |
|          |          |           | Flash attention | 14.172             | **1.48744** |

### Stable Diffusion
Following model's [BKM](https://github.com/intel-innersource/frameworks.ai.models.intel-models/blob/develop/quickstart/diffusion/pytorch/stable_diffusion/inference/cpu/README.md).
Mode: Inference; Machine: SPR.

| Dtype    | SDPA                    | Throughput (fps) | Speedup SDPA | Total Time (ms) | Speedup |
| -------- | --------                | -------          | -------      | -------         | ------- |
| float32  | Unfused                 | 1.63             |              | 1139            |         |
|          | Flash attention         | 1.983            | 1.216564     | 547.488         | **2.080411**|
| bfloat16 | Flash attention in IPEX | 4.784            |              | 429.051         |         |
|          | Flash attention         | 4.857            | 1.015259     | 408.823         | **1.049479**|

### LLM models of Torchbench

Dtype: float32; Mode: Inference, single socket; Machine: CPX.
Model   name | SDPA | Inductor_new | Inductor_old | Inductor   Ratio(old/new)
-- | -- | -- | -- | --
hf_Albert | Unfused -> Flash attention | 0.048629309 | 0.05591545 | **1.14983024**
hf_Bert | Unfused -> Flash attention | 0.053156243 | 0.060732115 | **1.142520841**
hf_Bert_large | Unfused -> Flash attention | 0.141089502 | 0.155190077 | **1.099940636**
llama | Unfused -> Flash attention | 0.033250106 | 0.033720745 | **1.01415451**

Dtype: bfloat16; Mode: Inference, single socket; Machine: SPR.
Model   name | SDPA | Inductor_new | Inductor_old | Inductor   Ratio(old/new)
-- | -- | -- | -- | --
hf_Albert | Unfused -> Flash attention | 0.020681298 | 0.020718282 | **1.001788324**
hf_Bert | Unfused -> Flash attention | 0.019932816 | 0.019935424 | **1.000130842**
hf_Bert_large | Unfused -> Flash attention | 0.047949174 | 0.048312502 | **1.007577355**
llama | Unfused -> Flash attention | 0.018528057 | 0.01861126 | **1.0044907**

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105131
Approved by: https://github.com/drisspg
ghstack dependencies: #104583, #104584, #103826, #104693, #104863, #107128
2023-08-20 08:56:21 +00:00
ydwu4
3349725766 Fakify leaf of FunctionalTensor (#107062)
This PR allows dynamo to fakify FunctionalTensorWrapper by unwrapping, replacing and wrapping again for FunctionalTensorWrapper so that FunctionalTensorWrapper can be passed in as input for dynamo.optimize and we can support something like this
```python
ff = torch.func.functionalize(f)
torch.compile(ff)(x)
```

This PR didn't follow the \_\_tensor_flatten\_\_ and \_\_tensor_unflatten\_\_ protocol right now because we're not sure the plan of doing that for FunctionalTensorWrapper (it's implemented in C++).

**Test Plan:**
Add a new test.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107062
Approved by: https://github.com/zou3519
ghstack dependencies: #107042
2023-08-19 17:33:42 +00:00