mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add new_empty (with dtype argument only) to torch::stable (#159508)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159508 Approved by: https://github.com/janeyx99 ghstack dependencies: #160557
This commit is contained in:
parent
543896fcf3
commit
78a8e6a671
|
|
@ -4,6 +4,7 @@
|
||||||
#include <torch/csrc/stable/tensor.h>
|
#include <torch/csrc/stable/tensor.h>
|
||||||
#include <torch/csrc/stable/ops.h>
|
#include <torch/csrc/stable/ops.h>
|
||||||
#include <torch/headeronly/util/Exception.h>
|
#include <torch/headeronly/util/Exception.h>
|
||||||
|
#include <torch/headeronly/core/ScalarType.h>
|
||||||
|
|
||||||
#ifdef LAE_USE_CUDA
|
#ifdef LAE_USE_CUDA
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
|
|
@ -340,12 +341,24 @@ void boxed_my_narrow(
|
||||||
stack[0] = from(res);
|
stack[0] = from(res);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Tensor my_new_empty_dtype_variant(Tensor t) {
|
||||||
|
std::vector<int64_t> sizes = {2, 5};
|
||||||
|
auto dtype = std::make_optional(at::ScalarType::BFloat16);
|
||||||
|
return new_empty(t, sizes, dtype);
|
||||||
|
}
|
||||||
|
|
||||||
|
void boxed_my_new_empty_dtype_variant(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||||
|
auto res = my_new_empty_dtype_variant(to<Tensor>(stack[0]));
|
||||||
|
stack[0] = from(res);
|
||||||
|
}
|
||||||
|
|
||||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||||
m.def("my_transpose(Tensor t, int dim0, int dim1) -> Tensor");
|
m.def("my_transpose(Tensor t, int dim0, int dim1) -> Tensor");
|
||||||
m.def("my_empty_like(Tensor t) -> Tensor");
|
m.def("my_empty_like(Tensor t) -> Tensor");
|
||||||
m.def("fill_infinity(Tensor(a!) t) -> Tensor(a!)");
|
m.def("fill_infinity(Tensor(a!) t) -> Tensor(a!)");
|
||||||
m.def("my_pad(Tensor t) -> Tensor");
|
m.def("my_pad(Tensor t) -> Tensor");
|
||||||
m.def("my_narrow(Tensor t, int dim, int start, int length) -> Tensor");
|
m.def("my_narrow(Tensor t, int dim, int start, int length) -> Tensor");
|
||||||
|
m.def("my_new_empty_dtype_variant(Tensor t) -> Tensor");
|
||||||
}
|
}
|
||||||
|
|
||||||
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
|
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
|
||||||
|
|
@ -353,6 +366,7 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
|
||||||
m.impl("my_empty_like", &boxed_empty_like);
|
m.impl("my_empty_like", &boxed_empty_like);
|
||||||
m.impl("fill_infinity", &boxed_fill_infinity);
|
m.impl("fill_infinity", &boxed_fill_infinity);
|
||||||
m.impl("my_is_cpu", &boxed_my_is_cpu);
|
m.impl("my_is_cpu", &boxed_my_is_cpu);
|
||||||
|
m.impl("my_new_empty_dtype_variant", &boxed_my_new_empty_dtype_variant);
|
||||||
}
|
}
|
||||||
|
|
||||||
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeImplicitAutograd, m) {
|
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeImplicitAutograd, m) {
|
||||||
|
|
|
||||||
|
|
@ -283,3 +283,15 @@ def test_get_current_device_index() -> int:
|
||||||
Returns: Current device index as an integer
|
Returns: Current device index as an integer
|
||||||
"""
|
"""
|
||||||
return torch.ops.libtorch_agnostic.test_get_current_device_index.default()
|
return torch.ops.libtorch_agnostic.test_get_current_device_index.default()
|
||||||
|
|
||||||
|
|
||||||
|
def my_new_empty_dtype_variant(t) -> Tensor:
|
||||||
|
"""
|
||||||
|
Returns a new empty tensor with shape [2, 5] and dtype bfloat16
|
||||||
|
|
||||||
|
Args:
|
||||||
|
t: Input tensor used as a reference for device and other properties
|
||||||
|
|
||||||
|
Returns: New empty tensor with shape [2, 5] and dtype bfloat16
|
||||||
|
"""
|
||||||
|
return torch.ops.libtorch_agnostic.my_new_empty_dtype_variant.default(t)
|
||||||
|
|
|
||||||
|
|
@ -190,7 +190,7 @@ if not IS_WINDOWS:
|
||||||
|
|
||||||
deterministic = torch.are_deterministic_algorithms_enabled()
|
deterministic = torch.are_deterministic_algorithms_enabled()
|
||||||
try:
|
try:
|
||||||
# set use_deterministic_algorithms to fill unintialized memory
|
# set use_deterministic_algorithms to fill uninitialized memory
|
||||||
torch.use_deterministic_algorithms(True)
|
torch.use_deterministic_algorithms(True)
|
||||||
|
|
||||||
t = torch.rand(2, 7, device=device)
|
t = torch.rand(2, 7, device=device)
|
||||||
|
|
@ -322,6 +322,21 @@ if not IS_WINDOWS:
|
||||||
finally:
|
finally:
|
||||||
torch.cuda.set_device(prev_device)
|
torch.cuda.set_device(prev_device)
|
||||||
|
|
||||||
|
def test_my_new_empty_dtype_variant(self, device):
|
||||||
|
import libtorch_agnostic
|
||||||
|
|
||||||
|
deterministic = torch.are_deterministic_algorithms_enabled()
|
||||||
|
try:
|
||||||
|
# set use_deterministic_algorithms to fill uninitialized memory
|
||||||
|
torch.use_deterministic_algorithms(True)
|
||||||
|
t = torch.randn(3, 4, device=device)
|
||||||
|
out = libtorch_agnostic.ops.my_new_empty_dtype_variant(t)
|
||||||
|
ref_out = t.new_empty((2, 5), dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
self.assertEqual(out, ref_out, exact_device=True)
|
||||||
|
finally:
|
||||||
|
torch.use_deterministic_algorithms(deterministic)
|
||||||
|
|
||||||
instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None)
|
instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -220,6 +220,9 @@ aoti_torch_get_device_type(AtenTensorHandle tensor, int32_t* ret_device_type);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError
|
AOTI_TORCH_EXPORT AOTITorchError
|
||||||
aoti_torch_get_device_index(AtenTensorHandle tensor, int32_t* ret_device_index);
|
aoti_torch_get_device_index(AtenTensorHandle tensor, int32_t* ret_device_index);
|
||||||
|
|
||||||
|
AOTI_TORCH_EXPORT AOTITorchError
|
||||||
|
aoti_torch_get_layout(AtenTensorHandle tensor, int32_t* ret_layout);
|
||||||
|
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_storage_offset(
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_storage_offset(
|
||||||
AtenTensorHandle tensor,
|
AtenTensorHandle tensor,
|
||||||
int64_t* ret_storage_offset);
|
int64_t* ret_storage_offset);
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,7 @@ extern "C" {
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_amax(AtenTensorHandle self, const int64_t* dim, int64_t dim_len_, int32_t keepdim, AtenTensorHandle* ret0);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_amax(AtenTensorHandle self, const int64_t* dim, int64_t dim_len_, int32_t keepdim, AtenTensorHandle* ret0);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_fill__Scalar(AtenTensorHandle self, double value);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_fill__Scalar(AtenTensorHandle self, double value);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_narrow(AtenTensorHandle self, int64_t dim, int64_t start, int64_t length, AtenTensorHandle* ret0);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_narrow(AtenTensorHandle self, int64_t dim, int64_t start, int64_t length, AtenTensorHandle* ret0);
|
||||||
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_new_empty(AtenTensorHandle self, const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_pad(AtenTensorHandle self, const int64_t* pad, int64_t pad_len_, const char* mode, double* value, AtenTensorHandle* ret0);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_pad(AtenTensorHandle self, const int64_t* pad, int64_t pad_len_, const char* mode, double* value, AtenTensorHandle* ret0);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
|
|
|
||||||
|
|
@ -389,6 +389,15 @@ AOTITorchError aoti_torch_get_device_index(
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
AOTITorchError aoti_torch_get_layout(
|
||||||
|
AtenTensorHandle tensor,
|
||||||
|
int32_t* ret_layout) {
|
||||||
|
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||||
|
at::Tensor* t = tensor_handle_to_tensor_pointer(tensor);
|
||||||
|
*ret_layout = static_cast<int32_t>(t->layout());
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
AOTITorchError aoti_torch_get_storage_offset(
|
AOTITorchError aoti_torch_get_storage_offset(
|
||||||
AtenTensorHandle tensor,
|
AtenTensorHandle tensor,
|
||||||
int64_t* ret_storage_offset) {
|
int64_t* ret_storage_offset) {
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include <torch/csrc/inductor/aoti_torch/generated/c_shim_aten.h>
|
#include <torch/csrc/inductor/aoti_torch/generated/c_shim_aten.h>
|
||||||
|
#include <torch/headeronly/core/ScalarType.h>
|
||||||
|
|
||||||
using torch::stable::Tensor;
|
using torch::stable::Tensor;
|
||||||
|
|
||||||
|
|
@ -51,6 +52,44 @@ inline Tensor narrow(Tensor& self, int64_t dim, int64_t start, int64_t length) {
|
||||||
return Tensor(ret0);
|
return Tensor(ret0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// We expect this to be a stable version of the new_empty op that takes in
|
||||||
|
// only dtype information.
|
||||||
|
inline Tensor new_empty(
|
||||||
|
const Tensor& self,
|
||||||
|
std::vector<int64_t> size,
|
||||||
|
std::optional<c10::ScalarType> dtype = std::nullopt) {
|
||||||
|
int32_t device_type;
|
||||||
|
TORCH_ERROR_CODE_CHECK(aoti_torch_get_device_type(self.get(), &device_type));
|
||||||
|
|
||||||
|
int32_t device_index;
|
||||||
|
TORCH_ERROR_CODE_CHECK(
|
||||||
|
aoti_torch_get_device_index(self.get(), &device_index));
|
||||||
|
|
||||||
|
int32_t target_dtype;
|
||||||
|
if (dtype.has_value()) {
|
||||||
|
target_dtype = to<int32_t>(from(dtype.value()));
|
||||||
|
} else {
|
||||||
|
TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype(self.get(), &target_dtype));
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t layout;
|
||||||
|
TORCH_ERROR_CODE_CHECK(aoti_torch_get_layout(self.get(), &layout));
|
||||||
|
|
||||||
|
AtenTensorHandle ret0;
|
||||||
|
TORCH_ERROR_CODE_CHECK(aoti_torch_aten_new_empty(
|
||||||
|
self.get(),
|
||||||
|
size.data(),
|
||||||
|
static_cast<int64_t>(size.size()),
|
||||||
|
&target_dtype,
|
||||||
|
&layout,
|
||||||
|
&device_type,
|
||||||
|
device_index,
|
||||||
|
nullptr, // pin_memory (nullptr for default)
|
||||||
|
&ret0));
|
||||||
|
|
||||||
|
return Tensor(ret0);
|
||||||
|
}
|
||||||
|
|
||||||
// We expect this to be the stable version of the pad.default op.
|
// We expect this to be the stable version of the pad.default op.
|
||||||
// pad.default takes in a SymInt[] as the pad argument however pad is typed as
|
// pad.default takes in a SymInt[] as the pad argument however pad is typed as
|
||||||
// use std::vector<int64_t> because
|
// use std::vector<int64_t> because
|
||||||
|
|
|
||||||
|
|
@ -186,4 +186,5 @@ aten_shimified_ops: dict[str, dict[str, list[str]]] = {
|
||||||
"aten.pad.default": {},
|
"aten.pad.default": {},
|
||||||
"aten.narrow.default": {},
|
"aten.narrow.default": {},
|
||||||
"aten.amax.default": {},
|
"aten.amax.default": {},
|
||||||
|
"aten.new_empty.default": {},
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,7 @@ from torchgen.model import (
|
||||||
OperatorName,
|
OperatorName,
|
||||||
OptionalType,
|
OptionalType,
|
||||||
Type,
|
Type,
|
||||||
|
Variant,
|
||||||
)
|
)
|
||||||
from torchgen.utils import FileManager, mapMaybe
|
from torchgen.utils import FileManager, mapMaybe
|
||||||
|
|
||||||
|
|
@ -396,7 +397,22 @@ def gen_static_dispatch_backend_call(
|
||||||
) -> str:
|
) -> str:
|
||||||
sig = DispatcherSignature.from_schema(f.func)
|
sig = DispatcherSignature.from_schema(f.func)
|
||||||
cpp_sig = gen_static_dispatch_backend_call_signature(sig, f)
|
cpp_sig = gen_static_dispatch_backend_call_signature(sig, f)
|
||||||
|
|
||||||
if backend_index is None:
|
if backend_index is None:
|
||||||
|
# Check if this is a symint function and if the function only has method variants
|
||||||
|
if sig.symint and f.func.has_symint():
|
||||||
|
has_function_variant = Variant.function in f.variants
|
||||||
|
|
||||||
|
if not has_function_variant:
|
||||||
|
# Functions with both function and method variants can use the at::{*}_symint version
|
||||||
|
# (e.g., narrow -> at::narrow_symint), BUT
|
||||||
|
# Method-only functions with symint parameters should use at::symint:: namespace
|
||||||
|
# Remove the _symint suffix since at::symint:: namespace uses the base name
|
||||||
|
# (e.g., new_empty -> at::symint::new_empty<c10::SymInt>)
|
||||||
|
base_name = cpp_sig.name()
|
||||||
|
base_name = base_name.removesuffix("_symint") # Remove "_symint" suffix
|
||||||
|
return f"at::symint::{base_name}<c10::SymInt>"
|
||||||
|
|
||||||
return f"at::{cpp_sig.name()}"
|
return f"at::{cpp_sig.name()}"
|
||||||
else:
|
else:
|
||||||
return f"at::{backend_index.dispatch_key.lower()}::{cpp_sig.name()}"
|
return f"at::{backend_index.dispatch_key.lower()}::{cpp_sig.name()}"
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user