Revert "[BC-Breaking] Remove long-deprecated casting functions from native_functions.yaml (#164641)"

This reverts commit 64108bdbed.

Reverted https://github.com/pytorch/pytorch/pull/164641 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/164641#issuecomment-3386346474))
This commit is contained in:
PyTorch MergeBot 2025-10-09 15:42:49 +00:00
parent a7fa1a91e3
commit 3d1fa40ae1
15 changed files with 261 additions and 15 deletions

View File

@ -23,6 +23,14 @@
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_cast_Byte_native.h>
#include <ATen/ops/_cast_Char_native.h>
#include <ATen/ops/_cast_Double_native.h>
#include <ATen/ops/_cast_Float_native.h>
#include <ATen/ops/_cast_Half_native.h>
#include <ATen/ops/_cast_Int_native.h>
#include <ATen/ops/_cast_Long_native.h>
#include <ATen/ops/_cast_Short_native.h>
#include <ATen/ops/_dim_arange_native.h>
#include <ATen/ops/_efficientzerotensor_native.h>
#include <ATen/ops/_empty_affine_quantized.h>

View File

@ -5,6 +5,38 @@
# representing ScalarType's. They are now superseded by usage of
# `aten::to()`. The ops remain here for backward compatibility purposes.
# DEPRECATED. DO NOT USE
- func: _cast_Byte(Tensor self, bool non_blocking=False) -> Tensor
variants: function
# DEPRECATED. DO NOT USE
- func: _cast_Char(Tensor self, bool non_blocking=False) -> Tensor
variants: function
# DEPRECATED. DO NOT USE
- func: _cast_Double(Tensor self, bool non_blocking=False) -> Tensor
variants: function
# DEPRECATED. DO NOT USE
- func: _cast_Float(Tensor self, bool non_blocking=False) -> Tensor
variants: function
# DEPRECATED. DO NOT USE
- func: _cast_Int(Tensor self, bool non_blocking=False) -> Tensor
variants: function
# DEPRECATED. DO NOT USE
- func: _cast_Long(Tensor self, bool non_blocking=False) -> Tensor
variants: function
# DEPRECATED. DO NOT USE
- func: _cast_Short(Tensor self, bool non_blocking=False) -> Tensor
variants: function
# DEPRECATED. DO NOT USE
- func: _cast_Half(Tensor self, bool non_blocking=False) -> Tensor
variants: function
# Computes the gradient of current tensor w.r.t. graph leaves.
- func: _backward(Tensor self, Tensor[] inputs, Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False) -> ()
manual_cpp_binding: True

View File

@ -1111,6 +1111,14 @@
"_amp_update_scale_",
"_assert_async",
"_batch_norm_impl_index",
"_cast_Byte",
"_cast_Char",
"_cast_Double",
"_cast_Float",
"_cast_Half",
"_cast_Int",
"_cast_Long",
"_cast_Short",
"_choose_qparams_per_tensor",
"_coalesce",
"_compute_linear_combination",

View File

@ -135,6 +135,84 @@ TEST_F(LazyOpsTest, TestIsSigned) {
});
}
TEST_F(LazyOpsTest, TestCastByte) {
torch::Tensor a =
torch::rand(
{2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
100.0;
torch::Tensor b = torch::_cast_Byte(a);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor lazy_a = CopyToDevice(a, device);
torch::Tensor lazy_b = torch::_cast_Byte(lazy_a);
AllEqual(b, lazy_b);
});
}
TEST_F(LazyOpsTest, TestCastChar) {
torch::Tensor a =
torch::rand(
{2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
100.0;
torch::Tensor b = torch::_cast_Char(a);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor lazy_a = CopyToDevice(a, device);
torch::Tensor lazy_b = torch::_cast_Char(lazy_a);
AllEqual(b, lazy_b);
});
}
TEST_F(LazyOpsTest, TestCastShort) {
torch::Tensor a =
torch::rand(
{2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
100.0;
torch::Tensor b = torch::_cast_Short(a);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor lazy_a = CopyToDevice(a, device);
torch::Tensor lazy_b = torch::_cast_Short(lazy_a);
AllEqual(b, lazy_b);
});
}
TEST_F(LazyOpsTest, TestCastInt) {
torch::Tensor a =
torch::rand(
{2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
100.0;
torch::Tensor b = torch::_cast_Int(a);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor lazy_a = CopyToDevice(a, device);
torch::Tensor lazy_b = torch::_cast_Int(lazy_a);
AllEqual(b, lazy_b);
});
}
TEST_F(LazyOpsTest, TestCastLong) {
torch::Tensor a =
torch::rand(
{2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
100.0;
torch::Tensor b = torch::_cast_Long(a);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor lazy_a = CopyToDevice(a, device);
torch::Tensor lazy_b = torch::_cast_Long(lazy_a);
AllEqual(b, lazy_b);
});
}
TEST_F(LazyOpsTest, TestCastFloat) {
torch::Tensor a =
torch::rand(
{2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
100.0;
torch::Tensor b = torch::_cast_Float(a);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor lazy_a = CopyToDevice(a, device);
torch::Tensor lazy_b = torch::_cast_Float(lazy_a);
AllEqual(b, lazy_b);
});
}
TEST_F(LazyOpsTest, TestRetainType) {
torch::Tensor lazy_a = torch::zeros(
{2, 2}, torch::TensorOptions(torch::kByte).device(torch::kLazy));

View File

@ -141,15 +141,6 @@ ALLOW_LIST = [
("c10d::.*", datetime.date(9999, 1, 1)),
# Previously MPS_only did not support backward
("aten::_fused_rms_norm", datetime.date(2025, 12, 30)),
# These casting ops were deprecated in PyTorch 1
("aten::_cast_Half", datetime.date(9999, 1, 1), None, True),
("aten::_cast_Short", datetime.date(9999, 1, 1), None, True),
("aten::_cast_Long", datetime.date(9999, 1, 1), None, True),
("aten::_cast_Int", datetime.date(9999, 1, 1), None, True),
("aten::_cast_Float", datetime.date(9999, 1, 1), None, True),
("aten::_cast_Double", datetime.date(9999, 1, 1), None, True),
("aten::_cast_Char", datetime.date(9999, 1, 1), None, True),
("aten::_cast_Byte", datetime.date(9999, 1, 1), None, True),
]
ALLOW_LIST_COMPILED = [

View File

@ -1414,6 +1414,14 @@ torch_c_binding_in_graph_functions = dict.fromkeys(
"torch._C.unify_type_list",
"torch._C.vitals_enabled",
"torch._C.wait",
"torch._cast_Byte",
"torch._cast_Char",
"torch._cast_Double",
"torch._cast_Float",
"torch._cast_Half",
"torch._cast_Int",
"torch._cast_Long",
"torch._cast_Short",
"torch._choose_qparams_per_tensor",
"torch._chunk_cat",
"torch._coalesce",

View File

@ -181,6 +181,7 @@ struct RHSTemplate {
static std::string encodeRHS(const Node* n) {
static std::unordered_map<NodeKind, RHSTemplate> simple_map_ops = {
// unary
{aten::_cast_Float, "static_cast<float>(${0})"},
{aten::abs, "fabs(${0})"},
{aten::sigmoid, {"1.f / (1.f + expf(-${0}))", "1. / (1. + exp(-${0}))"}},
{aten::relu, "${0} < 0 ? 0.f : ${0} "},

View File

@ -34,6 +34,8 @@ namespace {
// carefully read the code first, as we rely on these assumptions.
bool isSimpleMap(Node* node) {
static OperatorSet simple_mappable{{
"aten::_cast_Float(Tensor self, bool non_blocking) -> Tensor",
"aten::abs(Tensor self) -> Tensor",
"aten::acos(Tensor self) -> Tensor",
"aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",

View File

@ -340,6 +340,8 @@ struct GuardElimination {
case aten::reciprocal:
case aten::addcmul:
case aten::where:
case aten::_cast_Float:
case aten::_cast_Long:
case aten::__and__:
case aten::__or__:
case aten::__xor__:

View File

@ -23,9 +23,9 @@ static void PrepareDivisionForONNXOnBlock(Block* block) {
subgraph->insertNode(subgraph->createNumToTensor(input))
->output();
longtensor->node()->copyMetadata(input->node());
auto* cast = subgraph->create(at::onnx::Cast, 1);
cast->addInput(longtensor);
cast->i_(attr::to, 1);
auto* nonblocking = subgraph->insertConstant(0);
auto* cast =
subgraph->create(aten::_cast_Float, {longtensor, nonblocking});
cast->copyMetadata(*it);
return subgraph->insertNode(cast)->output();
});

View File

@ -1517,6 +1517,50 @@ class ShapePropagator : public PropertyPropBase {
return {};
}};
static const auto get_cast_scalar_type = [](Node* node) -> at::ScalarType {
switch (node->kind()) {
case aten::_cast_Byte:
return at::kByte;
case aten::_cast_Char:
return at::kChar;
case aten::_cast_Double:
return at::kDouble;
case aten::_cast_Float:
return at::kFloat;
case aten::_cast_Half:
return at::kHalf;
case aten::_cast_Int:
return at::kInt;
case aten::_cast_Long:
return at::kLong;
case aten::_cast_Short:
return at::kShort;
default:
TORCH_INTERNAL_ASSERT(
false,
"unknown node kind in get_cast_scalar_type: ",
node->kind().toQualString());
}
};
static const register_formula_for cast_ops{
{
"aten::_cast_Byte(Tensor self, bool non_blocking) -> Tensor",
"aten::_cast_Char(Tensor self, bool non_blocking) -> Tensor",
"aten::_cast_Double(Tensor self, bool non_blocking) -> Tensor",
"aten::_cast_Float(Tensor self, bool non_blocking) -> Tensor",
"aten::_cast_Half(Tensor self, bool non_blocking) -> Tensor",
"aten::_cast_Int(Tensor self, bool non_blocking) -> Tensor",
"aten::_cast_Long(Tensor self, bool non_blocking) -> Tensor",
"aten::_cast_Short(Tensor self, bool non_blocking) -> Tensor",
},
[](Node* node) -> type_vec_t {
if (auto type =
node->namedInput(attr::self)->type()->cast<TensorType>()) {
return {type->withScalarType(get_cast_scalar_type(node))};
}
return {};
}};
// First, try to match one of the registered formulas to their operator
// sets.
for (auto& entry : shape_formulas) {

View File

@ -11,6 +11,7 @@ const OperatorMap<std::string>& get_tensorexpr_elementwise_set() {
// clang-format off
static const OperatorMap<std::string> tensorexpr_elementwise_set{
{"aten::add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor", "unary"},
{"aten::_cast_Float(Tensor self, bool non_blocking) -> Tensor", "unary"},
{"aten::sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor", "unary"},
{"aten::mul.Scalar(Tensor self, Scalar other) -> Tensor", "unary"},
{"aten::div.Scalar(Tensor self, Scalar other) -> Tensor", "unary"},

View File

@ -1554,6 +1554,22 @@ int nnc_lowerings_lazy_registration() {
[](const ExprHandle& a) { return trunc(a); });
});
RegisterNNCLoweringsFunction aten__cast_Float(
{"aten::_cast_Float(Tensor self, bool non_blocking=False) -> (Tensor)"},
[](const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const std::optional<ScalarType>& outputType,
at::Device device) {
return computeOneOperand(
"aten_cast_float",
inputs,
outputShape,
outputStrides,
outputType,
[](const ExprHandle& a) { return cast<float>(a); });
});
RegisterNNCLoweringsFunction aten_to(
{"aten::to.dtype(Tensor(a) self, int dtype, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Tensor(a))",
"aten::to.dtype_layout(Tensor(a) self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Tensor(a))",

View File

@ -191,7 +191,7 @@ def _try_cast_integer_to_float(g: jit_utils.GraphContext, *args):
def _cast_to_type(g: jit_utils.GraphContext, input, to_type):
if to_type is None:
return input
return g.op("Cast", input, to_i=symbolic_helper.cast_pytorch_to_onnx[to_type])
return getattr(opset9, f"_cast_{to_type}")(g, input, False)
def _comparison_operator(g: jit_utils.GraphContext, input, other, op_name):

View File

@ -15,6 +15,7 @@ import math
import sys
import warnings
from typing import TYPE_CHECKING
from typing_extensions import deprecated
import torch
import torch._C._onnx as _C_onnx
@ -1965,8 +1966,8 @@ def wrap_logical_op_with_cast_to(to_type):
def decorator(fn):
@functools.wraps(fn)
def wrap_with_cast(g, input, other):
to_i = symbolic_helper.cast_pytorch_to_onnx[to_type]
return fn(g, g.op("Cast", input, to_i=to_i), g.op("Cast", other, to_i=to_i))
to_cast_func = globals()[f"_cast_{to_type}"]
return fn(g, to_cast_func(g, input, False), to_cast_func(g, other, False))
return wrap_with_cast
@ -3330,6 +3331,60 @@ def _unique2(g: jit_utils.GraphContext, input, sorted, return_inverse, return_co
symbolic_helper._onnx_opset_unsupported("_unique2", 9, 11, input)
@_onnx_symbolic("aten::_cast_Byte")
@deprecated("Avoid using this function and create a Cast node instead")
def _cast_Byte(g: jit_utils.GraphContext, input, non_blocking):
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.UINT8)
@_onnx_symbolic("aten::_cast_Char")
@deprecated("Avoid using this function and create a Cast node instead")
def _cast_Char(g: jit_utils.GraphContext, input, non_blocking):
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT8)
@_onnx_symbolic("aten::_cast_Short")
@deprecated("Avoid using this function and create a Cast node instead")
def _cast_Short(g: jit_utils.GraphContext, input, non_blocking):
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT16)
@_onnx_symbolic("aten::_cast_Int")
@deprecated("Avoid using this function and create a Cast node instead")
def _cast_Int(g: jit_utils.GraphContext, input, non_blocking):
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT32)
@_onnx_symbolic("aten::_cast_Long")
@deprecated("Avoid using this function and create a Cast node instead")
def _cast_Long(g: jit_utils.GraphContext, input, non_blocking):
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT64)
@_onnx_symbolic("aten::_cast_Half")
@deprecated("Avoid using this function and create a Cast node instead")
def _cast_Half(g: jit_utils.GraphContext, input, non_blocking):
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.FLOAT16)
@_onnx_symbolic("aten::_cast_Float")
@deprecated("Avoid using this function and create a Cast node instead")
def _cast_Float(g: jit_utils.GraphContext, input, non_blocking):
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.FLOAT)
@_onnx_symbolic("aten::_cast_Double")
@deprecated("Avoid using this function and create a Cast node instead")
def _cast_Double(g: jit_utils.GraphContext, input, non_blocking):
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.DOUBLE)
@_onnx_symbolic("aten::_cast_Bool")
@deprecated("Avoid using this function and create a Cast node instead")
def _cast_Bool(g: jit_utils.GraphContext, input, non_blocking):
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.BOOL)
@_onnx_symbolic("aten::empty")
@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v")
def empty(