pytorch/torch/_export/constraints.py
Angela Yi 1d077f28ed [export] Constraints API (#98433)
Wrapper for users to insert constraints into model code.

The constraints will not be maintained in the graph after tracing through make_fx so retracing with dynamo/make_fx will not work. This will be supported after torch._assert supported is implemented. Then we can convert the constrain_range calls to torch._asserts.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98433
Approved by: https://github.com/avikchaudhuri, https://github.com/tugsbayasgalan
2023-04-13 21:20:10 +00:00

29 lines
938 B
Python

from typing import Optional
from torch.fx.experimental.symbolic_shapes import constrain_range
from torch.utils._sympy.value_ranges import ValueRangeError
# TODO: we want to hide this min/max stuff under some abstraction similar to
# DynamicDim
def constrain_as_value(symbol, min: Optional[int] = None, max: Optional[int] = None):
"""
Add min/max constraint on the intermediate symbol at tracing time
"""
constrain_range(symbol, min=min, max=max)
return symbol
def constrain_as_size(symbol, min: int = 2, max: Optional[int] = None):
"""
Add min/max constraint on the intermediate symbol which will be used as a size
"""
# TODO: we should investigate turning off 0/1 specialization for unbacked
# SymInts
if min < 2:
raise ValueRangeError(
"Unable to set min size to be <= 2 because we specialize on 0/1 sizes."
)
return constrain_as_value(symbol, min, max)