mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add zero_() and empty_like(t) to torch/csrc/stable/ops.h (#158866)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158866 Approved by: https://github.com/janeyx99
This commit is contained in:
parent
76be282e3a
commit
fef236da69
|
|
@ -269,10 +269,39 @@ void boxed_my_transpose(StableIValue* stack, uint64_t num_args, uint64_t num_out
|
|||
stack[0] = from(res);
|
||||
}
|
||||
|
||||
Tensor my_empty_like(Tensor t) {
|
||||
return empty_like(t);
|
||||
}
|
||||
|
||||
void boxed_empty_like(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
auto res = my_empty_like(to<Tensor>(stack[0]));
|
||||
stack[0] = from(res);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
m.def("my_transpose(Tensor t, int dim0, int dim1) -> Tensor");
|
||||
m.def("my_empty_like(Tensor t) -> Tensor");
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
|
||||
m.impl("my_transpose", &boxed_my_transpose);
|
||||
m.impl("my_empty_like", &boxed_empty_like);
|
||||
}
|
||||
|
||||
|
||||
Tensor my_zero_(Tensor t) {
|
||||
return zero_(t);
|
||||
}
|
||||
|
||||
void boxed_my_zero_(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
auto res = my_zero_(to<Tensor>(stack[0]));
|
||||
stack[0] = from(res);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
m.def("my_zero_(Tensor(a!) t) -> Tensor(a!)");
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CPU, m) {
|
||||
m.impl("my_zero_", &boxed_my_zero_);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -128,3 +128,27 @@ def my_transpose(t, dim0, dim1) -> Tensor:
|
|||
Returns: my_transpose(t, dim0, dim1)
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic.my_transpose.default(t, dim0, dim1)
|
||||
|
||||
|
||||
def my_empty_like(t) -> Tensor:
|
||||
"""
|
||||
Returns t.empty_like()
|
||||
|
||||
Args:
|
||||
t: Tensor
|
||||
|
||||
Returns: my_empty_like(t)
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic.my_empty_like.default(t)
|
||||
|
||||
|
||||
def my_zero_(t) -> Tensor:
|
||||
"""
|
||||
Returns t.zero_()
|
||||
|
||||
Args:
|
||||
t: Tensor
|
||||
|
||||
Returns: my_zero_(t)
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic.my_zero_.default(t)
|
||||
|
|
|
|||
|
|
@ -183,6 +183,30 @@ if not IS_WINDOWS:
|
|||
with self.assertRaisesRegex(RuntimeError, "API call failed"):
|
||||
libtorch_agnostic.ops.my_transpose(t, 1, 2)
|
||||
|
||||
def test_my_empty_like(self, device):
|
||||
import libtorch_agnostic
|
||||
|
||||
deterministic = torch.are_deterministic_algorithms_enabled()
|
||||
try:
|
||||
# set use_deterministic_algorithms to fill unintialized memory
|
||||
torch.use_deterministic_algorithms(True)
|
||||
|
||||
t = torch.rand(2, 7, device=device)
|
||||
out = libtorch_agnostic.ops.my_empty_like(t)
|
||||
self.assertTrue(id(out != id(t)))
|
||||
self.assertEqual(out, torch.empty_like(t))
|
||||
finally:
|
||||
torch.use_deterministic_algorithms(deterministic)
|
||||
|
||||
@onlyCPU
|
||||
def test_my_zero_(self, device):
|
||||
import libtorch_agnostic
|
||||
|
||||
t = torch.rand(2, 7, device=device)
|
||||
out = libtorch_agnostic.ops.my_zero_(t)
|
||||
self.assertEqual(id(out), id(t))
|
||||
self.assertEqual(out, torch.zeros_like(t))
|
||||
|
||||
instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -3,9 +3,27 @@
|
|||
#include <torch/csrc/stable/library.h>
|
||||
#include <array>
|
||||
#include <cstdint>
|
||||
#include <optional>
|
||||
|
||||
using torch::stable::Tensor;
|
||||
|
||||
// We expect this to be the stable version of the empty_like op that takes in
|
||||
// no kwargs (device, dtype, layout, memory_format). We will add kwargs
|
||||
// support in the future.
|
||||
inline Tensor empty_like(const Tensor& self) {
|
||||
const auto num_args = 6;
|
||||
std::array<StableIValue, num_args> stack{
|
||||
from(self),
|
||||
from(std::nullopt),
|
||||
from(std::nullopt),
|
||||
from(std::nullopt),
|
||||
from(std::nullopt),
|
||||
from(std::nullopt)};
|
||||
AOTI_TORCH_ERROR_CODE_CHECK(
|
||||
aoti_torch_call_dispatcher("aten::empty_like", "", stack.data()));
|
||||
return to<Tensor>(stack[0]);
|
||||
}
|
||||
|
||||
// We expect this to be the stable version of the transpose op with identical
|
||||
// semantics to the existing transpose.int op.
|
||||
inline Tensor transpose(const Tensor& self, int64_t dim0, int64_t dim1) {
|
||||
|
|
@ -15,3 +33,14 @@ inline Tensor transpose(const Tensor& self, int64_t dim0, int64_t dim1) {
|
|||
aoti_torch_call_dispatcher("aten::transpose", "int", stack.data()));
|
||||
return to<Tensor>(stack[0]);
|
||||
}
|
||||
|
||||
// We expect this to be the stable version of the zero_ op with identical
|
||||
// semantics to the existing zero_ op (except that it will not be called as
|
||||
// a tensor method but only as a function i.e. zero_(t) not t.zero_()).
|
||||
inline Tensor zero_(Tensor& self) {
|
||||
const auto num_args = 1;
|
||||
std::array<StableIValue, num_args> stack{from(self)};
|
||||
AOTI_TORCH_ERROR_CODE_CHECK(
|
||||
aoti_torch_call_dispatcher("aten::zero_", "", stack.data()));
|
||||
return to<Tensor>(stack[0]);
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user