mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
To not pollute the global namespace, we should move the `from`/`to` APIs into torch::stable::detail. We are also following our normal deprecation cycle and choosing to continue exposing the global `from`/`to` for the time being as people who onboard their extensions onto 2.9 would not be able to build with 2.10 otherwise. Note that this means that within libtorch, we do not get the luxury of tacking on a `using torch::stable::detail::from` because then it leads to build time ambiguous calls --> both the global and namespace APIs are exposed, which one do I want? So that is why you see every local site is updated. Note that the update is _not_ necessary from a custom op writer point of view. FA3 can continue to build on torch nightlies without changing any code. (Since this is a header change, this PR has no implication on runtime, a previously built FA3 ABI stable wheel will continue to work fine with newer torch versions after this PR.) Once TORCH_BOX lands, we would be free to remove these global APIs when the deprecation cycle is up (April 2026) and encourage people to use TORCH_BOX and avoid from/to entirely. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164956 Approved by: https://github.com/malfet ghstack dependencies: #164882
248 lines
8.7 KiB
C++
248 lines
8.7 KiB
C++
#pragma once
|
|
|
|
#include <torch/csrc/stable/stableivalue_conversions.h>
|
|
#include <array>
|
|
#include <cstdint>
|
|
#include <optional>
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
#include <torch/csrc/inductor/aoti_torch/generated/c_shim_aten.h>
|
|
#include <torch/headeronly/core/ScalarType.h>
|
|
|
|
namespace torch::stable {
|
|
|
|
// We expect this to be the stable version of the empty_like op that takes in
|
|
// no kwargs (device, dtype, layout, memory_format). We will add kwargs
|
|
// support in the future.
|
|
inline torch::stable::Tensor empty_like(const torch::stable::Tensor& self) {
|
|
const auto num_args = 6;
|
|
std::array<StableIValue, num_args> stack{
|
|
torch::stable::detail::from(self),
|
|
torch::stable::detail::from(std::nullopt),
|
|
torch::stable::detail::from(std::nullopt),
|
|
torch::stable::detail::from(std::nullopt),
|
|
torch::stable::detail::from(std::nullopt),
|
|
torch::stable::detail::from(std::nullopt)};
|
|
TORCH_ERROR_CODE_CHECK(
|
|
aoti_torch_call_dispatcher("aten::empty_like", "", stack.data()));
|
|
return torch::stable::detail::to<torch::stable::Tensor>(stack[0]);
|
|
}
|
|
|
|
// We expect this to be the stable version of the fill_.Scalar op
|
|
// with identical semantics to the existing fill_.Scalar op.
|
|
// A subtle nuance is that `value` is typed as a double, but it is
|
|
// actually a Scalar. This is because Scalar.h is currently not
|
|
// header-only.
|
|
inline torch::stable::Tensor fill_(
|
|
const torch::stable::Tensor& self,
|
|
double value) {
|
|
TORCH_ERROR_CODE_CHECK(aoti_torch_aten_fill__Scalar(self.get(), value));
|
|
return self;
|
|
}
|
|
|
|
// We expect this to be the stable version of the narrow.default op.
|
|
// narrow takes in a SymInt for start and length, but these are typed as
|
|
// int64_t as SymInt is not yet header-only.
|
|
inline torch::stable::Tensor narrow(
|
|
torch::stable::Tensor& self,
|
|
int64_t dim,
|
|
int64_t start,
|
|
int64_t length) {
|
|
AtenTensorHandle ret0 = nullptr;
|
|
|
|
TORCH_ERROR_CODE_CHECK(
|
|
aoti_torch_aten_narrow(self.get(), dim, start, length, &ret0));
|
|
return torch::stable::Tensor(ret0);
|
|
}
|
|
|
|
// We expect this to be a stable version of the new_empty op that takes in
|
|
// only dtype information.
|
|
inline torch::stable::Tensor new_empty(
|
|
const torch::stable::Tensor& self,
|
|
std::vector<int64_t> size,
|
|
std::optional<c10::ScalarType> dtype = std::nullopt) {
|
|
int32_t device_type;
|
|
TORCH_ERROR_CODE_CHECK(aoti_torch_get_device_type(self.get(), &device_type));
|
|
|
|
int32_t device_index;
|
|
TORCH_ERROR_CODE_CHECK(
|
|
aoti_torch_get_device_index(self.get(), &device_index));
|
|
|
|
int32_t target_dtype;
|
|
if (dtype.has_value()) {
|
|
target_dtype = torch::stable::detail::to<int32_t>(
|
|
torch::stable::detail::from(dtype.value()));
|
|
} else {
|
|
TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype(self.get(), &target_dtype));
|
|
}
|
|
|
|
int32_t layout;
|
|
TORCH_ERROR_CODE_CHECK(aoti_torch_get_layout(self.get(), &layout));
|
|
|
|
AtenTensorHandle ret0;
|
|
TORCH_ERROR_CODE_CHECK(aoti_torch_aten_new_empty(
|
|
self.get(),
|
|
size.data(),
|
|
static_cast<int64_t>(size.size()),
|
|
&target_dtype,
|
|
&layout,
|
|
&device_type,
|
|
device_index,
|
|
nullptr, // pin_memory (nullptr for default)
|
|
&ret0));
|
|
|
|
return torch::stable::Tensor(ret0);
|
|
}
|
|
|
|
// We expect this to be a stable version of the new_zeros op that takes in
|
|
// only dtype information.
|
|
inline torch::stable::Tensor new_zeros(
|
|
const torch::stable::Tensor& self,
|
|
std::vector<int64_t> size,
|
|
std::optional<c10::ScalarType> dtype = std::nullopt) {
|
|
int32_t device_type;
|
|
TORCH_ERROR_CODE_CHECK(aoti_torch_get_device_type(self.get(), &device_type));
|
|
|
|
int32_t device_index;
|
|
TORCH_ERROR_CODE_CHECK(
|
|
aoti_torch_get_device_index(self.get(), &device_index));
|
|
|
|
int32_t target_dtype;
|
|
if (dtype.has_value()) {
|
|
target_dtype = torch::stable::detail::to<int32_t>(
|
|
torch::stable::detail::from(dtype.value()));
|
|
} else {
|
|
TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype(self.get(), &target_dtype));
|
|
}
|
|
|
|
int32_t layout;
|
|
TORCH_ERROR_CODE_CHECK(aoti_torch_get_layout(self.get(), &layout));
|
|
|
|
AtenTensorHandle ath;
|
|
TORCH_ERROR_CODE_CHECK(aoti_torch_aten_new_zeros(
|
|
self.get(),
|
|
size.data(),
|
|
static_cast<int64_t>(size.size()),
|
|
&target_dtype,
|
|
&layout,
|
|
&device_type,
|
|
device_index,
|
|
nullptr, // pin_memory (nullptr for default)
|
|
&ath));
|
|
|
|
return torch::stable::Tensor(ath);
|
|
}
|
|
|
|
// We expect this to be the stable version of the pad.default op.
|
|
// pad.default takes in a SymInt[] as the pad argument however pad is typed as
|
|
// use std::vector<int64_t> because
|
|
// (1) IntArrayRef is not yet header-only
|
|
// (2) SymInt is not yet header-only
|
|
inline torch::stable::Tensor pad(
|
|
const torch::stable::Tensor& self,
|
|
std::vector<int64_t> pad,
|
|
const std::string& mode = "constant",
|
|
double value = 0.0) {
|
|
AtenTensorHandle ret0 = nullptr;
|
|
|
|
TORCH_ERROR_CODE_CHECK(aoti_torch_aten_pad(
|
|
self.get(), pad.data(), pad.size(), mode.c_str(), &value, &ret0));
|
|
return torch::stable::Tensor(ret0);
|
|
}
|
|
|
|
// We expect the following two functions to be stable versions of the
|
|
// amax.default op with identical semantics to the existing amax.default op. If
|
|
// `keepdim` is true, the result will have the same number of dimensions as
|
|
// `self`, with the specified dimension having size 1. Otherwise, the result
|
|
// will have one fewer dimension than `self`, with the specified dimension
|
|
// removed.
|
|
|
|
// This function is an overload to compute the maximum value along each slice of
|
|
// `self` along a single dimension `dim`.
|
|
inline torch::stable::Tensor amax(
|
|
const torch::stable::Tensor& self,
|
|
int64_t dim,
|
|
bool keepdim = false) {
|
|
AtenTensorHandle ret = nullptr;
|
|
TORCH_ERROR_CODE_CHECK(
|
|
aoti_torch_aten_amax(self.get(), &dim, 1, keepdim, &ret));
|
|
return torch::stable::Tensor(ret);
|
|
}
|
|
|
|
// This function is an overload to compute the maximum value along each slice of
|
|
// `self` reducing over all the dimensions in the vector `dims`. The
|
|
// amax.default op takes in a SymInt[] as the dims argument, however dims is
|
|
// typed as use std::vector<int64_t> here because (1) IntArrayRef is not yet
|
|
// header-only (2) SymInt is not yet header-only
|
|
inline torch::stable::Tensor amax(
|
|
const torch::stable::Tensor& self,
|
|
std::vector<int64_t> dims,
|
|
bool keepdim = false) {
|
|
AtenTensorHandle ret = nullptr;
|
|
TORCH_ERROR_CODE_CHECK(aoti_torch_aten_amax(
|
|
self.get(),
|
|
dims.data(),
|
|
static_cast<int64_t>(dims.size()),
|
|
keepdim,
|
|
&ret));
|
|
return torch::stable::Tensor(ret);
|
|
}
|
|
|
|
// We expect this to be the stable version of the transpose op with identical
|
|
// semantics to the existing transpose.int op.
|
|
inline torch::stable::Tensor transpose(
|
|
const torch::stable::Tensor& self,
|
|
int64_t dim0,
|
|
int64_t dim1) {
|
|
const auto num_args = 3;
|
|
std::array<StableIValue, num_args> stack{
|
|
torch::stable::detail::from(self),
|
|
torch::stable::detail::from(dim0),
|
|
torch::stable::detail::from(dim1)};
|
|
TORCH_ERROR_CODE_CHECK(
|
|
aoti_torch_call_dispatcher("aten::transpose", "int", stack.data()));
|
|
return torch::stable::detail::to<torch::stable::Tensor>(stack[0]);
|
|
}
|
|
|
|
// We expect this to be the stable version of the zero_ op with identical
|
|
// semantics to the existing zero_ op (except that it will not be called as
|
|
// a tensor method but only as a function i.e. zero_(t) not t.zero_()).
|
|
inline torch::stable::Tensor zero_(torch::stable::Tensor& self) {
|
|
const auto num_args = 1;
|
|
std::array<StableIValue, num_args> stack{torch::stable::detail::from(self)};
|
|
TORCH_ERROR_CODE_CHECK(
|
|
aoti_torch_call_dispatcher("aten::zero_", "", stack.data()));
|
|
return torch::stable::detail::to<torch::stable::Tensor>(stack[0]);
|
|
}
|
|
|
|
// We expect this to be the stable version of the copy_ op with
|
|
// identical semantics to the existing copy_ op.
|
|
inline torch::stable::Tensor copy_(
|
|
torch::stable::Tensor& self,
|
|
const torch::stable::Tensor& src,
|
|
std::optional<bool> non_blocking = std::nullopt) {
|
|
const auto num_args = 3;
|
|
std::array<StableIValue, num_args> stack{
|
|
torch::stable::detail::from(self),
|
|
torch::stable::detail::from(src),
|
|
torch::stable::detail::from(non_blocking.value_or(false))};
|
|
TORCH_ERROR_CODE_CHECK(
|
|
aoti_torch_call_dispatcher("aten::copy_", "", stack.data()));
|
|
return torch::stable::detail::to<torch::stable::Tensor>(stack[0]);
|
|
}
|
|
|
|
// We expect this to be the stable version of the clone op. We will
|
|
// add optional memory_format kwarg support in the future.
|
|
inline torch::stable::Tensor clone(const torch::stable::Tensor& self) {
|
|
const auto num_args = 2;
|
|
std::array<StableIValue, num_args> stack{
|
|
torch::stable::detail::from(self),
|
|
torch::stable::detail::from(std::nullopt)};
|
|
TORCH_ERROR_CODE_CHECK(
|
|
aoti_torch_call_dispatcher("aten::clone", "", stack.data()));
|
|
return torch::stable::detail::to<torch::stable::Tensor>(stack[0]);
|
|
}
|
|
|
|
} // namespace torch::stable
|