mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[easy][PyTorch] Use at::native::is_nonzero (#67195)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/67195 Now that `is_nonzero` is part of `at::native` refer https://github.com/pytorch/pytorch/pull/66663, replacing `TensorCompare::is_nonzero` to `at::native::is_nonzero` ghstack-source-id: 141514416 Test Plan: CI Reviewed By: larryliu0820 Differential Revision: D31704041 fbshipit-source-id: 36813e5411d0aa2eb2d0442e2a195bbed417b33d
This commit is contained in:
parent
a33d3d84df
commit
1ce500f56f
|
|
@ -667,7 +667,7 @@ TEST(CustomAutogradTest, DeepReentrant) {
|
|||
}
|
||||
|
||||
static variable_list backward(AutogradContext*ctx, variable_list grad_output) {
|
||||
if (!ctx->saved_data["x"].toTensor().is_nonzero()) {
|
||||
if (!at::native::is_nonzero(ctx->saved_data["x"].toTensor())) {
|
||||
return grad_output;
|
||||
}
|
||||
{
|
||||
|
|
@ -708,7 +708,7 @@ TEST(CustomAutogradTest, ReentrantPriority) {
|
|||
|
||||
static variable_list backward(AutogradContext*ctx, variable_list grad_output) {
|
||||
order.push_back(1);
|
||||
if (!ctx->saved_data["x"].toTensor().is_nonzero()) {
|
||||
if (!at::native::is_nonzero(ctx->saved_data["x"].toTensor())) {
|
||||
return grad_output;
|
||||
}
|
||||
{
|
||||
|
|
|
|||
|
|
@ -56,7 +56,7 @@ void call_setup_methods() {
|
|||
at::Tensor t2 = t1.fill_(3);
|
||||
at::narrow(t2, 1, 0, 1);
|
||||
at::eq(t1, t2);
|
||||
const volatile bool nz = at::zeros({1}).is_nonzero();
|
||||
const volatile bool nz = at::native::is_nonzero(at::zeros({1}));
|
||||
(void)nz;
|
||||
|
||||
// Create a byte tensor and copy it
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ void listIndex<at::Tensor>(Stack& stack) {
|
|||
auto pos =
|
||||
std::find_if(list.begin(), list.end(), [elem](const at::Tensor& b) {
|
||||
const auto cmp_result = elem.eq(b);
|
||||
return cmp_result.is_nonzero();
|
||||
return at::native::is_nonzero(cmp_result);
|
||||
});
|
||||
|
||||
if (pos != list.end()) {
|
||||
|
|
@ -38,7 +38,7 @@ void listCount<at::Tensor>(Stack& stack) {
|
|||
const int64_t count =
|
||||
std::count_if(list.begin(), list.end(), [&](const at::Tensor& b) {
|
||||
const auto cmp_result = elem.eq(b);
|
||||
return cmp_result.is_nonzero();
|
||||
return at::native::is_nonzero(cmp_result);
|
||||
});
|
||||
push(stack, count);
|
||||
}
|
||||
|
|
@ -69,7 +69,7 @@ void listSort<at::Tensor>(Stack& stack) {
|
|||
if (a.getIntrusivePtr() == b.getIntrusivePtr()) {
|
||||
return false;
|
||||
}
|
||||
return (a.lt(b).is_nonzero()) ^ reverse;
|
||||
return (at::native::is_nonzero(a.lt(b))) ^ reverse;
|
||||
});
|
||||
}
|
||||
|
||||
|
|
@ -81,7 +81,7 @@ void listCopyAndSort<at::Tensor>(Stack& stack) {
|
|||
list_copied.begin(),
|
||||
list_copied.end(),
|
||||
[](const at::Tensor& a, const at::Tensor& b) {
|
||||
return a.lt(b).is_nonzero();
|
||||
return at::native::is_nonzero(a.lt(b));
|
||||
});
|
||||
push(stack, list_copied);
|
||||
}
|
||||
|
|
@ -93,7 +93,7 @@ void listRemove<at::Tensor>(Stack& stack) {
|
|||
|
||||
auto pos = std::find_if(list.begin(), list.end(), [&](const at::Tensor& b) {
|
||||
const auto cmp_result = elem.eq(b);
|
||||
return cmp_result.is_nonzero();
|
||||
return at::native::is_nonzero(cmp_result);
|
||||
});
|
||||
|
||||
if (pos != list.end()) {
|
||||
|
|
|
|||
|
|
@ -347,7 +347,7 @@ inline bool tensor_list_equal(
|
|||
// elements, then passes the result to bool().
|
||||
// see: https://docs.python.org/3.4/reference/datamodel.html#object.__ge__
|
||||
const auto cmp_result = a_element.eq(b_element);
|
||||
if (!cmp_result.is_nonzero()) {
|
||||
if (!at::native::is_nonzero(cmp_result)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user