mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "[BC-Breaking] Remove long-deprecated casting functions from native_functions.yaml (#164641)"
This reverts commit 64108bdbed.
Reverted https://github.com/pytorch/pytorch/pull/164641 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/164641#issuecomment-3386346474))
This commit is contained in:
parent
a7fa1a91e3
commit
3d1fa40ae1
|
|
@ -23,6 +23,14 @@
|
|||
#include <ATen/Functions.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#else
|
||||
#include <ATen/ops/_cast_Byte_native.h>
|
||||
#include <ATen/ops/_cast_Char_native.h>
|
||||
#include <ATen/ops/_cast_Double_native.h>
|
||||
#include <ATen/ops/_cast_Float_native.h>
|
||||
#include <ATen/ops/_cast_Half_native.h>
|
||||
#include <ATen/ops/_cast_Int_native.h>
|
||||
#include <ATen/ops/_cast_Long_native.h>
|
||||
#include <ATen/ops/_cast_Short_native.h>
|
||||
#include <ATen/ops/_dim_arange_native.h>
|
||||
#include <ATen/ops/_efficientzerotensor_native.h>
|
||||
#include <ATen/ops/_empty_affine_quantized.h>
|
||||
|
|
|
|||
|
|
@ -5,6 +5,38 @@
|
|||
# representing ScalarType's. They are now superseded by usage of
|
||||
# `aten::to()`. The ops remain here for backward compatibility purposes.
|
||||
|
||||
# DEPRECATED. DO NOT USE
|
||||
- func: _cast_Byte(Tensor self, bool non_blocking=False) -> Tensor
|
||||
variants: function
|
||||
|
||||
# DEPRECATED. DO NOT USE
|
||||
- func: _cast_Char(Tensor self, bool non_blocking=False) -> Tensor
|
||||
variants: function
|
||||
|
||||
# DEPRECATED. DO NOT USE
|
||||
- func: _cast_Double(Tensor self, bool non_blocking=False) -> Tensor
|
||||
variants: function
|
||||
|
||||
# DEPRECATED. DO NOT USE
|
||||
- func: _cast_Float(Tensor self, bool non_blocking=False) -> Tensor
|
||||
variants: function
|
||||
|
||||
# DEPRECATED. DO NOT USE
|
||||
- func: _cast_Int(Tensor self, bool non_blocking=False) -> Tensor
|
||||
variants: function
|
||||
|
||||
# DEPRECATED. DO NOT USE
|
||||
- func: _cast_Long(Tensor self, bool non_blocking=False) -> Tensor
|
||||
variants: function
|
||||
|
||||
# DEPRECATED. DO NOT USE
|
||||
- func: _cast_Short(Tensor self, bool non_blocking=False) -> Tensor
|
||||
variants: function
|
||||
|
||||
# DEPRECATED. DO NOT USE
|
||||
- func: _cast_Half(Tensor self, bool non_blocking=False) -> Tensor
|
||||
variants: function
|
||||
|
||||
# Computes the gradient of current tensor w.r.t. graph leaves.
|
||||
- func: _backward(Tensor self, Tensor[] inputs, Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False) -> ()
|
||||
manual_cpp_binding: True
|
||||
|
|
|
|||
|
|
@ -1111,6 +1111,14 @@
|
|||
"_amp_update_scale_",
|
||||
"_assert_async",
|
||||
"_batch_norm_impl_index",
|
||||
"_cast_Byte",
|
||||
"_cast_Char",
|
||||
"_cast_Double",
|
||||
"_cast_Float",
|
||||
"_cast_Half",
|
||||
"_cast_Int",
|
||||
"_cast_Long",
|
||||
"_cast_Short",
|
||||
"_choose_qparams_per_tensor",
|
||||
"_coalesce",
|
||||
"_compute_linear_combination",
|
||||
|
|
|
|||
|
|
@ -135,6 +135,84 @@ TEST_F(LazyOpsTest, TestIsSigned) {
|
|||
});
|
||||
}
|
||||
|
||||
TEST_F(LazyOpsTest, TestCastByte) {
|
||||
torch::Tensor a =
|
||||
torch::rand(
|
||||
{2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
|
||||
100.0;
|
||||
torch::Tensor b = torch::_cast_Byte(a);
|
||||
ForEachDevice([&](const torch::Device& device) {
|
||||
torch::Tensor lazy_a = CopyToDevice(a, device);
|
||||
torch::Tensor lazy_b = torch::_cast_Byte(lazy_a);
|
||||
AllEqual(b, lazy_b);
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(LazyOpsTest, TestCastChar) {
|
||||
torch::Tensor a =
|
||||
torch::rand(
|
||||
{2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
|
||||
100.0;
|
||||
torch::Tensor b = torch::_cast_Char(a);
|
||||
ForEachDevice([&](const torch::Device& device) {
|
||||
torch::Tensor lazy_a = CopyToDevice(a, device);
|
||||
torch::Tensor lazy_b = torch::_cast_Char(lazy_a);
|
||||
AllEqual(b, lazy_b);
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(LazyOpsTest, TestCastShort) {
|
||||
torch::Tensor a =
|
||||
torch::rand(
|
||||
{2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
|
||||
100.0;
|
||||
torch::Tensor b = torch::_cast_Short(a);
|
||||
ForEachDevice([&](const torch::Device& device) {
|
||||
torch::Tensor lazy_a = CopyToDevice(a, device);
|
||||
torch::Tensor lazy_b = torch::_cast_Short(lazy_a);
|
||||
AllEqual(b, lazy_b);
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(LazyOpsTest, TestCastInt) {
|
||||
torch::Tensor a =
|
||||
torch::rand(
|
||||
{2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
|
||||
100.0;
|
||||
torch::Tensor b = torch::_cast_Int(a);
|
||||
ForEachDevice([&](const torch::Device& device) {
|
||||
torch::Tensor lazy_a = CopyToDevice(a, device);
|
||||
torch::Tensor lazy_b = torch::_cast_Int(lazy_a);
|
||||
AllEqual(b, lazy_b);
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(LazyOpsTest, TestCastLong) {
|
||||
torch::Tensor a =
|
||||
torch::rand(
|
||||
{2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
|
||||
100.0;
|
||||
torch::Tensor b = torch::_cast_Long(a);
|
||||
ForEachDevice([&](const torch::Device& device) {
|
||||
torch::Tensor lazy_a = CopyToDevice(a, device);
|
||||
torch::Tensor lazy_b = torch::_cast_Long(lazy_a);
|
||||
AllEqual(b, lazy_b);
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(LazyOpsTest, TestCastFloat) {
|
||||
torch::Tensor a =
|
||||
torch::rand(
|
||||
{2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
|
||||
100.0;
|
||||
torch::Tensor b = torch::_cast_Float(a);
|
||||
ForEachDevice([&](const torch::Device& device) {
|
||||
torch::Tensor lazy_a = CopyToDevice(a, device);
|
||||
torch::Tensor lazy_b = torch::_cast_Float(lazy_a);
|
||||
AllEqual(b, lazy_b);
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(LazyOpsTest, TestRetainType) {
|
||||
torch::Tensor lazy_a = torch::zeros(
|
||||
{2, 2}, torch::TensorOptions(torch::kByte).device(torch::kLazy));
|
||||
|
|
|
|||
|
|
@ -141,15 +141,6 @@ ALLOW_LIST = [
|
|||
("c10d::.*", datetime.date(9999, 1, 1)),
|
||||
# Previously MPS_only did not support backward
|
||||
("aten::_fused_rms_norm", datetime.date(2025, 12, 30)),
|
||||
# These casting ops were deprecated in PyTorch 1
|
||||
("aten::_cast_Half", datetime.date(9999, 1, 1), None, True),
|
||||
("aten::_cast_Short", datetime.date(9999, 1, 1), None, True),
|
||||
("aten::_cast_Long", datetime.date(9999, 1, 1), None, True),
|
||||
("aten::_cast_Int", datetime.date(9999, 1, 1), None, True),
|
||||
("aten::_cast_Float", datetime.date(9999, 1, 1), None, True),
|
||||
("aten::_cast_Double", datetime.date(9999, 1, 1), None, True),
|
||||
("aten::_cast_Char", datetime.date(9999, 1, 1), None, True),
|
||||
("aten::_cast_Byte", datetime.date(9999, 1, 1), None, True),
|
||||
]
|
||||
|
||||
ALLOW_LIST_COMPILED = [
|
||||
|
|
|
|||
|
|
@ -1414,6 +1414,14 @@ torch_c_binding_in_graph_functions = dict.fromkeys(
|
|||
"torch._C.unify_type_list",
|
||||
"torch._C.vitals_enabled",
|
||||
"torch._C.wait",
|
||||
"torch._cast_Byte",
|
||||
"torch._cast_Char",
|
||||
"torch._cast_Double",
|
||||
"torch._cast_Float",
|
||||
"torch._cast_Half",
|
||||
"torch._cast_Int",
|
||||
"torch._cast_Long",
|
||||
"torch._cast_Short",
|
||||
"torch._choose_qparams_per_tensor",
|
||||
"torch._chunk_cat",
|
||||
"torch._coalesce",
|
||||
|
|
|
|||
|
|
@ -181,6 +181,7 @@ struct RHSTemplate {
|
|||
static std::string encodeRHS(const Node* n) {
|
||||
static std::unordered_map<NodeKind, RHSTemplate> simple_map_ops = {
|
||||
// unary
|
||||
{aten::_cast_Float, "static_cast<float>(${0})"},
|
||||
{aten::abs, "fabs(${0})"},
|
||||
{aten::sigmoid, {"1.f / (1.f + expf(-${0}))", "1. / (1. + exp(-${0}))"}},
|
||||
{aten::relu, "${0} < 0 ? 0.f : ${0} "},
|
||||
|
|
|
|||
|
|
@ -34,6 +34,8 @@ namespace {
|
|||
// carefully read the code first, as we rely on these assumptions.
|
||||
bool isSimpleMap(Node* node) {
|
||||
static OperatorSet simple_mappable{{
|
||||
"aten::_cast_Float(Tensor self, bool non_blocking) -> Tensor",
|
||||
|
||||
"aten::abs(Tensor self) -> Tensor",
|
||||
"aten::acos(Tensor self) -> Tensor",
|
||||
"aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
|
||||
|
|
|
|||
|
|
@ -340,6 +340,8 @@ struct GuardElimination {
|
|||
case aten::reciprocal:
|
||||
case aten::addcmul:
|
||||
case aten::where:
|
||||
case aten::_cast_Float:
|
||||
case aten::_cast_Long:
|
||||
case aten::__and__:
|
||||
case aten::__or__:
|
||||
case aten::__xor__:
|
||||
|
|
|
|||
|
|
@ -23,9 +23,9 @@ static void PrepareDivisionForONNXOnBlock(Block* block) {
|
|||
subgraph->insertNode(subgraph->createNumToTensor(input))
|
||||
->output();
|
||||
longtensor->node()->copyMetadata(input->node());
|
||||
auto* cast = subgraph->create(at::onnx::Cast, 1);
|
||||
cast->addInput(longtensor);
|
||||
cast->i_(attr::to, 1);
|
||||
auto* nonblocking = subgraph->insertConstant(0);
|
||||
auto* cast =
|
||||
subgraph->create(aten::_cast_Float, {longtensor, nonblocking});
|
||||
cast->copyMetadata(*it);
|
||||
return subgraph->insertNode(cast)->output();
|
||||
});
|
||||
|
|
|
|||
|
|
@ -1517,6 +1517,50 @@ class ShapePropagator : public PropertyPropBase {
|
|||
return {};
|
||||
}};
|
||||
|
||||
static const auto get_cast_scalar_type = [](Node* node) -> at::ScalarType {
|
||||
switch (node->kind()) {
|
||||
case aten::_cast_Byte:
|
||||
return at::kByte;
|
||||
case aten::_cast_Char:
|
||||
return at::kChar;
|
||||
case aten::_cast_Double:
|
||||
return at::kDouble;
|
||||
case aten::_cast_Float:
|
||||
return at::kFloat;
|
||||
case aten::_cast_Half:
|
||||
return at::kHalf;
|
||||
case aten::_cast_Int:
|
||||
return at::kInt;
|
||||
case aten::_cast_Long:
|
||||
return at::kLong;
|
||||
case aten::_cast_Short:
|
||||
return at::kShort;
|
||||
default:
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
false,
|
||||
"unknown node kind in get_cast_scalar_type: ",
|
||||
node->kind().toQualString());
|
||||
}
|
||||
};
|
||||
static const register_formula_for cast_ops{
|
||||
{
|
||||
"aten::_cast_Byte(Tensor self, bool non_blocking) -> Tensor",
|
||||
"aten::_cast_Char(Tensor self, bool non_blocking) -> Tensor",
|
||||
"aten::_cast_Double(Tensor self, bool non_blocking) -> Tensor",
|
||||
"aten::_cast_Float(Tensor self, bool non_blocking) -> Tensor",
|
||||
"aten::_cast_Half(Tensor self, bool non_blocking) -> Tensor",
|
||||
"aten::_cast_Int(Tensor self, bool non_blocking) -> Tensor",
|
||||
"aten::_cast_Long(Tensor self, bool non_blocking) -> Tensor",
|
||||
"aten::_cast_Short(Tensor self, bool non_blocking) -> Tensor",
|
||||
},
|
||||
[](Node* node) -> type_vec_t {
|
||||
if (auto type =
|
||||
node->namedInput(attr::self)->type()->cast<TensorType>()) {
|
||||
return {type->withScalarType(get_cast_scalar_type(node))};
|
||||
}
|
||||
return {};
|
||||
}};
|
||||
|
||||
// First, try to match one of the registered formulas to their operator
|
||||
// sets.
|
||||
for (auto& entry : shape_formulas) {
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ const OperatorMap<std::string>& get_tensorexpr_elementwise_set() {
|
|||
// clang-format off
|
||||
static const OperatorMap<std::string> tensorexpr_elementwise_set{
|
||||
{"aten::add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor", "unary"},
|
||||
{"aten::_cast_Float(Tensor self, bool non_blocking) -> Tensor", "unary"},
|
||||
{"aten::sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor", "unary"},
|
||||
{"aten::mul.Scalar(Tensor self, Scalar other) -> Tensor", "unary"},
|
||||
{"aten::div.Scalar(Tensor self, Scalar other) -> Tensor", "unary"},
|
||||
|
|
|
|||
|
|
@ -1554,6 +1554,22 @@ int nnc_lowerings_lazy_registration() {
|
|||
[](const ExprHandle& a) { return trunc(a); });
|
||||
});
|
||||
|
||||
RegisterNNCLoweringsFunction aten__cast_Float(
|
||||
{"aten::_cast_Float(Tensor self, bool non_blocking=False) -> (Tensor)"},
|
||||
[](const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const std::optional<ScalarType>& outputType,
|
||||
at::Device device) {
|
||||
return computeOneOperand(
|
||||
"aten_cast_float",
|
||||
inputs,
|
||||
outputShape,
|
||||
outputStrides,
|
||||
outputType,
|
||||
[](const ExprHandle& a) { return cast<float>(a); });
|
||||
});
|
||||
|
||||
RegisterNNCLoweringsFunction aten_to(
|
||||
{"aten::to.dtype(Tensor(a) self, int dtype, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Tensor(a))",
|
||||
"aten::to.dtype_layout(Tensor(a) self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Tensor(a))",
|
||||
|
|
|
|||
|
|
@ -191,7 +191,7 @@ def _try_cast_integer_to_float(g: jit_utils.GraphContext, *args):
|
|||
def _cast_to_type(g: jit_utils.GraphContext, input, to_type):
|
||||
if to_type is None:
|
||||
return input
|
||||
return g.op("Cast", input, to_i=symbolic_helper.cast_pytorch_to_onnx[to_type])
|
||||
return getattr(opset9, f"_cast_{to_type}")(g, input, False)
|
||||
|
||||
|
||||
def _comparison_operator(g: jit_utils.GraphContext, input, other, op_name):
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ import math
|
|||
import sys
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING
|
||||
from typing_extensions import deprecated
|
||||
|
||||
import torch
|
||||
import torch._C._onnx as _C_onnx
|
||||
|
|
@ -1965,8 +1966,8 @@ def wrap_logical_op_with_cast_to(to_type):
|
|||
def decorator(fn):
|
||||
@functools.wraps(fn)
|
||||
def wrap_with_cast(g, input, other):
|
||||
to_i = symbolic_helper.cast_pytorch_to_onnx[to_type]
|
||||
return fn(g, g.op("Cast", input, to_i=to_i), g.op("Cast", other, to_i=to_i))
|
||||
to_cast_func = globals()[f"_cast_{to_type}"]
|
||||
return fn(g, to_cast_func(g, input, False), to_cast_func(g, other, False))
|
||||
|
||||
return wrap_with_cast
|
||||
|
||||
|
|
@ -3330,6 +3331,60 @@ def _unique2(g: jit_utils.GraphContext, input, sorted, return_inverse, return_co
|
|||
symbolic_helper._onnx_opset_unsupported("_unique2", 9, 11, input)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::_cast_Byte")
|
||||
@deprecated("Avoid using this function and create a Cast node instead")
|
||||
def _cast_Byte(g: jit_utils.GraphContext, input, non_blocking):
|
||||
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.UINT8)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::_cast_Char")
|
||||
@deprecated("Avoid using this function and create a Cast node instead")
|
||||
def _cast_Char(g: jit_utils.GraphContext, input, non_blocking):
|
||||
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT8)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::_cast_Short")
|
||||
@deprecated("Avoid using this function and create a Cast node instead")
|
||||
def _cast_Short(g: jit_utils.GraphContext, input, non_blocking):
|
||||
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT16)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::_cast_Int")
|
||||
@deprecated("Avoid using this function and create a Cast node instead")
|
||||
def _cast_Int(g: jit_utils.GraphContext, input, non_blocking):
|
||||
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT32)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::_cast_Long")
|
||||
@deprecated("Avoid using this function and create a Cast node instead")
|
||||
def _cast_Long(g: jit_utils.GraphContext, input, non_blocking):
|
||||
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT64)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::_cast_Half")
|
||||
@deprecated("Avoid using this function and create a Cast node instead")
|
||||
def _cast_Half(g: jit_utils.GraphContext, input, non_blocking):
|
||||
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.FLOAT16)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::_cast_Float")
|
||||
@deprecated("Avoid using this function and create a Cast node instead")
|
||||
def _cast_Float(g: jit_utils.GraphContext, input, non_blocking):
|
||||
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.FLOAT)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::_cast_Double")
|
||||
@deprecated("Avoid using this function and create a Cast node instead")
|
||||
def _cast_Double(g: jit_utils.GraphContext, input, non_blocking):
|
||||
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.DOUBLE)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::_cast_Bool")
|
||||
@deprecated("Avoid using this function and create a Cast node instead")
|
||||
def _cast_Bool(g: jit_utils.GraphContext, input, non_blocking):
|
||||
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.BOOL)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::empty")
|
||||
@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v")
|
||||
def empty(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user