mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[HigherOrderOp] expose torch.cond (#110293)
This pr expose torch._higher_order_ops.cond as torch.cond. 1. Need to add #noqa: F811 to the _check calls in torch/__init__.py to address some confusing linter error "Redefinition of unused 'cond'" but only one cond is imported and for these lines that have this error, they don't define the cond but just use it as an argument. 2. Also add cond to the list that allows it to be traced through so as dynamo could trigger the CondHigherOrder logic instead of creating a TorchVariable. Pull Request resolved: https://github.com/pytorch/pytorch/pull/110293 Approved by: https://github.com/zou3519
This commit is contained in:
parent
0a5bb1c2eb
commit
d84bcb9c8c
|
|
@ -1,4 +1,4 @@
|
||||||
.. _control_flow_cond:
|
.. _cond:
|
||||||
|
|
||||||
Control Flow - Cond
|
Control Flow - Cond
|
||||||
====================
|
====================
|
||||||
|
|
@ -501,7 +501,7 @@ Graph breaks can also be encountered on data-dependent control flow (``if
|
||||||
x.shape[0] > 2``) when shapes are not being specialized, as a tracing compiler cannot
|
x.shape[0] > 2``) when shapes are not being specialized, as a tracing compiler cannot
|
||||||
possibly deal with without generating code for a combinatorially exploding
|
possibly deal with without generating code for a combinatorially exploding
|
||||||
number of paths. In such cases, users will need to rewrite their code using
|
number of paths. In such cases, users will need to rewrite their code using
|
||||||
special control flow operators. Currently, we support :ref:`torch.cond <control_flow_cond>`
|
special control flow operators. Currently, we support :ref:`torch.cond <cond>`
|
||||||
to express if-else like control flow (more coming soon!).
|
to express if-else like control flow (more coming soon!).
|
||||||
|
|
||||||
Data-Dependent Accesses
|
Data-Dependent Accesses
|
||||||
|
|
@ -540,7 +540,7 @@ Read More
|
||||||
torch.compiler_transformations
|
torch.compiler_transformations
|
||||||
torch.compiler_ir
|
torch.compiler_ir
|
||||||
generated/exportdb/index
|
generated/exportdb/index
|
||||||
control_flow_cond
|
cond
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
:caption: Deep Dive for PyTorch Developers
|
:caption: Deep Dive for PyTorch Developers
|
||||||
|
|
|
||||||
|
|
@ -718,6 +718,18 @@ Export Path
|
||||||
export
|
export
|
||||||
generated/exportdb/index
|
generated/exportdb/index
|
||||||
|
|
||||||
|
Control Flow
|
||||||
|
------------
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
This feature is a prototype and may have compatibility breaking changes in the future.
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: generated
|
||||||
|
:nosignatures:
|
||||||
|
|
||||||
|
cond
|
||||||
|
|
||||||
Optimizations
|
Optimizations
|
||||||
-------------
|
-------------
|
||||||
.. autosummary::
|
.. autosummary::
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,4 @@
|
||||||
from torch._higher_order_ops.cond import ( # noqa: F401
|
from torch import cond # noqa: F401
|
||||||
cond,
|
from torch._higher_order_ops.cond import UnsupportedAliasMutationException # noqa: F401
|
||||||
UnsupportedAliasMutationException,
|
|
||||||
)
|
|
||||||
|
|
||||||
from ._map import map # noqa: F401
|
from ._map import map # noqa: F401
|
||||||
|
|
|
||||||
|
|
@ -56,7 +56,7 @@ __all__ = [
|
||||||
'set_warn_always', 'is_warn_always_enabled', 'SymInt', 'SymFloat',
|
'set_warn_always', 'is_warn_always_enabled', 'SymInt', 'SymFloat',
|
||||||
'SymBool', 'sym_not',
|
'SymBool', 'sym_not',
|
||||||
'sym_int', 'sym_float', 'sym_max', 'sym_min', 'compile', 'vmap',
|
'sym_int', 'sym_float', 'sym_max', 'sym_min', 'compile', 'vmap',
|
||||||
'export', 'autocast',
|
'export', 'autocast', 'cond',
|
||||||
]
|
]
|
||||||
|
|
||||||
################################################################################
|
################################################################################
|
||||||
|
|
@ -986,7 +986,7 @@ def is_warn_always_enabled() -> builtins.bool:
|
||||||
# These error checking functions must be kept consistent with their C++
|
# These error checking functions must be kept consistent with their C++
|
||||||
# equivalents. Their C++ equivalents are mentioned where applicable.
|
# equivalents. Their C++ equivalents are mentioned where applicable.
|
||||||
|
|
||||||
def _check_with(error_type, cond: Union[builtins.bool, SymBool], message: Callable[[], str]):
|
def _check_with(error_type, cond: Union[builtins.bool, SymBool], message: Callable[[], str]): # noqa: F811
|
||||||
if not isinstance(cond, (builtins.bool, torch.SymBool)):
|
if not isinstance(cond, (builtins.bool, torch.SymBool)):
|
||||||
raise TypeError(f'cond must be a bool, but got {type(cond)}')
|
raise TypeError(f'cond must be a bool, but got {type(cond)}')
|
||||||
|
|
||||||
|
|
@ -1010,7 +1010,7 @@ def _check_with(error_type, cond: Union[builtins.bool, SymBool], message: Callab
|
||||||
|
|
||||||
raise error_type(message_evaluated)
|
raise error_type(message_evaluated)
|
||||||
|
|
||||||
def _check(cond, message=None):
|
def _check(cond, message=None): # noqa: F811
|
||||||
r"""Throws error containing an optional message if the specified condition
|
r"""Throws error containing an optional message if the specified condition
|
||||||
is False.
|
is False.
|
||||||
|
|
||||||
|
|
@ -1041,7 +1041,7 @@ def _check_is_size(i, message=None):
|
||||||
_check(i >= 0, message)
|
_check(i >= 0, message)
|
||||||
torch.fx.experimental.symbolic_shapes._advise_is_size(i)
|
torch.fx.experimental.symbolic_shapes._advise_is_size(i)
|
||||||
|
|
||||||
def _check_index(cond, message=None):
|
def _check_index(cond, message=None): # noqa: F811
|
||||||
r"""Throws error containing an optional message if the specified condition
|
r"""Throws error containing an optional message if the specified condition
|
||||||
is False.
|
is False.
|
||||||
|
|
||||||
|
|
@ -1058,7 +1058,7 @@ def _check_index(cond, message=None):
|
||||||
"""
|
"""
|
||||||
_check_with(IndexError, cond, message)
|
_check_with(IndexError, cond, message)
|
||||||
|
|
||||||
def _check_value(cond, message=None):
|
def _check_value(cond, message=None): # noqa: F811
|
||||||
r"""Throws error containing an optional message if the specified condition
|
r"""Throws error containing an optional message if the specified condition
|
||||||
is False.
|
is False.
|
||||||
|
|
||||||
|
|
@ -1075,7 +1075,7 @@ def _check_value(cond, message=None):
|
||||||
"""
|
"""
|
||||||
_check_with(ValueError, cond, message)
|
_check_with(ValueError, cond, message)
|
||||||
|
|
||||||
def _check_type(cond, message=None):
|
def _check_type(cond, message=None): # noqa: F811
|
||||||
r"""Throws error containing an optional message if the specified condition
|
r"""Throws error containing an optional message if the specified condition
|
||||||
is False.
|
is False.
|
||||||
|
|
||||||
|
|
@ -1092,7 +1092,7 @@ def _check_type(cond, message=None):
|
||||||
"""
|
"""
|
||||||
_check_with(TypeError, cond, message)
|
_check_with(TypeError, cond, message)
|
||||||
|
|
||||||
def _check_not_implemented(cond, message=None):
|
def _check_not_implemented(cond, message=None): # noqa: F811
|
||||||
r"""Throws error containing an optional message if the specified condition
|
r"""Throws error containing an optional message if the specified condition
|
||||||
is False.
|
is False.
|
||||||
|
|
||||||
|
|
@ -1109,7 +1109,7 @@ def _check_not_implemented(cond, message=None):
|
||||||
"""
|
"""
|
||||||
_check_with(NotImplementedError, cond, message)
|
_check_with(NotImplementedError, cond, message)
|
||||||
|
|
||||||
def _check_tensor_all_with(error_type, cond, message=None):
|
def _check_tensor_all_with(error_type, cond, message=None): # noqa: F811
|
||||||
if not torch.is_tensor(cond):
|
if not torch.is_tensor(cond):
|
||||||
raise TypeError(f'cond must be a tensor, but got {type(cond)}')
|
raise TypeError(f'cond must be a tensor, but got {type(cond)}')
|
||||||
|
|
||||||
|
|
@ -1120,7 +1120,7 @@ def _check_tensor_all_with(error_type, cond, message=None):
|
||||||
_check_with(error_type, cond._is_all_true().item(), message)
|
_check_with(error_type, cond._is_all_true().item(), message)
|
||||||
|
|
||||||
# C++ equivalent: `TORCH_CHECK_TENSOR_ALL`
|
# C++ equivalent: `TORCH_CHECK_TENSOR_ALL`
|
||||||
def _check_tensor_all(cond, message=None):
|
def _check_tensor_all(cond, message=None): # noqa: F811
|
||||||
r"""Throws error containing an optional message if the specified condition
|
r"""Throws error containing an optional message if the specified condition
|
||||||
is False.
|
is False.
|
||||||
|
|
||||||
|
|
@ -1761,6 +1761,7 @@ def compile(model: Optional[Callable] = None, *,
|
||||||
|
|
||||||
from torch import export as export
|
from torch import export as export
|
||||||
|
|
||||||
|
from torch._higher_order_ops import cond
|
||||||
|
|
||||||
def _register_device_module(device_type, module):
|
def _register_device_module(device_type, module):
|
||||||
r"""Register an external runtime module of the specific :attr:`device_type`
|
r"""Register an external runtime module of the specific :attr:`device_type`
|
||||||
|
|
|
||||||
|
|
@ -215,6 +215,7 @@ def _allowed_function_ids():
|
||||||
torch.func.vmap,
|
torch.func.vmap,
|
||||||
deprecated_func.vmap,
|
deprecated_func.vmap,
|
||||||
torch.nn.functional.triplet_margin_with_distance_loss,
|
torch.nn.functional.triplet_margin_with_distance_loss,
|
||||||
|
torch.cond,
|
||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
from .cond import cond
|
||||||
|
|
@ -8,8 +8,7 @@ import torch.fx.traceback as fx_traceback
|
||||||
import torch.utils._pytree as pytree
|
import torch.utils._pytree as pytree
|
||||||
|
|
||||||
from torch._C import DispatchKey
|
from torch._C import DispatchKey
|
||||||
from torch._dynamo.exc import CondOpArgsMismatchError
|
from torch._functorch.utils import exposed_in
|
||||||
from torch._dynamo.utils import disable_cache_limit
|
|
||||||
|
|
||||||
from torch._higher_order_ops.utils import autograd_not_implemented
|
from torch._higher_order_ops.utils import autograd_not_implemented
|
||||||
from torch._ops import HigherOrderOperator
|
from torch._ops import HigherOrderOperator
|
||||||
|
|
@ -42,6 +41,7 @@ class UnsupportedAliasMutationException(RuntimeError):
|
||||||
reason: str
|
reason: str
|
||||||
|
|
||||||
|
|
||||||
|
@exposed_in("torch")
|
||||||
def cond(pred, true_fn, false_fn, operands):
|
def cond(pred, true_fn, false_fn, operands):
|
||||||
r"""
|
r"""
|
||||||
Conditionally applies `true_fn` or `false_fn`.
|
Conditionally applies `true_fn` or `false_fn`.
|
||||||
|
|
@ -142,7 +142,7 @@ def cond(pred, true_fn, false_fn, operands):
|
||||||
raise RuntimeError("torch.cond requires dynamo support.")
|
raise RuntimeError("torch.cond requires dynamo support.")
|
||||||
|
|
||||||
with _set_compilation_env():
|
with _set_compilation_env():
|
||||||
with disable_cache_limit():
|
with torch._dynamo.utils.disable_cache_limit():
|
||||||
return torch.compile(cond_op, backend="eager", fullgraph=True)(
|
return torch.compile(cond_op, backend="eager", fullgraph=True)(
|
||||||
pred, true_fn, false_fn, operands
|
pred, true_fn, false_fn, operands
|
||||||
)
|
)
|
||||||
|
|
@ -198,7 +198,7 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
|
||||||
flat_true_outs, _ = pytree.tree_flatten(true_outs)
|
flat_true_outs, _ = pytree.tree_flatten(true_outs)
|
||||||
flat_false_outs, _ = pytree.tree_flatten(false_outs)
|
flat_false_outs, _ = pytree.tree_flatten(false_outs)
|
||||||
if len(flat_true_outs) != len(flat_false_outs):
|
if len(flat_true_outs) != len(flat_false_outs):
|
||||||
raise CondOpArgsMismatchError(
|
raise torch._dynamo.exc.CondOpArgsMismatchError(
|
||||||
f"Expected to return same number of outputs but got:"
|
f"Expected to return same number of outputs but got:"
|
||||||
f"\n {true_fn.__name__} returns {len(flat_true_outs)} item(s)"
|
f"\n {true_fn.__name__} returns {len(flat_true_outs)} item(s)"
|
||||||
f"\n {false_fn.__name__} returns {len(flat_false_outs)} item(s)"
|
f"\n {false_fn.__name__} returns {len(flat_false_outs)} item(s)"
|
||||||
|
|
@ -208,7 +208,7 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
|
||||||
true_out = flat_true_outs[i]
|
true_out = flat_true_outs[i]
|
||||||
false_out = flat_false_outs[i]
|
false_out = flat_false_outs[i]
|
||||||
if true_out.meta["tensor_meta"] != false_out.meta["tensor_meta"]:
|
if true_out.meta["tensor_meta"] != false_out.meta["tensor_meta"]:
|
||||||
raise CondOpArgsMismatchError(
|
raise torch._dynamo.exc.CondOpArgsMismatchError(
|
||||||
f"Expected each tensor to have same metadata but got:"
|
f"Expected each tensor to have same metadata but got:"
|
||||||
f"\n {true_fn.__name__} returns {true_out.meta['tensor_meta']}"
|
f"\n {true_fn.__name__} returns {true_out.meta['tensor_meta']}"
|
||||||
f"\n {false_fn.__name__} returns {false_out.meta['tensor_meta']}"
|
f"\n {false_fn.__name__} returns {false_out.meta['tensor_meta']}"
|
||||||
|
|
@ -291,7 +291,7 @@ def cond_fake_tensor_mode(mode, pred, true_fn, false_fn, operands):
|
||||||
true_meta = _extract_tensor_metadata(true_out)
|
true_meta = _extract_tensor_metadata(true_out)
|
||||||
false_meta = _extract_tensor_metadata(false_out)
|
false_meta = _extract_tensor_metadata(false_out)
|
||||||
if true_meta != false_meta:
|
if true_meta != false_meta:
|
||||||
raise CondOpArgsMismatchError(
|
raise torch._dynamo.exc.CondOpArgsMismatchError(
|
||||||
f"Expected each tensor to have same metadata but got:"
|
f"Expected each tensor to have same metadata but got:"
|
||||||
f"\n {true_fn.__name__} returns {true_meta}"
|
f"\n {true_fn.__name__} returns {true_meta}"
|
||||||
f"\n {false_fn.__name__} returns {false_meta}"
|
f"\n {false_fn.__name__} returns {false_meta}"
|
||||||
|
|
|
||||||
|
|
@ -297,6 +297,7 @@ def get_ignored_functions() -> Set[Callable]:
|
||||||
torch.set_vital,
|
torch.set_vital,
|
||||||
torch.read_vitals,
|
torch.read_vitals,
|
||||||
torch.vmap,
|
torch.vmap,
|
||||||
|
torch.cond,
|
||||||
torch.frombuffer,
|
torch.frombuffer,
|
||||||
torch.asarray,
|
torch.asarray,
|
||||||
torch._functional_sym_constrain_range,
|
torch._functional_sym_constrain_range,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user