pytorch/torch/_export
Pian Pawakapan c20bc18d59 [export] allow static constraints in dynamic_shapes (#121860)
This PR allows users to specify int values for dimensions in dynamic_shapes as well as None, for example:

```
class Foo(torch.nn.Module):
    def forward(self, x, y, z):
        ...

    foo = Foo()
    inputs = (torch.randn(4, 6), torch.randn(5, 4), torch.randn(3, 3))

for dynamic_shapes in [
    None
    ((4, 6), (5, 4), (3, 3)),
    ((None, 6), None, {0: 3, 1: 3})
]:
    _ = export(foo, inputs, dynamic_shapes=dynamic_shapes)
```

All of the above should produce the same ExportedProgram.

This is done by temporarily creating a static dim constraint during analysis, where vr.lower == vr.upper. These constraints are then deleted during _process_constraints(), and do not show up in the final ExportedProgram's range_constraints.

Additionally, export() will also fail if the shapes are mis-specified, for example:
```
_ = export(foo, inputs, dynamic_shapes=((5, None), None, None))
```
leads to `torch._dynamo.exc.UserError: Static shape constraint of 5 does not match input size of 4, for L['x'].size()[0]`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121860
Approved by: https://github.com/avikchaudhuri
2024-03-21 16:59:59 +00:00
..
db Enable local_partial_types (#118467) 2024-01-28 13:38:22 +00:00
pass_infra [export] ExportPassBase + view_copy pass (#100000) 2023-04-26 21:01:25 +00:00
passes [export] Improve consistency for nn_module_stack metadata, add checks to _trace.py (#120661) 2024-03-16 21:44:52 +00:00
serde Fix GraphModuleDeserializer (#122342) 2024-03-21 02:27:39 +00:00
__init__.py make _process_dynamic_shapes an implementation detail (#121713) 2024-03-13 08:33:00 +00:00
error.py Add experimental export() API (#100034) 2023-04-28 06:12:59 +00:00
exported_program.py [export] Remove CallSpec (#117671) 2024-02-08 17:19:03 +00:00
non_strict_utils.py [export] allow static constraints in dynamic_shapes (#121860) 2024-03-21 16:59:59 +00:00
pass_base.py [HigherOrderOp] change signature of map_impl (#117161) 2024-01-13 02:50:46 +00:00
utils.py fix accidental specialization with faketensor input checks (#121460) 2024-03-08 08:02:37 +00:00
verifier.py [sigmoid] Clean up serialization API. (#122102) 2024-03-20 03:45:36 +00:00
wrappers.py [export] Use forward hooks to capture module signatures. (#120468) 2024-02-27 17:44:06 +00:00