mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Initial prototype for dynamic int inputs, allows users to run with `torch.compile(f)(DynamicInt(4))`, compiling dynamically and using the underlying hint at runtime.
Current behavior:
- Also works in eager (mostly by subclassing int), as scalar input to torch functions, or numpy/math/etc. For example, `x = DynamicInt(3); torch.randn(x); torch.add(y, z, alpha=x); np.arange(x)` all act as if x = 3.
- Behavior for arithmetic ops is to return new DynamicInts rather than static ints; `DynamicInt(3) * 2 = DynamicInt(6)`. This is via SymNode magic methods, but coverage might not be 100% - for example, I had to explicitly override floordiv to avoid int casting. This is not necessarily the case for non-magic method ops (e.g. `math.cos(x)`). The alternative here is to int cast on all operations, but I opted for this for dynamism propagation in non-compiled regions.
- Doesn't ban fullgraph=False; DynamicInt objects might be leaked back to the user, but I guess this is fine, because they can be casted to ints when needed?
- Dynamo only allocates one symbol per DynamicInt; specifying the same DynamicInt for multiple inputs leads to input deduplication, and a guard installed.
- We don't raise on int specialization (in allowlist/maybe_mark_dynamic style) - but an easy change if needed.
- DynamicInts as nn.Module attributes are handled.
- We don't guard on the DynamicInt id, e.g. users can do the following without recompiling (maybe we should guard?)
```python
x = DynamicInt(4)
f(x)
f(1)
f(DynamicInt(3)) # same as f(3)
```
Follow-up work:
- Specifying shape constraints, either at the int-level, e.g.
```python
DynamicInt(64, name="s0", constraints=["s0 % 32 == 0", "s0 <= 1024"]
```
or at the compilation level, e.g. something like
```python
s0 = DynamicInt(64, name="s0")
s1 = DynamicInt(128, name="s1")
with some_compiler_config.dynamic_int_constraints(["s1 == 2*s0", "s0 % 32 == 0"]):
f(s0, s1)
```
This should subsume the need for specifying derived SymInts?
- SymFloat support - currently it seems backed floats are specialized by the tensorify float pass, and there's no handling in inductor.
- Propagating dynamism in tensor constructors, e.g. `x = DynamicInt(4); torch.randn(x)` could annotate `_dynamo_dynamic_indices`.
Differential Revision: D81698719
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162194
Approved by: https://github.com/bobrenjc93
95 lines
1.7 KiB
Markdown
95 lines
1.7 KiB
Markdown
```{eval-rst}
|
|
.. currentmodule:: torch.fx.experimental
|
|
```
|
|
|
|
# torch.fx.experimental
|
|
|
|
:::{warning}
|
|
These APIs are experimental and subject to change without notice.
|
|
:::
|
|
|
|
```{eval-rst}
|
|
.. autoclass:: torch.fx.experimental.sym_node.DynamicInt
|
|
```
|
|
|
|
## torch.fx.experimental.symbolic_shapes
|
|
|
|
```{eval-rst}
|
|
.. currentmodule:: torch.fx.experimental.symbolic_shapes
|
|
```
|
|
|
|
```{eval-rst}
|
|
.. automodule:: torch.fx.experimental.symbolic_shapes
|
|
```
|
|
|
|
```{eval-rst}
|
|
.. autosummary::
|
|
:toctree: generated
|
|
:nosignatures:
|
|
|
|
ShapeEnv
|
|
DimDynamic
|
|
StrictMinMaxConstraint
|
|
RelaxedUnspecConstraint
|
|
EqualityConstraint
|
|
SymbolicContext
|
|
StatelessSymbolicContext
|
|
StatefulSymbolicContext
|
|
SubclassSymbolicContext
|
|
DimConstraints
|
|
ShapeEnvSettings
|
|
ConvertIntKey
|
|
CallMethodKey
|
|
PropagateUnbackedSymInts
|
|
DivideByKey
|
|
InnerTensorKey
|
|
Specialization
|
|
|
|
hint_int
|
|
is_concrete_int
|
|
is_concrete_bool
|
|
is_concrete_float
|
|
has_free_symbols
|
|
has_free_unbacked_symbols
|
|
guard_or_true
|
|
guard_or_false
|
|
guard_size_oblivious
|
|
sym_and
|
|
sym_eq
|
|
sym_or
|
|
constrain_range
|
|
constrain_unify
|
|
canonicalize_bool_expr
|
|
statically_known_true
|
|
statically_known_false
|
|
has_static_value
|
|
lru_cache
|
|
check_consistent
|
|
compute_unbacked_bindings
|
|
rebind_unbacked
|
|
resolve_unbacked_bindings
|
|
is_accessor_node
|
|
```
|
|
|
|
## torch.fx.experimental.proxy_tensor
|
|
|
|
```{eval-rst}
|
|
.. currentmodule:: torch.fx.experimental.proxy_tensor
|
|
```
|
|
|
|
```{eval-rst}
|
|
.. automodule:: torch.fx.experimental.proxy_tensor
|
|
```
|
|
|
|
```{eval-rst}
|
|
.. autosummary::
|
|
:toctree: generated
|
|
:nosignatures:
|
|
|
|
make_fx
|
|
handle_sym_dispatch
|
|
get_proxy_mode
|
|
maybe_enable_thunkify
|
|
maybe_disable_thunkify
|
|
```
|