Fix NaN propagation in TE fuser's min/max implementation (#43609)

Summary:
Per eager mode source-of-truth, NaNs shall be propagated by min/max.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/43609

Reviewed By: ZolotukhinM

Differential Revision: D23349184

Pulled By: bertmaher

fbshipit-source-id: 094eb8b89a02b27d5ecf3988d0f473c0f91e4afb
This commit is contained in:
Bert Maher 2020-09-01 02:07:02 -07:00 committed by Facebook GitHub Bot
parent 820c4b05a9
commit c14a3613a8
8 changed files with 111 additions and 234 deletions

View File

@ -13,6 +13,7 @@
#include "torch/csrc/jit/tensorexpr/loopnest.h"
#include "torch/csrc/jit/tensorexpr/tensor.h"
#include <cmath>
#include <numeric>
namespace torch {
@ -649,7 +650,7 @@ void testLLVMElemwiseMinInt() {
assertAllEqual(c_buffer, 1);
}
void testLLVMElemwiseMaxNumFloat() {
void testLLVMElemwiseMaxFloat() {
KernelScope kernel_scope;
constexpr int N = 1024;
Buffer a(BufHandle("A", {N}, kFloat));
@ -684,7 +685,7 @@ void testLLVMElemwiseMaxNumFloat() {
assertAllEqual(c_buffer, 41.0f);
}
void testLLVMElemwiseMaxNumNaNFloat() {
void testLLVMElemwiseMaxNaNFloat() {
KernelScope kernel_scope;
constexpr int N = 1024;
Buffer a(BufHandle("A", {N}, kFloat));
@ -715,10 +716,12 @@ void testLLVMElemwiseMaxNumNaNFloat() {
ASSERT_EQ(b_buffer.size(), N);
ASSERT_EQ(c_buffer.size(), N);
assertAllEqual(b_buffer, 1.0f);
assertAllEqual(c_buffer, 1.0f);
for (auto const& elt : c_buffer) {
ASSERT_TRUE(std::isnan(elt));
}
}
void testLLVMElemwiseMinNumFloat() {
void testLLVMElemwiseMinFloat() {
KernelScope kernel_scope;
constexpr int N = 1024;
Buffer a(BufHandle("A", {N}, kFloat));
@ -753,7 +756,7 @@ void testLLVMElemwiseMinNumFloat() {
assertAllEqual(c_buffer, 1.0f);
}
void testLLVMElemwiseMinNumNaNFloat() {
void testLLVMElemwiseMinNaNFloat() {
KernelScope kernel_scope;
constexpr int N = 1024;
Buffer a(BufHandle("A", {N}, kFloat));
@ -784,153 +787,11 @@ void testLLVMElemwiseMinNumNaNFloat() {
ASSERT_EQ(b_buffer.size(), N);
ASSERT_EQ(c_buffer.size(), N);
assertAllEqual(b_buffer, 1.0f);
assertAllEqual(c_buffer, 1.0f);
}
#if 1 // LLVM doesn't currently have implementations for maximum/minimum on x86
void testLLVMElemwiseMaximumFloat() {
KernelScope kernel_scope;
constexpr int N = 1024;
Buffer a(BufHandle("A", {N}, kFloat));
Buffer b(BufHandle("B", {N}, kFloat));
Buffer c(BufHandle("C", {N}, kFloat));
std::vector<float> a_buffer(N, 41);
std::vector<float> b_buffer(N, 1);
std::vector<float> c_buffer(N, 1);
auto mask = IntImm::make(1);
VarHandle i("i", kInt);
auto expr = For::make(
i,
0,
N,
Store::make(
c,
{i},
Max::make(Load::make(a, {i}, mask), Load::make(b, {i}, mask), true),
mask));
LLVMCodeGen cg(expr, {a, b, c});
std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
ASSERT_EQ(cg.value<int>(args), 0);
ASSERT_EQ(a_buffer.size(), N);
ASSERT_EQ(b_buffer.size(), N);
ASSERT_EQ(c_buffer.size(), N);
assertAllEqual(a_buffer, 41.0f);
assertAllEqual(b_buffer, 1.0f);
assertAllEqual(c_buffer, 41.0f);
}
void testLLVMElemwiseMaximumNaNFloat() {
KernelScope kernel_scope;
constexpr int N = 1024;
Buffer a(BufHandle("A", {N}, kFloat));
Buffer b(BufHandle("B", {N}, kFloat));
Buffer c(BufHandle("C", {N}, kFloat));
std::vector<float> a_buffer(N, NAN);
std::vector<float> b_buffer(N, 1);
std::vector<float> c_buffer(N, 1);
auto mask = IntImm::make(1);
VarHandle i("i", kInt);
auto expr = For::make(
i,
0,
N,
Store::make(
c,
{i},
Max::make(Load::make(a, {i}, mask), Load::make(b, {i}, mask), true),
mask));
LLVMCodeGen cg(expr, {a, b, c});
std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
ASSERT_EQ(cg.value<int>(args), 0);
ASSERT_EQ(a_buffer.size(), N);
ASSERT_EQ(b_buffer.size(), N);
ASSERT_EQ(c_buffer.size(), N);
for (int i = 0; i < N; ++i) {
ASSERT_TRUE(std::isnan(a_buffer[i]));
ASSERT_TRUE(std::isnan(c_buffer[i]));
for (auto const& elt : c_buffer) {
ASSERT_TRUE(std::isnan(elt));
}
}
void testLLVMElemwiseMinimumFloat() {
KernelScope kernel_scope;
constexpr int N = 1024;
Buffer a(BufHandle("A", {N}, kFloat));
Buffer b(BufHandle("B", {N}, kFloat));
Buffer c(BufHandle("C", {N}, kFloat));
std::vector<float> a_buffer(N, 41);
std::vector<float> b_buffer(N, 1);
std::vector<float> c_buffer(N, 1);
auto mask = IntImm::make(1);
VarHandle i("i", kInt);
auto expr = For::make(
i,
0,
N,
Store::make(
c,
{i},
Min::make(Load::make(a, {i}, mask), Load::make(b, {i}, mask), true),
mask));
LLVMCodeGen cg(expr, {a, b, c});
std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
ASSERT_EQ(cg.value<int>(args), 0);
ASSERT_EQ(a_buffer.size(), N);
ASSERT_EQ(b_buffer.size(), N);
ASSERT_EQ(c_buffer.size(), N);
assertAllEqual(a_buffer, 41.0f);
assertAllEqual(b_buffer, 1.0f);
assertAllEqual(c_buffer, 1.0f);
}
void testLLVMElemwiseMinimumNaNFloat() {
KernelScope kernel_scope;
constexpr int N = 1024;
Buffer a(BufHandle("A", {N}, kFloat));
Buffer b(BufHandle("B", {N}, kFloat));
Buffer c(BufHandle("C", {N}, kFloat));
std::vector<float> a_buffer(N, NAN);
std::vector<float> b_buffer(N, 1);
std::vector<float> c_buffer(N, 1);
auto mask = IntImm::make(1);
VarHandle i("i", kInt);
auto expr = For::make(
i,
0,
N,
Store::make(
c,
{i},
Min::make(Load::make(a, {i}, mask), Load::make(b, {i}, mask), true),
mask));
LLVMCodeGen cg(expr, {a, b, c});
std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
ASSERT_EQ(cg.value<int>(args), 0);
ASSERT_EQ(a_buffer.size(), N);
ASSERT_EQ(b_buffer.size(), N);
ASSERT_EQ(c_buffer.size(), N);
for (int i = 0; i < N; ++i) {
ASSERT_TRUE(std::isnan(a_buffer[i]));
ASSERT_TRUE(std::isnan(c_buffer[i]));
}
}
#endif
void testLLVMCompareSelectIntEQ() {
KernelScope kernel_scope;
constexpr int N = 1024;

View File

@ -349,10 +349,10 @@ namespace jit {
_(LLVMElemwiseLog10Float) \
_(LLVMElemwiseMaxInt) \
_(LLVMElemwiseMinInt) \
_(LLVMElemwiseMaxNumFloat) \
_(LLVMElemwiseMaxNumNaNFloat) \
_(LLVMElemwiseMinNumFloat) \
_(LLVMElemwiseMinNumNaNFloat) \
_(LLVMElemwiseMaxFloat) \
_(LLVMElemwiseMaxNaNFloat) \
_(LLVMElemwiseMinFloat) \
_(LLVMElemwiseMinNaNFloat) \
_(LLVMCompareSelectIntEQ) \
_(LLVMCompareSelectFloatEQ) \
_(LLVMStoreFloat) \

View File

@ -362,6 +362,28 @@ class TestTEFuser(JitTestCase):
ge = self.checkScript(fn, inputs)
self.assertAllFused(ge.graph_for(*inputs))
def test_minmax(self):
def tmax(a, b):
return torch.max(2 * a, b)
def tmin(a, b):
return torch.min(2 * a, b)
a = torch.randn(4, 4, dtype=torch.float)
b = torch.randn(4, 4, dtype=torch.float)
nan = torch.tensor(float('nan'), dtype=torch.float)
devices = ["cpu"]
if torch.cuda.is_available():
devices.append("cuda")
for f, inputs, device in product(
(tmax, tmin),
([a, b], [a, nan], [b, nan]),
devices):
inputs = [t.to(device) for t in inputs]
s = self.checkScript(f, inputs)
self.assertAllFused(s.graph_for(*inputs))
# TODO: reenable the test after backwards passes start working in PE
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
@unittest.skip("temporarily disable")

View File

@ -874,11 +874,10 @@ class TestTensorExprFuser(BaseTestClass):
test_lgamma,
test_reciprocal,
test_neg,
# TODO: properly handle NaNs in Max/Min and reenable these tests:
# test_threshold,
# test_relu,
# test_tanh,
# test_sigmoid,
test_threshold,
test_relu,
test_tanh,
test_sigmoid,
}
device_options = ["cpu", "cuda"] if torch.cuda.is_available() else ['cpu']
@ -939,9 +938,9 @@ class TestTensorExprFuser(BaseTestClass):
x = torch.tensor([np.nan])
y = torch.tensor([1.0])
assert not np.isnan(tmin(x, y).item())
assert np.isnan(tmin(x, y).item())
assert np.isnan(tmin(y, x).item())
assert not np.isnan(tmax(x, y).item())
assert np.isnan(tmax(x, y).item())
assert np.isnan(tmax(y, x).item())

View File

@ -388,21 +388,7 @@ void CudaPrinter::visit(const AtomicAdd* v) {
}
void CudaPrinter::visit(const Max* v) {
auto dtype = v->dtype().scalar_type();
switch (dtype) {
case ScalarType::Half:
// doing Half math in float.
case ScalarType::Float:
os() << "fmaxf";
break;
case ScalarType::Double:
os() << "fmax";
break;
default:
os() << "max";
break;
}
os() << "(";
os() << "maximum(";
v->lhs()->accept(this);
os() << ",";
v->rhs()->accept(this);
@ -410,21 +396,7 @@ void CudaPrinter::visit(const Max* v) {
}
void CudaPrinter::visit(const Min* v) {
auto dtype = v->dtype().scalar_type();
switch (dtype) {
case ScalarType::Half:
// doing Half math in float.
case ScalarType::Float:
os() << "fminf";
break;
case ScalarType::Double:
os() << "fmin";
break;
default:
os() << "min";
break;
}
os() << "(";
os() << "minimum(";
v->lhs()->accept(this);
os() << ",";
v->rhs()->accept(this);
@ -831,6 +803,23 @@ static std::ostream& operator<<(
return out;
}
static const char* resource_string = R"(
#define NAN __int_as_float(0x7fffffff)
#define POS_INFINITY __int_as_float(0x7f800000)
#define NEG_INFINITY __int_as_float(0xff800000)
template<typename T>
T maximum(T a, T b) {
return isnan(a) ? a : (a > b ? a : b);
}
template<typename T>
T minimum(T a, T b) {
return isnan(a) ? a : (a < b ? a : b);
}
)";
void CudaCodeGen::Initialize() {
// TODO: handle multiple kernels.
// TODO: handle dynamic dimension.
@ -846,9 +835,8 @@ void CudaCodeGen::Initialize() {
std::make_unique<CudaPrinter>(&oss_, cuda_analysis_.get(), has_random_);
metavar_rewriter_ = std::make_unique<GPUMetaVarRewriter>();
os() << "#define NAN __int_as_float(0x7fffffff)\n"
"#define POS_INFINITY __int_as_float(0x7f800000)\n"
"#define NEG_INFINITY __int_as_float(0xff800000)\n";
os() << resource_string;
if (has_random_) {
os() << philox_random_string << std::endl;
}

View File

@ -226,11 +226,35 @@ class SimpleIREvaluator : public CodeGen, public IRVisitor {
}
template <typename T>
Value binary_op(
const Value& lhs,
const Value& rhs,
IRNodeType op_type,
bool option = false) {
typename std::enable_if_t<std::is_floating_point<T>::value, T> max_value(
T a,
T b) {
return std::isnan(a) ? a : (std::isnan(b) ? b : (a < b ? b : a));
}
template <typename T>
typename std::enable_if_t<!std::is_floating_point<T>::value, T> max_value(
T a,
T b) {
return a < b ? b : a;
}
template <typename T>
typename std::enable_if_t<std::is_floating_point<T>::value, T> min_value(
T a,
T b) {
return std::isnan(a) ? a : (std::isnan(b) ? b : (a < b ? a : b));
}
template <typename T>
typename std::enable_if_t<!std::is_floating_point<T>::value, T> min_value(
T a,
T b) {
return a < b ? a : b;
}
template <typename T>
Value binary_op(const Value& lhs, const Value& rhs, IRNodeType op_type) {
std::vector<T> lhs_v = lhs.as_vec<T>();
std::vector<T> rhs_v = rhs.as_vec<T>();
std::vector<T> result_v(lhs_v.size());
@ -252,30 +276,10 @@ class SimpleIREvaluator : public CodeGen, public IRVisitor {
result_v[i] = mod_value(lhs_v[i], rhs_v[i]);
break;
case IRNodeType::kMax:
if (option) {
// Propagate NaNs
if (is_floating_point(lhs.dtype().scalar_type()) &&
is_floating_point(rhs.dtype().scalar_type())) {
result_v[i] = lhs_v[i];
} else if (std::isnan((float)rhs_v[i])) {
result_v[i] = rhs_v[i];
}
} else {
result_v[i] = lhs_v[i] > rhs_v[i] ? lhs_v[i] : rhs_v[i];
}
result_v[i] = max_value(lhs_v[i], rhs_v[i]);
break;
case IRNodeType::kMin:
if (option) {
// Propagate NaNs
if (is_floating_point(lhs.dtype().scalar_type()) &&
is_floating_point(rhs.dtype().scalar_type())) {
result_v[i] = lhs_v[i];
} else if (std::isnan((float)rhs_v[i])) {
result_v[i] = rhs_v[i];
}
} else {
result_v[i] = lhs_v[i] < rhs_v[i] ? lhs_v[i] : rhs_v[i];
}
result_v[i] = min_value(lhs_v[i], rhs_v[i]);
break;
default:
// TODO: change to a proper error report

View File

@ -853,7 +853,8 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) {
case aten::relu: {
return computeOneOperand("aten_relu", v, [](const ExprHandle& a) {
return Max::make(a, 0, false);
auto zero = Cast::make(a.dtype(), 0);
return ifThenElse(CompareSelect::make(a, zero, kLT), zero, a);
});
} break;
@ -1081,7 +1082,7 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) {
[](const ExprHandle& a,
const ExprHandle& threshold,
const ExprHandle& value) {
return ifThenElse(CompareSelect::make(a, threshold, kGT), a, value);
return ifThenElse(CompareSelect::make(a, threshold, kLE), value, a);
});
} break;

View File

@ -623,13 +623,14 @@ void LLVMCodeGenImpl::visit(const Max* v) {
return;
}
if (v->propagate_nans()) {
value_ = irb_.CreateBinaryIntrinsic(llvm::Intrinsic::maximum, lhs, rhs);
return;
}
value_ = irb_.CreateSelect(
irb_.CreateFCmp(llvm::FCmpInst::FCMP_OGT, lhs, rhs), lhs, rhs);
irb_.CreateFCmp(
llvm::FCmpInst::FCMP_UNO,
lhs,
llvm::ConstantFP::get(lhs->getType(), 0.0)),
lhs,
irb_.CreateSelect(
irb_.CreateFCmp(llvm::FCmpInst::FCMP_OGT, lhs, rhs), lhs, rhs));
}
void LLVMCodeGenImpl::visit(const Min* v) {
@ -644,13 +645,14 @@ void LLVMCodeGenImpl::visit(const Min* v) {
return;
}
if (v->propagate_nans()) {
value_ = irb_.CreateBinaryIntrinsic(llvm::Intrinsic::minimum, lhs, rhs);
return;
}
value_ = irb_.CreateSelect(
irb_.CreateFCmp(llvm::FCmpInst::FCMP_OLT, lhs, rhs), lhs, rhs);
irb_.CreateFCmp(
llvm::FCmpInst::FCMP_UNO,
lhs,
llvm::ConstantFP::get(lhs->getType(), 0.0)),
lhs,
irb_.CreateSelect(
irb_.CreateFCmp(llvm::FCmpInst::FCMP_OLT, lhs, rhs), lhs, rhs));
}
void LLVMCodeGenImpl::visit(const CompareSelect* v) {