This PR adds a parametrized test for cond. It tests cond can be traced with valid inputs. Specifically valid inputs is combination of:
- pred (python boolean, boolean tensor, int tensor, scalar tensor)
- true_fn/false_fn (func, obj, nn_module)
- Operands (0 or more tensor inputs), tested with 0 and 2
- closures (0 or more tensor closures), tested with 0 and 2
- nested_level (no nesting or level-2 nested cond)
What this test doesn't cover:
- pred: symbolic boolean expression as predicate
- true_fn/false_fn: that mutates indermediate tensors
- operands: non-tensor operands such as float, int
- closures: nn_module attribute closures, python constant closures
- nested_level: 3+
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110727
Approved by: https://github.com/zou3519
This is a PoC of AOTDispatch support. This PR actually works on basic examples, and I'm working on testing it out on `DTensor` (with @wanchaol), `SemiStructuredSparsityTensor` (with @jcaip), and `FP8Tensor`.
There are some design decisions baked into the PR that I think we need consensus on though - so I'm planning on writing a larger design doc to go over the changes.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104483
Approved by: https://github.com/ezyang
The first reland broke internal (failing diff: D49617462).
The major error looks like it's because there's an internal-only higher order op that needs a new functionalization rule. I'm going to land an internal diff for that and confirm tests pass before relanding this PR.
Also confirmed that the issue from https://github.com/pytorch/pytorch/issues/110121 is fixed, and added a test.
This reverts commit 1b90f07f5a.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110079
Approved by: https://github.com/ezyang
A resubmit of https://github.com/pytorch/pytorch/pull/108447. Copy over the descriptions:
This is a follow-up of the discussion in https://github.com/pytorch/pytorch/pull/108356, where we want to repalce source_fn with source_fn_stack
Before this PR, for the following example:
```python
backend = EagerAndRecordGraphs()
@torch.compile(backend=backend, fullgraph=True)
def cond_f(pred, pred2, x, y):
def true_fn(pred2, x, y):
return x + y
def false_fn(pred2, x, y):
def true_fn2(x, y):
return x.sin() - y.cos()
def false_fn2(x, y):
return x.cos() - y.sin()
return control_flow.cond(pred2, true_fn2, false_fn2, (x, y))
return control_flow.cond(pred, true_fn, false_fn, (pred2, x, y))
```
The graph captured is shown below:
```python
class GraphModule(torch.nn.Module):
def forward(self, L_pred_ : torch.Tensor, L_pred2_ : torch.Tensor, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
l_pred_ = L_pred_
l_pred2_ = L_pred2_
l_x_ = L_x_
l_y_ = L_y_
cond_true_1 = self.cond_true_1
cond_false_1 = self.cond_false_1
cond = torch.ops.higher_order.cond(l_pred_, cond_true_1, cond_false_1, [l_pred2_, l_x_, l_y_]); l_pred_ = cond_true_1 = cond_false_1 = l_pred2_ = l_x_ = l_y_ = None
return (cond,)
class GraphModule(torch.nn.Module):
def forward(self, l_pred2_, l_x_, l_y_):
add = l_x_ + l_y_; l_x_ = l_y_ = None
return add
class GraphModule(torch.nn.Module):
def forward(self, l_pred2_, l_x_, l_y_):
cond_true_0 = self.cond_true_0
cond_false_0 = self.cond_false_0
cond = torch.ops.higher_order.cond(l_pred2_, cond_true_0, cond_false_0, [l_x_, l_y_]); l_pred2_ = cond_true_0 = cond_false_0 = l_x_ = l_y_ = None
return cond
class GraphModule(torch.nn.Module):
def forward(self, l_x_, l_y_):
sin = l_x_.sin(); l_x_ = None
cos = l_y_.cos(); l_y_ = None
sub = sin - cos; sin = cos = None
return sub
class GraphModule(torch.nn.Module):
def forward(self, l_x_, l_y_):
cos = l_x_.cos(); l_x_ = None
sin = l_y_.sin(); l_y_ = None
sub = cos - sin; cos = sin = None
return sub
```
the source_fn for inner cond, sin, cos will be a (name, target) tuple:
```
('cond', <torch._ops.HigherOrderOperator object at xxx>)
('sin', 'sin')
('cos', 'cos')
('sub'. <built-in function sub>)
```
After this pr, the source_fn_stack will be a list of (name, target) tuple. The bottom of stack is the end of the list.
```
[('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>)],
[('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('sin', 'sin')],
[('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cos', 'cos')]
[('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('sub', <built-in function sub>)]
```
Test Plan:
See added tests in test_higher_order_ops.py and modify existing test.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108595
Approved by: https://github.com/angelayi, https://github.com/zou3519
I'm pretty sure this is fixed but I'll run inductor and trunk CI. The failing test in trunk previously was that the selective activation checkpointing code that landed recently assumes that it can detect whether or not AOTAutograd is running by seeing if the inputs to SAC are C++ `FunctionalTensorWrapper`s
previous land broke some inductor trunk tests
This reverts commit 629a628cc8.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109906
Approved by: https://github.com/ezyang
Now that FunctionalTensor and `FunctionalTensorMode` are lower down in this stack, the changes in this PR are more mechanical: Everywhere in AOTAutograd that I used to use the C++ functionalization API, I now use the python functionalization API.
Note that this doesn't actually cause functionalization to run underneath torch_dispatch. I'm saving that re-ordering for later in the stack.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106406
Approved by: https://github.com/ezyang
ghstack dependencies: #108654, #109662, #109632, #109023
We now have two types of functionalization, C++ Functionalization (through the `Functionalize` dispatch key), and python functionalization (through the `FunctionalTensorMode` torch_dispatch mode).
This means that all higher order ops need custom functionalization rules for the python variant too. I added them here, as well as a helper function `dispatch_functionalize()` - equivalent to `torch.func.functionalize()`, except that it uses `FunctionalTensorMode`.
In theory we could have secretly switched `torch.func.functionalize` to use `FunctionalTensorMode`. This would be BC-breaking, though, since `FunctionalTensorMode` isn't composable with the other functorch transforms (the functorch layer-mode stack doesn't know how to re-order torch_dispatch modes arbitrarily).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108656
Approved by: https://github.com/zou3519
ghstack dependencies: #109024, #109248
This fixes numerous tests which were xfailing. For instance, the
`_segment_reduce.lengths` OpInfo test, which was previously relying on
the fallback kernel to determine the shape of the meta tensor. The
fallback kernel would fail with
segment_reduce(): Expected all rows of lengths along axis to sum to data.size(lengths.dim()-1) when !unsafe.
as it was trying to read the values of a meta tensor.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109359
Approved by: https://github.com/ezyang
Before this PR, if we run the following code:
```python
def true_fn(x):
return x - x.cos()
def false_fn(x):
return x + x.sin()
def foo(x):
return cond(x.shape[0] == 4, true_fn, false_fn, [x])
gm = make_fx(foo, tracing_mode='symbolic')(torch.ones(3, 4))
gm = make_fx(foo, tracing_mode='symbolic')(torch.ones(4, 5))
```
we'll have the following error:
```python
Traceback (most recent call last):
File "/home/yidi/local/pytorch/make_fx.py", line 16, in <module>
gm = make_fx(foo, tracing_mode='symbolic')(torch.ones(4, 5))
File "/home/yidi/local/pytorch/torch/fx/experimental/proxy_tensor.py", line 841, in wrapped
t = dispatch_trace(wrap_key(func, args, fx_tracer, pre_dispatch), tracer=fx_tracer, concrete_args=tuple(phs))
File "/home/yidi/local/pytorch/torch/_compile.py", line 24, in inner
return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 397, in _fn
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/external_utils.py", line 17, in inner
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/fx/experimental/proxy_tensor.py", line 461, in dispatch_trace
graph = tracer.trace(root, concrete_args)
File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 397, in _fn
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/external_utils.py", line 17, in inner
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/fx/_symbolic_trace.py", line 817, in trace
(self.create_arg(fn(*args)),),
File "/home/yidi/local/pytorch/torch/fx/experimental/proxy_tensor.py", line 497, in wrapped
out = f(*tensors)
File "/home/yidi/local/pytorch/make_fx.py", line 13, in foo
return control_flow.cond(x.shape[0] == 4, true_fn, false_fn, [x])
File "/home/yidi/local/pytorch/torch/_higher_order_ops/cond.py", line 151, in cond
return torch.compile(cond_op, backend="eager", fullgraph=True)(
File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 397, in _fn
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 545, in catch_errors
return callback(frame, cache_entry, hooks, frame_state)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 140, in _fn
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 380, in _convert_frame_assert
return _compile(
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 561, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/home/yidi/local/pytorch/torch/_dynamo/utils.py", line 189, in time_wrapper
r = func(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 483, in compile_inner
out_code = transform_code_object(code, transform)
File "/home/yidi/local/pytorch/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
transformations(instructions, code_options)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 432, in transform
tracer = InstructionTranslator(
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2032, in __init__
self.symbolic_locals = collections.OrderedDict(
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2035, in <genexpr>
VariableBuilder(
File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 229, in __call__
vt = self._wrap(value).clone(**self.options())
File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 374, in _wrap
return type_dispatch(self, value)
File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 808, in wrap_listlike
output = [
File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 809, in <listcomp>
VariableBuilder(self.tx, GetItemSource(self.get_source(), i))(
File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 229, in __call__
vt = self._wrap(value).clone(**self.options())
File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 374, in _wrap
return type_dispatch(self, value)
File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 808, in wrap_listlike
output = [
File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 809, in <listcomp>
VariableBuilder(self.tx, GetItemSource(self.get_source(), i))(
File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 229, in __call__
vt = self._wrap(value).clone(**self.options())
File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 374, in _wrap
return type_dispatch(self, value)
File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 1040, in wrap_tensor
tensor_variable = wrap_fx_proxy(
File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 1267, in wrap_fx_proxy
return wrap_fx_proxy_cls(
File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 1382, in wrap_fx_proxy_cls
example_value = wrap_to_fake_tensor_and_record(
File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 1652, in wrap_to_fake_tensor_and_record
dynamic_dims, constraint_dims = _automatic_dynamic(
File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 1550, in _automatic_dynamic
if dim is not None and e.size()[i] != dim:
File "/home/yidi/local/pytorch/torch/__init__.py", line 352, in __bool__
return self.node.bool_()
File "/home/yidi/local/pytorch/torch/fx/experimental/symbolic_shapes.py", line 1019, in bool_
return self.guard_bool("", 0)
File "/home/yidi/local/pytorch/torch/fx/experimental/symbolic_shapes.py", line 1001, in guard_bool
r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node)
File "/home/yidi/local/pytorch/torch/fx/experimental/recording.py", line 227, in wrapper
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/fx/experimental/symbolic_shapes.py", line 3793, in evaluate_expr
assert orig_expr == hint, f"{orig_expr} != {hint}"
AssertionError: False != True
from user code:
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
```
It's because we record the SymInt in the frame state in _automatic_dynamic the first time we compile the function. Then In the second time, when we are given a symint sized input with different hints, the comparison fails.
Implementation:
This PR returns shape dynamism according to the dynamism of inputs: if a diemsion is SymInt, return DYNAMIC else return static.
Test Plan:
Add a test.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109331
Approved by: https://github.com/ezyang
We could have SymBool inputs for torch.compile, e.g. in the following situation:
```
def f(x:torch.Tensor):
pred = x.size(0) == 3
torch.compile(f)(pred, x)
make_fx(f, tracing_mode="symbolic")(x)
```
The idea of this PR (credit to @ezyang) is to support SymBool by re-using the infra we've already had for SymInt so that we don't need to replicate a lot of stuff.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107850
Approved by: https://github.com/ezyang
ghstack dependencies: #107662
We could have SymBool inputs for torch.compile, e.g. in the following situation:
```
def f(x:torch.Tensor):
pred = x.size(0) == 3
torch.compile(f)(pred, x)
make_fx(f, tracing_mode="symbolic")(x)
```
The idea of this PR (credit to @ezyang) is to support SymBool by re-using the infra we've already had for SymInt so that we don't need to replicate a lot of stuff.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107850
Approved by: https://github.com/ezyang
ghstack dependencies: #107662
## Exception in cond:
For code below:
```python
import torch
import functorch.experimental.control_flow as control_flow
def true_fn(x):
return x.sin()
def false_fn(x):
return x, x
def f(x, y):
return control_flow.cond(y, true_fn, false_fn, [x])
f(torch.ones(3, 4), torch.tensor(False))
```
The original exception stack trace is:
```python
Traceback (most recent call last):
File "/home/yidi/local/pytorch/test_exc.py", line 33, in <module>
f(torch.ones(3, 4), torch.tensor(False))
File "/home/yidi/local/pytorch/test_exc.py", line 31, in f
return control_flow.cond(y, true_fn, false_fn, [x])
File "/home/yidi/local/pytorch/torch/_higher_order_ops/cond.py", line 154, in cond
return torch.compile(cond_op, backend="eager", fullgraph=True)(
File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 365, in _fn
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 513, in catch_errors
return callback(frame, cache_entry, hooks, frame_state)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 140, in _fn
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 380, in _convert_frame_assert
return _compile(
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 560, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/home/yidi/local/pytorch/torch/_dynamo/utils.py", line 197, in time_wrapper
r = func(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 482, in compile_inner
out_code = transform_code_object(code, transform)
File "/home/yidi/local/pytorch/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
transformations(instructions, code_options)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 449, in transform
tracer.run()
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2083, in run
super().run()
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 733, in run
and self.step()
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 696, in step
getattr(self, inst.opname)(inst)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 397, in wrapper
return inner_fn(self, inst)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1164, in CALL_FUNCTION_EX
self.call_function(fn, argsvars.items, kwargsvars.items)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 570, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 418, in call_function
(false_r, false_graph, false_lifted_freevars) = speculate_branch(False)
File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 410, in speculate_branch
raise UncapturedHigherOrderOpError(
torch._dynamo.exc.UncapturedHigherOrderOpError: Expected branch to return a single tensor
from user code:
File "/home/yidi/local/pytorch/torch/_dynamo/external_utils.py", line 17, in inner
return fn(*args, **kwargs)
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
```
After this PR we get:
```python
Traceback (most recent call last):
File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 50, in graph_break_as_hard_error
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 429, in call_function
(false_r, false_graph, false_lifted_freevars) = speculate_branch(False)
File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 421, in speculate_branch
unimplemented(
File "/home/yidi/local/pytorch/torch/_dynamo/exc.py", line 187, in unimplemented
raise Unsupported(msg)
torch._dynamo.exc.Unsupported: Expected branch to return a single tensor
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/yidi/local/pytorch/test_exc.py", line 33, in <module>
f(torch.ones(3, 4), torch.tensor(False))
File "/home/yidi/local/pytorch/test_exc.py", line 31, in f
return control_flow.cond(y, true_fn, false_fn, [x])
File "/home/yidi/local/pytorch/torch/_higher_order_ops/cond.py", line 154, in cond
return torch.compile(cond_op, backend="eager", fullgraph=True)(
File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 338, in _fn
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 500, in catch_errors
return callback(frame, cache_entry, hooks, frame_state)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 140, in _fn
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 382, in _convert_frame_assert
return _compile(
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 562, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/home/yidi/local/pytorch/torch/_dynamo/utils.py", line 189, in time_wrapper
r = func(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 484, in compile_inner
out_code = transform_code_object(code, transform)
File "/home/yidi/local/pytorch/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
transformations(instructions, code_options)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 451, in transform
tracer.run()
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2088, in run
super().run()
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 728, in run
and self.step()
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 691, in step
getattr(self, inst.opname)(inst)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 392, in wrapper
return inner_fn(self, inst)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1159, in CALL_FUNCTION_EX
self.call_function(fn, argsvars.items, kwargsvars.items)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 565, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 53, in graph_break_as_hard_error
raise UncapturedHigherOrderOpError(reason + msg) from e
torch._dynamo.exc.UncapturedHigherOrderOpError: Cond doesn't work unless it is captured completely with torch.compile. Scroll up to find out what causes the graph break.
from user code:
File "/home/yidi/local/pytorch/torch/_dynamo/external_utils.py", line 17, in inner
return fn(*args, **kwargs)
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
```
## Exception during speculating branches
The example code below has a inplace-buffer mutation error,
```python
import torch
import functorch.experimental.control_flow as control_flow
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("buffer", torch.ones(6, 4))
def forward(self, x):
def true_fn(x):
self.buffer += 1
return self.buffer.sum() + x.sum()
def false_fn(x):
return (x - 1).sum()
return control_flow.cond(x.shape[0] > 4, true_fn, false_fn, [x])
mod_for_compile = torch.compile(Foo(), backend="eager", dynamic=True)
mod_for_compile(torch.ones(3, 4))
```
Before this PR the exception looks like:
```python
[2023-09-08 15:20:03,332] [0/0] torch._dynamo.variables.higher_order_ops: [WARNING] speculate_subgraph: while introspecting cond, we were unable to trace function `true_fn` into a single graph. This means that Dynamo was unable to prove safety for this API and will fall back to eager-mode PyTorch, which could lead to a slowdown.
[2023-09-08 15:20:03,332] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] Can't inplace modify module params/buffers inside HigherOrderOp
Traceback (most recent call last):
File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 163, in speculate_subgraph
output = f.call_function(tx, args, sub_kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 90, in call_function
return tx.inline_user_function_return(
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 606, in inline_user_function_return
result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2200, in inline_call
return cls.inline_call_(parent, func, args, kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2316, in inline_call_
tracer.run()
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 733, in run
and self.step()
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 696, in step
getattr(self, inst.opname)(inst)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1219, in STORE_ATTR
.call_function(self, [obj, ConstantVariable(inst.argval), val], {})
File "/home/yidi/local/pytorch/torch/_dynamo/variables/builtin.py", line 618, in call_function
result = handler(tx, *args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/variables/builtin.py", line 1169, in call_setattr
raise AttributeMutationError(
torch._dynamo.exc.AttributeMutationError: Can't inplace modify module params/buffers inside HigherOrderOp
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 394, in speculate_branch
ret_val, ret_graph, ret_lifted_freevars = speculate_subgraph(
File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 222, in speculate_subgraph
raise Unsupported(
torch._dynamo.exc.Unsupported: speculate_subgraph: while introspecting cond, we were unable to trace function `true_fn` into a single graph. This means that Dynamo was unable to prove safety for this API and will fall back to eager-mode PyTorch, which could lead to a slowdown. Scroll up for the stack trace of the initial exception. The reason was: Can't inplace modify module params/buffers inside HigherOrderOp
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/yidi/local/pytorch/test_exc.py", line 20, in <module>
mod_for_compile(torch.ones(3, 4))
File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1519, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1528, in _call_impl
return forward_call(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 365, in _fn
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1519, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1528, in _call_impl
return forward_call(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 513, in catch_errors
return callback(frame, cache_entry, hooks, frame_state)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 632, in _convert_frame
result = inner_convert(frame, cache_entry, hooks, frame_state)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 140, in _fn
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 380, in _convert_frame_assert
return _compile(
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 560, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/home/yidi/local/pytorch/torch/_dynamo/utils.py", line 197, in time_wrapper
r = func(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 482, in compile_inner
out_code = transform_code_object(code, transform)
File "/home/yidi/local/pytorch/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
transformations(instructions, code_options)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 449, in transform
tracer.run()
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2083, in run
super().run()
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 733, in run
and self.step()
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 696, in step
getattr(self, inst.opname)(inst)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 397, in wrapper
return inner_fn(self, inst)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1124, in CALL_FUNCTION
self.call_function(fn, args, {})
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 570, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 261, in call_function
return super().call_function(tx, args, kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 90, in call_function
return tx.inline_user_function_return(
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 606, in inline_user_function_return
result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2200, in inline_call
return cls.inline_call_(parent, func, args, kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2316, in inline_call_
tracer.run()
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 733, in run
and self.step()
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 696, in step
getattr(self, inst.opname)(inst)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 397, in wrapper
return inner_fn(self, inst)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1124, in CALL_FUNCTION
self.call_function(fn, args, {})
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 570, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 415, in call_function
(true_r, true_graph, true_lifted_freevars) = speculate_branch(True)
File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 405, in speculate_branch
raise UncapturedHigherOrderOpError(
torch._dynamo.exc.UncapturedHigherOrderOpError: Cond doesn't work unless it is captured completely with torch.compile
from user code:
File "/home/yidi/local/pytorch/test_exc.py", line 16, in forward
return control_flow.cond(x.shape[0] > 4, true_fn, false_fn, [x])
File "/home/yidi/local/pytorch/torch/_higher_order_ops/cond.py", line 127, in cond
return cond_op(pred, true_fn, false_fn, operands)
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
```
after this PR, the only difference is the error message of UncapturedHigherOrderOpError changes from `Cond doesn't work unless it is captured completely with torch.compile` to `Cond doesn't work unless it is captured completely with torch.compile. Scroll up to find out what causes the graph break`.
```python
[2023-09-08 15:17:02,052] [0/0] torch._dynamo.variables.higher_order_ops: [WARNING] speculate_subgraph: while introspecting cond, we were unable to trace function `true_fn` into a single graph. This means that Dynamo was unable to prove safety for this API and will fall back to eager-mode PyTorch, which could lead to a slowdown.
[2023-09-08 15:17:02,052] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] Can't inplace modify module params/buffers inside HigherOrderOp
Traceback (most recent call last):
File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 177, in speculate_subgraph
output = f.call_function(tx, args, sub_kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 90, in call_function
return tx.inline_user_function_return(
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 601, in inline_user_function_return
result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2193, in inline_call
return cls.inline_call_(parent, func, args, kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2300, in inline_call_
tracer.run()
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 728, in run
and self.step()
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 691, in step
getattr(self, inst.opname)(inst)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1214, in STORE_ATTR
.call_function(self, [obj, ConstantVariable(inst.argval), val], {})
File "/home/yidi/local/pytorch/torch/_dynamo/variables/builtin.py", line 618, in call_function
result = handler(tx, *args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/variables/builtin.py", line 1169, in call_setattr
raise AttributeMutationError(
torch._dynamo.exc.AttributeMutationError: Can't inplace modify module params/buffers inside HigherOrderOp
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 50, in graph_break_as_hard_error
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 426, in call_function
(true_r, true_graph, true_lifted_freevars) = speculate_branch(True)
File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 410, in speculate_branch
ret_val, ret_graph, ret_lifted_freevars = speculate_subgraph(
File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 236, in speculate_subgraph
raise Unsupported(
torch._dynamo.exc.Unsupported: speculate_subgraph: while introspecting cond, we were unable to trace function `true_fn` into a single graph. This means that Dynamo was unable to prove safety for this API and will fall back to eager-mode PyTorch, which could lead to a slowdown. Scroll up for the stack trace of the initial exception. The reason was: Can't inplace modify module params/buffers inside HigherOrderOp
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/yidi/local/pytorch/test_exc.py", line 20, in <module>
mod_for_compile(torch.ones(3, 4))
File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 338, in _fn
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 500, in catch_errors
return callback(frame, cache_entry, hooks, frame_state)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 634, in _convert_frame
result = inner_convert(frame, cache_entry, hooks, frame_state)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 140, in _fn
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 382, in _convert_frame_assert
return _compile(
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 562, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/home/yidi/local/pytorch/torch/_dynamo/utils.py", line 189, in time_wrapper
r = func(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 484, in compile_inner
out_code = transform_code_object(code, transform)
File "/home/yidi/local/pytorch/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
transformations(instructions, code_options)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 451, in transform
tracer.run()
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2088, in run
super().run()
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 728, in run
and self.step()
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 691, in step
getattr(self, inst.opname)(inst)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 392, in wrapper
return inner_fn(self, inst)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1119, in CALL_FUNCTION
self.call_function(fn, args, {})
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 565, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 261, in call_function
return super().call_function(tx, args, kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 90, in call_function
return tx.inline_user_function_return(
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 601, in inline_user_function_return
result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2193, in inline_call
return cls.inline_call_(parent, func, args, kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2300, in inline_call_
tracer.run()
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 728, in run
and self.step()
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 691, in step
getattr(self, inst.opname)(inst)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 392, in wrapper
return inner_fn(self, inst)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1119, in CALL_FUNCTION
self.call_function(fn, args, {})
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 565, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 53, in graph_break_as_hard_error
raise UncapturedHigherOrderOpError(reason + msg) from e
torch._dynamo.exc.UncapturedHigherOrderOpError: Cond doesn't work unless it is captured completely with torch.compile. Scroll up to find out what causes the graph break.
from user code:
File "/home/yidi/local/pytorch/test_exc.py", line 16, in forward
return control_flow.cond(x.shape[0] > 4, true_fn, false_fn, [x])
File "/home/yidi/local/pytorch/torch/_higher_order_ops/cond.py", line 127, in cond
return cond_op(pred, true_fn, false_fn, operands)
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108817
Approved by: https://github.com/zou3519
PoC demonstrating vmap + NT based on the [design doc](https://docs.google.com/document/d/1dVVk6TOqz93PLTIneU2T3xaxCs9qZ0MaJyCvOAp_bC0). This PR:
* Allows `BatchedTensorImpl`s to contain NTs
* Introduces a `BatchedNestedTensor` dispatch key for NT-specific batching rules
* Provides a batching rule fallback that unbinds the NTs -> performs computation on constituent -> rebinds results into NT
Restrictions:
* Only supports one level of vmap
* Only supports vmapping over dim=0 for NTs
* For operations with mixed NT / dense inputs, support is also limited to dim=0 for the dense inputs
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106786
Approved by: https://github.com/zou3519
Fixes https://github.com/pytorch/pytorch/issues/105327. The problem is that `lift_fresh_copy()`'s functionalization implementation currently assumes that the input is always not functional. This is apparently too limiting: when you have "user" code like this (which can potentially come from exporting a model and then running compile on the resulting graph):
```
tensor_constant0 = torch.tensor(2)
lift_fresh = torch.ops.aten.lift_fresh_copy.default(tensor_constant0)
```
When we run this through AOTAutograd, the first call (torch.tensor(2)) will **already** be lifted into a functional tensor wrapper - so the `lift_fresh_copy` call doesn't need to do any "lifting" anymore - it just needs to do a clone.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108243
Approved by: https://github.com/albanD
ghstack dependencies: #108081, #108235
**Motivation:**
Currently, for the following code that exports cond operator:
```python
import torch
from functorch.experimental.control_flow import cond
class MySubModule(torch.nn.Module):
def foo(self, x):
return x.cos()
def forward(self, x):
return self.foo(x)
class CondBranchClassMethod(torch.nn.Module):
def __init__(self):
super().__init__()
self.subm = MySubModule()
def bar(self, x):
return x.sin()
def forward(self, x):
return cond(x.shape[0] <= 2, self.subm.forward, self.bar, [x])
from torch._export import capture_pre_autograd_graph
example_inputs = (torch.randn(1, 3, 3, 3),)
m = CondBranchClassMethod()
m.eval()
gm = capture_pre_autograd_graph(m, example_inputs)
print(gm)
# source_fn for original cond op, getattr submodule op are all cond op
for n in gm.graph.nodes:
print("n:", n.format_node(), n.meta)
print("\n\n\n")
# source_fn for submodule nodes are all cond op
# Expected: ideally this should be the real ops, e.g. torch.sin, aten.cos, etc
for n in gm.submodule_0.graph.nodes:
print("n:", n.format_node(), n.meta)
```
Output is like below:
```
GraphModule(
(submodule_0): GraphModule()
(submodule_1): GraphModule()
)
def forward(self, arg_0):
arg0_1, = fx_pytree.tree_flatten_spec([arg_0], self._in_spec)
submodule_0 = self.submodule_0
submodule_1 = self.submodule_1
cond = torch.ops.higher_order.cond(True, submodule_0, submodule_1, [arg0_1]); submodule_0 = submodule_1 = arg0_1 = None
return pytree.tree_unflatten((cond,), self._out_spec)
# To see more debug info, please use `graph_module.print_readable()`
n: %arg0_1 : [num_users=1] = placeholder[target=arg0_1] {'val': FakeTensor(..., size=(1, 3, 3, 3)), 'tensor_meta': None, 'is_torch_exported': True, 'stack_trace': 'NoneType: None\n'}
n: %submodule_0 : [num_users=1] = get_attr[target=submodule_0] {'stack_trace': 'NoneType: None\n', 'source_fn': ('cond', <torch._ops.HigherOrderOperator object at 0x7f68ae93efd0>), 'original_aten': None, 'from_node': [('cond', <torch._ops.HigherOrderOperator object at 0x7f68ae93efd0>), ('conditional', <torch._ops.HigherOrderOperator object at 0x7f68ae93efd0>), ('cond', <torch._ops.HigherOrderOperator object at 0x7f68ae93efd0>)], 'seq_nr': -1}
n: %submodule_1 : [num_users=1] = get_attr[target=submodule_1] {'stack_trace': 'NoneType: None\n', 'source_fn': ('cond', <torch._ops.HigherOrderOperator object at 0x7f68ae93efd0>), 'original_aten': None, 'from_node': [('cond', <torch._ops.HigherOrderOperator object at 0x7f68ae93efd0>), ('conditional', <torch._ops.HigherOrderOperator object at 0x7f68ae93efd0>), ('cond', <torch._ops.HigherOrderOperator object at 0x7f68ae93efd0>)], 'seq_nr': -1}
n: %cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (True, %submodule_0, %submodule_1, [%arg0_1]), kwargs = {}) {'stack_trace': 'NoneType: None\n', 'source_fn': ('cond', <torch._ops.HigherOrderOperator object at 0x7f68ae93efd0>), 'original_aten': None, 'from_node': [('cond', <torch._ops.HigherOrderOperator object at 0x7f68ae93efd0>), ('conditional', <torch._ops.HigherOrderOperator object at 0x7f68ae93efd0>), ('cond', <torch._ops.HigherOrderOperator object at 0x7f68ae93efd0>)], 'seq_nr': -1, 'val': FakeTensor(..., size=(1, 3, 3, 3)), 'tensor_meta': None, 'is_torch_exported': True}
n: return (cond,) {'stack_trace': 'NoneType: None\n', 'from_node': [('output', 'output')], 'seq_nr': -1, 'is_torch_exported': True, 'val': (FakeTensor(..., size=(1, 3, 3, 3)),), 'tensor_meta': (None,)}
n: %arg0_1 : [num_users=1] = placeholder[target=arg0_1] {'stack_trace': ' File "<ipython-input-9-2a8c7c0498ed>", line 36, in forward\n return cond(x.shape[0] <= 2, self.subm.forward, self.bar, [x])\n', 'source_fn': ('cond', <torch._ops.HigherOrderOperator object at 0x7f68ae93efd0>), 'original_aten': None, 'from_node': [('cond', <torch._ops.HigherOrderOperator object at 0x7f68ae93efd0>), ('arg0_1', 'arg0_1')], 'seq_nr': -1, 'val': FakeTensor(..., size=(1, 3, 3, 3)), 'tensor_meta': None}
n: %cos_default : [num_users=1] = call_function[target=torch.ops.aten.cos.default](args = (%arg0_1,), kwargs = {}) {'stack_trace': ' File "<ipython-input-9-2a8c7c0498ed>", line 36, in forward\n return cond(x.shape[0] <= 2, self.subm.forward, self.bar, [x])\n', 'source_fn': ('cond', <torch._ops.HigherOrderOperator object at 0x7f68ae93efd0>), 'original_aten': <OpOverload(op='aten.cos', overload='default')>, 'from_node': [('cond', <torch._ops.HigherOrderOperator object at 0x7f68ae93efd0>), ('cos', <OpOverload(op='aten.cos', overload='default')>), ('cos_default', <OpOverload(op='aten.cos', overload='default')>)], 'seq_nr': -1, 'val': FakeTensor(..., size=(1, 3, 3, 3)), 'tensor_meta': None}
n: return cos_default {'stack_trace': ' File "<ipython-input-9-2a8c7c0498ed>", line 36, in forward\n return cond(x.shape[0] <= 2, self.subm.forward, self.bar, [x])\n', 'source_fn': ('cond', <torch._ops.HigherOrderOperator object at 0x7f68ae93efd0>), 'original_aten': None, 'from_node': [('cond', <torch._ops.HigherOrderOperator object at 0x7f68ae93efd0>), ('output', 'output')], 'seq_nr': -1, 'val': FakeTensor(..., size=(1, 3, 3, 3)), 'tensor_meta': None}
```
As we can see, the meta of nodes in subgrarphs are overriden with the cond's metat data. This is because the function _set_current_meta is only invoked at the top-level graph module in interpreter. When we're calling into cond and dealing with the submodules here, we didn't set the current_meta to the meta of nodes of subgraph properly.
**Implementation:**
This pr fixes it by: in trace_cond, we optionally use an fx.interpreter to interpret the subgraphs so that the meta data is preserved only when the following conditions are satisfied:
- The subgraphs are graph_module: this is necessary that we use the fx.Interpreter
- The current make_fx has turned preserve_node_meta on (as is the case for capture_pre_autograd_graph).
**Test Plan**
See added tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108356
Approved by: https://github.com/SherlockNoMad
torch.empty_strided is able to create a new tensor based on the meta data. For boolean tensor, we call a clone directly, however, we'll get a functional tensor if input is a functional tensor and that functional tensor won't be tracked by tracer's tensor_tracker after dispatching so it become a tensor\_constant in the graph if create_arg. So we manually unwrap the functional tensor before calling clone.
Test Plan:
See added test.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108289
Approved by: https://github.com/angelayi
Manually enable `capture_func_transforms` for testing as plan is to default `capture_func_transforms` to False in 2.1. (enable it so that we still test the support on release branch).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107122
Approved by: https://github.com/zou3519
This PR allows dynamo to fakify FunctionalTensorWrapper by unwrapping, replacing and wrapping again for FunctionalTensorWrapper so that FunctionalTensorWrapper can be passed in as input for dynamo.optimize and we can support something like this
```python
ff = torch.func.functionalize(f)
torch.compile(ff)(x)
```
This PR didn't follow the \_\_tensor_flatten\_\_ and \_\_tensor_unflatten\_\_ protocol right now because we're not sure the plan of doing that for FunctionalTensorWrapper (it's implemented in C++).
**Test Plan:**
Add a new test.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107062
Approved by: https://github.com/zou3519
ghstack dependencies: #107042