This PR allows users to specify int values for dimensions in dynamic_shapes as well as None, for example:
```
class Foo(torch.nn.Module):
def forward(self, x, y, z):
...
foo = Foo()
inputs = (torch.randn(4, 6), torch.randn(5, 4), torch.randn(3, 3))
for dynamic_shapes in [
None
((4, 6), (5, 4), (3, 3)),
((None, 6), None, {0: 3, 1: 3})
]:
_ = export(foo, inputs, dynamic_shapes=dynamic_shapes)
```
All of the above should produce the same ExportedProgram.
This is done by temporarily creating a static dim constraint during analysis, where vr.lower == vr.upper. These constraints are then deleted during _process_constraints(), and do not show up in the final ExportedProgram's range_constraints.
Additionally, export() will also fail if the shapes are mis-specified, for example:
```
_ = export(foo, inputs, dynamic_shapes=((5, None), None, None))
```
leads to `torch._dynamo.exc.UserError: Static shape constraint of 5 does not match input size of 4, for L['x'].size()[0]`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121860
Approved by: https://github.com/avikchaudhuri
Creating this after [PR](https://github.com/pytorch/pytorch/pull/121642) got reverted.
Current dynamic shapes implementation fixes lower range of Dims to be 2 for analysis, but allows 0/1 shapes during runtime. This leads to failures when initializing Dim(1,2). This PR sets the lower bound to 0, and avoids erroring out when conflicting with the generated (2, maxsize) constraint during analysis.
Also resolves a derived dim constraints issue with the following code:
```
class Bar(torch.nn.Module):
def forward(self, x, y):
return x + y[1:]
dx = Dim("dx", min=1, max=3)
ep = export(
Bar(),
(torch.randn(2, 2), torch.randn(3, 2)),
dynamic_shapes=({0: dx, 1: None}, {0: dx+1, 1: None})
)
print(ep.range_constraints)
```
In main:
```
{s0: ValueRanges(lower=2, upper=3, is_bool=False), s0 + 1: ValueRanges(lower=3, upper=4, is_bool=False)}
```
This PR:
```
{s0: ValueRanges(lower=1, upper=3, is_bool=False), s0 + 1: ValueRanges(lower=2, upper=4, is_bool=False)}
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121910
Approved by: https://github.com/avikchaudhuri, https://github.com/zhxchen17
Current dynamic shapes implementation fixes lower range of Dims to be 2 for analysis, but allows 0/1 shapes during runtime. This leads to failures when initializing Dim(1,2). This PR sets the lower bound to 0, and avoids erroring out when conflicting with the generated (2, maxsize) constraint during analysis.
Also resolves a derived dim constraints issue with the following code:
```
class Bar(torch.nn.Module):
def forward(self, x, y):
return x + y[1:]
dx = Dim("dx", min=1, max=3)
ep = export(
Bar(),
(torch.randn(2, 2), torch.randn(3, 2)),
dynamic_shapes=({0: dx, 1: None}, {0: dx+1, 1: None})
)
print(ep.range_constraints)
```
In main:
```
{s0: ValueRanges(lower=2, upper=3, is_bool=False), s0 + 1: ValueRanges(lower=3, upper=4, is_bool=False)}
```
This PR:
```
{s0: ValueRanges(lower=1, upper=3, is_bool=False), s0 + 1: ValueRanges(lower=2, upper=4, is_bool=False)}
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121642
Approved by: https://github.com/avikchaudhuri
With the current `Dim`-based dynamic shapes API for export, one can express that shapes of different input shapes must be equal by reusing the same `Dim`. However, non-trivial relationships between such input shapes cannot be expressed.
Recently we are seeing more and more examples of code that require this additional expressibility, e.g., where a pair of shapes might differ by one, or a shape might be double another (or simply even).
This PR introduces the concept of a "derived" `Dim`, i.e., a linear arithmetic expression over a `Dim`. By using a combination of `Dim`s and derived `Dim`s to specify input shapes, the desired relationships can be expressed naturally. E.g., a pair of shapes might be `dim` and `dim + 1`, or `dim` and `2*dim`, or even `2*dim` and `dim + 1`.
We extend the current infrastructure that translates `Dim`s to deprecated `dynamic_dim`-based constraints to work with derived `Dim`s. As usual, we raise constraint violation errors when shape guards cannot be verified given a dynamic shapes spec; suggest fixes; and raise runtime errors when future inputs violate the spec.
Importantly, some guards that used to cause forced specializations in the constraint solver because they were deemed "too complex" now do not do so, because they can now be specified as constraints. Since this was what motivated the introduction of a `disable_constraint_solver` flag to some internal APIs, we may not need that flag any more.
Note that shapes of placeholders in exported programs can now contain symbolic expressions and not just symbols.
Differential Revision: D53254587
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118729
Approved by: https://github.com/ezyang
Summary: Exposes `dynamic_shapes` api at multiple levels so it's easier to replace the old API `dynamic_dim()` with the new API `Dim()`.
Test Plan: CI
Differential Revision: D53246409
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118695
Approved by: https://github.com/ydwu4
Summary:
In `torch.export.export(f, args, kwargs, ..., dynamic_shpapes=None, ...)`, `dataclass` is an acceptable type of inputs (for args and kwargs). The `dynamic_shapes` of the `dataclass` inputs needs to be the same `dataclass` type which replaces each tensor attributes with `dynamic_shapes` of the corresponding tensors. (https://github.com/pytorch/pytorch/blob/main/torch/export/dynamic_shapes.py#L375)
However, some `dataclass` may have limitations on the types of attributes (e.g., having to be tensors) such that the same `dataclass` cannot be constructed for dynamic shapes.
For an input of `dataclass` type, this task enables a `dynamic_shapes` of a tuple type that specifies dynamic shape specifications for each tensor of the input in the same order as the input dataclass type's flatten_fn (https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py#L103)
Test Plan: buck test //caffe2/test:test_export
Differential Revision: D52932856
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117917
Approved by: https://github.com/avikchaudhuri