mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "[BC-Breaking] Remove long-deprecated casting functions from native_functions.yaml (#164641)"
This reverts commit 64108bdbed.
Reverted https://github.com/pytorch/pytorch/pull/164641 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/164641#issuecomment-3386346474))
This commit is contained in:
parent
a7fa1a91e3
commit
3d1fa40ae1
|
|
@ -23,6 +23,14 @@
|
||||||
#include <ATen/Functions.h>
|
#include <ATen/Functions.h>
|
||||||
#include <ATen/NativeFunctions.h>
|
#include <ATen/NativeFunctions.h>
|
||||||
#else
|
#else
|
||||||
|
#include <ATen/ops/_cast_Byte_native.h>
|
||||||
|
#include <ATen/ops/_cast_Char_native.h>
|
||||||
|
#include <ATen/ops/_cast_Double_native.h>
|
||||||
|
#include <ATen/ops/_cast_Float_native.h>
|
||||||
|
#include <ATen/ops/_cast_Half_native.h>
|
||||||
|
#include <ATen/ops/_cast_Int_native.h>
|
||||||
|
#include <ATen/ops/_cast_Long_native.h>
|
||||||
|
#include <ATen/ops/_cast_Short_native.h>
|
||||||
#include <ATen/ops/_dim_arange_native.h>
|
#include <ATen/ops/_dim_arange_native.h>
|
||||||
#include <ATen/ops/_efficientzerotensor_native.h>
|
#include <ATen/ops/_efficientzerotensor_native.h>
|
||||||
#include <ATen/ops/_empty_affine_quantized.h>
|
#include <ATen/ops/_empty_affine_quantized.h>
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,38 @@
|
||||||
# representing ScalarType's. They are now superseded by usage of
|
# representing ScalarType's. They are now superseded by usage of
|
||||||
# `aten::to()`. The ops remain here for backward compatibility purposes.
|
# `aten::to()`. The ops remain here for backward compatibility purposes.
|
||||||
|
|
||||||
|
# DEPRECATED. DO NOT USE
|
||||||
|
- func: _cast_Byte(Tensor self, bool non_blocking=False) -> Tensor
|
||||||
|
variants: function
|
||||||
|
|
||||||
|
# DEPRECATED. DO NOT USE
|
||||||
|
- func: _cast_Char(Tensor self, bool non_blocking=False) -> Tensor
|
||||||
|
variants: function
|
||||||
|
|
||||||
|
# DEPRECATED. DO NOT USE
|
||||||
|
- func: _cast_Double(Tensor self, bool non_blocking=False) -> Tensor
|
||||||
|
variants: function
|
||||||
|
|
||||||
|
# DEPRECATED. DO NOT USE
|
||||||
|
- func: _cast_Float(Tensor self, bool non_blocking=False) -> Tensor
|
||||||
|
variants: function
|
||||||
|
|
||||||
|
# DEPRECATED. DO NOT USE
|
||||||
|
- func: _cast_Int(Tensor self, bool non_blocking=False) -> Tensor
|
||||||
|
variants: function
|
||||||
|
|
||||||
|
# DEPRECATED. DO NOT USE
|
||||||
|
- func: _cast_Long(Tensor self, bool non_blocking=False) -> Tensor
|
||||||
|
variants: function
|
||||||
|
|
||||||
|
# DEPRECATED. DO NOT USE
|
||||||
|
- func: _cast_Short(Tensor self, bool non_blocking=False) -> Tensor
|
||||||
|
variants: function
|
||||||
|
|
||||||
|
# DEPRECATED. DO NOT USE
|
||||||
|
- func: _cast_Half(Tensor self, bool non_blocking=False) -> Tensor
|
||||||
|
variants: function
|
||||||
|
|
||||||
# Computes the gradient of current tensor w.r.t. graph leaves.
|
# Computes the gradient of current tensor w.r.t. graph leaves.
|
||||||
- func: _backward(Tensor self, Tensor[] inputs, Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False) -> ()
|
- func: _backward(Tensor self, Tensor[] inputs, Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False) -> ()
|
||||||
manual_cpp_binding: True
|
manual_cpp_binding: True
|
||||||
|
|
|
||||||
|
|
@ -1111,6 +1111,14 @@
|
||||||
"_amp_update_scale_",
|
"_amp_update_scale_",
|
||||||
"_assert_async",
|
"_assert_async",
|
||||||
"_batch_norm_impl_index",
|
"_batch_norm_impl_index",
|
||||||
|
"_cast_Byte",
|
||||||
|
"_cast_Char",
|
||||||
|
"_cast_Double",
|
||||||
|
"_cast_Float",
|
||||||
|
"_cast_Half",
|
||||||
|
"_cast_Int",
|
||||||
|
"_cast_Long",
|
||||||
|
"_cast_Short",
|
||||||
"_choose_qparams_per_tensor",
|
"_choose_qparams_per_tensor",
|
||||||
"_coalesce",
|
"_coalesce",
|
||||||
"_compute_linear_combination",
|
"_compute_linear_combination",
|
||||||
|
|
|
||||||
|
|
@ -135,6 +135,84 @@ TEST_F(LazyOpsTest, TestIsSigned) {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(LazyOpsTest, TestCastByte) {
|
||||||
|
torch::Tensor a =
|
||||||
|
torch::rand(
|
||||||
|
{2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
|
||||||
|
100.0;
|
||||||
|
torch::Tensor b = torch::_cast_Byte(a);
|
||||||
|
ForEachDevice([&](const torch::Device& device) {
|
||||||
|
torch::Tensor lazy_a = CopyToDevice(a, device);
|
||||||
|
torch::Tensor lazy_b = torch::_cast_Byte(lazy_a);
|
||||||
|
AllEqual(b, lazy_b);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(LazyOpsTest, TestCastChar) {
|
||||||
|
torch::Tensor a =
|
||||||
|
torch::rand(
|
||||||
|
{2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
|
||||||
|
100.0;
|
||||||
|
torch::Tensor b = torch::_cast_Char(a);
|
||||||
|
ForEachDevice([&](const torch::Device& device) {
|
||||||
|
torch::Tensor lazy_a = CopyToDevice(a, device);
|
||||||
|
torch::Tensor lazy_b = torch::_cast_Char(lazy_a);
|
||||||
|
AllEqual(b, lazy_b);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(LazyOpsTest, TestCastShort) {
|
||||||
|
torch::Tensor a =
|
||||||
|
torch::rand(
|
||||||
|
{2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
|
||||||
|
100.0;
|
||||||
|
torch::Tensor b = torch::_cast_Short(a);
|
||||||
|
ForEachDevice([&](const torch::Device& device) {
|
||||||
|
torch::Tensor lazy_a = CopyToDevice(a, device);
|
||||||
|
torch::Tensor lazy_b = torch::_cast_Short(lazy_a);
|
||||||
|
AllEqual(b, lazy_b);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(LazyOpsTest, TestCastInt) {
|
||||||
|
torch::Tensor a =
|
||||||
|
torch::rand(
|
||||||
|
{2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
|
||||||
|
100.0;
|
||||||
|
torch::Tensor b = torch::_cast_Int(a);
|
||||||
|
ForEachDevice([&](const torch::Device& device) {
|
||||||
|
torch::Tensor lazy_a = CopyToDevice(a, device);
|
||||||
|
torch::Tensor lazy_b = torch::_cast_Int(lazy_a);
|
||||||
|
AllEqual(b, lazy_b);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(LazyOpsTest, TestCastLong) {
|
||||||
|
torch::Tensor a =
|
||||||
|
torch::rand(
|
||||||
|
{2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
|
||||||
|
100.0;
|
||||||
|
torch::Tensor b = torch::_cast_Long(a);
|
||||||
|
ForEachDevice([&](const torch::Device& device) {
|
||||||
|
torch::Tensor lazy_a = CopyToDevice(a, device);
|
||||||
|
torch::Tensor lazy_b = torch::_cast_Long(lazy_a);
|
||||||
|
AllEqual(b, lazy_b);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(LazyOpsTest, TestCastFloat) {
|
||||||
|
torch::Tensor a =
|
||||||
|
torch::rand(
|
||||||
|
{2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
|
||||||
|
100.0;
|
||||||
|
torch::Tensor b = torch::_cast_Float(a);
|
||||||
|
ForEachDevice([&](const torch::Device& device) {
|
||||||
|
torch::Tensor lazy_a = CopyToDevice(a, device);
|
||||||
|
torch::Tensor lazy_b = torch::_cast_Float(lazy_a);
|
||||||
|
AllEqual(b, lazy_b);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(LazyOpsTest, TestRetainType) {
|
TEST_F(LazyOpsTest, TestRetainType) {
|
||||||
torch::Tensor lazy_a = torch::zeros(
|
torch::Tensor lazy_a = torch::zeros(
|
||||||
{2, 2}, torch::TensorOptions(torch::kByte).device(torch::kLazy));
|
{2, 2}, torch::TensorOptions(torch::kByte).device(torch::kLazy));
|
||||||
|
|
|
||||||
|
|
@ -141,15 +141,6 @@ ALLOW_LIST = [
|
||||||
("c10d::.*", datetime.date(9999, 1, 1)),
|
("c10d::.*", datetime.date(9999, 1, 1)),
|
||||||
# Previously MPS_only did not support backward
|
# Previously MPS_only did not support backward
|
||||||
("aten::_fused_rms_norm", datetime.date(2025, 12, 30)),
|
("aten::_fused_rms_norm", datetime.date(2025, 12, 30)),
|
||||||
# These casting ops were deprecated in PyTorch 1
|
|
||||||
("aten::_cast_Half", datetime.date(9999, 1, 1), None, True),
|
|
||||||
("aten::_cast_Short", datetime.date(9999, 1, 1), None, True),
|
|
||||||
("aten::_cast_Long", datetime.date(9999, 1, 1), None, True),
|
|
||||||
("aten::_cast_Int", datetime.date(9999, 1, 1), None, True),
|
|
||||||
("aten::_cast_Float", datetime.date(9999, 1, 1), None, True),
|
|
||||||
("aten::_cast_Double", datetime.date(9999, 1, 1), None, True),
|
|
||||||
("aten::_cast_Char", datetime.date(9999, 1, 1), None, True),
|
|
||||||
("aten::_cast_Byte", datetime.date(9999, 1, 1), None, True),
|
|
||||||
]
|
]
|
||||||
|
|
||||||
ALLOW_LIST_COMPILED = [
|
ALLOW_LIST_COMPILED = [
|
||||||
|
|
|
||||||
|
|
@ -1414,6 +1414,14 @@ torch_c_binding_in_graph_functions = dict.fromkeys(
|
||||||
"torch._C.unify_type_list",
|
"torch._C.unify_type_list",
|
||||||
"torch._C.vitals_enabled",
|
"torch._C.vitals_enabled",
|
||||||
"torch._C.wait",
|
"torch._C.wait",
|
||||||
|
"torch._cast_Byte",
|
||||||
|
"torch._cast_Char",
|
||||||
|
"torch._cast_Double",
|
||||||
|
"torch._cast_Float",
|
||||||
|
"torch._cast_Half",
|
||||||
|
"torch._cast_Int",
|
||||||
|
"torch._cast_Long",
|
||||||
|
"torch._cast_Short",
|
||||||
"torch._choose_qparams_per_tensor",
|
"torch._choose_qparams_per_tensor",
|
||||||
"torch._chunk_cat",
|
"torch._chunk_cat",
|
||||||
"torch._coalesce",
|
"torch._coalesce",
|
||||||
|
|
|
||||||
|
|
@ -181,6 +181,7 @@ struct RHSTemplate {
|
||||||
static std::string encodeRHS(const Node* n) {
|
static std::string encodeRHS(const Node* n) {
|
||||||
static std::unordered_map<NodeKind, RHSTemplate> simple_map_ops = {
|
static std::unordered_map<NodeKind, RHSTemplate> simple_map_ops = {
|
||||||
// unary
|
// unary
|
||||||
|
{aten::_cast_Float, "static_cast<float>(${0})"},
|
||||||
{aten::abs, "fabs(${0})"},
|
{aten::abs, "fabs(${0})"},
|
||||||
{aten::sigmoid, {"1.f / (1.f + expf(-${0}))", "1. / (1. + exp(-${0}))"}},
|
{aten::sigmoid, {"1.f / (1.f + expf(-${0}))", "1. / (1. + exp(-${0}))"}},
|
||||||
{aten::relu, "${0} < 0 ? 0.f : ${0} "},
|
{aten::relu, "${0} < 0 ? 0.f : ${0} "},
|
||||||
|
|
|
||||||
|
|
@ -34,6 +34,8 @@ namespace {
|
||||||
// carefully read the code first, as we rely on these assumptions.
|
// carefully read the code first, as we rely on these assumptions.
|
||||||
bool isSimpleMap(Node* node) {
|
bool isSimpleMap(Node* node) {
|
||||||
static OperatorSet simple_mappable{{
|
static OperatorSet simple_mappable{{
|
||||||
|
"aten::_cast_Float(Tensor self, bool non_blocking) -> Tensor",
|
||||||
|
|
||||||
"aten::abs(Tensor self) -> Tensor",
|
"aten::abs(Tensor self) -> Tensor",
|
||||||
"aten::acos(Tensor self) -> Tensor",
|
"aten::acos(Tensor self) -> Tensor",
|
||||||
"aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
|
"aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
|
||||||
|
|
|
||||||
|
|
@ -340,6 +340,8 @@ struct GuardElimination {
|
||||||
case aten::reciprocal:
|
case aten::reciprocal:
|
||||||
case aten::addcmul:
|
case aten::addcmul:
|
||||||
case aten::where:
|
case aten::where:
|
||||||
|
case aten::_cast_Float:
|
||||||
|
case aten::_cast_Long:
|
||||||
case aten::__and__:
|
case aten::__and__:
|
||||||
case aten::__or__:
|
case aten::__or__:
|
||||||
case aten::__xor__:
|
case aten::__xor__:
|
||||||
|
|
|
||||||
|
|
@ -23,9 +23,9 @@ static void PrepareDivisionForONNXOnBlock(Block* block) {
|
||||||
subgraph->insertNode(subgraph->createNumToTensor(input))
|
subgraph->insertNode(subgraph->createNumToTensor(input))
|
||||||
->output();
|
->output();
|
||||||
longtensor->node()->copyMetadata(input->node());
|
longtensor->node()->copyMetadata(input->node());
|
||||||
auto* cast = subgraph->create(at::onnx::Cast, 1);
|
auto* nonblocking = subgraph->insertConstant(0);
|
||||||
cast->addInput(longtensor);
|
auto* cast =
|
||||||
cast->i_(attr::to, 1);
|
subgraph->create(aten::_cast_Float, {longtensor, nonblocking});
|
||||||
cast->copyMetadata(*it);
|
cast->copyMetadata(*it);
|
||||||
return subgraph->insertNode(cast)->output();
|
return subgraph->insertNode(cast)->output();
|
||||||
});
|
});
|
||||||
|
|
|
||||||
|
|
@ -1517,6 +1517,50 @@ class ShapePropagator : public PropertyPropBase {
|
||||||
return {};
|
return {};
|
||||||
}};
|
}};
|
||||||
|
|
||||||
|
static const auto get_cast_scalar_type = [](Node* node) -> at::ScalarType {
|
||||||
|
switch (node->kind()) {
|
||||||
|
case aten::_cast_Byte:
|
||||||
|
return at::kByte;
|
||||||
|
case aten::_cast_Char:
|
||||||
|
return at::kChar;
|
||||||
|
case aten::_cast_Double:
|
||||||
|
return at::kDouble;
|
||||||
|
case aten::_cast_Float:
|
||||||
|
return at::kFloat;
|
||||||
|
case aten::_cast_Half:
|
||||||
|
return at::kHalf;
|
||||||
|
case aten::_cast_Int:
|
||||||
|
return at::kInt;
|
||||||
|
case aten::_cast_Long:
|
||||||
|
return at::kLong;
|
||||||
|
case aten::_cast_Short:
|
||||||
|
return at::kShort;
|
||||||
|
default:
|
||||||
|
TORCH_INTERNAL_ASSERT(
|
||||||
|
false,
|
||||||
|
"unknown node kind in get_cast_scalar_type: ",
|
||||||
|
node->kind().toQualString());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
static const register_formula_for cast_ops{
|
||||||
|
{
|
||||||
|
"aten::_cast_Byte(Tensor self, bool non_blocking) -> Tensor",
|
||||||
|
"aten::_cast_Char(Tensor self, bool non_blocking) -> Tensor",
|
||||||
|
"aten::_cast_Double(Tensor self, bool non_blocking) -> Tensor",
|
||||||
|
"aten::_cast_Float(Tensor self, bool non_blocking) -> Tensor",
|
||||||
|
"aten::_cast_Half(Tensor self, bool non_blocking) -> Tensor",
|
||||||
|
"aten::_cast_Int(Tensor self, bool non_blocking) -> Tensor",
|
||||||
|
"aten::_cast_Long(Tensor self, bool non_blocking) -> Tensor",
|
||||||
|
"aten::_cast_Short(Tensor self, bool non_blocking) -> Tensor",
|
||||||
|
},
|
||||||
|
[](Node* node) -> type_vec_t {
|
||||||
|
if (auto type =
|
||||||
|
node->namedInput(attr::self)->type()->cast<TensorType>()) {
|
||||||
|
return {type->withScalarType(get_cast_scalar_type(node))};
|
||||||
|
}
|
||||||
|
return {};
|
||||||
|
}};
|
||||||
|
|
||||||
// First, try to match one of the registered formulas to their operator
|
// First, try to match one of the registered formulas to their operator
|
||||||
// sets.
|
// sets.
|
||||||
for (auto& entry : shape_formulas) {
|
for (auto& entry : shape_formulas) {
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,7 @@ const OperatorMap<std::string>& get_tensorexpr_elementwise_set() {
|
||||||
// clang-format off
|
// clang-format off
|
||||||
static const OperatorMap<std::string> tensorexpr_elementwise_set{
|
static const OperatorMap<std::string> tensorexpr_elementwise_set{
|
||||||
{"aten::add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor", "unary"},
|
{"aten::add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor", "unary"},
|
||||||
|
{"aten::_cast_Float(Tensor self, bool non_blocking) -> Tensor", "unary"},
|
||||||
{"aten::sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor", "unary"},
|
{"aten::sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor", "unary"},
|
||||||
{"aten::mul.Scalar(Tensor self, Scalar other) -> Tensor", "unary"},
|
{"aten::mul.Scalar(Tensor self, Scalar other) -> Tensor", "unary"},
|
||||||
{"aten::div.Scalar(Tensor self, Scalar other) -> Tensor", "unary"},
|
{"aten::div.Scalar(Tensor self, Scalar other) -> Tensor", "unary"},
|
||||||
|
|
|
||||||
|
|
@ -1554,6 +1554,22 @@ int nnc_lowerings_lazy_registration() {
|
||||||
[](const ExprHandle& a) { return trunc(a); });
|
[](const ExprHandle& a) { return trunc(a); });
|
||||||
});
|
});
|
||||||
|
|
||||||
|
RegisterNNCLoweringsFunction aten__cast_Float(
|
||||||
|
{"aten::_cast_Float(Tensor self, bool non_blocking=False) -> (Tensor)"},
|
||||||
|
[](const std::vector<ArgValue>& inputs,
|
||||||
|
const std::vector<ExprHandle>& outputShape,
|
||||||
|
const std::vector<ExprHandle>& outputStrides,
|
||||||
|
const std::optional<ScalarType>& outputType,
|
||||||
|
at::Device device) {
|
||||||
|
return computeOneOperand(
|
||||||
|
"aten_cast_float",
|
||||||
|
inputs,
|
||||||
|
outputShape,
|
||||||
|
outputStrides,
|
||||||
|
outputType,
|
||||||
|
[](const ExprHandle& a) { return cast<float>(a); });
|
||||||
|
});
|
||||||
|
|
||||||
RegisterNNCLoweringsFunction aten_to(
|
RegisterNNCLoweringsFunction aten_to(
|
||||||
{"aten::to.dtype(Tensor(a) self, int dtype, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Tensor(a))",
|
{"aten::to.dtype(Tensor(a) self, int dtype, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Tensor(a))",
|
||||||
"aten::to.dtype_layout(Tensor(a) self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Tensor(a))",
|
"aten::to.dtype_layout(Tensor(a) self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Tensor(a))",
|
||||||
|
|
|
||||||
|
|
@ -191,7 +191,7 @@ def _try_cast_integer_to_float(g: jit_utils.GraphContext, *args):
|
||||||
def _cast_to_type(g: jit_utils.GraphContext, input, to_type):
|
def _cast_to_type(g: jit_utils.GraphContext, input, to_type):
|
||||||
if to_type is None:
|
if to_type is None:
|
||||||
return input
|
return input
|
||||||
return g.op("Cast", input, to_i=symbolic_helper.cast_pytorch_to_onnx[to_type])
|
return getattr(opset9, f"_cast_{to_type}")(g, input, False)
|
||||||
|
|
||||||
|
|
||||||
def _comparison_operator(g: jit_utils.GraphContext, input, other, op_name):
|
def _comparison_operator(g: jit_utils.GraphContext, input, other, op_name):
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,7 @@ import math
|
||||||
import sys
|
import sys
|
||||||
import warnings
|
import warnings
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
from typing_extensions import deprecated
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch._C._onnx as _C_onnx
|
import torch._C._onnx as _C_onnx
|
||||||
|
|
@ -1965,8 +1966,8 @@ def wrap_logical_op_with_cast_to(to_type):
|
||||||
def decorator(fn):
|
def decorator(fn):
|
||||||
@functools.wraps(fn)
|
@functools.wraps(fn)
|
||||||
def wrap_with_cast(g, input, other):
|
def wrap_with_cast(g, input, other):
|
||||||
to_i = symbolic_helper.cast_pytorch_to_onnx[to_type]
|
to_cast_func = globals()[f"_cast_{to_type}"]
|
||||||
return fn(g, g.op("Cast", input, to_i=to_i), g.op("Cast", other, to_i=to_i))
|
return fn(g, to_cast_func(g, input, False), to_cast_func(g, other, False))
|
||||||
|
|
||||||
return wrap_with_cast
|
return wrap_with_cast
|
||||||
|
|
||||||
|
|
@ -3330,6 +3331,60 @@ def _unique2(g: jit_utils.GraphContext, input, sorted, return_inverse, return_co
|
||||||
symbolic_helper._onnx_opset_unsupported("_unique2", 9, 11, input)
|
symbolic_helper._onnx_opset_unsupported("_unique2", 9, 11, input)
|
||||||
|
|
||||||
|
|
||||||
|
@_onnx_symbolic("aten::_cast_Byte")
|
||||||
|
@deprecated("Avoid using this function and create a Cast node instead")
|
||||||
|
def _cast_Byte(g: jit_utils.GraphContext, input, non_blocking):
|
||||||
|
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.UINT8)
|
||||||
|
|
||||||
|
|
||||||
|
@_onnx_symbolic("aten::_cast_Char")
|
||||||
|
@deprecated("Avoid using this function and create a Cast node instead")
|
||||||
|
def _cast_Char(g: jit_utils.GraphContext, input, non_blocking):
|
||||||
|
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT8)
|
||||||
|
|
||||||
|
|
||||||
|
@_onnx_symbolic("aten::_cast_Short")
|
||||||
|
@deprecated("Avoid using this function and create a Cast node instead")
|
||||||
|
def _cast_Short(g: jit_utils.GraphContext, input, non_blocking):
|
||||||
|
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT16)
|
||||||
|
|
||||||
|
|
||||||
|
@_onnx_symbolic("aten::_cast_Int")
|
||||||
|
@deprecated("Avoid using this function and create a Cast node instead")
|
||||||
|
def _cast_Int(g: jit_utils.GraphContext, input, non_blocking):
|
||||||
|
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT32)
|
||||||
|
|
||||||
|
|
||||||
|
@_onnx_symbolic("aten::_cast_Long")
|
||||||
|
@deprecated("Avoid using this function and create a Cast node instead")
|
||||||
|
def _cast_Long(g: jit_utils.GraphContext, input, non_blocking):
|
||||||
|
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT64)
|
||||||
|
|
||||||
|
|
||||||
|
@_onnx_symbolic("aten::_cast_Half")
|
||||||
|
@deprecated("Avoid using this function and create a Cast node instead")
|
||||||
|
def _cast_Half(g: jit_utils.GraphContext, input, non_blocking):
|
||||||
|
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.FLOAT16)
|
||||||
|
|
||||||
|
|
||||||
|
@_onnx_symbolic("aten::_cast_Float")
|
||||||
|
@deprecated("Avoid using this function and create a Cast node instead")
|
||||||
|
def _cast_Float(g: jit_utils.GraphContext, input, non_blocking):
|
||||||
|
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.FLOAT)
|
||||||
|
|
||||||
|
|
||||||
|
@_onnx_symbolic("aten::_cast_Double")
|
||||||
|
@deprecated("Avoid using this function and create a Cast node instead")
|
||||||
|
def _cast_Double(g: jit_utils.GraphContext, input, non_blocking):
|
||||||
|
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.DOUBLE)
|
||||||
|
|
||||||
|
|
||||||
|
@_onnx_symbolic("aten::_cast_Bool")
|
||||||
|
@deprecated("Avoid using this function and create a Cast node instead")
|
||||||
|
def _cast_Bool(g: jit_utils.GraphContext, input, non_blocking):
|
||||||
|
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.BOOL)
|
||||||
|
|
||||||
|
|
||||||
@_onnx_symbolic("aten::empty")
|
@_onnx_symbolic("aten::empty")
|
||||||
@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v")
|
@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v")
|
||||||
def empty(
|
def empty(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user