Commit Graph

18 Commits

Author SHA1 Message Date
Brian Hirsh
25e81f19f3 reland "python functionalization: add helpers, functionalize_sync and mirror_autograd_meta (#107917)" (#109518)
Reland - the previous PR was reverted by internal with this error:
```
  File "/data/sandcastle/boxes/eden-trunk-hg-fbcode-fbsource/buck-out/v2/gen/fbcode/363cd7e240f5d021/caffe2/torch/fb/trainer/data_modules/tests/__test_dataloader__/test_dataloader#link-tree/torch/__init__.py", line 29, in <module>
    from ._utils_internal import _functionalize_sync as _sync
ImportError: cannot import name '_functionalize_sync' from 'torch._utils_internal'
```

I couldn't figure out why internal was unhappy with the import. One potential reason is that I see a build rule for *another* `_utils_internal.py` in the fb folder here ([link](https://www.internalfb.com/code/fbsource/[30ed85cd88409af98b7490be137aaa5dfd7afd01]/fbcode/caffe2/TARGETS?lines=444))

Rather than burn more time investigating, I confirmed internally that the error goes away if I move the util from `torch/_utils_internal.py` to `torch/_utils.py`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109518
Approved by: https://github.com/albanD
2023-09-19 13:25:24 +00:00
PyTorch MergeBot
49b18ae546 Revert "python functionalization: add helpers, functionalize_sync and mirror_autograd_meta (#107917)"
This reverts commit 0ad595954a.

Reverted https://github.com/pytorch/pytorch/pull/107917 on behalf of https://github.com/clee2000 due to breaking internal builds D49346637 ([comment](https://github.com/pytorch/pytorch/pull/107917#issuecomment-1722566885))
2023-09-17 20:57:41 +00:00
ydwu4
706d8e2230 [dynamo] Respect shape dynamism of SymInt sized tensor (#109331)
Before this PR, if we run the following code:
```python
def true_fn(x):
    return x - x.cos()

def false_fn(x):
    return x + x.sin()

def foo(x):
    return cond(x.shape[0] == 4, true_fn, false_fn, [x])
gm = make_fx(foo, tracing_mode='symbolic')(torch.ones(3, 4))
gm = make_fx(foo, tracing_mode='symbolic')(torch.ones(4, 5))
```
we'll have the following error:
```python
Traceback (most recent call last):
  File "/home/yidi/local/pytorch/make_fx.py", line 16, in <module>
    gm = make_fx(foo, tracing_mode='symbolic')(torch.ones(4, 5))
  File "/home/yidi/local/pytorch/torch/fx/experimental/proxy_tensor.py", line 841, in wrapped
    t = dispatch_trace(wrap_key(func, args, fx_tracer, pre_dispatch), tracer=fx_tracer, concrete_args=tuple(phs))
  File "/home/yidi/local/pytorch/torch/_compile.py", line 24, in inner
    return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 397, in _fn
    return fn(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/fx/experimental/proxy_tensor.py", line 461, in dispatch_trace
    graph = tracer.trace(root, concrete_args)
  File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 397, in _fn
    return fn(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/fx/_symbolic_trace.py", line 817, in trace
    (self.create_arg(fn(*args)),),
  File "/home/yidi/local/pytorch/torch/fx/experimental/proxy_tensor.py", line 497, in wrapped
    out = f(*tensors)
  File "/home/yidi/local/pytorch/make_fx.py", line 13, in foo
    return control_flow.cond(x.shape[0] == 4, true_fn, false_fn, [x])
  File "/home/yidi/local/pytorch/torch/_higher_order_ops/cond.py", line 151, in cond
    return torch.compile(cond_op, backend="eager", fullgraph=True)(
  File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 397, in _fn
    return fn(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 545, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 140, in _fn
    return fn(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 380, in _convert_frame_assert
    return _compile(
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 561, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/home/yidi/local/pytorch/torch/_dynamo/utils.py", line 189, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 483, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/home/yidi/local/pytorch/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
    transformations(instructions, code_options)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 432, in transform
    tracer = InstructionTranslator(
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2032, in __init__
    self.symbolic_locals = collections.OrderedDict(
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2035, in <genexpr>
    VariableBuilder(
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 229, in __call__
    vt = self._wrap(value).clone(**self.options())
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 374, in _wrap
    return type_dispatch(self, value)
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 808, in wrap_listlike
    output = [
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 809, in <listcomp>
    VariableBuilder(self.tx, GetItemSource(self.get_source(), i))(
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 229, in __call__
    vt = self._wrap(value).clone(**self.options())
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 374, in _wrap
    return type_dispatch(self, value)
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 808, in wrap_listlike
    output = [
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 809, in <listcomp>
    VariableBuilder(self.tx, GetItemSource(self.get_source(), i))(
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 229, in __call__
    vt = self._wrap(value).clone(**self.options())
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 374, in _wrap
    return type_dispatch(self, value)
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 1040, in wrap_tensor
    tensor_variable = wrap_fx_proxy(
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 1267, in wrap_fx_proxy
    return wrap_fx_proxy_cls(
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 1382, in wrap_fx_proxy_cls
    example_value = wrap_to_fake_tensor_and_record(
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 1652, in wrap_to_fake_tensor_and_record
    dynamic_dims, constraint_dims = _automatic_dynamic(
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 1550, in _automatic_dynamic
    if dim is not None and e.size()[i] != dim:
  File "/home/yidi/local/pytorch/torch/__init__.py", line 352, in __bool__
    return self.node.bool_()
  File "/home/yidi/local/pytorch/torch/fx/experimental/symbolic_shapes.py", line 1019, in bool_
    return self.guard_bool("", 0)
  File "/home/yidi/local/pytorch/torch/fx/experimental/symbolic_shapes.py", line 1001, in guard_bool
    r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node)
  File "/home/yidi/local/pytorch/torch/fx/experimental/recording.py", line 227, in wrapper
    return fn(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/fx/experimental/symbolic_shapes.py", line 3793, in evaluate_expr
    assert orig_expr == hint, f"{orig_expr} != {hint}"
AssertionError: False != True

from user code:

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True
```

It's because we record the SymInt in the frame state in _automatic_dynamic the first time we compile the function. Then In the second time, when we are given a symint sized input with different hints, the comparison fails.

Implementation:
This PR returns shape dynamism according to the dynamism of inputs: if a diemsion is SymInt, return DYNAMIC else return static.

Test Plan:
Add a test.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109331
Approved by: https://github.com/ezyang
2023-09-16 02:56:53 +00:00
Brian Hirsh
0ad595954a python functionalization: add helpers, functionalize_sync and mirror_autograd_meta (#107917)
Added two new utils to help with turning python functionalization on in AOTAutograd (next PR):

(1) updated `torch._sync()`. Previously, this API could only handle `torch.Tensor` instances that had a `FunctionalTensorWrapper` TensorImpl. It now needs to handle python `FunctionalTensor`'s. In theory I can probably break BC and change this API (since it's private?), but I decided not to do it in this PR stack do minimize the chance of reverts. Instead of updating that API directly (which is in C++), I just added a python shim that first tries to unwrap the python `FunctionalTensor` if there is one, then calls the existing C++ logic

(2) `mirror_autograd_meta` is now a standalone API that tries to mirror the `requires_grad` and `is_leaf` autograd metadata from one tensor to another. Previously this was hardcoded into `torch._to_functional_tensor()`. But I now need to use it in a more standalone way: later in AOTAutograd when we unwrap and re-wrap a tensor subclasses, we need to manually mirror the autograd metadata from the original to the updated version of the subclass.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107917
Approved by: https://github.com/ezyang
ghstack dependencies: #106404
2023-09-15 20:19:25 +00:00
ydwu4
6140facf00 Support SymBool input to torch.compile (#107850)
We could have SymBool inputs for torch.compile, e.g. in the following situation:
```
def f(x:torch.Tensor):
  pred = x.size(0) == 3
  torch.compile(f)(pred, x)

make_fx(f, tracing_mode="symbolic")(x)
```

The idea of this PR (credit to @ezyang) is to support SymBool by re-using the infra we've already had for SymInt so that we don't need to replicate a lot of stuff.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107850
Approved by: https://github.com/ezyang
ghstack dependencies: #107662
2023-09-14 21:34:31 +00:00
PyTorch MergeBot
47f79e9a2b Revert "Support SymBool input to torch.compile (#107850)"
This reverts commit 9f6d70b2fd.

Reverted https://github.com/pytorch/pytorch/pull/107850 on behalf of https://github.com/huydhn due to Sorry for reverting this, but test_export_with_symbool_inputs is failing in trunk a08e1370ef ([comment](https://github.com/pytorch/pytorch/pull/107850#issuecomment-1718675877))
2023-09-14 02:53:36 +00:00
ydwu4
9f6d70b2fd Support SymBool input to torch.compile (#107850)
We could have SymBool inputs for torch.compile, e.g. in the following situation:
```
def f(x:torch.Tensor):
  pred = x.size(0) == 3
  torch.compile(f)(pred, x)

make_fx(f, tracing_mode="symbolic")(x)
```

The idea of this PR (credit to @ezyang) is to support SymBool by re-using the infra we've already had for SymInt so that we don't need to replicate a lot of stuff.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107850
Approved by: https://github.com/ezyang
ghstack dependencies: #107662
2023-09-14 01:16:29 +00:00
Brian Hirsh
5efd63b1b8 better support for fakeifying and dynamoing through torch_dispatch subclasses (with dynamic shapes) (#107415)
There is already some support for plumbing `__torch_dispatch__` tensor subclasses through dynamo, but this PR beefs it up a bit and adds a test. In particular:

(1) Fakeifying tensor subclasses didn't properly set autograd metadata (requires_grad, is_leaf) on the newly fakeified wrapper subclass. I don't actually have a test for this in this PR, but it's tested pretty heavily later in my aot autograd tests

(2) Fakeifying tensor subclasses didn't properly track source information for dynamic shapes on the inner tensors. I added a new `WrapperSubclassFieldSource` subclass, that represents a source coming from a tensor field on a wrapper subclass, which I use in the fakeifying logic, and again in symbolic_shapes.py to generate proper guards.

(3) `_make_wrapper_subclass()` marginally updated this code to work better with dynamic shapes. One thing that's a bit weird about `_make_wrapper_subclass`: it has two overloads, and the first explicitly does not support dynamic shapes (and the second.. does not support kwargs). I think that later we probably want to consolidate / at least make the first overload work with dynamic shapes, but I didn't want to handle that in this PR (so these smaller changes seemed like a strict improvement).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107415
Approved by: https://github.com/ezyang
2023-08-29 02:36:48 +00:00
ydwu4
31b0445702 Fix torch.compile with FakeTensor that has SymInt sizes (#107662)
**Motivation:**
When input FakeTensor to torch.compile has SymInt sizes (e.g. make_fx(opt_f, tracing_mode="symbolic"):
1. We cannot create a FakeTensor from that input in dynamo due to the SymInts.
2. We cannot check input tensors in guard check function and will abort due to tensor check calls sizes/strides.

For 1, we specialize the FakeTensor's SymInts using their hints. This is mostly safe since inputs mostly have concrete shapes and not computed from some DynamicOutputShape ops. We'll throw a data dependent error if the symint is unbacked.

For 2, we replace size/stride calls with the sym_* variants in TENSOR_CHECK guards' check function.

**Test Plan:**
See added tests.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107662
Approved by: https://github.com/ezyang
2023-08-23 05:27:57 +00:00
ydwu4
cbcd551045 Fix torch.compile FunctionalTensor inputs for higherOrderOps (#107604)
Before this PR, for the added [test](https://github.com/pytorch/pytorch/pull/107604/files#diff-c618f2274b6b5ccc533c580549d2e552edbd9fc5ac0da1aa4b00338525c8f78dR224), which feeds FunctionTensorWrapper inputs to higherOrderOperator, we have an assertion error in this line [code](https://github.com/pytorch/pytorch/pull/107604/files#diff-9f0663783bcd93e948e0491ef61b48123bdc9977bcc632fd707da578df13bfa1R1284).

The key difference of this PR is this [line ](https://github.com/pytorch/pytorch/pull/107604/files#diff-9f0663783bcd93e948e0491ef61b48123bdc9977bcc632fd707da578df13bfa1L1263)of check:
```python
        elif (
            isinstance(example_value, FakeTensor)
            and example_value.fake_mode is tx.fake_mode
        ):
```
The original intention of it seems to be dealing with case where we want to wrap an fx proxy for an intermediate fake tensor that's produced by some tensor ops and an example value is provided (as is the case for higherOrderOps [here](https://github.com/pytorch/pytorch/blob/main/torch/_dynamo/variables/higher_order_ops.py#L85)). A fakified FunctionalTensorWrapper(FakeTensor) always fails this check. This PR changes it to checking whether it's already fakified by tx.fake_mode.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107604
Approved by: https://github.com/zou3519
ghstack dependencies: #107569
2023-08-23 02:42:18 +00:00
ydwu4
a408920817 Reland fakify FunctionalTensor (#107569)
Try to rebase and reland https://github.com/pytorch/pytorch/pull/107062 . One difference compared with previous is to make the DTensor logic same as previously in _clone_input.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107569
Approved by: https://github.com/zou3519
2023-08-22 15:46:25 +00:00
PyTorch MergeBot
96c5be8bc4 Revert "Fakify leaf of FunctionalTensor (#107062)"
This reverts commit 3349725766.

Reverted https://github.com/pytorch/pytorch/pull/107062 on behalf of https://github.com/ydwu4 due to This appears to have broken the test TestDTensorCompile.test_dtensor_fullgraph. Probably a land race ([comment](https://github.com/pytorch/pytorch/pull/107062#issuecomment-1685447747))
2023-08-21 00:30:16 +00:00
ydwu4
3349725766 Fakify leaf of FunctionalTensor (#107062)
This PR allows dynamo to fakify FunctionalTensorWrapper by unwrapping, replacing and wrapping again for FunctionalTensorWrapper so that FunctionalTensorWrapper can be passed in as input for dynamo.optimize and we can support something like this
```python
ff = torch.func.functionalize(f)
torch.compile(ff)(x)
```

This PR didn't follow the \_\_tensor_flatten\_\_ and \_\_tensor_unflatten\_\_ protocol right now because we're not sure the plan of doing that for FunctionalTensorWrapper (it's implemented in C++).

**Test Plan:**
Add a new test.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107062
Approved by: https://github.com/zou3519
ghstack dependencies: #107042
2023-08-19 17:33:42 +00:00
PyTorch MergeBot
3c11184ca8 Revert "Fakify leaf of FunctionalTensor (#107062)"
This reverts commit 6cb0128c8a.

Reverted https://github.com/pytorch/pytorch/pull/107062 on behalf of https://github.com/ZainRizvi due to This appears to have broken the test TestDTensorCompile.test_dtensor_fullgraph.  Probably a land race ([comment](https://github.com/pytorch/pytorch/pull/107062#issuecomment-1684124230))
2023-08-18 16:02:54 +00:00
ydwu4
6cb0128c8a Fakify leaf of FunctionalTensor (#107062)
This PR allows dynamo to fakify FunctionalTensorWrapper by unwrapping, replacing and wrapping again for FunctionalTensorWrapper so that FunctionalTensorWrapper can be passed in as input for dynamo.optimize and we can support something like this
```python
ff = torch.func.functionalize(f)
torch.compile(ff)(x)
```

This PR didn't follow the \_\_tensor_flatten\_\_ and \_\_tensor_unflatten\_\_ protocol right now because we're not sure the plan of doing that for FunctionalTensorWrapper (it's implemented in C++).

**Test Plan:**
Add a new test.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107062
Approved by: https://github.com/zou3519
ghstack dependencies: #107042
2023-08-18 03:05:45 +00:00
ydwu4
c71828b097 Lift non-FakeTensor restriction for compile (#107042)
Currently, we have the assertion that dynamo won't accept FakeTensor input unless we're exporting. This PR try to remove this restriction to finish https://github.com/pytorch/pytorch/pull/105679.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107042
Approved by: https://github.com/ezyang, https://github.com/zou3519
2023-08-15 20:58:56 +00:00
Michael Lazos
05eea20eb9 [dynamo] Simulate torch function enablement state (#105091)
Part of https://github.com/pytorch/pytorch/issues/93723

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105091
Approved by: https://github.com/voznesenskym, https://github.com/anijain2305
2023-07-13 17:42:20 +00:00
Michael Lazos
dfe7a1e089 [dynamo] Support wrapping + returning tensor subclasses (#104802)
as title - used for [tracing the FSDP collectives](d8cb80e382/torch/distributed/_functional_collectives.py (L425))

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104802
Approved by: https://github.com/jansel
2023-07-09 22:16:10 +00:00