mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Differential Revision: D45279206 Pull Request resolved: https://github.com/pytorch/pytorch/pull/99992 Approved by: https://github.com/angelayi, https://github.com/gmagogsfm
35 lines
1.1 KiB
Python
35 lines
1.1 KiB
Python
from typing import Optional
|
|
|
|
from torch._dynamo import allow_in_graph
|
|
from torch.fx.experimental.symbolic_shapes import constrain_range
|
|
from torch.utils._sympy.value_ranges import ValueRangeError
|
|
|
|
|
|
# TODO: we want to hide this min/max stuff under some abstraction similar to
|
|
# DynamicDim
|
|
@allow_in_graph
|
|
def constrain_as_value(symbol, min: Optional[int] = None, max: Optional[int] = None):
|
|
"""
|
|
Add min/max constraint on the intermediate symbol at tracing time
|
|
"""
|
|
|
|
constrain_range(symbol, min=min, max=max)
|
|
return symbol
|
|
|
|
|
|
# TODO: we want to hide this min/max stuff under some abstraction similar to
|
|
# DynamicDim
|
|
@allow_in_graph
|
|
def constrain_as_size(symbol, min: int = 2, max: Optional[int] = None):
|
|
"""
|
|
Add min/max constraint on the intermediate symbol which will be used as a size
|
|
"""
|
|
|
|
# TODO: we should investigate turning off 0/1 specialization for unbacked
|
|
# SymInts
|
|
if min < 2:
|
|
raise ValueRangeError(
|
|
"Unable to set min size to be <= 2 because we specialize on 0/1 sizes."
|
|
)
|
|
return constrain_as_value(symbol, min, max)
|