[HigherOrderOp] expose torch.cond (#110293)

This pr expose torch._higher_order_ops.cond as torch.cond.

1. Need to add #noqa: F811 to the _check calls in torch/__init__.py to address some confusing linter error "Redefinition of unused 'cond'" but only one cond is imported and for these lines that have this error, they don't define the cond but just use it as an argument.
2. Also add cond to the list that allows it to be traced through so as dynamo could trigger the CondHigherOrder logic instead of creating a TorchVariable.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110293
Approved by: https://github.com/zou3519
This commit is contained in:
ydwu4 2023-10-06 14:40:21 -07:00 committed by PyTorch MergeBot
parent 0a5bb1c2eb
commit d84bcb9c8c
9 changed files with 36 additions and 22 deletions

View File

@ -1,4 +1,4 @@
.. _control_flow_cond: .. _cond:
Control Flow - Cond Control Flow - Cond
==================== ====================

View File

@ -501,7 +501,7 @@ Graph breaks can also be encountered on data-dependent control flow (``if
x.shape[0] > 2``) when shapes are not being specialized, as a tracing compiler cannot x.shape[0] > 2``) when shapes are not being specialized, as a tracing compiler cannot
possibly deal with without generating code for a combinatorially exploding possibly deal with without generating code for a combinatorially exploding
number of paths. In such cases, users will need to rewrite their code using number of paths. In such cases, users will need to rewrite their code using
special control flow operators. Currently, we support :ref:`torch.cond <control_flow_cond>` special control flow operators. Currently, we support :ref:`torch.cond <cond>`
to express if-else like control flow (more coming soon!). to express if-else like control flow (more coming soon!).
Data-Dependent Accesses Data-Dependent Accesses
@ -540,7 +540,7 @@ Read More
torch.compiler_transformations torch.compiler_transformations
torch.compiler_ir torch.compiler_ir
generated/exportdb/index generated/exportdb/index
control_flow_cond cond
.. toctree:: .. toctree::
:caption: Deep Dive for PyTorch Developers :caption: Deep Dive for PyTorch Developers

View File

@ -718,6 +718,18 @@ Export Path
export export
generated/exportdb/index generated/exportdb/index
Control Flow
------------
.. warning::
This feature is a prototype and may have compatibility breaking changes in the future.
.. autosummary::
:toctree: generated
:nosignatures:
cond
Optimizations Optimizations
------------- -------------
.. autosummary:: .. autosummary::

View File

@ -1,6 +1,4 @@
from torch._higher_order_ops.cond import ( # noqa: F401 from torch import cond # noqa: F401
cond, from torch._higher_order_ops.cond import UnsupportedAliasMutationException # noqa: F401
UnsupportedAliasMutationException,
)
from ._map import map # noqa: F401 from ._map import map # noqa: F401

View File

@ -56,7 +56,7 @@ __all__ = [
'set_warn_always', 'is_warn_always_enabled', 'SymInt', 'SymFloat', 'set_warn_always', 'is_warn_always_enabled', 'SymInt', 'SymFloat',
'SymBool', 'sym_not', 'SymBool', 'sym_not',
'sym_int', 'sym_float', 'sym_max', 'sym_min', 'compile', 'vmap', 'sym_int', 'sym_float', 'sym_max', 'sym_min', 'compile', 'vmap',
'export', 'autocast', 'export', 'autocast', 'cond',
] ]
################################################################################ ################################################################################
@ -986,7 +986,7 @@ def is_warn_always_enabled() -> builtins.bool:
# These error checking functions must be kept consistent with their C++ # These error checking functions must be kept consistent with their C++
# equivalents. Their C++ equivalents are mentioned where applicable. # equivalents. Their C++ equivalents are mentioned where applicable.
def _check_with(error_type, cond: Union[builtins.bool, SymBool], message: Callable[[], str]): def _check_with(error_type, cond: Union[builtins.bool, SymBool], message: Callable[[], str]): # noqa: F811
if not isinstance(cond, (builtins.bool, torch.SymBool)): if not isinstance(cond, (builtins.bool, torch.SymBool)):
raise TypeError(f'cond must be a bool, but got {type(cond)}') raise TypeError(f'cond must be a bool, but got {type(cond)}')
@ -1010,7 +1010,7 @@ def _check_with(error_type, cond: Union[builtins.bool, SymBool], message: Callab
raise error_type(message_evaluated) raise error_type(message_evaluated)
def _check(cond, message=None): def _check(cond, message=None): # noqa: F811
r"""Throws error containing an optional message if the specified condition r"""Throws error containing an optional message if the specified condition
is False. is False.
@ -1041,7 +1041,7 @@ def _check_is_size(i, message=None):
_check(i >= 0, message) _check(i >= 0, message)
torch.fx.experimental.symbolic_shapes._advise_is_size(i) torch.fx.experimental.symbolic_shapes._advise_is_size(i)
def _check_index(cond, message=None): def _check_index(cond, message=None): # noqa: F811
r"""Throws error containing an optional message if the specified condition r"""Throws error containing an optional message if the specified condition
is False. is False.
@ -1058,7 +1058,7 @@ def _check_index(cond, message=None):
""" """
_check_with(IndexError, cond, message) _check_with(IndexError, cond, message)
def _check_value(cond, message=None): def _check_value(cond, message=None): # noqa: F811
r"""Throws error containing an optional message if the specified condition r"""Throws error containing an optional message if the specified condition
is False. is False.
@ -1075,7 +1075,7 @@ def _check_value(cond, message=None):
""" """
_check_with(ValueError, cond, message) _check_with(ValueError, cond, message)
def _check_type(cond, message=None): def _check_type(cond, message=None): # noqa: F811
r"""Throws error containing an optional message if the specified condition r"""Throws error containing an optional message if the specified condition
is False. is False.
@ -1092,7 +1092,7 @@ def _check_type(cond, message=None):
""" """
_check_with(TypeError, cond, message) _check_with(TypeError, cond, message)
def _check_not_implemented(cond, message=None): def _check_not_implemented(cond, message=None): # noqa: F811
r"""Throws error containing an optional message if the specified condition r"""Throws error containing an optional message if the specified condition
is False. is False.
@ -1109,7 +1109,7 @@ def _check_not_implemented(cond, message=None):
""" """
_check_with(NotImplementedError, cond, message) _check_with(NotImplementedError, cond, message)
def _check_tensor_all_with(error_type, cond, message=None): def _check_tensor_all_with(error_type, cond, message=None): # noqa: F811
if not torch.is_tensor(cond): if not torch.is_tensor(cond):
raise TypeError(f'cond must be a tensor, but got {type(cond)}') raise TypeError(f'cond must be a tensor, but got {type(cond)}')
@ -1120,7 +1120,7 @@ def _check_tensor_all_with(error_type, cond, message=None):
_check_with(error_type, cond._is_all_true().item(), message) _check_with(error_type, cond._is_all_true().item(), message)
# C++ equivalent: `TORCH_CHECK_TENSOR_ALL` # C++ equivalent: `TORCH_CHECK_TENSOR_ALL`
def _check_tensor_all(cond, message=None): def _check_tensor_all(cond, message=None): # noqa: F811
r"""Throws error containing an optional message if the specified condition r"""Throws error containing an optional message if the specified condition
is False. is False.
@ -1761,6 +1761,7 @@ def compile(model: Optional[Callable] = None, *,
from torch import export as export from torch import export as export
from torch._higher_order_ops import cond
def _register_device_module(device_type, module): def _register_device_module(device_type, module):
r"""Register an external runtime module of the specific :attr:`device_type` r"""Register an external runtime module of the specific :attr:`device_type`

View File

@ -215,6 +215,7 @@ def _allowed_function_ids():
torch.func.vmap, torch.func.vmap,
deprecated_func.vmap, deprecated_func.vmap,
torch.nn.functional.triplet_margin_with_distance_loss, torch.nn.functional.triplet_margin_with_distance_loss,
torch.cond,
): ):
continue continue

View File

@ -0,0 +1 @@
from .cond import cond

View File

@ -8,8 +8,7 @@ import torch.fx.traceback as fx_traceback
import torch.utils._pytree as pytree import torch.utils._pytree as pytree
from torch._C import DispatchKey from torch._C import DispatchKey
from torch._dynamo.exc import CondOpArgsMismatchError from torch._functorch.utils import exposed_in
from torch._dynamo.utils import disable_cache_limit
from torch._higher_order_ops.utils import autograd_not_implemented from torch._higher_order_ops.utils import autograd_not_implemented
from torch._ops import HigherOrderOperator from torch._ops import HigherOrderOperator
@ -42,6 +41,7 @@ class UnsupportedAliasMutationException(RuntimeError):
reason: str reason: str
@exposed_in("torch")
def cond(pred, true_fn, false_fn, operands): def cond(pred, true_fn, false_fn, operands):
r""" r"""
Conditionally applies `true_fn` or `false_fn`. Conditionally applies `true_fn` or `false_fn`.
@ -142,7 +142,7 @@ def cond(pred, true_fn, false_fn, operands):
raise RuntimeError("torch.cond requires dynamo support.") raise RuntimeError("torch.cond requires dynamo support.")
with _set_compilation_env(): with _set_compilation_env():
with disable_cache_limit(): with torch._dynamo.utils.disable_cache_limit():
return torch.compile(cond_op, backend="eager", fullgraph=True)( return torch.compile(cond_op, backend="eager", fullgraph=True)(
pred, true_fn, false_fn, operands pred, true_fn, false_fn, operands
) )
@ -198,7 +198,7 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
flat_true_outs, _ = pytree.tree_flatten(true_outs) flat_true_outs, _ = pytree.tree_flatten(true_outs)
flat_false_outs, _ = pytree.tree_flatten(false_outs) flat_false_outs, _ = pytree.tree_flatten(false_outs)
if len(flat_true_outs) != len(flat_false_outs): if len(flat_true_outs) != len(flat_false_outs):
raise CondOpArgsMismatchError( raise torch._dynamo.exc.CondOpArgsMismatchError(
f"Expected to return same number of outputs but got:" f"Expected to return same number of outputs but got:"
f"\n {true_fn.__name__} returns {len(flat_true_outs)} item(s)" f"\n {true_fn.__name__} returns {len(flat_true_outs)} item(s)"
f"\n {false_fn.__name__} returns {len(flat_false_outs)} item(s)" f"\n {false_fn.__name__} returns {len(flat_false_outs)} item(s)"
@ -208,7 +208,7 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
true_out = flat_true_outs[i] true_out = flat_true_outs[i]
false_out = flat_false_outs[i] false_out = flat_false_outs[i]
if true_out.meta["tensor_meta"] != false_out.meta["tensor_meta"]: if true_out.meta["tensor_meta"] != false_out.meta["tensor_meta"]:
raise CondOpArgsMismatchError( raise torch._dynamo.exc.CondOpArgsMismatchError(
f"Expected each tensor to have same metadata but got:" f"Expected each tensor to have same metadata but got:"
f"\n {true_fn.__name__} returns {true_out.meta['tensor_meta']}" f"\n {true_fn.__name__} returns {true_out.meta['tensor_meta']}"
f"\n {false_fn.__name__} returns {false_out.meta['tensor_meta']}" f"\n {false_fn.__name__} returns {false_out.meta['tensor_meta']}"
@ -291,7 +291,7 @@ def cond_fake_tensor_mode(mode, pred, true_fn, false_fn, operands):
true_meta = _extract_tensor_metadata(true_out) true_meta = _extract_tensor_metadata(true_out)
false_meta = _extract_tensor_metadata(false_out) false_meta = _extract_tensor_metadata(false_out)
if true_meta != false_meta: if true_meta != false_meta:
raise CondOpArgsMismatchError( raise torch._dynamo.exc.CondOpArgsMismatchError(
f"Expected each tensor to have same metadata but got:" f"Expected each tensor to have same metadata but got:"
f"\n {true_fn.__name__} returns {true_meta}" f"\n {true_fn.__name__} returns {true_meta}"
f"\n {false_fn.__name__} returns {false_meta}" f"\n {false_fn.__name__} returns {false_meta}"

View File

@ -297,6 +297,7 @@ def get_ignored_functions() -> Set[Callable]:
torch.set_vital, torch.set_vital,
torch.read_vitals, torch.read_vitals,
torch.vmap, torch.vmap,
torch.cond,
torch.frombuffer, torch.frombuffer,
torch.asarray, torch.asarray,
torch._functional_sym_constrain_range, torch._functional_sym_constrain_range,