[HigherOrderOp] expose torch.cond (#110293)

This pr expose torch._higher_order_ops.cond as torch.cond.

1. Need to add #noqa: F811 to the _check calls in torch/__init__.py to address some confusing linter error "Redefinition of unused 'cond'" but only one cond is imported and for these lines that have this error, they don't define the cond but just use it as an argument.
2. Also add cond to the list that allows it to be traced through so as dynamo could trigger the CondHigherOrder logic instead of creating a TorchVariable.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110293
Approved by: https://github.com/zou3519
This commit is contained in:
ydwu4 2023-10-06 14:40:21 -07:00 committed by PyTorch MergeBot
parent 0a5bb1c2eb
commit d84bcb9c8c
9 changed files with 36 additions and 22 deletions

View File

@ -1,4 +1,4 @@
.. _control_flow_cond:
.. _cond:
Control Flow - Cond
====================

View File

@ -501,7 +501,7 @@ Graph breaks can also be encountered on data-dependent control flow (``if
x.shape[0] > 2``) when shapes are not being specialized, as a tracing compiler cannot
possibly deal with without generating code for a combinatorially exploding
number of paths. In such cases, users will need to rewrite their code using
special control flow operators. Currently, we support :ref:`torch.cond <control_flow_cond>`
special control flow operators. Currently, we support :ref:`torch.cond <cond>`
to express if-else like control flow (more coming soon!).
Data-Dependent Accesses
@ -540,7 +540,7 @@ Read More
torch.compiler_transformations
torch.compiler_ir
generated/exportdb/index
control_flow_cond
cond
.. toctree::
:caption: Deep Dive for PyTorch Developers

View File

@ -718,6 +718,18 @@ Export Path
export
generated/exportdb/index
Control Flow
------------
.. warning::
This feature is a prototype and may have compatibility breaking changes in the future.
.. autosummary::
:toctree: generated
:nosignatures:
cond
Optimizations
-------------
.. autosummary::

View File

@ -1,6 +1,4 @@
from torch._higher_order_ops.cond import ( # noqa: F401
cond,
UnsupportedAliasMutationException,
)
from torch import cond # noqa: F401
from torch._higher_order_ops.cond import UnsupportedAliasMutationException # noqa: F401
from ._map import map # noqa: F401

View File

@ -56,7 +56,7 @@ __all__ = [
'set_warn_always', 'is_warn_always_enabled', 'SymInt', 'SymFloat',
'SymBool', 'sym_not',
'sym_int', 'sym_float', 'sym_max', 'sym_min', 'compile', 'vmap',
'export', 'autocast',
'export', 'autocast', 'cond',
]
################################################################################
@ -986,7 +986,7 @@ def is_warn_always_enabled() -> builtins.bool:
# These error checking functions must be kept consistent with their C++
# equivalents. Their C++ equivalents are mentioned where applicable.
def _check_with(error_type, cond: Union[builtins.bool, SymBool], message: Callable[[], str]):
def _check_with(error_type, cond: Union[builtins.bool, SymBool], message: Callable[[], str]): # noqa: F811
if not isinstance(cond, (builtins.bool, torch.SymBool)):
raise TypeError(f'cond must be a bool, but got {type(cond)}')
@ -1010,7 +1010,7 @@ def _check_with(error_type, cond: Union[builtins.bool, SymBool], message: Callab
raise error_type(message_evaluated)
def _check(cond, message=None):
def _check(cond, message=None): # noqa: F811
r"""Throws error containing an optional message if the specified condition
is False.
@ -1041,7 +1041,7 @@ def _check_is_size(i, message=None):
_check(i >= 0, message)
torch.fx.experimental.symbolic_shapes._advise_is_size(i)
def _check_index(cond, message=None):
def _check_index(cond, message=None): # noqa: F811
r"""Throws error containing an optional message if the specified condition
is False.
@ -1058,7 +1058,7 @@ def _check_index(cond, message=None):
"""
_check_with(IndexError, cond, message)
def _check_value(cond, message=None):
def _check_value(cond, message=None): # noqa: F811
r"""Throws error containing an optional message if the specified condition
is False.
@ -1075,7 +1075,7 @@ def _check_value(cond, message=None):
"""
_check_with(ValueError, cond, message)
def _check_type(cond, message=None):
def _check_type(cond, message=None): # noqa: F811
r"""Throws error containing an optional message if the specified condition
is False.
@ -1092,7 +1092,7 @@ def _check_type(cond, message=None):
"""
_check_with(TypeError, cond, message)
def _check_not_implemented(cond, message=None):
def _check_not_implemented(cond, message=None): # noqa: F811
r"""Throws error containing an optional message if the specified condition
is False.
@ -1109,7 +1109,7 @@ def _check_not_implemented(cond, message=None):
"""
_check_with(NotImplementedError, cond, message)
def _check_tensor_all_with(error_type, cond, message=None):
def _check_tensor_all_with(error_type, cond, message=None): # noqa: F811
if not torch.is_tensor(cond):
raise TypeError(f'cond must be a tensor, but got {type(cond)}')
@ -1120,7 +1120,7 @@ def _check_tensor_all_with(error_type, cond, message=None):
_check_with(error_type, cond._is_all_true().item(), message)
# C++ equivalent: `TORCH_CHECK_TENSOR_ALL`
def _check_tensor_all(cond, message=None):
def _check_tensor_all(cond, message=None): # noqa: F811
r"""Throws error containing an optional message if the specified condition
is False.
@ -1761,6 +1761,7 @@ def compile(model: Optional[Callable] = None, *,
from torch import export as export
from torch._higher_order_ops import cond
def _register_device_module(device_type, module):
r"""Register an external runtime module of the specific :attr:`device_type`

View File

@ -215,6 +215,7 @@ def _allowed_function_ids():
torch.func.vmap,
deprecated_func.vmap,
torch.nn.functional.triplet_margin_with_distance_loss,
torch.cond,
):
continue

View File

@ -0,0 +1 @@
from .cond import cond

View File

@ -8,8 +8,7 @@ import torch.fx.traceback as fx_traceback
import torch.utils._pytree as pytree
from torch._C import DispatchKey
from torch._dynamo.exc import CondOpArgsMismatchError
from torch._dynamo.utils import disable_cache_limit
from torch._functorch.utils import exposed_in
from torch._higher_order_ops.utils import autograd_not_implemented
from torch._ops import HigherOrderOperator
@ -42,6 +41,7 @@ class UnsupportedAliasMutationException(RuntimeError):
reason: str
@exposed_in("torch")
def cond(pred, true_fn, false_fn, operands):
r"""
Conditionally applies `true_fn` or `false_fn`.
@ -142,7 +142,7 @@ def cond(pred, true_fn, false_fn, operands):
raise RuntimeError("torch.cond requires dynamo support.")
with _set_compilation_env():
with disable_cache_limit():
with torch._dynamo.utils.disable_cache_limit():
return torch.compile(cond_op, backend="eager", fullgraph=True)(
pred, true_fn, false_fn, operands
)
@ -198,7 +198,7 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
flat_true_outs, _ = pytree.tree_flatten(true_outs)
flat_false_outs, _ = pytree.tree_flatten(false_outs)
if len(flat_true_outs) != len(flat_false_outs):
raise CondOpArgsMismatchError(
raise torch._dynamo.exc.CondOpArgsMismatchError(
f"Expected to return same number of outputs but got:"
f"\n {true_fn.__name__} returns {len(flat_true_outs)} item(s)"
f"\n {false_fn.__name__} returns {len(flat_false_outs)} item(s)"
@ -208,7 +208,7 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
true_out = flat_true_outs[i]
false_out = flat_false_outs[i]
if true_out.meta["tensor_meta"] != false_out.meta["tensor_meta"]:
raise CondOpArgsMismatchError(
raise torch._dynamo.exc.CondOpArgsMismatchError(
f"Expected each tensor to have same metadata but got:"
f"\n {true_fn.__name__} returns {true_out.meta['tensor_meta']}"
f"\n {false_fn.__name__} returns {false_out.meta['tensor_meta']}"
@ -291,7 +291,7 @@ def cond_fake_tensor_mode(mode, pred, true_fn, false_fn, operands):
true_meta = _extract_tensor_metadata(true_out)
false_meta = _extract_tensor_metadata(false_out)
if true_meta != false_meta:
raise CondOpArgsMismatchError(
raise torch._dynamo.exc.CondOpArgsMismatchError(
f"Expected each tensor to have same metadata but got:"
f"\n {true_fn.__name__} returns {true_meta}"
f"\n {false_fn.__name__} returns {false_meta}"

View File

@ -297,6 +297,7 @@ def get_ignored_functions() -> Set[Callable]:
torch.set_vital,
torch.read_vitals,
torch.vmap,
torch.cond,
torch.frombuffer,
torch.asarray,
torch._functional_sym_constrain_range,