[easy][PyTorch] Use at::native::is_nonzero (#67195)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67195

Now that `is_nonzero` is part of `at::native` refer https://github.com/pytorch/pytorch/pull/66663, replacing `TensorCompare::is_nonzero` to `at::native::is_nonzero`

ghstack-source-id: 141514416

Test Plan: CI

Reviewed By: larryliu0820

Differential Revision: D31704041

fbshipit-source-id: 36813e5411d0aa2eb2d0442e2a195bbed417b33d
This commit is contained in:
Pavithran Ramachandran 2021-10-26 12:34:47 -07:00 committed by Facebook GitHub Bot
parent a33d3d84df
commit 1ce500f56f
4 changed files with 9 additions and 9 deletions

View File

@ -667,7 +667,7 @@ TEST(CustomAutogradTest, DeepReentrant) {
}
static variable_list backward(AutogradContext*ctx, variable_list grad_output) {
if (!ctx->saved_data["x"].toTensor().is_nonzero()) {
if (!at::native::is_nonzero(ctx->saved_data["x"].toTensor())) {
return grad_output;
}
{
@ -708,7 +708,7 @@ TEST(CustomAutogradTest, ReentrantPriority) {
static variable_list backward(AutogradContext*ctx, variable_list grad_output) {
order.push_back(1);
if (!ctx->saved_data["x"].toTensor().is_nonzero()) {
if (!at::native::is_nonzero(ctx->saved_data["x"].toTensor())) {
return grad_output;
}
{

View File

@ -56,7 +56,7 @@ void call_setup_methods() {
at::Tensor t2 = t1.fill_(3);
at::narrow(t2, 1, 0, 1);
at::eq(t1, t2);
const volatile bool nz = at::zeros({1}).is_nonzero();
const volatile bool nz = at::native::is_nonzero(at::zeros({1}));
(void)nz;
// Create a byte tensor and copy it

View File

@ -20,7 +20,7 @@ void listIndex<at::Tensor>(Stack& stack) {
auto pos =
std::find_if(list.begin(), list.end(), [elem](const at::Tensor& b) {
const auto cmp_result = elem.eq(b);
return cmp_result.is_nonzero();
return at::native::is_nonzero(cmp_result);
});
if (pos != list.end()) {
@ -38,7 +38,7 @@ void listCount<at::Tensor>(Stack& stack) {
const int64_t count =
std::count_if(list.begin(), list.end(), [&](const at::Tensor& b) {
const auto cmp_result = elem.eq(b);
return cmp_result.is_nonzero();
return at::native::is_nonzero(cmp_result);
});
push(stack, count);
}
@ -69,7 +69,7 @@ void listSort<at::Tensor>(Stack& stack) {
if (a.getIntrusivePtr() == b.getIntrusivePtr()) {
return false;
}
return (a.lt(b).is_nonzero()) ^ reverse;
return (at::native::is_nonzero(a.lt(b))) ^ reverse;
});
}
@ -81,7 +81,7 @@ void listCopyAndSort<at::Tensor>(Stack& stack) {
list_copied.begin(),
list_copied.end(),
[](const at::Tensor& a, const at::Tensor& b) {
return a.lt(b).is_nonzero();
return at::native::is_nonzero(a.lt(b));
});
push(stack, list_copied);
}
@ -93,7 +93,7 @@ void listRemove<at::Tensor>(Stack& stack) {
auto pos = std::find_if(list.begin(), list.end(), [&](const at::Tensor& b) {
const auto cmp_result = elem.eq(b);
return cmp_result.is_nonzero();
return at::native::is_nonzero(cmp_result);
});
if (pos != list.end()) {

View File

@ -347,7 +347,7 @@ inline bool tensor_list_equal(
// elements, then passes the result to bool().
// see: https://docs.python.org/3.4/reference/datamodel.html#object.__ge__
const auto cmp_result = a_element.eq(b_element);
if (!cmp_result.is_nonzero()) {
if (!at::native::is_nonzero(cmp_result)) {
return false;
}
}