Reland - the previous PR was reverted by internal with this error:
```
File "/data/sandcastle/boxes/eden-trunk-hg-fbcode-fbsource/buck-out/v2/gen/fbcode/363cd7e240f5d021/caffe2/torch/fb/trainer/data_modules/tests/__test_dataloader__/test_dataloader#link-tree/torch/__init__.py", line 29, in <module>
from ._utils_internal import _functionalize_sync as _sync
ImportError: cannot import name '_functionalize_sync' from 'torch._utils_internal'
```
I couldn't figure out why internal was unhappy with the import. One potential reason is that I see a build rule for *another* `_utils_internal.py` in the fb folder here ([link](https://www.internalfb.com/code/fbsource/[30ed85cd88409af98b7490be137aaa5dfd7afd01]/fbcode/caffe2/TARGETS?lines=444))
Rather than burn more time investigating, I confirmed internally that the error goes away if I move the util from `torch/_utils_internal.py` to `torch/_utils.py`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109518
Approved by: https://github.com/albanD
Before this PR, if we run the following code:
```python
def true_fn(x):
return x - x.cos()
def false_fn(x):
return x + x.sin()
def foo(x):
return cond(x.shape[0] == 4, true_fn, false_fn, [x])
gm = make_fx(foo, tracing_mode='symbolic')(torch.ones(3, 4))
gm = make_fx(foo, tracing_mode='symbolic')(torch.ones(4, 5))
```
we'll have the following error:
```python
Traceback (most recent call last):
File "/home/yidi/local/pytorch/make_fx.py", line 16, in <module>
gm = make_fx(foo, tracing_mode='symbolic')(torch.ones(4, 5))
File "/home/yidi/local/pytorch/torch/fx/experimental/proxy_tensor.py", line 841, in wrapped
t = dispatch_trace(wrap_key(func, args, fx_tracer, pre_dispatch), tracer=fx_tracer, concrete_args=tuple(phs))
File "/home/yidi/local/pytorch/torch/_compile.py", line 24, in inner
return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 397, in _fn
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/external_utils.py", line 17, in inner
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/fx/experimental/proxy_tensor.py", line 461, in dispatch_trace
graph = tracer.trace(root, concrete_args)
File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 397, in _fn
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/external_utils.py", line 17, in inner
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/fx/_symbolic_trace.py", line 817, in trace
(self.create_arg(fn(*args)),),
File "/home/yidi/local/pytorch/torch/fx/experimental/proxy_tensor.py", line 497, in wrapped
out = f(*tensors)
File "/home/yidi/local/pytorch/make_fx.py", line 13, in foo
return control_flow.cond(x.shape[0] == 4, true_fn, false_fn, [x])
File "/home/yidi/local/pytorch/torch/_higher_order_ops/cond.py", line 151, in cond
return torch.compile(cond_op, backend="eager", fullgraph=True)(
File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 397, in _fn
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 545, in catch_errors
return callback(frame, cache_entry, hooks, frame_state)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 140, in _fn
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 380, in _convert_frame_assert
return _compile(
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 561, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/home/yidi/local/pytorch/torch/_dynamo/utils.py", line 189, in time_wrapper
r = func(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 483, in compile_inner
out_code = transform_code_object(code, transform)
File "/home/yidi/local/pytorch/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
transformations(instructions, code_options)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 432, in transform
tracer = InstructionTranslator(
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2032, in __init__
self.symbolic_locals = collections.OrderedDict(
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2035, in <genexpr>
VariableBuilder(
File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 229, in __call__
vt = self._wrap(value).clone(**self.options())
File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 374, in _wrap
return type_dispatch(self, value)
File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 808, in wrap_listlike
output = [
File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 809, in <listcomp>
VariableBuilder(self.tx, GetItemSource(self.get_source(), i))(
File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 229, in __call__
vt = self._wrap(value).clone(**self.options())
File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 374, in _wrap
return type_dispatch(self, value)
File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 808, in wrap_listlike
output = [
File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 809, in <listcomp>
VariableBuilder(self.tx, GetItemSource(self.get_source(), i))(
File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 229, in __call__
vt = self._wrap(value).clone(**self.options())
File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 374, in _wrap
return type_dispatch(self, value)
File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 1040, in wrap_tensor
tensor_variable = wrap_fx_proxy(
File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 1267, in wrap_fx_proxy
return wrap_fx_proxy_cls(
File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 1382, in wrap_fx_proxy_cls
example_value = wrap_to_fake_tensor_and_record(
File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 1652, in wrap_to_fake_tensor_and_record
dynamic_dims, constraint_dims = _automatic_dynamic(
File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 1550, in _automatic_dynamic
if dim is not None and e.size()[i] != dim:
File "/home/yidi/local/pytorch/torch/__init__.py", line 352, in __bool__
return self.node.bool_()
File "/home/yidi/local/pytorch/torch/fx/experimental/symbolic_shapes.py", line 1019, in bool_
return self.guard_bool("", 0)
File "/home/yidi/local/pytorch/torch/fx/experimental/symbolic_shapes.py", line 1001, in guard_bool
r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node)
File "/home/yidi/local/pytorch/torch/fx/experimental/recording.py", line 227, in wrapper
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/fx/experimental/symbolic_shapes.py", line 3793, in evaluate_expr
assert orig_expr == hint, f"{orig_expr} != {hint}"
AssertionError: False != True
from user code:
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
```
It's because we record the SymInt in the frame state in _automatic_dynamic the first time we compile the function. Then In the second time, when we are given a symint sized input with different hints, the comparison fails.
Implementation:
This PR returns shape dynamism according to the dynamism of inputs: if a diemsion is SymInt, return DYNAMIC else return static.
Test Plan:
Add a test.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109331
Approved by: https://github.com/ezyang
Added two new utils to help with turning python functionalization on in AOTAutograd (next PR):
(1) updated `torch._sync()`. Previously, this API could only handle `torch.Tensor` instances that had a `FunctionalTensorWrapper` TensorImpl. It now needs to handle python `FunctionalTensor`'s. In theory I can probably break BC and change this API (since it's private?), but I decided not to do it in this PR stack do minimize the chance of reverts. Instead of updating that API directly (which is in C++), I just added a python shim that first tries to unwrap the python `FunctionalTensor` if there is one, then calls the existing C++ logic
(2) `mirror_autograd_meta` is now a standalone API that tries to mirror the `requires_grad` and `is_leaf` autograd metadata from one tensor to another. Previously this was hardcoded into `torch._to_functional_tensor()`. But I now need to use it in a more standalone way: later in AOTAutograd when we unwrap and re-wrap a tensor subclasses, we need to manually mirror the autograd metadata from the original to the updated version of the subclass.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107917
Approved by: https://github.com/ezyang
ghstack dependencies: #106404
We could have SymBool inputs for torch.compile, e.g. in the following situation:
```
def f(x:torch.Tensor):
pred = x.size(0) == 3
torch.compile(f)(pred, x)
make_fx(f, tracing_mode="symbolic")(x)
```
The idea of this PR (credit to @ezyang) is to support SymBool by re-using the infra we've already had for SymInt so that we don't need to replicate a lot of stuff.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107850
Approved by: https://github.com/ezyang
ghstack dependencies: #107662
We could have SymBool inputs for torch.compile, e.g. in the following situation:
```
def f(x:torch.Tensor):
pred = x.size(0) == 3
torch.compile(f)(pred, x)
make_fx(f, tracing_mode="symbolic")(x)
```
The idea of this PR (credit to @ezyang) is to support SymBool by re-using the infra we've already had for SymInt so that we don't need to replicate a lot of stuff.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107850
Approved by: https://github.com/ezyang
ghstack dependencies: #107662
There is already some support for plumbing `__torch_dispatch__` tensor subclasses through dynamo, but this PR beefs it up a bit and adds a test. In particular:
(1) Fakeifying tensor subclasses didn't properly set autograd metadata (requires_grad, is_leaf) on the newly fakeified wrapper subclass. I don't actually have a test for this in this PR, but it's tested pretty heavily later in my aot autograd tests
(2) Fakeifying tensor subclasses didn't properly track source information for dynamic shapes on the inner tensors. I added a new `WrapperSubclassFieldSource` subclass, that represents a source coming from a tensor field on a wrapper subclass, which I use in the fakeifying logic, and again in symbolic_shapes.py to generate proper guards.
(3) `_make_wrapper_subclass()` marginally updated this code to work better with dynamic shapes. One thing that's a bit weird about `_make_wrapper_subclass`: it has two overloads, and the first explicitly does not support dynamic shapes (and the second.. does not support kwargs). I think that later we probably want to consolidate / at least make the first overload work with dynamic shapes, but I didn't want to handle that in this PR (so these smaller changes seemed like a strict improvement).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107415
Approved by: https://github.com/ezyang
**Motivation:**
When input FakeTensor to torch.compile has SymInt sizes (e.g. make_fx(opt_f, tracing_mode="symbolic"):
1. We cannot create a FakeTensor from that input in dynamo due to the SymInts.
2. We cannot check input tensors in guard check function and will abort due to tensor check calls sizes/strides.
For 1, we specialize the FakeTensor's SymInts using their hints. This is mostly safe since inputs mostly have concrete shapes and not computed from some DynamicOutputShape ops. We'll throw a data dependent error if the symint is unbacked.
For 2, we replace size/stride calls with the sym_* variants in TENSOR_CHECK guards' check function.
**Test Plan:**
See added tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107662
Approved by: https://github.com/ezyang
This PR allows dynamo to fakify FunctionalTensorWrapper by unwrapping, replacing and wrapping again for FunctionalTensorWrapper so that FunctionalTensorWrapper can be passed in as input for dynamo.optimize and we can support something like this
```python
ff = torch.func.functionalize(f)
torch.compile(ff)(x)
```
This PR didn't follow the \_\_tensor_flatten\_\_ and \_\_tensor_unflatten\_\_ protocol right now because we're not sure the plan of doing that for FunctionalTensorWrapper (it's implemented in C++).
**Test Plan:**
Add a new test.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107062
Approved by: https://github.com/zou3519
ghstack dependencies: #107042
This PR allows dynamo to fakify FunctionalTensorWrapper by unwrapping, replacing and wrapping again for FunctionalTensorWrapper so that FunctionalTensorWrapper can be passed in as input for dynamo.optimize and we can support something like this
```python
ff = torch.func.functionalize(f)
torch.compile(ff)(x)
```
This PR didn't follow the \_\_tensor_flatten\_\_ and \_\_tensor_unflatten\_\_ protocol right now because we're not sure the plan of doing that for FunctionalTensorWrapper (it's implemented in C++).
**Test Plan:**
Add a new test.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107062
Approved by: https://github.com/zou3519
ghstack dependencies: #107042