Commit Graph

120 Commits

Author SHA1 Message Date
Tugsbayasgalan (Tugsuu) Manlaibaatar
ff32f6c93b Use freshly traced jit-traced module to be used in export analysis (#127577)
Summary: When we export already traced module, it seems to be modifying some global state causing the traced modules to fail to run. For now, we are only logging for test cases, so it is probs ok to trace fresh copy to be used in export for now.

Test Plan: CI

Differential Revision: D57983518

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127577
Approved by: https://github.com/pianpwk
2024-06-04 16:54:23 +00:00
Pian Pawakapan
8a31c2aa84 [export] allow complex guards as runtime asserts (#127129)
With the current state of export's dynamic shapes, we struggle with guards and constraints that are beyond the current dynamic shapes language, expressed with dims and derived dims. While we can compile and guarantee correctness for guards within the current language (e.g. min/max ranges, linear relationships, integer divisibility) we struggle to dynamically compile guards which extend beyond that.

For these "complex" guards, we typically do either of the following: 1) raise a constraint violation error, along the lines of "not all values of <symbol> in the specified range satisfy <guard>", with or without suggested fixes, 2) specialize to the provided static values and suggest removing dynamism, or 3) fail compilation due to some arbitrary unsupported case. Previous [work](https://github.com/pytorch/pytorch/pull/124949) went towards resolving this by disabling forced specializations, instead allowing the user to fail at runtime with incorrect inputs.

In this PR, relying on [hybrid backed-unbacked symints](https://github.com/pytorch/pytorch/issues/121749), [deferred runtime asserts](https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/runtime_assert.py), and the function [_is_supported_equivalence()](d7de4c9d80/torch/fx/experimental/symbolic_shapes.py (L1824)), we add a flag `_allow_complex_guards_as_runtime_asserts` which allows the user to compile exported programs containing these guards and maintain dynamism, while adding correctness checks as runtime assertions in the graph.

Hybrid backed-unbacked symints allow us to easily bypass "implicit" guards emitted from computation - guards that we ~expect to be true. Popular examples revolve around reshapes:
```
# reshape
def forward(self, x, y):  # x: [s0, s1], y: [s2]
    return x.reshape([-1]) + y  # guard s0 * s1 = s2

This leads to the following exported program

class GraphModule(torch.nn.Module):
    def forward(self, x: "f32[s0, s1]", y: "f32[s2]"):
        sym_size_int: "Sym(s2)" = torch.ops.aten.sym_size.int(y, 0)
        mul: "Sym(-s2)" = -1 * sym_size_int;  sym_size_int = None
        sym_size_int_1: "Sym(s0)" = torch.ops.aten.sym_size.int(x, 0)
        sym_size_int_2: "Sym(s1)" = torch.ops.aten.sym_size.int(x, 1)
        mul_1: "Sym(s0*s1)" = sym_size_int_1 * sym_size_int_2;  sym_size_int_1 = sym_size_int_2 = None
        add: "Sym(s0*s1 - s2)" = mul + mul_1;  mul = mul_1 = None
        eq: "Sym(Eq(s0*s1 - s2, 0))" = add == 0;  add = None
        _assert_scalar = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(s0*s1 - s2, 0) on node 'eq'");  eq = None

        view: "f32[s0*s1]" = torch.ops.aten.view.default(x, [-1]);  x = None
        add_1: "f32[s0*s1]" = torch.ops.aten.add.Tensor(view, y);  view = y = None
        return (add_1,)
```
Another case is symbol divisibility:
```
def forward(self, x):  # x: [s0, s1]
    return x.reshape([-1, x.shape[0] - 1])  # Eq(Mod(s0 * s1, s0 - 1), 0)
```

Applying deferred runtime asserts also helps dynamic compilation for "explicit" complex guards that typically cause problems for export. For example we can generate runtime asserts for not-equal guards, and complex conditions like the following:
```
class Foo(torch.nn.Module):
    def forward(self, x, y):
        # check that negation of first guard also shows up as runtime assertion
        if x.shape[0] == y.shape[0]:  # False
            return x + y
        elif x.shape[0] == y.shape[0] ** 3:  # False
            return x + 2, y + 3
        elif x.shape[0] ** 2 == y.shape[0] * 3:  # True
            return x * 2.0, y * 3.0
```
For the above graph we will generate 3 runtime assertions: the negation of the first 2, and the 3rd condition as a guard.

One additional benefit here over the current state of exported programs is that this adds further correctness guarantees - previously with explicit complex guards, if compilation succeeded, the guards would be ignored at runtime, treated as given.

As shown above, the runtime asserts appear as math ops in the graph, generated by the sympy interpreter, resulting in an _assert_scalar call. There is an option to avoid adding these asserts into the graph, by setting `TORCH_DYNAMO_DO_NOT_EMIT_RUNTIME_ASSERTS=1`. This results in the "original" computation graph, with dynamism, and any incorrect inputs will fail on ops during runtime. Further work could go into prettifying the printer, so the majority of the graph isn't guard-related.

Ideally this PR would subsume and remove the recently added [_disable_forced_specializations](https://github.com/pytorch/pytorch/pull/124949) flag, but that flag still handles one additional case of specialization: single-variable equalities where the symbol is solvable for a concrete value: see this [PR](https://github.com/pytorch/pytorch/pull/126925)

This PR doesn't change any behavior around data-dependent errors/unbacked symints yet, that could be further work.

NOTE: will take naming change suggestions for the flag :)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127129
Approved by: https://github.com/avikchaudhuri
2024-05-29 17:15:25 +00:00
Pian Pawakapan
f206c5c628 [export] handle new roots & root swapping in derived dims suggested fixes (#125543)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125543

This PR address 2 issues with derived dim suggested fixes, 1) newly introduced roots, and 2) root swapping.

1 | Newly introduced roots appear with modulo guards, e.g. Mod(dx, 2) = 0 suggests dx is a derived dim equal to 2 * _dx, introducing a new root _dx. Currently the final suggested fixes handle this correctly, but we can get intermediate results where related derived dims don't rely on a unified root, and are a mixture of min/max range and derived suggestions.

For example:
```
"dx": {"eq": 3*_dx-1, "max": 36}
"dy": {"eq": dx+1}
This should lead to suggested fixes
  _dx = Dim('_dx', max=12)
  dx = 3 * _dx - 1
  dy = 3 * _dx
```

This PR prettifies the suggested fixes routine by unifying to a single root, and making each intermediate suggestion either a derived dim or min/max range, not both.

2 | The current suggested fixes for derived dims can lead to root dims/derived dims being swapped, e.g. `dy - 1, dy` -> `dx, dx + 1`. This leads to problematic suggested fixes that look like `dy - 1 = Dim("dy - 1")` since we don't have access to the original variable name.

This PR only adds a suggested fix for the root dim, and removes all other derived suggestions.

For example, with the export test case test_derived_dim_out_of_order_simplified:
```
_dimz = torch.export.Dim("_dimz", min=6, max=8)
dimy = _dimz - 1
dimx = dimy - 1
dimz = torch.export.Dim("dimz", min=6, max=8)  # doesn't work, should be = _dimz

class Foo(torch.nn.Module):
    def forward(self, x, y, z):
        return x + y[1:] + z[2:]

foo = Foo()
u, v, w = torch.randn(5), torch.randn(6), torch.randn(7)
export(
    foo,
    (u, v, w),
    dynamic_shapes=({0: dimx}, {0: dimy}, {0: dimz}),
)
```

Before:
```
Suggested fixes:
  _dimz = Dim('_dimz', min=3, max=9223372036854775807)  # 2 <= _dimz - 1 <= 9223372036854775806
  _dimz - 2 = Dim('_dimz - 2', min=4, max=6)
  _dimz = Dim('_dimz', min=2, max=9223372036854775806)  # 2 <= _dimz <= 9223372036854775806
  _dimz - 1 = _dimz - 1
  dimz = _dimz
```

New suggested fixes:
```
Suggested fixes:
  dimz = _dimz
```

Note: This assumes the specified derived relations between dims are correct. This should be valid because: 1) if the relation is plain wrong (e.g. (dx, dx - 1) provided with inputs (6, 4)), this gets caught in beforehand in produce_guards. 2) if the relation is correct but does not match the emitted guard, for example:
```
def forward(self, x, y):
    return x.reshape([-1]) + y  # guard: s0 * 2 = s1
dx = Dim("dx")
export(
    model,
    (torch.randn(6, 2), torch.randn(12)),
    dynamic_shapes={"x": (dx, 2), "y": (dx + 6, )}
)
```
This produces two linear equations, leading to specialization since a) produce_guards is able to solve for a concrete value, and b) the export constraint solver will anyways force specializations due to range constraints.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125543
Approved by: https://github.com/avikchaudhuri
2024-05-28 20:41:43 +00:00
Tugsbayasgalan (Tugsuu) Manlaibaatar
9521528f71 Log export result of torch.jit.trace to scuba (#126900)
Summary: We want to track how well torch.jit.trace can be converted to export in large scale. As a first step, we log all of torch.jit.trace unittests whether we can convert the traced module to export module OR we can export the model directly

Test Plan: CI

Differential Revision: D57629682

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126900
Approved by: https://github.com/SherlockNoMad
2024-05-28 17:49:34 +00:00
Jiashen Cao
254783ce80 [Fix]: populate input parameter name when convert TorchScript to ExportedProgram (#126787)
## Goal
As title

## Design
Based on the fact that each TorchScript module has a `code` property which provides the original source code for the `forward` function, I implemented a function to extrapolate `forward` function signature by using the AST parser.

Some other tradeoff
* Directly parsing src code as string --> will be very buggy
* Directly using `compile` function in Python to get the function object --> raises a lot of exceptions because of missing packages or undefined variable names
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126787
Approved by: https://github.com/angelayi, https://github.com/tugsbayasgalan
2024-05-28 17:33:44 +00:00
Tugsbayasgalan (Tugsuu) Manlaibaatar
246311c944 Unconditionally add asserts after export (#127132)
Summary: Today AOTAutograd drops some of assert nodes so we reapply it after strict export.

Test Plan: CI

Reviewed By: angelayi

Differential Revision: D57786907

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127132
Approved by: https://github.com/zhxchen17
2024-05-28 06:31:39 +00:00
Sheng Fu
bbeb0906c4 Register creak_node_hook (#126671)
Differential Revision: D57469157

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126671
Approved by: https://github.com/angelayi
2024-05-24 23:32:15 +00:00
Jiashen Cao
041e8d73fd Separate non/strict functions in _export (#126718)
Move non/strict _export to different functions.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126718
Approved by: https://github.com/angelayi
2024-05-23 17:41:23 +00:00
Edward Z. Yang
0d17aae242 Teach FakeTensor to fill in item_memo when converting scalar CPU tensor (#126245)
This PR requires a little justification, but let's start with what it does first:

1. When you have a 0d CPU scalar int64/float64 tensor input to a graph, we will preallocate a backed SymInt/SymFloat corresponding to what you would get if you call item() on this tensor. This means you can freely change your input to be a Python int/float or a Tensor with an item() call and end up with exactly the same level of expressivity (specifically, you can guard on the internal SymInt/SymFloat no matter what). By default, the source of the backed SymInt/SymFloat is `L['tensor'].item()`, but if you have promoted a float input into a Tensor, we will cancel out `torch.as_tensor(L['float']).item()` into just `L['float']`.
2. We switch wrap_symfloat to use this, instead of hand crafting the new SymNodeVariable. Everything works out, except that we carefully pass the item() result to tracked fakes (and not the fake Tensor argument)

OK, so why do this at all? There is some marginal benefit where now some item() calls on scalar inputs can be guarded on, but IMO this is a pretty marginal benefit, and if it was the only reason, I wouldn't do this. The real reason for this is that I need to be able to propagate fake tensors through the graphs that are produced by Dynamo, and if I am doing the old custom wrap_symfloat logic, there's no way I can do this, because ordinarily an item() call will cause an unbacked SymInt when I reallocate.

The other obvious way to solve the problem above is to make a HOP alternative that item() that "bakes in" the backed SymInt its supposed to return. But this strategy seems more parsimonious, and it does have the marginal benefit I mentioned above. The main downside is that what I have to do next, is make it so that when I run tensor computation, I also apply the equivalent operations to the SymInt/SymFloat as well. That's next PR.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126245
Approved by: https://github.com/eellison
ghstack dependencies: #126637
2024-05-22 15:25:38 +00:00
Jiashen Cao
99af1b3ab0 Refactor variables / function names related to non-strict export (#126458)
Improve variable and function naming for better clarity: `non strict` --> `aten`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126458
Approved by: https://github.com/angelayi
2024-05-18 06:05:14 +00:00
Tugsbayasgalan Manlaibaatar
bed1c600bb Experimental prototype for converting torch.jit.trace modules to export (#124449)
Differential Revision: [D56440613](https://our.internmc.facebook.com/intern/diff/D56440613)

We want to do this for following reasons:
1. There is current limitation in export tracing for torch.jit.trace d modules that cannot be easily upstreamed
2. We need to run internal CI regularly to understand feature gaps and continuously track them
3. Multiple people will be working on this prototype so it is better to have a checked in version so we don't always run into merge conflicts.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124449
Approved by: https://github.com/angelayi, https://github.com/avikchaudhuri
2024-05-17 20:42:42 +00:00
Matthew Hoffman
81277baa0c Remove removed ruff rule TRY200 (#126256)
My TOML linter is complaining that "TRY200" is not acceptable for the `tool.ruff.lint` schema.

From the ruff docs: https://docs.astral.sh/ruff/rules/reraise-no-cause/

> This rule has been removed and its documentation is only available for historical reasons.
>
> This rule is identical to [B904](https://docs.astral.sh/ruff/rules/raise-without-from-inside-except/) which should be used instead.

and we are currently explicitly ignoring B904.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126256
Approved by: https://github.com/Skylion007
2024-05-17 16:31:05 +00:00
Wang, Eikan
08aa704d0c [1/N] Non-Tensor: Scalar Support: Enable aot compile to support aten operations with scalar input like alpha (#124177)
Some operations have a scalar input parameter, like `torch.add(a, b, alpha=2.0)`.  Currently, the aot compile does not support such a case because it requires the signature of the captured graph to align with the operation's signature. This means that some inputs in the captured graph may be scalar(float, int, bool, etc.). It breaks the assumption of `compile_fx_aot` as it assumes all the example inputs are tensor - 0f6ce45bcb/torch/_inductor/compile_fx.py (L1048)

This PR intends to support such cases by allowing not-aligned signature and filtering out the non-Tensor parameters.

Captured graph for `torch.add(a, b, alpha=2.0)`

```
opcode         name      target           args              kwargs
-------------  --------  ---------------  ----------------  --------------
placeholder    arg0_1    arg0_1           ()                {}
placeholder    arg1_1    arg1_1           ()                {}
call_function  add       aten.add.Tensor  (arg0_1, arg1_1)  {'alpha': 2.0}
output         output_1  output           ((add,),)         {}
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124177
Approved by: https://github.com/jansel, https://github.com/desertfire, https://github.com/jgong5
2024-05-16 05:15:55 +00:00
Pian Pawakapan
e046c59e5b [export] handle aliased/unused params for unflattening (#125758)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125758

Aliased and unused params are currently an issue for strict-mode export. For a model like this:
```
def __init__(self):
    # ...
    self.alpha = nn.Parameter(torch.randn(4))
    self.beta = self.alpha
    self.gamma = self.alpha
def forward(self, x):
    return x + self.beta
```
Dynamo will trace only 1 parameter (beta) and assign a dynamo name (e.g. `L__self___beta`) which can be difficult to match to the correct FQN in the original eager module. This leads to export graph signature potentially having the incorrect target FQN for the parameter, leading to downstream issues unflattening (the parameter may be assigned to the wrong target attribute, mismatching the relevant placeholder node in the unflattened module).

This handles aliasing issues by assigning all tensors present in the state dict as module attributes, even if they're unused. Still, only the used tensors will appear in the graph's forward pass.

Another issue that exists is weight-sharing is not maintained in unflattening (all params/buffers are re-cloned) - handle this by checking tensor ids too.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125758
Approved by: https://github.com/zhxchen17
2024-05-14 23:00:46 +00:00
Pian Pawakapan
c9a258e474 [export] handle constant aliasing for export (#125509)
Summary: Currently export will [error out](2b5ae2611e/torch/export/_trace.py (L477)) if a constant is aliased. This PR supports this by modifying ConstantAttrMap to map constants to a list of FQNs instead of a single FQN, populating the ExportedProgram constants dict to contain multiple entries to the same constant.

Test Plan: added test case in test_export.py

Differential Revision: D56955654

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125509
Approved by: https://github.com/angelayi, https://github.com/ydwu4
2024-05-10 00:14:37 +00:00
Tugsbayasgalan Manlaibaatar
0e419b9146 Fix graph partitioner and make runtime assertion work with submodules in export (#125793)
Summary: This fix does three things:

1. When we add inputs from partioner to the top level graph module, we insert in the order of partioner which is not guaranteed to be same as original graph inputs. This PR fixes that.
2. When we replace autograd ops with HOP, we create new submodules and access their outputs via getitem calls. As a result, previous node names associated with getitem gets updated, resulting in the graph being different from produced graph signature. So I just update the graph signature accordingly.
3. We run runtime_assertion pass before autograd HOP pass because the constraints won't be populated correctly.

Differential Revision: [D57130314](https://our.internmc.facebook.com/intern/diff/D57130314)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125793
Approved by: https://github.com/zhxchen17
2024-05-09 18:13:46 +00:00
Pian Pawakapan
f4b2d50fd7 [export] disable_forced_specializations (#124949)
Summary:
By default, some inferred dynamic shapes guards/constraints that are not expressible with the current dynamic shapes language will lead to specialization to the concrete input values provided. If disable_forced_specializations is set to True, we will not specialize, and will not perform runtime checks on such produced guards. Instead, we allow the user to specify arbitrary shapes, and fail during runtime if the inputs are invalid. Constraints expressible with the language (e.g. ranges, linear derived dims) will still be enforced, and behavior for all other guards remains the same.

Cases where we typically specialize are reshapes:
```
x: [4, 6]  # [s0, s1]
x = x.reshape([x.shape[0] - 1, -1])
# this emits a guard Mod(s0*s1, s0-1) = 0, we specialize on s0=4, s1=6

x: [4, 6], y: [24]  # [s0, s1], [s2]
x = x.reshape([-1]) + y
# this emits a guard s0*s1 = s2, we specialize on s0=4, s1=6, s2=24
```

For now only applicable for non-strict mode (need to figure out how to pass this flag into dynamo's call of produce_guards).

Test Plan: Added test case that checks compilation, runtime, and suggested fixes behavior.

Differential Revision: D56361177

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124949
Approved by: https://github.com/avikchaudhuri
2024-05-08 18:42:39 +00:00
angelayi
8be4c1bc2f [export] Add metadata for nodes insert_deferred_runtime_asserts (#125414)
Fixes [internal error](https://fb.workplace.com/groups/1075192433118967/permalink/1416709435633930/).

The issue is that the asserting nodes added in the `insert_deferred_runtime_assertion` pass do not contain metadata that the ExportedProgram requires the graph to have. One solution to fix this is to retrace the entire module, or another solution is to manually add back this metadata.

This diff implements the latter solution (manually add back the metadata) through hooking into fx.graph's `create_node` function, and adding export-specific metadata for every node that is created. The reason I did this is so that the `insert_deferred_runtime_assertion` does not have to know about what metadata export wants.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125414
Approved by: https://github.com/zhxchen17, https://github.com/BoyuanFeng
2024-05-07 23:15:21 +00:00
ydwu4
0302dc68bf [Reland] Fakify script object inputs and attributes for non-strict ex… (#125490)
A re-land of #124239.

This PR fakify ScriptObject inputs and attributes in export non-strict mode by default.

The basic idea is to only fakify the script object during tracing (i.e. aot_export). After we get the traced graph module, eagerly executing, serializing, or running more passes will use the real script objects. This is essentially treating the script object as constant tensor.

Concretely, we

fakify all the script object inputs, and module attributes (gathered by constant_attrs).
patch the module's attributes with fakified script object
right after aot_export, remove the patching (to avoid changing the original module) then modify the exported graph module's attribute to real script object.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125490
Approved by: https://github.com/angelayi
2024-05-04 02:39:42 +00:00
Pian Pawakapan
ef757a5c00 [export] use tree_map for _flatten_dynamic_shapes (#125415)
Summary:
Fixing the implementation of `_flatten_dynamic_shapes()`, to follow how `_process_dynamic_shapes()` does it. The previous implementation would misinterpret some nested dynamic shapes specs, causing it to miss out on some shapes specs, for example with nested inputs/constant input tuples:

```
inputs = (
    (2, 1),
    (
        torch.randn(2, 1),
        torch.randn(2, 2),
        torch.randn(2, 3),
    )
)

dynamic_shapes = (
    (None, None),
    (
        None,
        None,
        None,
    )
)
```
This would get interpreted as 2 shapes specs for 2d and 3d tensors. Fixing so this doesn't happen.

Test Plan: Existing export tests

Differential Revision: D56894923

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125415
Approved by: https://github.com/angelayi
2024-05-03 04:59:17 +00:00
PyTorch MergeBot
f1f142c44f Revert "Fakify script object inputs and attributes for non-strict export (#124239)"
This reverts commit ecc2e034f7.

Reverted https://github.com/pytorch/pytorch/pull/124239 on behalf of https://github.com/kit1980 due to breaking internal builds ([comment](https://github.com/pytorch/pytorch/pull/124239#issuecomment-2089305447))
2024-05-01 23:56:00 +00:00
Avik Chaudhuri
746da8755c switch tests from constrain_as* to torch._check* (#125253)
To fix data-dependent errors we want to recommend that people use `torch._check*` APIs. The `constrain_as*` APIs should be fully subsumed by them, and in the future we should kill them entirely.

Differential Revision: D56774333

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125253
Approved by: https://github.com/ezyang
2024-05-01 21:01:27 +00:00
ydwu4
ecc2e034f7 Fakify script object inputs and attributes for non-strict export (#124239)
This PR fakify ScriptObject inputs and attributes in export non-strict mode by default.

The basic idea is to `only fakify the script object during tracing (i.e. aot_export)`. After we get the traced graph module, eagerly executing, serializing, or running more passes will use the real script objects. This is essentially treating the script object as constant tensor.

Concretely, we
1. fakify all the script object inputs, and module attributes (gathered by constant_attrs).
2. patch the module's attributes with fakified script object
3. right after aot_export, remove the patching (to avoid changing the original module) then modify the exported graph module's attribute to real script object.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124239
Approved by: https://github.com/zou3519
2024-04-30 15:57:25 +00:00
Avik Chaudhuri
e7846447e0 dynamic shapes builder API (#124898)
This PR introduces a new way of building `dynamic_shapes` for export. The idea is to build up a mapping from input tensors to the dynamic shapes that should be assigned to their corresponding fake tensors.

This mapping is automatically converted to the current form of `dynamic_shapes`, which must exactly match the structure of inputs. We do this by using pytree utils.

With the current `dynamic_shapes`, we had to be careful about user-defined classes that are registered with pytree, since  such classes are not necessarily polymorphic containers; they may be fine containing tensors, but not dynamic shapes. Thus we had decided to allow input instances of such classes to be associated with dynamic shapes in flattened form. This decision needs to be mirrored in this PR as well. To make it easier to keep these code paths in sync, we refactor the current recursive procedure for associating inputs with dynamic shapes to use the same pytree utils. This needs minor fixes to a few tests where `dynamic_shapes` were not exactly matching the structure of inputs.

Differential Revision: D56551992

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124898
Approved by: https://github.com/zhxchen17
2024-04-30 03:59:49 +00:00
Pian Pawakapan
946e202c07 [export] Restore user input names to unlifted graph modules (#124765)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/122842

Currently, calling ep.module() on an ExportedProgram leads to a GraphModule with a default forward signature (e.g. arg_0, arg_1, ...). This leads to original placeholder names disappearing for retracing/re-exporting.

Fixing this issue by creating a forward_arg_names field (will take renaming suggestions for this), that stores the positional & keyword arg names that are used. These names aren't present in the call_spec currently stored, and requires a major version bump for the ExportedProgram schema.

Test Plan: Tests exist for export, but names are now changed from generic (e.g. arg_0, arg_1) to follow user inputs (e.g. x, y)

Differential Revision: D56484994

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124765
Approved by: https://github.com/zhxchen17
2024-04-29 20:58:17 +00:00
Tugsbayasgalan (Tugsuu) Manlaibaatar
06b845dedc Make metadata serialization more strict (#124411)
Summary: When I was debugging an issue, this silent error makes the debugging harder. It is better to error earlier with more descriptive error message.

Test Plan: None

Differential Revision: D56312433

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124411
Approved by: https://github.com/zhxchen17
2024-04-29 02:11:40 +00:00
Tugsbayasgalan (Tugsuu) Manlaibaatar
cc06c00a56 Don't run auto grad safe mode when predispatch is on (#125066)
Summary: Title

Test Plan: CI

Differential Revision: D56646678

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125066
Approved by: https://github.com/zhxchen17
2024-04-29 01:53:23 +00:00
Zhengxu Chen
7bb89bcaa4 [export] Fix state dict reparametrization in non-strict. (#124847)
Summary:

There are multiple things implemented incorrectly in non strict for reparametrizing state dict:
1. The same fake tensor should be generated for duplicated weights.
2. We should snapshot state dict in the beginning to always hold the invariant that ep.state_dict == mod.state_dict()
3. We will overwrite real weights with fake weights if we don't restore the weights in LIFO ordering.
4. We don't turn on strict checking which could sliently fail on corner cases.

This diff aims to solve all these issues at once.

Test Plan: CI

Differential Revision: D56505020

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124847
Approved by: https://github.com/pianpwk
2024-04-25 22:44:16 +00:00
Pian Pawakapan
93a319a4fc [export] kill _process_constraints() (#123985)
The process for populating range_constraints follows separate methods for non-strict (`make_constraints`), and strict (`_process_constraints`). The strict method is somewhat more convoluted, and the analysis that Dynamo performs for strict is already present as part of the non-strict process in make_constraints (produce_guards(), running the export constraint solver).

This PR kills _process_constraints() and replaces calls with make_constraints, without duplicating the work that Dynamo already does.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123985
Approved by: https://github.com/avikchaudhuri
2024-04-25 16:58:57 +00:00
Zhengxu Chen
d40774f4ed [export] Fix up nn_module_stack for nodes occured around tracepoint ops. (#124457)
Summary: as title.

Test Plan:
hg checkout D55901896
buck run mode/opt torchrec/ir/tests:test_serializer -- --filter-regex test_serialize_deserialize_ebc

Differential Revision: D56340319

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124457
Approved by: https://github.com/tugsbayasgalan
2024-04-23 20:26:44 +00:00
Pian Pawakapan
e112792a69 [export] refactor _AddRuntimeAssertionsForInlineConstraintsPass (#124503)
Summary:
The current _AddRuntimeAssertionsForInlineConstraintsPass has 2 known issues caused by its use of torch.fx.Interpreter:
1. SymInt-related ops (e.g. item()) are executed, causing new Unbacked SymInts to appear in the graph during the pass.
2. The graph is reconstructed, and node names/indices can be different from before, causing mismatches with `module_call_graph`, and leading to issues during unflattening.

This refactors the pass to use PassBase instead of _ExportPassBaseDeprecatedDoNotUse, only constructing new nodes for assertions.

Test Plan: This pass is called on all strict-mode export calls with range_constraints, test that behavior remains unchanged.

Differential Revision: D56360137

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124503
Approved by: https://github.com/zhxchen17
2024-04-23 20:07:49 +00:00
Boyuan Feng
aa2da0cdd2 [Export] Add runtime assert to non-strict export (#123681)
This PR moves insert_deferred_runtime_asserts from dynamo to torch.fx.passes and uses it to add runtime assertion for non-strict export.

Differential Revision: D55944267

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123681
Approved by: https://github.com/tugsbayasgalan, https://github.com/angelayi
2024-04-18 16:13:27 +00:00
Pian Pawakapan
90d1720861 [export] Restore original placeholder names (part 3: constant input de/serialization) (#123590)
Summary:
note: breaking the original diff D55225818 into 3 parts (top-level renaming, higher-order-op subgraphs, constant input de/serialization) because of its size.

Stacked PR to restore original names to placeholder nodes, replacing the default names arg0_1, arg1_1, ...

This PR supports constant argument placeholder (e.g. forward(self, x, y=1)) names and de/serialization, by adding a name field for ConstantArguments in the graph signature, and ConstantInputSpec in the input specs for serialization.

Test Plan: verification checks on placeholder names for all export() calls, unit test in test/export/test_export.py

Differential Revision: D55506949

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123590
Approved by: https://github.com/angelayi, https://github.com/zhxchen17
2024-04-15 19:09:41 +00:00
Avik Chaudhuri
5961e23e76 primitive attribute assignment (#123898)
This PR ensures that assignment of attributes of primitive type work without needing any code changes in non-strict mode. (In a previous PR we banned attribute assignments of tensor type unless such attributes are registered as buffers.)

While strict mode errors on (all) attribute assignments, non-strict doesn't care, so one might assume that this kind of attribute assignment should already work in non-strict. However, there's a problem: we run through the program once for metadata collection and then run through it again for tracing, so the values observed during tracing (and potentially burned into the graph) do not reflect what should have been observed had the metadata collection pass not run.

So the only thing this PR needs to do is restore values of assigned attributes of primitive type once the metadata collection pass has run. We do this by moving the attribute assignment detecting context manager from the overall `aot_export` call in `_trace.py` to the metadata collection pass in `aot_autograd.py`, and extending it. The rest of the PR moves some utils around.

Differential Revision: D56047952

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123898
Approved by: https://github.com/angelayi
2024-04-13 05:27:52 +00:00
Pian Pawakapan
d0ccf599cc [export] Restore original placeholder names (part 2: higher-order-op subgraph naming) (#123587)
Summary:
note: breaking the original diff [D55225818](https://www.internalfb.com/diff/D55225818) into 3 parts (top-level renaming, higher-order-op subgraphs, constant input de/serialization) because of its size.

Stacked PR to restore original names to placeholder nodes, replacing the default names arg0_1, arg1_1, ...

This PR propagates node names to higher-order-op subgraph placeholders, retaining the top-level names and handling naming collisions by suffixing other non-placeholder nodes in the subgraph with an index. This is the same handling as in fx.Graph/fx.Node, but implemented separately as a pass.

Since the input schemas of HOO subgraphs are very different, they are enumerated in _name_hoo_subgraph_placeholders(). Currently cond, map_impl, and wrap_with_set_grad_enabled are handled, but other ops can be easily added.

Test Plan: verification checks on placeholder names for all export() calls, unit test in test/export/test_export.py

Differential Revision: D55456749

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123587
Approved by: https://github.com/angelayi
2024-04-11 22:40:46 +00:00
Avik Chaudhuri
10a03c56e5 fix leaky fake tensor on attribute assignment, support buffer assignment (#122337)
In non-strict, assignment of attributes in a model causes their state to contain fake tensors post-tracing, which leads to incorrect results on running the exported model. We now error when this happens, asking the user to use buffers instead.
Next, we add support for assignment of buffers. The final values of the buffers turn into outputs of the graph. Since the buffers are already lifted as inputs and populated with the initial values when the model is run, this leads to a simple programming model where the driver of the model can feed the outputs back as inputs for successive runs.

Differential Revision: D55146852

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122337
Approved by: https://github.com/bdhirsh, https://github.com/tugsbayasgalan
2024-04-11 18:08:31 +00:00
Pian Pawakapan
42c2a5477c [export] nn_module_stack to return class name str (#123308)
Previously, `node.meta["nn_module_stack"]` had type `Dict[str, Tuple[str, class]]` when exported, and later `Dict[str, Tuple[str, str]]` after de/serialization. This PR changes it to consistently be `Dict[str, Tuple[str, str]]` for round-trippability, i.e.
```
{..., 'L__self___conv': ('conv', 'torch.nn.modules.conv.Conv2d')}
```

`source_fn_stack` is left untouched in this PR.

note: the `Union[type, str]` type annotations in ONNX are because ONNX goes through both `export.export()` and `_dynamo.export()` (which still has the original `Dict[str, Tuple[str, class]]` format). nn_module_stack from `export.export()` should consistently have the new format, and we verify/test for that in `_trace.py`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123308
Approved by: https://github.com/zhxchen17, https://github.com/thiagocrepaldi
2024-04-05 21:48:22 +00:00
Pian Pawakapan
d7f23f6826 [export] Restore original placeholder names (part 1: top-level renaming) (#122904)
Summary:
This PR restores original names to placeholder nodes, replacing the default names arg0_1, arg1_1, and so on.

User inputs now follow the signature of mod.forward(), for example forward(x, y) produces nodes x, y. If the tensors are nested in dictionaries, lists, tuples, or dataclasses, the names are a concatenation of the path to the tensor, e.g. x = {'a': torch.randn(4), 'b': [torch.randn(4), torch.randn(4)]} produces nodes x_a, x_b_0, x_b_1.

Parameters, buffers, constants, and custom objects follow the FQN of the object, prefixed by "p", "b", "c", and "obj" respectively. For example, self.bar.l0.weight gets you p_bar_l0_weight.
Effect tokens are named token_1, token_2, and so on, since they are not grounded in model inputs or named attributes.

note: breaking the original diff into 3 parts (top-level renaming, higher-order-op subgraphs, constant input de/serialization) because of its size.

Examples:
```python
# params, buffers, constants, inputs, torch.cond

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, p_l0_weight: "f32[4, 4]", p_l0_bias: "f32[4]", c_alpha: "f32[4]", b_beta: "f32[4]", x_0_a: "f32[4, 4]", y: "f32[4, 4]"):
            # No stacktrace found for following nodes
            mul: "f32[4, 4]" = torch.ops.aten.mul.Tensor(x_0_a, x_0_a)
            t: "f32[4, 4]" = torch.ops.aten.t.default(p_l0_weight);  p_l0_weight = None
            addmm: "f32[4, 4]" = torch.ops.aten.addmm.default(p_l0_bias, y, t);  p_l0_bias = y = t = None
            return addmm

# model code

class Bar(torch.nn.Module):
    def forward(self, x):
        return x * x
class Foo(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.bar = Bar()
        self.l0 = torch.nn.Linear(4, 4)
        self.alpha = torch.randn(4)
        self.register_buffer('beta', torch.randn(4))
    def forward(self, x, y):
        x = x[0]['a']
        mul = self.bar(x)
        z1 = self.l0(y)
        return z1

# custom objects, dataclasses, tokens, constant inputs

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, token_1: "f32[0]", obj_attr, data_x: "f32[4, 4]", data_y: "f32[4, 4]", mode):
            # No stacktrace found for following nodes
            mul: "f32[4, 4]" = torch.ops.aten.mul.Scalar(data_x, 30);  data_x = None
            div: "f32[4, 4]" = torch.ops.aten.div.Tensor_mode(data_y, 1.0, rounding_mode = 'floor');  data_y = None
            add: "f32[4, 4]" = torch.ops.aten.add.Tensor(mul, div);  mul = div = None
            with_effects = torch._higher_order_ops.effects.with_effects(token_1, torch.ops._TorchScriptTesting.takes_foo.default, obj_attr, add);  token_1 = obj_attr = add = None
            getitem: "f32[0]" = with_effects[0]
            getitem_1: "f32[4, 4]" = with_effects[1];  with_effects = None
            return (getitem, getitem_1)

# model code

class Foo(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
    def forward(self, data, a=1.0, mode="floor"):
        x = self.attr.add_tensor(data.x) + torch.div(data.y, a, rounding_mode=mode)
        x = torch.ops._TorchScriptTesting.takes_foo(self.attr, x)
        return x

dataclass
class DataClass:
    x: Tensor
    y: Tensor
register_dataclass_as_pytree_node(
    DataClass,
    serialized_type_name="test.DataClass"
)

args = (DataClass(x=torch.randn(4, 4), y=torch.randn(4, 4)), )
kwargs = {'mode': 'floor'}
ep = torch.export.export(Foo(), args, kwargs, strict=False)

```

Test Plan: verification checks on placeholder names for all export() calls, unit test in test/export/test_export.py

Differential Revision: D55456418

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122904
Approved by: https://github.com/angelayi, https://github.com/thiagocrepaldi
2024-04-05 18:56:00 +00:00
Pian Pawakapan
4b1b4db231 [export] Add stack_trace for non-strict export (#121034)
This addresses 2 issues with stack_trace metadata:
- stack_trace is currently missing from nodes in non-strict export
- in strict mode, stack_trace is populated for placeholder nodes, which may not be well-defined (with multiple uses)

We filter the call stack during tracing for calls from forward() methods, or ops in `torch.__init__.py` (e.g. sym_size_int, sym_constrain_range, etc.) to populate stack_trace. A node-level check is also added to _export_non_strict().

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121034
Approved by: https://github.com/angelayi
2024-04-04 22:35:33 +00:00
Boyuan Feng
64d743044d Add inline constraints to non-strict exported program (#123017)
Summary: This PR reduces the difference between strict and non-strict exported program by supporting inline_constraints for non-strict exported program,

Test Plan: CI

Differential Revision: D55547830

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123017
Approved by: https://github.com/angelayi
2024-04-02 18:16:16 +00:00
PyTorch MergeBot
3beb9d85a6 Revert "Add non strict inline constraints and runtime assertions to non-strict exported program (#122722)"
This reverts commit b693fff5d7.

Reverted https://github.com/pytorch/pytorch/pull/122722 on behalf of https://github.com/BoyuanFeng due to This breaks torchrec.distributed.tests.test_pt2.TestPt2: test_kjt__getitem__ ([comment](https://github.com/pytorch/pytorch/pull/122722#issuecomment-2026078351))
2024-03-28 20:42:35 +00:00
Boyuan Feng
b693fff5d7 Add non strict inline constraints and runtime assertions to non-strict exported program (#122722)
This PR reduces the difference between strict and non-strict exported program by

- Support `inline_constraints` for non-strict exported program
- Add runtime assertions for range constraints to non-strict exported program

After this PR, the following unit tests are no longer `expectedFailureNonStrict`:
- test_automatic_constrain_size
- test_export_with_inline_constraints
- test_redundant_asserts
- test_constrain_size_with_constrain_value
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122722
Approved by: https://github.com/pianpwk
2024-03-27 21:20:03 +00:00
Zhengxu Chen
0465a90b00 [export][reland] Fix unflattened submodule ordering. (#122341) (#122507)
Summary:

Make sure the order of submodules is the same as the original eager module.

bypass-github-export-checks

Test Plan: buck test mode/opt caffe2/test:test_export -- -r test_unflatten_submodule_ordering

Differential Revision: D55251277

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122507
Approved by: https://github.com/tugsbayasgalan
2024-03-25 15:22:01 +00:00
Pian Pawakapan
3f99306452 [export] Remove from_export flag (#122500)
Summary: The flag from_export was incorrectly included in a previous diff (https://www.internalfb.com/diff/D54314379) - it was intended for helping with ExportedProgram verification, but was no longer needed in the final implementation.

Test Plan: Changes no functionality, test/export already covers everything

Differential Revision: D55205857

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122500
Approved by: https://github.com/avikchaudhuri, https://github.com/zhxchen17
2024-03-22 22:55:14 +00:00
Zhengxu Chen
d1e8b97387 [export] Log module hierarchy. (#121970)
Summary:
We can also log the module hierarchy in the following format:
```
:ToplevelModule
sparse:SparshArch
dense:DenseArch
```
So that we can have more information recorded about model's identity.

Test Plan: CI

Differential Revision: D54921097

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121970
Approved by: https://github.com/angelayi
2024-03-20 18:59:42 +00:00
Pian Pawakapan
7832efb242 [export] skip nn_module_stack verifier for non-fx.GraphModule modules (#122210)
Downstream users of torch.export may have different module classes (e.g. LoweredBackendModule), which cannot be checked for metadata in the same way. Add lines to skip this for non-fx.GraphModule modules.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122210
Approved by: https://github.com/angelayi, https://github.com/zhxchen17
2024-03-20 07:40:48 +00:00
Pian Pawakapan
c5ffebebab [export] allow Dim(1,2) for export dynamic shapes (v2 after revert) (#121910)
Creating this after [PR](https://github.com/pytorch/pytorch/pull/121642) got reverted.

Current dynamic shapes implementation fixes lower range of Dims to be 2 for analysis, but allows 0/1 shapes during runtime. This leads to failures when initializing Dim(1,2). This PR sets the lower bound to 0, and avoids erroring out when conflicting with the generated (2, maxsize) constraint during analysis.

Also resolves a derived dim constraints issue with the following code:
```
class Bar(torch.nn.Module):
    def forward(self, x, y):
        return x + y[1:]

dx = Dim("dx", min=1, max=3)
ep = export(
    Bar(),
    (torch.randn(2, 2), torch.randn(3, 2)),
    dynamic_shapes=({0: dx, 1: None}, {0: dx+1, 1: None})
)
print(ep.range_constraints)
```

In main:
```
{s0: ValueRanges(lower=2, upper=3, is_bool=False), s0 + 1: ValueRanges(lower=3, upper=4, is_bool=False)}
```

This PR:
```
{s0: ValueRanges(lower=1, upper=3, is_bool=False), s0 + 1: ValueRanges(lower=2, upper=4, is_bool=False)}
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121910
Approved by: https://github.com/avikchaudhuri, https://github.com/zhxchen17
2024-03-19 19:08:05 +00:00
Pian Pawakapan
3bd38928ba [export] Improve consistency for nn_module_stack metadata, add checks to _trace.py (#120661)
We would like to improve consistency for nn_module_stack metadata in torch.export.

This PR ensures that all tests in test/export/test_export.py has the following constraints:
- Remove nn_module_stack for all placeholder & output nodes, for all modules and submodules
- Ensure nn_module_stack is present for all other node types for the top-level module (there is still an issue with torch.cond submodules having empty fields)
- Add these checks to _export() in _trace.py (we would add this in the Verifier, but downstream apps construct ExportedPrograms separate from _export(), and metadata may not be maintained there)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120661
Approved by: https://github.com/avikchaudhuri
2024-03-16 21:44:52 +00:00
angelayi
ef25d83a62 [export] Add serialization support for tokens (#121552)
Differential Revision: [D54906766](https://our.internmc.facebook.com/intern/diff/D54906766)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121552
Approved by: https://github.com/zhxchen17
2024-03-15 16:15:11 +00:00
PyTorch MergeBot
bf7ac4ddf7 Revert "[export] allow Dim(1,2) for export dynamic shapes (#121642)"
This reverts commit a8dcbf2749.

Reverted https://github.com/pytorch/pytorch/pull/121642 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/121642#issuecomment-1996121710))
2024-03-13 23:51:20 +00:00