Commit Graph

142 Commits

Author SHA1 Message Date
Jiashen Cao
10d2373abd Add a registry for GraphModuleSerializer (#126550)
This PR adds a registration function and a global registry for GraphModuleSerializer. After this PR, custom serialization methods can be done through registration instead of subclassing for ease of maintenance.

## Changes
- Add a test case where it injects custom op to test serialization.
- Add custom op handler
- Change allowed op for verifier
Co-authored-by: Zhengxu Chen <zhxchen17@outlook.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126550
Approved by: https://github.com/zhxchen17
2024-05-29 03:12:48 +00:00
Angela Yi
cb6ef68caa Propagate tokens in aotautograd (#127028)
Test Plan: `buck run mode/dev-nosan //aimp/experimental/pt2:pt2_export -- --model-entity-id 938593492 --output /tmp/938593492.zip --use-torchrec-eager-mp --use-manifold`

Differential Revision: D57750072

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127028
Approved by: https://github.com/tugsbayasgalan
2024-05-24 03:23:17 +00:00
Jiashen Cao
ac1f0befcf Remove redundant serialization code (#126803)
After https://github.com/pytorch/pytorch/pull/123308, we no longer need separate serialization path to handle different types that exist in the nn_module metadata. This PR cleans up the redundant code.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126803
Approved by: https://github.com/angelayi
2024-05-22 03:14:17 +00:00
PyTorch MergeBot
f89500030b Revert "Remove redundant serialization code (#126249)"
This reverts commit aab448e381.

Reverted https://github.com/pytorch/pytorch/pull/126249 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it is failing sigmoid/frontend:serialization_test internally ([comment](https://github.com/pytorch/pytorch/pull/126249#issuecomment-2118233656))
2024-05-17 19:19:02 +00:00
Jiashen Cao
aab448e381 Remove redundant serialization code (#126249)
After https://github.com/pytorch/pytorch/pull/123308, we no longer need separate serialization path to handle different types that exist in the `nn_module` metadata. This PR cleans up the redundant code.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126249
Approved by: https://github.com/angelayi
2024-05-16 19:22:20 +00:00
Zhengxu Chen
3ccf107f01 [export] remove upgrader. (#125625)
Summary: talked to executorch team, seems we can remove this now.

Test Plan: CI

Differential Revision: D57013451

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125625
Approved by: https://github.com/larryliu0820
2024-05-09 16:30:12 +00:00
angelayi
0de9ce9bb3 [export] Fix serialization of empty torch artifact (#125542)
A previous PR added support for serializing/deserializing example inputs, but this fails when `example_inputs` is none.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125542
Approved by: https://github.com/pianpwk, https://github.com/BoyuanFeng, https://github.com/ydwu4
2024-05-07 15:54:45 +00:00
Zhengxu Chen
12a69afa6d [export] Fix deserializer node meta handling. (#125454)
Summary: The code seems not needed because serializer shouldn't make any meaningful decision about what goes to node metadata.

Test Plan: CI

Differential Revision: D56918543

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125454
Approved by: https://github.com/angelayi
2024-05-03 16:51:08 +00:00
Aaron Gokaslan
3e1fb96964 [BE]: RUF018 - ban assignment in assert (#125125)
Ban assignment inside of assert. Python code should ideally not break with assertions disabled. Adds a ruff lint rule to enforce this.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125125
Approved by: https://github.com/ezyang
2024-04-28 21:41:36 +00:00
Edward Z. Yang
7aa6bd7fa0 Refactor all top level usages of record_shapeenv_event to ShapeEnv class (#123735)
This ensures that first argument to record_shapeenv_event is a ShapeEnv
so we can appropriately short circuit when recording is not in progress.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123735
Approved by: https://github.com/ysiraichi, https://github.com/zou3519, https://github.com/albanD
2024-04-27 20:36:40 +00:00
Aaron Gokaslan
2f3b0befed [BE]: Apply ruff FURB 118. (#124743)
Replaces various lambdas with operator.itemgetter which is more efficient (as it's a builtin function). Particularly useful for when lambdas are used as 'key' functions.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124743
Approved by: https://github.com/albanD, https://github.com/malfet
2024-04-26 14:34:52 +00:00
PyTorch MergeBot
e607dc8abb Revert "Refactor all top level usages of record_shapeenv_event to ShapeEnv class (#123735)"
This reverts commit 87bec7db4e.

Reverted https://github.com/pytorch/pytorch/pull/123735 on behalf of https://github.com/jeanschmidt due to Breaking internal signals, more info in D56587358 ([comment](https://github.com/pytorch/pytorch/pull/123735#issuecomment-2078695590))
2024-04-26 06:10:58 +00:00
angelayi
724f8dd8c5 [export] Serialize empty list based on argument type (#123748)
Fixes https://github.com/pytorch/pytorch/issues/123480

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123748
Approved by: https://github.com/zhxchen17
2024-04-25 23:03:27 +00:00
angelayi
84fb96130f [export] Fix check for optional tensor returns (#123739)
Sorry for the delay! Addressing issue in https://www.internalfb.com/diff/D55455000?dst_version_fbid=1599488570890576&transaction_fbid=776042617791884
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123739
Approved by: https://github.com/zhxchen17
2024-04-25 20:51:26 +00:00
Edward Z. Yang
87bec7db4e Refactor all top level usages of record_shapeenv_event to ShapeEnv class (#123735)
This ensures that first argument to record_shapeenv_event is a ShapeEnv
so we can appropriately short circuit when recording is not in progress.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123735
Approved by: https://github.com/ysiraichi, https://github.com/zou3519, https://github.com/albanD
ghstack dependencies: #124310, #124314, #124316, #124394, #124739, #124782, #124785
2024-04-25 14:02:48 +00:00
Pian Pawakapan
10b9d4d19c [export] handle Dim.lower = 0, 1 for ep.run_decompositions() (#123602)
Summary:
With pre-dispatch export and ep.run_decompositions(), range constraints are updated through looking at ShapeEnv.var_to_range. However the lower bounds on these may be incorrect - analysis on un-specialized symbols are done with lower bounds of 2, which mismatch with user-specified bounds (may be 0, 1).

This updates `_get_updated_range_constraints()` to use the old range constraints if possible.

Test Plan: Existing pre-dispatch/dynamic shapes test case.

Differential Revision: D55899872

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123602
Approved by: https://github.com/tugsbayasgalan
2024-04-19 21:29:36 +00:00
angelayi
74bedbb9e1 [export] Serialize rational symint ranges (#123884)
Some symints result in rational ranges like 10/3 which runs into an error ([example](https://www.internalfb.com/intern/everpaste/?handle=GMG2AxkeoFUrh-UDAFcE8pKPgjoUbsIXAAAB)).

Ed will eventually get rid(?) of these rational ranges but as a workaround export can just clamp the results during serialization time
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123884
Approved by: https://github.com/zhxchen17
2024-04-18 18:20:11 +00:00
Pian Pawakapan
90d1720861 [export] Restore original placeholder names (part 3: constant input de/serialization) (#123590)
Summary:
note: breaking the original diff D55225818 into 3 parts (top-level renaming, higher-order-op subgraphs, constant input de/serialization) because of its size.

Stacked PR to restore original names to placeholder nodes, replacing the default names arg0_1, arg1_1, ...

This PR supports constant argument placeholder (e.g. forward(self, x, y=1)) names and de/serialization, by adding a name field for ConstantArguments in the graph signature, and ConstantInputSpec in the input specs for serialization.

Test Plan: verification checks on placeholder names for all export() calls, unit test in test/export/test_export.py

Differential Revision: D55506949

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123590
Approved by: https://github.com/angelayi, https://github.com/zhxchen17
2024-04-15 19:09:41 +00:00
Zhengxu Chen
951582949b [export] Enforce final classes in serialization. (#123861)
Summary: as title, these are private API and not meant to be used across repos.

Test Plan: CI

Differential Revision: D56027954

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123861
Approved by: https://github.com/tugsbayasgalan
2024-04-12 15:44:56 +00:00
Pian Pawakapan
42c2a5477c [export] nn_module_stack to return class name str (#123308)
Previously, `node.meta["nn_module_stack"]` had type `Dict[str, Tuple[str, class]]` when exported, and later `Dict[str, Tuple[str, str]]` after de/serialization. This PR changes it to consistently be `Dict[str, Tuple[str, str]]` for round-trippability, i.e.
```
{..., 'L__self___conv': ('conv', 'torch.nn.modules.conv.Conv2d')}
```

`source_fn_stack` is left untouched in this PR.

note: the `Union[type, str]` type annotations in ONNX are because ONNX goes through both `export.export()` and `_dynamo.export()` (which still has the original `Dict[str, Tuple[str, class]]` format). nn_module_stack from `export.export()` should consistently have the new format, and we verify/test for that in `_trace.py`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123308
Approved by: https://github.com/zhxchen17, https://github.com/thiagocrepaldi
2024-04-05 21:48:22 +00:00
Pian Pawakapan
d7f23f6826 [export] Restore original placeholder names (part 1: top-level renaming) (#122904)
Summary:
This PR restores original names to placeholder nodes, replacing the default names arg0_1, arg1_1, and so on.

User inputs now follow the signature of mod.forward(), for example forward(x, y) produces nodes x, y. If the tensors are nested in dictionaries, lists, tuples, or dataclasses, the names are a concatenation of the path to the tensor, e.g. x = {'a': torch.randn(4), 'b': [torch.randn(4), torch.randn(4)]} produces nodes x_a, x_b_0, x_b_1.

Parameters, buffers, constants, and custom objects follow the FQN of the object, prefixed by "p", "b", "c", and "obj" respectively. For example, self.bar.l0.weight gets you p_bar_l0_weight.
Effect tokens are named token_1, token_2, and so on, since they are not grounded in model inputs or named attributes.

note: breaking the original diff into 3 parts (top-level renaming, higher-order-op subgraphs, constant input de/serialization) because of its size.

Examples:
```python
# params, buffers, constants, inputs, torch.cond

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, p_l0_weight: "f32[4, 4]", p_l0_bias: "f32[4]", c_alpha: "f32[4]", b_beta: "f32[4]", x_0_a: "f32[4, 4]", y: "f32[4, 4]"):
            # No stacktrace found for following nodes
            mul: "f32[4, 4]" = torch.ops.aten.mul.Tensor(x_0_a, x_0_a)
            t: "f32[4, 4]" = torch.ops.aten.t.default(p_l0_weight);  p_l0_weight = None
            addmm: "f32[4, 4]" = torch.ops.aten.addmm.default(p_l0_bias, y, t);  p_l0_bias = y = t = None
            return addmm

# model code

class Bar(torch.nn.Module):
    def forward(self, x):
        return x * x
class Foo(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.bar = Bar()
        self.l0 = torch.nn.Linear(4, 4)
        self.alpha = torch.randn(4)
        self.register_buffer('beta', torch.randn(4))
    def forward(self, x, y):
        x = x[0]['a']
        mul = self.bar(x)
        z1 = self.l0(y)
        return z1

# custom objects, dataclasses, tokens, constant inputs

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, token_1: "f32[0]", obj_attr, data_x: "f32[4, 4]", data_y: "f32[4, 4]", mode):
            # No stacktrace found for following nodes
            mul: "f32[4, 4]" = torch.ops.aten.mul.Scalar(data_x, 30);  data_x = None
            div: "f32[4, 4]" = torch.ops.aten.div.Tensor_mode(data_y, 1.0, rounding_mode = 'floor');  data_y = None
            add: "f32[4, 4]" = torch.ops.aten.add.Tensor(mul, div);  mul = div = None
            with_effects = torch._higher_order_ops.effects.with_effects(token_1, torch.ops._TorchScriptTesting.takes_foo.default, obj_attr, add);  token_1 = obj_attr = add = None
            getitem: "f32[0]" = with_effects[0]
            getitem_1: "f32[4, 4]" = with_effects[1];  with_effects = None
            return (getitem, getitem_1)

# model code

class Foo(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
    def forward(self, data, a=1.0, mode="floor"):
        x = self.attr.add_tensor(data.x) + torch.div(data.y, a, rounding_mode=mode)
        x = torch.ops._TorchScriptTesting.takes_foo(self.attr, x)
        return x

dataclass
class DataClass:
    x: Tensor
    y: Tensor
register_dataclass_as_pytree_node(
    DataClass,
    serialized_type_name="test.DataClass"
)

args = (DataClass(x=torch.randn(4, 4), y=torch.randn(4, 4)), )
kwargs = {'mode': 'floor'}
ep = torch.export.export(Foo(), args, kwargs, strict=False)

```

Test Plan: verification checks on placeholder names for all export() calls, unit test in test/export/test_export.py

Differential Revision: D55456418

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122904
Approved by: https://github.com/angelayi, https://github.com/thiagocrepaldi
2024-04-05 18:56:00 +00:00
Josh Fromm
0c8a165b43 [Export] Improve metadata and output parsing during deserialization (#122793)
Summary:
Deserialization of metadata could encounter a bug where commas are used in valid metadata names. This specifically occurs when a split of a `torch.nn.Sequential` stack is used, but may have other possible triggers. Because the deserialization relies on a comma based string split, such names trigger an error. This change uses a simple regular expression to ignore commas within parentheses to avoid the issue.

I add a test that constructs one such problematic sequential stack and show that it can be properly round-tripped with the improved splitting.

Similarly, deserialization could fail when outputs are not a tensor type. Although such outputs like None or constants are not very useful, they do show up in graphs and export should be able to support them. This change improves output node parsing and adds a corresponding test.

Test Plan: buck test //caffe2/test:test_export -- TestSerialize

Differential Revision: D55391674

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122793
Approved by: https://github.com/zhxchen17
2024-04-05 00:25:37 +00:00
angelayi
ed457c7dbe [export] Add torch_fn (#122693)
This PR adds a new metadata, `torch_fn` which is meant to replace `source_fn_stack` as `source_fn_stack` is not entirely well defined between strict/nonstrict. Previous discussion [here](https://docs.google.com/document/d/1sPmmsmh6rZFWH03QBOe49MaXrQkP8SxoG8AOMb-pFk4/edit#heading=h.anmx9qknhvm).

`torch_fn` represents the torch function that a particular aten operator came from. For example, `torch.nn.Linear` goes down to the `torch.nn.functional.linear` at the `__torch_function__` layer, and then `aten.t/aten.addmm` in the `__torch_dispatch__` layer. So the nodes `aten.t/aten.addmm` will now have the `torch_fn` metadata containing the `torch.nn.functional.linear`.

The `torch_fn` metadata is a tuple of 2 strings: a unique identifier for each torch function call, and the actual torch function `f"{fn.__class__}.{fn.__name__}"`. The purpose of the first value is to distinguish between 2 consecutive calls to the same function. For example, if we had 2 calls to `torch.nn.Linear`, the nodes and corresponding metadata would look something like:
```
aten.t - ("linear_1", "builtin_function_or_method.linear"),
aten.addmm - ("linear_1", "builtin_function_or_method.linear"),
aten.t - ("linear_2", "builtin_function_or_method.linear"),
aten.addmm - ("linear_2", "builtin_function_or_method.linear"),
```

Higher order ops -- currently we can get the torch_fn metadata for nodes within the HOO's subgraph, but after retracing, this becomes the `(cond, higher_order_op.cond)` :( This is because `fx_traceback.set_current_meta` points to the cond node in the toplevel graph, rather than the original node in the subgraph. I think this is because `fx.Interpreter` does not go into the cond subgraphs. (will discuss with Yidi more ab this)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122693
Approved by: https://github.com/tugsbayasgalan
2024-03-30 06:47:15 +00:00
Josh Fromm
0c47f8028e Keep example_inputs when saving and loading ExportedProgram (#122618)
Summary:
`torch.export` is a powerful tool for creating a structured and shareable package from arbitrary pytorch code. One great use case of `torch.export` is sharing models or subgraphs in a way that allows results to be easily replicated. However, in the current implementation of `export`, the `example_inputs` field is thrown out. When trying to replicate bugs, benchmarks, or behaviors, losing the original input shapes and values makes the process much messier.

This change adds saving and loading for the `example_inputs` attribute of an `ExportedProgram` when using `torch.export.save` and `torch.export.load`. This simple addition makes `ExportedPrograms`s a fantastic tool for performance and accuracy replication. For example, with this change we enable the following workflow:

```
# Script to create a reproducible accuracy issue with my model.
kwargs = {"fastmath_mode": True}
exp_program = export(my_model, sample_inputs, kwargs)
result = exp_program.module()(*sample_inputs, **kwargs)
# Uhoh, I dont like that result, lets send the module to a colleague to take a look.
torch.export.save(exp_program, "my_model.pt2")
```

My colleague can then easily reproduce my results llike so:

```
# Script to load and reproduce results from a saved ExportedProgram.
loaded_program = torch.export.load("my_model.pt2")
# The following line is enabled by this Diff, we pull out the arguments
# and options that caused the issue.
args, kwargs = loaded_program.example_inputs
reproduced_result = loaded_program.module()(*args, **kwargs)
# Oh I see what happened here, lets fix it.
```

Being able to share exact inputs and arguments makes `ExportedPrograms` much
more clean and powerful with little downside. The main potential issue with this change
is that it does slightly increase the size of saved programs. However, the size of
inputs will be much smaller than parameters in most cases. I am curious to hear
discussion on saved file size though.

The deserialization of `example_inputs` is currently implemented as `Optional`. Although this wont effect users of `export.save` and `export.load`, it does give backwards compatibility to any direct users of `serialize` and `deserialize`.

Test Plan:
This diff includes a new test which exercises the save / load flow with multiple args and kwargs.

```
buck test //caffe2/test:test_export -- TestSerialize
```

Differential Revision: D55294614

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122618
Approved by: https://github.com/zhxchen17
2024-03-26 03:32:44 +00:00
Pian Pawakapan
3f99306452 [export] Remove from_export flag (#122500)
Summary: The flag from_export was incorrectly included in a previous diff (https://www.internalfb.com/diff/D54314379) - it was intended for helping with ExportedProgram verification, but was no longer needed in the final implementation.

Test Plan: Changes no functionality, test/export already covers everything

Differential Revision: D55205857

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122500
Approved by: https://github.com/avikchaudhuri, https://github.com/zhxchen17
2024-03-22 22:55:14 +00:00
Sherlock Huang
ae913175c3 Fix GraphModuleDeserializer (#122342)
Summary: self.constants is used in self.deserialize_signature()

Test Plan: CI

Differential Revision: D55152971

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122342
Approved by: https://github.com/zhxchen17
2024-03-21 02:27:39 +00:00
Zhengxu Chen
f8565c4a28 [sigmoid] Clean up serialization API. (#122102)
Summary: Entirely remove the old serializer code to avoid further confusion and code bloat.

Test Plan: CI

Reviewed By: SherlockNoMad

Differential Revision: D54857118

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122102
Approved by: https://github.com/tugsbayasgalan
2024-03-20 03:45:36 +00:00
Pian Pawakapan
c5ffebebab [export] allow Dim(1,2) for export dynamic shapes (v2 after revert) (#121910)
Creating this after [PR](https://github.com/pytorch/pytorch/pull/121642) got reverted.

Current dynamic shapes implementation fixes lower range of Dims to be 2 for analysis, but allows 0/1 shapes during runtime. This leads to failures when initializing Dim(1,2). This PR sets the lower bound to 0, and avoids erroring out when conflicting with the generated (2, maxsize) constraint during analysis.

Also resolves a derived dim constraints issue with the following code:
```
class Bar(torch.nn.Module):
    def forward(self, x, y):
        return x + y[1:]

dx = Dim("dx", min=1, max=3)
ep = export(
    Bar(),
    (torch.randn(2, 2), torch.randn(3, 2)),
    dynamic_shapes=({0: dx, 1: None}, {0: dx+1, 1: None})
)
print(ep.range_constraints)
```

In main:
```
{s0: ValueRanges(lower=2, upper=3, is_bool=False), s0 + 1: ValueRanges(lower=3, upper=4, is_bool=False)}
```

This PR:
```
{s0: ValueRanges(lower=1, upper=3, is_bool=False), s0 + 1: ValueRanges(lower=2, upper=4, is_bool=False)}
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121910
Approved by: https://github.com/avikchaudhuri, https://github.com/zhxchen17
2024-03-19 19:08:05 +00:00
PyTorch MergeBot
d56ab7b020 Revert "[torch export][serialize] create a more compact stacktrace format for serialization (#121675)"
This reverts commit eae89138d8.

Reverted https://github.com/pytorch/pytorch/pull/121675 on behalf of https://github.com/jeanschmidt due to It seems that this PR broke lint jobs, I am reverting to confirm if this is the case ([comment](https://github.com/pytorch/pytorch/pull/121675#issuecomment-2007919486))
2024-03-19 19:02:09 +00:00
Wenting Wang
eae89138d8 [torch export][serialize] create a more compact stacktrace format for serialization (#121675)
Summary:
- we want fx nodes' stack trace format to be backward compatible and same as before in the program we export
- however in the serialized format, we would want to show a more compact stack_trace format, otherwise the nodes attributes are dominated by stack traces
- the diff implements the minimal in serialization process to dedupe node stack traces by resorting to a fileinfo_list and a filename_to_abbrev map, so we can use index to represent filenames, use lineno to represent lines.

Test Plan:
# llm
base on D54497918
```
buck2 run @//mode/dev-nosan fbcode//executorch/examples/models/llama2:export_llama -- -c ~/stories110M.pt -p ~/params.json
```
set up breakpoint after serialization/deserialization
- serialize
```
(Pdb) v_meta = [n.meta for n in exported_program.graph_module.graph.nodes]
(Pdb) paste_client.create_phabricator_paste_object(paste_creation_client_id=1093956601162697, content=str(v_meta)).number
1193647450
(Pdb) json_program = json.dumps(_dataclass_to_dict(serialized_graph.co_fileinfo_ordered_list),cls=EnumEncoder)
(Pdb) json_bytes = json_program.encode('utf-8')
(Pdb) paste_client.create_phabricator_paste_object(paste_creation_client_id=1093956601162697, content=str(json_bytes)).number
1193604333
(Pdb) sys.getsizeof(json_bytes)
3846
(Pdb) compressed_bytes = zstd.ZstdCompressor().compress(json_bytes)
(Pdb) sys.getsizeof(compressed_bytes)
1139
```
in P1193647450 (before serialization), search for `stack_trace`
in P1193604333 (after serialization), search for `stack_trace` and `co_fileinfo_ordered_list`

[note: didn't do compression in this diff since the size is pretty small and it adds complexity if we do compression]
- deserialize
```
(Pdb) v_meta = [n.meta for n in deserialized_exported_program.graph_module.graph.nodes]
(Pdb) paste_client.create_phabricator_paste_object(paste_creation_client_id=1093956601162697, content=str(v_meta)).number
1193629435
```
in P1193629435, search for `stack_trace`

# ads

Differential Revision: D54654443

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121675
Approved by: https://github.com/angelayi
2024-03-19 17:58:12 +00:00
Pian Pawakapan
3bd38928ba [export] Improve consistency for nn_module_stack metadata, add checks to _trace.py (#120661)
We would like to improve consistency for nn_module_stack metadata in torch.export.

This PR ensures that all tests in test/export/test_export.py has the following constraints:
- Remove nn_module_stack for all placeholder & output nodes, for all modules and submodules
- Ensure nn_module_stack is present for all other node types for the top-level module (there is still an issue with torch.cond submodules having empty fields)
- Add these checks to _export() in _trace.py (we would add this in the Verifier, but downstream apps construct ExportedPrograms separate from _export(), and metadata may not be maintained there)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120661
Approved by: https://github.com/avikchaudhuri
2024-03-16 21:44:52 +00:00
Wenting Wang
dfc5e9325d format caffe2/torch/_export/serde/serialize.py (#121670)
Summary: black caffe2/torch/_export/serde/serialize.py

Test Plan: tests

Differential Revision: D54654847

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121670
Approved by: https://github.com/angelayi
2024-03-15 21:30:16 +00:00
angelayi
ef25d83a62 [export] Add serialization support for tokens (#121552)
Differential Revision: [D54906766](https://our.internmc.facebook.com/intern/diff/D54906766)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121552
Approved by: https://github.com/zhxchen17
2024-03-15 16:15:11 +00:00
Zhengxu Chen
c409292197 [sigmoid] Use deserializer from oss. (#121839)
Summary:
Old path:
thrift -> thrift deserializer -> graph module.
new path:
thrift -> python dataclass -> oss deserializer -> graph_module

Test Plan:
CI
buck2 test mode/dev-nosan caffe2/test/inductor/fb:test_aot_inductor_pt2_inference

Reviewed By: SherlockNoMad

Differential Revision: D54855251

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121839
Approved by: https://github.com/angelayi
2024-03-14 18:38:58 +00:00
PyTorch MergeBot
bf7ac4ddf7 Revert "[export] allow Dim(1,2) for export dynamic shapes (#121642)"
This reverts commit a8dcbf2749.

Reverted https://github.com/pytorch/pytorch/pull/121642 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/121642#issuecomment-1996121710))
2024-03-13 23:51:20 +00:00
Pian Pawakapan
a8dcbf2749 [export] allow Dim(1,2) for export dynamic shapes (#121642)
Current dynamic shapes implementation fixes lower range of Dims to be 2 for analysis, but allows 0/1 shapes during runtime. This leads to failures when initializing Dim(1,2). This PR sets the lower bound to 0, and avoids erroring out when conflicting with the generated (2, maxsize) constraint during analysis.

Also resolves a derived dim constraints issue with the following code:
```
class Bar(torch.nn.Module):
    def forward(self, x, y):
        return x + y[1:]

dx = Dim("dx", min=1, max=3)
ep = export(
    Bar(),
    (torch.randn(2, 2), torch.randn(3, 2)),
    dynamic_shapes=({0: dx, 1: None}, {0: dx+1, 1: None})
)
print(ep.range_constraints)
```

In main:
```
{s0: ValueRanges(lower=2, upper=3, is_bool=False), s0 + 1: ValueRanges(lower=3, upper=4, is_bool=False)}
```

This PR:
```
{s0: ValueRanges(lower=1, upper=3, is_bool=False), s0 + 1: ValueRanges(lower=2, upper=4, is_bool=False)}
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121642
Approved by: https://github.com/avikchaudhuri
2024-03-13 22:59:07 +00:00
Sherlock Huang
dd568f4207 [Export, AOTInductor] Populate ShapeEnv's var_to_val during deserialization (#121759)
Summary:
Deserialization didn't populate ShapeEnv's `var_to_val` field properly, and AOTInductor is relying on this field to compile dynamic shape properly.
As a result, when AOTI failed at compiling a deserialized ExportedProgram.

Test Plan: buck2 test  mode/dev-nosan caffe2/test/inductor/fb:test_aot_inductor_pt2_inference

Differential Revision: D54559494

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121759
Approved by: https://github.com/avikchaudhuri
2024-03-13 21:28:25 +00:00
Zhengxu Chen
76f1461892 [export] Serialize union fields with single entry dict. (#121263) (#121337)
Summary:

remove "$type" and "$value" fields, instead only serialize as {type: value} for union fields directly.

bypass-github-export-checks

Test Plan: CI

Differential Revision: D54600943

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121337
Approved by: https://github.com/tugsbayasgalan
2024-03-07 21:24:28 +00:00
PyTorch MergeBot
23fb37fa41 Revert "[export] Serialize union fields with single entry dict. (#121263)"
This reverts commit 7feabe9b73.

Reverted https://github.com/pytorch/pytorch/pull/121263 on behalf of https://github.com/osalpekar due to A large number of inductor benchmarking jobs failing starting this PR. See for details: 7feabe9b73 ([comment](https://github.com/pytorch/pytorch/pull/121263#issuecomment-1981680049))
2024-03-06 19:58:55 +00:00
Zhengxu Chen
7feabe9b73 [export] Serialize union fields with single entry dict. (#121263)
Summary: remove "$type" and "$value" fields, instead only serialize as {type: value} for union fields directly.

Test Plan: CI

Differential Revision: D54553770

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121263
Approved by: https://github.com/tugsbayasgalan
2024-03-06 18:16:16 +00:00
Shruthi GN
ef9e89984c [pytorch] Support output types that are non tensors (#120804)
Summary:
per title
This is needed because some modules return None and non tensors as output

Test Plan: sandcastle?

Reviewed By: zhxchen17

Differential Revision: D54311609

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120804
Approved by: https://github.com/zhxchen17
2024-02-29 02:49:10 +00:00
Avik Chaudhuri
5472923998 derived dim (#118729)
With the current `Dim`-based dynamic shapes API for export, one can express that shapes of different input shapes must be equal by reusing the same `Dim`. However, non-trivial relationships between such input shapes cannot be expressed.

Recently we are seeing more and more examples of code that require this additional expressibility, e.g., where a pair of shapes might differ by one, or a shape might be double another (or simply even).

This PR introduces the concept of a "derived" `Dim`, i.e., a linear arithmetic expression over a `Dim`. By using a combination of `Dim`s and derived `Dim`s to specify input shapes, the desired relationships can be expressed naturally. E.g., a pair of shapes might be `dim` and `dim + 1`, or `dim` and `2*dim`, or even `2*dim` and `dim + 1`.

We extend the current infrastructure that translates `Dim`s to deprecated `dynamic_dim`-based constraints to work with derived `Dim`s. As usual, we raise constraint violation errors when shape guards cannot be verified given a dynamic shapes spec; suggest fixes; and raise runtime errors when future inputs violate the spec.

Importantly, some guards that used to cause forced specializations in the constraint solver because they were deemed "too complex" now do not do so, because they can now be specified as constraints. Since this was what motivated the introduction of a `disable_constraint_solver` flag to some internal APIs, we may not need that flag any more.

Note that shapes of placeholders in exported programs can now contain symbolic expressions and not just symbols.

Differential Revision: D53254587

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118729
Approved by: https://github.com/ezyang
2024-02-28 19:48:32 +00:00
ydwu4
ac2ba7889d [export] turn on replace_set_grad_with_hop_pass in pre_dispatch (#119915)
This PR turns on replace_set_grad_with_hop_pass for pre_dispatch export. To do that, we need to propagate the meta-data from original submodule to the new higher order op and fix the names of nodes as is required by the _sig_to_specs pass.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119915
Approved by: https://github.com/tugsbayasgalan
ghstack dependencies: #119732, #119736, #119810, #119913, #119914
2024-02-17 02:18:35 +00:00
suo
8e029dc616 [export] fix tuple return with symints (#119829)
as title.

Differential Revision: [D53726648](https://our.internmc.facebook.com/intern/diff/D53726648/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119829
Approved by: https://github.com/zhxchen17, https://github.com/khabinov
2024-02-14 01:16:38 +00:00
suo
f15b517055 [export] suppress type error (#119720)
Differential Revision: [D53681243](https://our.internmc.facebook.com/intern/diff/D53681243/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119720
Approved by: https://github.com/kit1980, https://github.com/huydhn
2024-02-12 22:54:36 +00:00
suo
82248f0b1c [export] improve FakeTensor serialization (#119531)
Recently we made it possible to serialize ExportedPrograms with fake parameters/buffers/etc.

The serialization regime was kind of whacky; basically we serialized a stub and reassembled the FakeTensor using metadata that we had stashed elsewhere in the Graph state.

This was bad for a few reasons:
- Storing the metadata separately from the actual serialized object caused situations where you could have one but not the other. An example case is if you had a FakeTensor contained inside a TorchBind object—there was no obviously place to store the metadata for this. This actually happens—TensorQueue in fbgemm does this.
- It created an annoying cycle: we had to deserialize the Graph's tensor metadata in order to deserialize (potentially faked) constants, but we need constants in order to deserialize the Graph.

This fixes all that. The basic idea is to patch the reducer function for FakeTensor at serialization time, and serialize a copy of the FakeTensor metadata. We already are policing BC for the TensorMeta schema struct so it's not a net increase in the BC surface.

As a bonus, I fixed a weird bug with torchbind tracing where we were accidentally reinterpreting a torch.ScriptObject as a torch.ScriptModule (which was the root cause of some weird behavior @bahuang was seeing last week).

Differential Revision: [D53601251](https://our.internmc.facebook.com/intern/diff/D53601251/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119531
Approved by: https://github.com/zhxchen17
2024-02-12 19:28:08 +00:00
suo
5747ec24b4 [export] fix canonicalization for input mutations (#119533)
The comparison was off: user_input_mutation and buffer_mutation had the same numeric value, which led the comparison to move to the next element of the tuple and try to compare `None` to `spec.buffer_mutation.buffer_name`, which doesn't work. So make them different numbers.

Differential Revision: [D53601300](https://our.internmc.facebook.com/intern/diff/D53601300/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119533
Approved by: https://github.com/zhxchen17
2024-02-09 18:30:39 +00:00
angelayi
b181e52a8f [export] Support non-tensor tuple hoo outputs (#119402)
There's an internal custom op which has a None output, so when it becomes auto_functionalized, the HOO's output is (None, Tensor, Tensor, ...). This PR adds support for the None output, and any int/bool outputs from HOOs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119402
Approved by: https://github.com/suo, https://github.com/avikchaudhuri
2024-02-08 16:54:40 +00:00
Michael Suo
0e2330d84c fix lint (#119395)
Summary: as title

Test Plan: lint

Differential Revision: D53532399

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119395
Approved by: https://github.com/tugsbayasgalan, https://github.com/malfet
2024-02-07 19:06:41 +00:00
Michael Suo
f79ae7599a [export] fakify module state in nonstrict (#119297)
Summary:
Previously, we were not fakifying module state explicitly in the nonstrict path.

This led to errors when modules were constructed under a fake mode, since the user-provided fake mode was clashing with the one that we had constructed internally to fakify the inputs.

This fixes things to use a single fake mode for everything.

As a side effect, this raised the question of how we ought to serialize state_dicts/constants that might be fake tensors. Naively calling torch.save understandably explodes—so this diff piggybacks on our infra for doing this on meta["val"]. Open to revising this, I'm low confidence that it's the best way to do it.

Test Plan: unit tests

Differential Revision: D53484942

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119297
Approved by: https://github.com/tugsbayasgalan
2024-02-07 17:12:22 +00:00