diff --git a/aten/src/ATen/native/TensorFactories.cpp b/aten/src/ATen/native/TensorFactories.cpp index 224b7c755c9..1886e65fc1e 100644 --- a/aten/src/ATen/native/TensorFactories.cpp +++ b/aten/src/ATen/native/TensorFactories.cpp @@ -23,6 +23,14 @@ #include #include #else +#include +#include +#include +#include +#include +#include +#include +#include #include #include #include diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index f20b87b3604..c9709db3290 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -5,6 +5,38 @@ # representing ScalarType's. They are now superseded by usage of # `aten::to()`. The ops remain here for backward compatibility purposes. +# DEPRECATED. DO NOT USE +- func: _cast_Byte(Tensor self, bool non_blocking=False) -> Tensor + variants: function + +# DEPRECATED. DO NOT USE +- func: _cast_Char(Tensor self, bool non_blocking=False) -> Tensor + variants: function + +# DEPRECATED. DO NOT USE +- func: _cast_Double(Tensor self, bool non_blocking=False) -> Tensor + variants: function + +# DEPRECATED. DO NOT USE +- func: _cast_Float(Tensor self, bool non_blocking=False) -> Tensor + variants: function + +# DEPRECATED. DO NOT USE +- func: _cast_Int(Tensor self, bool non_blocking=False) -> Tensor + variants: function + +# DEPRECATED. DO NOT USE +- func: _cast_Long(Tensor self, bool non_blocking=False) -> Tensor + variants: function + +# DEPRECATED. DO NOT USE +- func: _cast_Short(Tensor self, bool non_blocking=False) -> Tensor + variants: function + +# DEPRECATED. DO NOT USE +- func: _cast_Half(Tensor self, bool non_blocking=False) -> Tensor + variants: function + # Computes the gradient of current tensor w.r.t. graph leaves. - func: _backward(Tensor self, Tensor[] inputs, Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False) -> () manual_cpp_binding: True diff --git a/test/allowlist_for_publicAPI.json b/test/allowlist_for_publicAPI.json index 15a5d4e80d5..d01d41d3799 100644 --- a/test/allowlist_for_publicAPI.json +++ b/test/allowlist_for_publicAPI.json @@ -1111,6 +1111,14 @@ "_amp_update_scale_", "_assert_async", "_batch_norm_impl_index", + "_cast_Byte", + "_cast_Char", + "_cast_Double", + "_cast_Float", + "_cast_Half", + "_cast_Int", + "_cast_Long", + "_cast_Short", "_choose_qparams_per_tensor", "_coalesce", "_compute_linear_combination", diff --git a/test/cpp/lazy/test_lazy_ops.cpp b/test/cpp/lazy/test_lazy_ops.cpp index 215e4018d4c..63cc28b89dc 100644 --- a/test/cpp/lazy/test_lazy_ops.cpp +++ b/test/cpp/lazy/test_lazy_ops.cpp @@ -135,6 +135,84 @@ TEST_F(LazyOpsTest, TestIsSigned) { }); } +TEST_F(LazyOpsTest, TestCastByte) { + torch::Tensor a = + torch::rand( + {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) * + 100.0; + torch::Tensor b = torch::_cast_Byte(a); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor lazy_a = CopyToDevice(a, device); + torch::Tensor lazy_b = torch::_cast_Byte(lazy_a); + AllEqual(b, lazy_b); + }); +} + +TEST_F(LazyOpsTest, TestCastChar) { + torch::Tensor a = + torch::rand( + {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) * + 100.0; + torch::Tensor b = torch::_cast_Char(a); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor lazy_a = CopyToDevice(a, device); + torch::Tensor lazy_b = torch::_cast_Char(lazy_a); + AllEqual(b, lazy_b); + }); +} + +TEST_F(LazyOpsTest, TestCastShort) { + torch::Tensor a = + torch::rand( + {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) * + 100.0; + torch::Tensor b = torch::_cast_Short(a); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor lazy_a = CopyToDevice(a, device); + torch::Tensor lazy_b = torch::_cast_Short(lazy_a); + AllEqual(b, lazy_b); + }); +} + +TEST_F(LazyOpsTest, TestCastInt) { + torch::Tensor a = + torch::rand( + {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) * + 100.0; + torch::Tensor b = torch::_cast_Int(a); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor lazy_a = CopyToDevice(a, device); + torch::Tensor lazy_b = torch::_cast_Int(lazy_a); + AllEqual(b, lazy_b); + }); +} + +TEST_F(LazyOpsTest, TestCastLong) { + torch::Tensor a = + torch::rand( + {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) * + 100.0; + torch::Tensor b = torch::_cast_Long(a); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor lazy_a = CopyToDevice(a, device); + torch::Tensor lazy_b = torch::_cast_Long(lazy_a); + AllEqual(b, lazy_b); + }); +} + +TEST_F(LazyOpsTest, TestCastFloat) { + torch::Tensor a = + torch::rand( + {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) * + 100.0; + torch::Tensor b = torch::_cast_Float(a); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor lazy_a = CopyToDevice(a, device); + torch::Tensor lazy_b = torch::_cast_Float(lazy_a); + AllEqual(b, lazy_b); + }); +} + TEST_F(LazyOpsTest, TestRetainType) { torch::Tensor lazy_a = torch::zeros( {2, 2}, torch::TensorOptions(torch::kByte).device(torch::kLazy)); diff --git a/test/forward_backward_compatibility/check_forward_backward_compatibility.py b/test/forward_backward_compatibility/check_forward_backward_compatibility.py index 6eac2fb8f99..5a962dfa57c 100644 --- a/test/forward_backward_compatibility/check_forward_backward_compatibility.py +++ b/test/forward_backward_compatibility/check_forward_backward_compatibility.py @@ -141,15 +141,6 @@ ALLOW_LIST = [ ("c10d::.*", datetime.date(9999, 1, 1)), # Previously MPS_only did not support backward ("aten::_fused_rms_norm", datetime.date(2025, 12, 30)), - # These casting ops were deprecated in PyTorch 1 - ("aten::_cast_Half", datetime.date(9999, 1, 1), None, True), - ("aten::_cast_Short", datetime.date(9999, 1, 1), None, True), - ("aten::_cast_Long", datetime.date(9999, 1, 1), None, True), - ("aten::_cast_Int", datetime.date(9999, 1, 1), None, True), - ("aten::_cast_Float", datetime.date(9999, 1, 1), None, True), - ("aten::_cast_Double", datetime.date(9999, 1, 1), None, True), - ("aten::_cast_Char", datetime.date(9999, 1, 1), None, True), - ("aten::_cast_Byte", datetime.date(9999, 1, 1), None, True), ] ALLOW_LIST_COMPILED = [ diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index fc0f7238a5c..6de46ead64e 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -1414,6 +1414,14 @@ torch_c_binding_in_graph_functions = dict.fromkeys( "torch._C.unify_type_list", "torch._C.vitals_enabled", "torch._C.wait", + "torch._cast_Byte", + "torch._cast_Char", + "torch._cast_Double", + "torch._cast_Float", + "torch._cast_Half", + "torch._cast_Int", + "torch._cast_Long", + "torch._cast_Short", "torch._choose_qparams_per_tensor", "torch._chunk_cat", "torch._coalesce", diff --git a/torch/csrc/jit/codegen/fuser/codegen.cpp b/torch/csrc/jit/codegen/fuser/codegen.cpp index ec60edc60a8..a5cd6f4e3a4 100644 --- a/torch/csrc/jit/codegen/fuser/codegen.cpp +++ b/torch/csrc/jit/codegen/fuser/codegen.cpp @@ -181,6 +181,7 @@ struct RHSTemplate { static std::string encodeRHS(const Node* n) { static std::unordered_map simple_map_ops = { // unary + {aten::_cast_Float, "static_cast(${0})"}, {aten::abs, "fabs(${0})"}, {aten::sigmoid, {"1.f / (1.f + expf(-${0}))", "1. / (1. + exp(-${0}))"}}, {aten::relu, "${0} < 0 ? 0.f : ${0} "}, diff --git a/torch/csrc/jit/passes/graph_fuser.cpp b/torch/csrc/jit/passes/graph_fuser.cpp index 46ed838fbe4..8dfa836f87b 100644 --- a/torch/csrc/jit/passes/graph_fuser.cpp +++ b/torch/csrc/jit/passes/graph_fuser.cpp @@ -34,6 +34,8 @@ namespace { // carefully read the code first, as we rely on these assumptions. bool isSimpleMap(Node* node) { static OperatorSet simple_mappable{{ + "aten::_cast_Float(Tensor self, bool non_blocking) -> Tensor", + "aten::abs(Tensor self) -> Tensor", "aten::acos(Tensor self) -> Tensor", "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor", diff --git a/torch/csrc/jit/passes/guard_elimination.cpp b/torch/csrc/jit/passes/guard_elimination.cpp index 682022ac83e..7b0fed5dc15 100644 --- a/torch/csrc/jit/passes/guard_elimination.cpp +++ b/torch/csrc/jit/passes/guard_elimination.cpp @@ -340,6 +340,8 @@ struct GuardElimination { case aten::reciprocal: case aten::addcmul: case aten::where: + case aten::_cast_Float: + case aten::_cast_Long: case aten::__and__: case aten::__or__: case aten::__xor__: diff --git a/torch/csrc/jit/passes/onnx/prepare_division_for_onnx.cpp b/torch/csrc/jit/passes/onnx/prepare_division_for_onnx.cpp index 6d68a9ff654..f7c0e0a7cac 100644 --- a/torch/csrc/jit/passes/onnx/prepare_division_for_onnx.cpp +++ b/torch/csrc/jit/passes/onnx/prepare_division_for_onnx.cpp @@ -23,9 +23,9 @@ static void PrepareDivisionForONNXOnBlock(Block* block) { subgraph->insertNode(subgraph->createNumToTensor(input)) ->output(); longtensor->node()->copyMetadata(input->node()); - auto* cast = subgraph->create(at::onnx::Cast, 1); - cast->addInput(longtensor); - cast->i_(attr::to, 1); + auto* nonblocking = subgraph->insertConstant(0); + auto* cast = + subgraph->create(aten::_cast_Float, {longtensor, nonblocking}); cast->copyMetadata(*it); return subgraph->insertNode(cast)->output(); }); diff --git a/torch/csrc/jit/passes/shape_analysis.cpp b/torch/csrc/jit/passes/shape_analysis.cpp index 58ad1c4214d..18068f2f78c 100644 --- a/torch/csrc/jit/passes/shape_analysis.cpp +++ b/torch/csrc/jit/passes/shape_analysis.cpp @@ -1517,6 +1517,50 @@ class ShapePropagator : public PropertyPropBase { return {}; }}; + static const auto get_cast_scalar_type = [](Node* node) -> at::ScalarType { + switch (node->kind()) { + case aten::_cast_Byte: + return at::kByte; + case aten::_cast_Char: + return at::kChar; + case aten::_cast_Double: + return at::kDouble; + case aten::_cast_Float: + return at::kFloat; + case aten::_cast_Half: + return at::kHalf; + case aten::_cast_Int: + return at::kInt; + case aten::_cast_Long: + return at::kLong; + case aten::_cast_Short: + return at::kShort; + default: + TORCH_INTERNAL_ASSERT( + false, + "unknown node kind in get_cast_scalar_type: ", + node->kind().toQualString()); + } + }; + static const register_formula_for cast_ops{ + { + "aten::_cast_Byte(Tensor self, bool non_blocking) -> Tensor", + "aten::_cast_Char(Tensor self, bool non_blocking) -> Tensor", + "aten::_cast_Double(Tensor self, bool non_blocking) -> Tensor", + "aten::_cast_Float(Tensor self, bool non_blocking) -> Tensor", + "aten::_cast_Half(Tensor self, bool non_blocking) -> Tensor", + "aten::_cast_Int(Tensor self, bool non_blocking) -> Tensor", + "aten::_cast_Long(Tensor self, bool non_blocking) -> Tensor", + "aten::_cast_Short(Tensor self, bool non_blocking) -> Tensor", + }, + [](Node* node) -> type_vec_t { + if (auto type = + node->namedInput(attr::self)->type()->cast()) { + return {type->withScalarType(get_cast_scalar_type(node))}; + } + return {}; + }}; + // First, try to match one of the registered formulas to their operator // sets. for (auto& entry : shape_formulas) { diff --git a/torch/csrc/jit/runtime/symbolic_shape_registry_util.cpp b/torch/csrc/jit/runtime/symbolic_shape_registry_util.cpp index 1e8b55f1513..ac0cd61fd2f 100644 --- a/torch/csrc/jit/runtime/symbolic_shape_registry_util.cpp +++ b/torch/csrc/jit/runtime/symbolic_shape_registry_util.cpp @@ -11,6 +11,7 @@ const OperatorMap& get_tensorexpr_elementwise_set() { // clang-format off static const OperatorMap tensorexpr_elementwise_set{ {"aten::add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor", "unary"}, + {"aten::_cast_Float(Tensor self, bool non_blocking) -> Tensor", "unary"}, {"aten::sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor", "unary"}, {"aten::mul.Scalar(Tensor self, Scalar other) -> Tensor", "unary"}, {"aten::div.Scalar(Tensor self, Scalar other) -> Tensor", "unary"}, diff --git a/torch/csrc/jit/tensorexpr/lowerings.cpp b/torch/csrc/jit/tensorexpr/lowerings.cpp index 713af124e6e..76bfe3e9ce4 100644 --- a/torch/csrc/jit/tensorexpr/lowerings.cpp +++ b/torch/csrc/jit/tensorexpr/lowerings.cpp @@ -1554,6 +1554,22 @@ int nnc_lowerings_lazy_registration() { [](const ExprHandle& a) { return trunc(a); }); }); + RegisterNNCLoweringsFunction aten__cast_Float( + {"aten::_cast_Float(Tensor self, bool non_blocking=False) -> (Tensor)"}, + [](const std::vector& inputs, + const std::vector& outputShape, + const std::vector& outputStrides, + const std::optional& outputType, + at::Device device) { + return computeOneOperand( + "aten_cast_float", + inputs, + outputShape, + outputStrides, + outputType, + [](const ExprHandle& a) { return cast(a); }); + }); + RegisterNNCLoweringsFunction aten_to( {"aten::to.dtype(Tensor(a) self, int dtype, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Tensor(a))", "aten::to.dtype_layout(Tensor(a) self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Tensor(a))", diff --git a/torch/onnx/_internal/torchscript_exporter/symbolic_opset8.py b/torch/onnx/_internal/torchscript_exporter/symbolic_opset8.py index 1c1b8ce04b9..bde07260808 100644 --- a/torch/onnx/_internal/torchscript_exporter/symbolic_opset8.py +++ b/torch/onnx/_internal/torchscript_exporter/symbolic_opset8.py @@ -191,7 +191,7 @@ def _try_cast_integer_to_float(g: jit_utils.GraphContext, *args): def _cast_to_type(g: jit_utils.GraphContext, input, to_type): if to_type is None: return input - return g.op("Cast", input, to_i=symbolic_helper.cast_pytorch_to_onnx[to_type]) + return getattr(opset9, f"_cast_{to_type}")(g, input, False) def _comparison_operator(g: jit_utils.GraphContext, input, other, op_name): diff --git a/torch/onnx/_internal/torchscript_exporter/symbolic_opset9.py b/torch/onnx/_internal/torchscript_exporter/symbolic_opset9.py index 177c98d8dde..9b7aba64ef3 100644 --- a/torch/onnx/_internal/torchscript_exporter/symbolic_opset9.py +++ b/torch/onnx/_internal/torchscript_exporter/symbolic_opset9.py @@ -15,6 +15,7 @@ import math import sys import warnings from typing import TYPE_CHECKING +from typing_extensions import deprecated import torch import torch._C._onnx as _C_onnx @@ -1965,8 +1966,8 @@ def wrap_logical_op_with_cast_to(to_type): def decorator(fn): @functools.wraps(fn) def wrap_with_cast(g, input, other): - to_i = symbolic_helper.cast_pytorch_to_onnx[to_type] - return fn(g, g.op("Cast", input, to_i=to_i), g.op("Cast", other, to_i=to_i)) + to_cast_func = globals()[f"_cast_{to_type}"] + return fn(g, to_cast_func(g, input, False), to_cast_func(g, other, False)) return wrap_with_cast @@ -3330,6 +3331,60 @@ def _unique2(g: jit_utils.GraphContext, input, sorted, return_inverse, return_co symbolic_helper._onnx_opset_unsupported("_unique2", 9, 11, input) +@_onnx_symbolic("aten::_cast_Byte") +@deprecated("Avoid using this function and create a Cast node instead") +def _cast_Byte(g: jit_utils.GraphContext, input, non_blocking): + return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.UINT8) + + +@_onnx_symbolic("aten::_cast_Char") +@deprecated("Avoid using this function and create a Cast node instead") +def _cast_Char(g: jit_utils.GraphContext, input, non_blocking): + return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT8) + + +@_onnx_symbolic("aten::_cast_Short") +@deprecated("Avoid using this function and create a Cast node instead") +def _cast_Short(g: jit_utils.GraphContext, input, non_blocking): + return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT16) + + +@_onnx_symbolic("aten::_cast_Int") +@deprecated("Avoid using this function and create a Cast node instead") +def _cast_Int(g: jit_utils.GraphContext, input, non_blocking): + return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT32) + + +@_onnx_symbolic("aten::_cast_Long") +@deprecated("Avoid using this function and create a Cast node instead") +def _cast_Long(g: jit_utils.GraphContext, input, non_blocking): + return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT64) + + +@_onnx_symbolic("aten::_cast_Half") +@deprecated("Avoid using this function and create a Cast node instead") +def _cast_Half(g: jit_utils.GraphContext, input, non_blocking): + return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.FLOAT16) + + +@_onnx_symbolic("aten::_cast_Float") +@deprecated("Avoid using this function and create a Cast node instead") +def _cast_Float(g: jit_utils.GraphContext, input, non_blocking): + return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.FLOAT) + + +@_onnx_symbolic("aten::_cast_Double") +@deprecated("Avoid using this function and create a Cast node instead") +def _cast_Double(g: jit_utils.GraphContext, input, non_blocking): + return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.DOUBLE) + + +@_onnx_symbolic("aten::_cast_Bool") +@deprecated("Avoid using this function and create a Cast node instead") +def _cast_Bool(g: jit_utils.GraphContext, input, non_blocking): + return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.BOOL) + + @_onnx_symbolic("aten::empty") @symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") def empty(