From 4f1a56cd425fd0959f5a0b51b07550e9f0449336 Mon Sep 17 00:00:00 2001 From: Tobias Ringwald Date: Thu, 16 May 2024 20:58:24 +0000 Subject: [PATCH] Switched from parameter in can_cast to from_. (#126030) Fixes #126012. `from` is a reserved keyword in Python, thus we can't make the C++ impl available with `from` as function parameter. This PR changes the name to `from_` and also adjusts the docs. If we want to preserve backwards compatibility, we can leave the C++ name as-is and only fix the docs. However, `torch.can_cast(from_=torch.int, to=torch.int)` won't work then. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126030 Approved by: https://github.com/albanD --- aten/src/ATen/native/TypeProperties.cpp | 4 ++-- aten/src/ATen/native/native_functions.yaml | 2 +- .../check_forward_backward_compatibility.py | 2 ++ torch/_torch_docs.py | 4 ++-- 4 files changed, 7 insertions(+), 5 deletions(-) diff --git a/aten/src/ATen/native/TypeProperties.cpp b/aten/src/ATen/native/TypeProperties.cpp index 7091e4f78ae..4afc7619c2e 100644 --- a/aten/src/ATen/native/TypeProperties.cpp +++ b/aten/src/ATen/native/TypeProperties.cpp @@ -191,8 +191,8 @@ ScalarType result_type(const Scalar& scalar1, const Scalar& scalar2) { return result_type(state); } -bool can_cast(const at::ScalarType from, const at::ScalarType to) { - return at::canCast(from, to); +bool can_cast(const at::ScalarType from_, const at::ScalarType to) { + return at::canCast(from_, to); } ScalarType promote_types(ScalarType type1, ScalarType type2) { diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 8cf229c69c2..10d8b1ad79c 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -7714,7 +7714,7 @@ - func: result_type.Scalar_Scalar(Scalar scalar1, Scalar scalar2) -> ScalarType -- func: can_cast(ScalarType from, ScalarType to) -> bool +- func: can_cast(ScalarType from_, ScalarType to) -> bool variants: function - func: promote_types(ScalarType type1, ScalarType type2) -> ScalarType diff --git a/test/forward_backward_compatibility/check_forward_backward_compatibility.py b/test/forward_backward_compatibility/check_forward_backward_compatibility.py index 285e410a79e..81b85a4fe42 100644 --- a/test/forward_backward_compatibility/check_forward_backward_compatibility.py +++ b/test/forward_backward_compatibility/check_forward_backward_compatibility.py @@ -140,6 +140,8 @@ ALLOW_LIST = [ ("onednn::qconv2d_pointwise", datetime.date(2024, 12, 31)), ("onednn::qconv3d_pointwise", datetime.date(2024, 12, 31)), ("onednn::qconv2d_pointwise.binary", datetime.date(2024, 12, 31)), + # BC-breaking change in can_cast signature: 'from' -> 'from_' + ("aten::can_cast", datetime.date(2024, 5, 31)), ] ALLOW_LIST_COMPILED = [ diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index ba8e899dc94..6d22f9dcf98 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -2195,13 +2195,13 @@ Example:: add_docstr( torch.can_cast, r""" -can_cast(from, to) -> bool +can_cast(from_, to) -> bool Determines if a type conversion is allowed under PyTorch casting rules described in the type promotion :ref:`documentation `. Args: - from (dtype): The original :class:`torch.dtype`. + from\_ (dtype): The original :class:`torch.dtype`. to (dtype): The target :class:`torch.dtype`. Example::