mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fix NaN propagation in TE fuser's min/max implementation (#43609)
Summary: Per eager mode source-of-truth, NaNs shall be propagated by min/max. Pull Request resolved: https://github.com/pytorch/pytorch/pull/43609 Reviewed By: ZolotukhinM Differential Revision: D23349184 Pulled By: bertmaher fbshipit-source-id: 094eb8b89a02b27d5ecf3988d0f473c0f91e4afb
This commit is contained in:
parent
820c4b05a9
commit
c14a3613a8
|
|
@ -13,6 +13,7 @@
|
|||
#include "torch/csrc/jit/tensorexpr/loopnest.h"
|
||||
#include "torch/csrc/jit/tensorexpr/tensor.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <numeric>
|
||||
|
||||
namespace torch {
|
||||
|
|
@ -649,7 +650,7 @@ void testLLVMElemwiseMinInt() {
|
|||
assertAllEqual(c_buffer, 1);
|
||||
}
|
||||
|
||||
void testLLVMElemwiseMaxNumFloat() {
|
||||
void testLLVMElemwiseMaxFloat() {
|
||||
KernelScope kernel_scope;
|
||||
constexpr int N = 1024;
|
||||
Buffer a(BufHandle("A", {N}, kFloat));
|
||||
|
|
@ -684,7 +685,7 @@ void testLLVMElemwiseMaxNumFloat() {
|
|||
assertAllEqual(c_buffer, 41.0f);
|
||||
}
|
||||
|
||||
void testLLVMElemwiseMaxNumNaNFloat() {
|
||||
void testLLVMElemwiseMaxNaNFloat() {
|
||||
KernelScope kernel_scope;
|
||||
constexpr int N = 1024;
|
||||
Buffer a(BufHandle("A", {N}, kFloat));
|
||||
|
|
@ -715,10 +716,12 @@ void testLLVMElemwiseMaxNumNaNFloat() {
|
|||
ASSERT_EQ(b_buffer.size(), N);
|
||||
ASSERT_EQ(c_buffer.size(), N);
|
||||
assertAllEqual(b_buffer, 1.0f);
|
||||
assertAllEqual(c_buffer, 1.0f);
|
||||
for (auto const& elt : c_buffer) {
|
||||
ASSERT_TRUE(std::isnan(elt));
|
||||
}
|
||||
}
|
||||
|
||||
void testLLVMElemwiseMinNumFloat() {
|
||||
void testLLVMElemwiseMinFloat() {
|
||||
KernelScope kernel_scope;
|
||||
constexpr int N = 1024;
|
||||
Buffer a(BufHandle("A", {N}, kFloat));
|
||||
|
|
@ -753,7 +756,7 @@ void testLLVMElemwiseMinNumFloat() {
|
|||
assertAllEqual(c_buffer, 1.0f);
|
||||
}
|
||||
|
||||
void testLLVMElemwiseMinNumNaNFloat() {
|
||||
void testLLVMElemwiseMinNaNFloat() {
|
||||
KernelScope kernel_scope;
|
||||
constexpr int N = 1024;
|
||||
Buffer a(BufHandle("A", {N}, kFloat));
|
||||
|
|
@ -784,153 +787,11 @@ void testLLVMElemwiseMinNumNaNFloat() {
|
|||
ASSERT_EQ(b_buffer.size(), N);
|
||||
ASSERT_EQ(c_buffer.size(), N);
|
||||
assertAllEqual(b_buffer, 1.0f);
|
||||
assertAllEqual(c_buffer, 1.0f);
|
||||
}
|
||||
|
||||
#if 1 // LLVM doesn't currently have implementations for maximum/minimum on x86
|
||||
void testLLVMElemwiseMaximumFloat() {
|
||||
KernelScope kernel_scope;
|
||||
constexpr int N = 1024;
|
||||
Buffer a(BufHandle("A", {N}, kFloat));
|
||||
Buffer b(BufHandle("B", {N}, kFloat));
|
||||
Buffer c(BufHandle("C", {N}, kFloat));
|
||||
std::vector<float> a_buffer(N, 41);
|
||||
std::vector<float> b_buffer(N, 1);
|
||||
std::vector<float> c_buffer(N, 1);
|
||||
|
||||
auto mask = IntImm::make(1);
|
||||
VarHandle i("i", kInt);
|
||||
auto expr = For::make(
|
||||
i,
|
||||
0,
|
||||
N,
|
||||
Store::make(
|
||||
c,
|
||||
{i},
|
||||
Max::make(Load::make(a, {i}, mask), Load::make(b, {i}, mask), true),
|
||||
mask));
|
||||
|
||||
LLVMCodeGen cg(expr, {a, b, c});
|
||||
|
||||
std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
|
||||
ASSERT_EQ(cg.value<int>(args), 0);
|
||||
|
||||
ASSERT_EQ(a_buffer.size(), N);
|
||||
ASSERT_EQ(b_buffer.size(), N);
|
||||
ASSERT_EQ(c_buffer.size(), N);
|
||||
assertAllEqual(a_buffer, 41.0f);
|
||||
assertAllEqual(b_buffer, 1.0f);
|
||||
assertAllEqual(c_buffer, 41.0f);
|
||||
}
|
||||
|
||||
void testLLVMElemwiseMaximumNaNFloat() {
|
||||
KernelScope kernel_scope;
|
||||
constexpr int N = 1024;
|
||||
Buffer a(BufHandle("A", {N}, kFloat));
|
||||
Buffer b(BufHandle("B", {N}, kFloat));
|
||||
Buffer c(BufHandle("C", {N}, kFloat));
|
||||
std::vector<float> a_buffer(N, NAN);
|
||||
std::vector<float> b_buffer(N, 1);
|
||||
std::vector<float> c_buffer(N, 1);
|
||||
|
||||
auto mask = IntImm::make(1);
|
||||
VarHandle i("i", kInt);
|
||||
auto expr = For::make(
|
||||
i,
|
||||
0,
|
||||
N,
|
||||
Store::make(
|
||||
c,
|
||||
{i},
|
||||
Max::make(Load::make(a, {i}, mask), Load::make(b, {i}, mask), true),
|
||||
mask));
|
||||
|
||||
LLVMCodeGen cg(expr, {a, b, c});
|
||||
|
||||
std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
|
||||
ASSERT_EQ(cg.value<int>(args), 0);
|
||||
|
||||
ASSERT_EQ(a_buffer.size(), N);
|
||||
ASSERT_EQ(b_buffer.size(), N);
|
||||
ASSERT_EQ(c_buffer.size(), N);
|
||||
for (int i = 0; i < N; ++i) {
|
||||
ASSERT_TRUE(std::isnan(a_buffer[i]));
|
||||
ASSERT_TRUE(std::isnan(c_buffer[i]));
|
||||
for (auto const& elt : c_buffer) {
|
||||
ASSERT_TRUE(std::isnan(elt));
|
||||
}
|
||||
}
|
||||
|
||||
void testLLVMElemwiseMinimumFloat() {
|
||||
KernelScope kernel_scope;
|
||||
constexpr int N = 1024;
|
||||
Buffer a(BufHandle("A", {N}, kFloat));
|
||||
Buffer b(BufHandle("B", {N}, kFloat));
|
||||
Buffer c(BufHandle("C", {N}, kFloat));
|
||||
std::vector<float> a_buffer(N, 41);
|
||||
std::vector<float> b_buffer(N, 1);
|
||||
std::vector<float> c_buffer(N, 1);
|
||||
|
||||
auto mask = IntImm::make(1);
|
||||
VarHandle i("i", kInt);
|
||||
auto expr = For::make(
|
||||
i,
|
||||
0,
|
||||
N,
|
||||
Store::make(
|
||||
c,
|
||||
{i},
|
||||
Min::make(Load::make(a, {i}, mask), Load::make(b, {i}, mask), true),
|
||||
mask));
|
||||
|
||||
LLVMCodeGen cg(expr, {a, b, c});
|
||||
|
||||
std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
|
||||
ASSERT_EQ(cg.value<int>(args), 0);
|
||||
|
||||
ASSERT_EQ(a_buffer.size(), N);
|
||||
ASSERT_EQ(b_buffer.size(), N);
|
||||
ASSERT_EQ(c_buffer.size(), N);
|
||||
assertAllEqual(a_buffer, 41.0f);
|
||||
assertAllEqual(b_buffer, 1.0f);
|
||||
assertAllEqual(c_buffer, 1.0f);
|
||||
}
|
||||
|
||||
void testLLVMElemwiseMinimumNaNFloat() {
|
||||
KernelScope kernel_scope;
|
||||
constexpr int N = 1024;
|
||||
Buffer a(BufHandle("A", {N}, kFloat));
|
||||
Buffer b(BufHandle("B", {N}, kFloat));
|
||||
Buffer c(BufHandle("C", {N}, kFloat));
|
||||
std::vector<float> a_buffer(N, NAN);
|
||||
std::vector<float> b_buffer(N, 1);
|
||||
std::vector<float> c_buffer(N, 1);
|
||||
|
||||
auto mask = IntImm::make(1);
|
||||
VarHandle i("i", kInt);
|
||||
auto expr = For::make(
|
||||
i,
|
||||
0,
|
||||
N,
|
||||
Store::make(
|
||||
c,
|
||||
{i},
|
||||
Min::make(Load::make(a, {i}, mask), Load::make(b, {i}, mask), true),
|
||||
mask));
|
||||
|
||||
LLVMCodeGen cg(expr, {a, b, c});
|
||||
|
||||
std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
|
||||
ASSERT_EQ(cg.value<int>(args), 0);
|
||||
|
||||
ASSERT_EQ(a_buffer.size(), N);
|
||||
ASSERT_EQ(b_buffer.size(), N);
|
||||
ASSERT_EQ(c_buffer.size(), N);
|
||||
for (int i = 0; i < N; ++i) {
|
||||
ASSERT_TRUE(std::isnan(a_buffer[i]));
|
||||
ASSERT_TRUE(std::isnan(c_buffer[i]));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
void testLLVMCompareSelectIntEQ() {
|
||||
KernelScope kernel_scope;
|
||||
constexpr int N = 1024;
|
||||
|
|
|
|||
|
|
@ -349,10 +349,10 @@ namespace jit {
|
|||
_(LLVMElemwiseLog10Float) \
|
||||
_(LLVMElemwiseMaxInt) \
|
||||
_(LLVMElemwiseMinInt) \
|
||||
_(LLVMElemwiseMaxNumFloat) \
|
||||
_(LLVMElemwiseMaxNumNaNFloat) \
|
||||
_(LLVMElemwiseMinNumFloat) \
|
||||
_(LLVMElemwiseMinNumNaNFloat) \
|
||||
_(LLVMElemwiseMaxFloat) \
|
||||
_(LLVMElemwiseMaxNaNFloat) \
|
||||
_(LLVMElemwiseMinFloat) \
|
||||
_(LLVMElemwiseMinNaNFloat) \
|
||||
_(LLVMCompareSelectIntEQ) \
|
||||
_(LLVMCompareSelectFloatEQ) \
|
||||
_(LLVMStoreFloat) \
|
||||
|
|
|
|||
|
|
@ -362,6 +362,28 @@ class TestTEFuser(JitTestCase):
|
|||
ge = self.checkScript(fn, inputs)
|
||||
self.assertAllFused(ge.graph_for(*inputs))
|
||||
|
||||
def test_minmax(self):
|
||||
def tmax(a, b):
|
||||
return torch.max(2 * a, b)
|
||||
|
||||
def tmin(a, b):
|
||||
return torch.min(2 * a, b)
|
||||
|
||||
a = torch.randn(4, 4, dtype=torch.float)
|
||||
b = torch.randn(4, 4, dtype=torch.float)
|
||||
nan = torch.tensor(float('nan'), dtype=torch.float)
|
||||
|
||||
devices = ["cpu"]
|
||||
if torch.cuda.is_available():
|
||||
devices.append("cuda")
|
||||
for f, inputs, device in product(
|
||||
(tmax, tmin),
|
||||
([a, b], [a, nan], [b, nan]),
|
||||
devices):
|
||||
inputs = [t.to(device) for t in inputs]
|
||||
s = self.checkScript(f, inputs)
|
||||
self.assertAllFused(s.graph_for(*inputs))
|
||||
|
||||
# TODO: reenable the test after backwards passes start working in PE
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
@unittest.skip("temporarily disable")
|
||||
|
|
|
|||
|
|
@ -874,11 +874,10 @@ class TestTensorExprFuser(BaseTestClass):
|
|||
test_lgamma,
|
||||
test_reciprocal,
|
||||
test_neg,
|
||||
# TODO: properly handle NaNs in Max/Min and reenable these tests:
|
||||
# test_threshold,
|
||||
# test_relu,
|
||||
# test_tanh,
|
||||
# test_sigmoid,
|
||||
test_threshold,
|
||||
test_relu,
|
||||
test_tanh,
|
||||
test_sigmoid,
|
||||
}
|
||||
device_options = ["cpu", "cuda"] if torch.cuda.is_available() else ['cpu']
|
||||
|
||||
|
|
@ -939,9 +938,9 @@ class TestTensorExprFuser(BaseTestClass):
|
|||
x = torch.tensor([np.nan])
|
||||
y = torch.tensor([1.0])
|
||||
|
||||
assert not np.isnan(tmin(x, y).item())
|
||||
assert np.isnan(tmin(x, y).item())
|
||||
assert np.isnan(tmin(y, x).item())
|
||||
assert not np.isnan(tmax(x, y).item())
|
||||
assert np.isnan(tmax(x, y).item())
|
||||
assert np.isnan(tmax(y, x).item())
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -388,21 +388,7 @@ void CudaPrinter::visit(const AtomicAdd* v) {
|
|||
}
|
||||
|
||||
void CudaPrinter::visit(const Max* v) {
|
||||
auto dtype = v->dtype().scalar_type();
|
||||
switch (dtype) {
|
||||
case ScalarType::Half:
|
||||
// doing Half math in float.
|
||||
case ScalarType::Float:
|
||||
os() << "fmaxf";
|
||||
break;
|
||||
case ScalarType::Double:
|
||||
os() << "fmax";
|
||||
break;
|
||||
default:
|
||||
os() << "max";
|
||||
break;
|
||||
}
|
||||
os() << "(";
|
||||
os() << "maximum(";
|
||||
v->lhs()->accept(this);
|
||||
os() << ",";
|
||||
v->rhs()->accept(this);
|
||||
|
|
@ -410,21 +396,7 @@ void CudaPrinter::visit(const Max* v) {
|
|||
}
|
||||
|
||||
void CudaPrinter::visit(const Min* v) {
|
||||
auto dtype = v->dtype().scalar_type();
|
||||
switch (dtype) {
|
||||
case ScalarType::Half:
|
||||
// doing Half math in float.
|
||||
case ScalarType::Float:
|
||||
os() << "fminf";
|
||||
break;
|
||||
case ScalarType::Double:
|
||||
os() << "fmin";
|
||||
break;
|
||||
default:
|
||||
os() << "min";
|
||||
break;
|
||||
}
|
||||
os() << "(";
|
||||
os() << "minimum(";
|
||||
v->lhs()->accept(this);
|
||||
os() << ",";
|
||||
v->rhs()->accept(this);
|
||||
|
|
@ -831,6 +803,23 @@ static std::ostream& operator<<(
|
|||
return out;
|
||||
}
|
||||
|
||||
static const char* resource_string = R"(
|
||||
#define NAN __int_as_float(0x7fffffff)
|
||||
#define POS_INFINITY __int_as_float(0x7f800000)
|
||||
#define NEG_INFINITY __int_as_float(0xff800000)
|
||||
|
||||
template<typename T>
|
||||
T maximum(T a, T b) {
|
||||
return isnan(a) ? a : (a > b ? a : b);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
T minimum(T a, T b) {
|
||||
return isnan(a) ? a : (a < b ? a : b);
|
||||
}
|
||||
|
||||
)";
|
||||
|
||||
void CudaCodeGen::Initialize() {
|
||||
// TODO: handle multiple kernels.
|
||||
// TODO: handle dynamic dimension.
|
||||
|
|
@ -846,9 +835,8 @@ void CudaCodeGen::Initialize() {
|
|||
std::make_unique<CudaPrinter>(&oss_, cuda_analysis_.get(), has_random_);
|
||||
metavar_rewriter_ = std::make_unique<GPUMetaVarRewriter>();
|
||||
|
||||
os() << "#define NAN __int_as_float(0x7fffffff)\n"
|
||||
"#define POS_INFINITY __int_as_float(0x7f800000)\n"
|
||||
"#define NEG_INFINITY __int_as_float(0xff800000)\n";
|
||||
os() << resource_string;
|
||||
|
||||
if (has_random_) {
|
||||
os() << philox_random_string << std::endl;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -226,11 +226,35 @@ class SimpleIREvaluator : public CodeGen, public IRVisitor {
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
Value binary_op(
|
||||
const Value& lhs,
|
||||
const Value& rhs,
|
||||
IRNodeType op_type,
|
||||
bool option = false) {
|
||||
typename std::enable_if_t<std::is_floating_point<T>::value, T> max_value(
|
||||
T a,
|
||||
T b) {
|
||||
return std::isnan(a) ? a : (std::isnan(b) ? b : (a < b ? b : a));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
typename std::enable_if_t<!std::is_floating_point<T>::value, T> max_value(
|
||||
T a,
|
||||
T b) {
|
||||
return a < b ? b : a;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
typename std::enable_if_t<std::is_floating_point<T>::value, T> min_value(
|
||||
T a,
|
||||
T b) {
|
||||
return std::isnan(a) ? a : (std::isnan(b) ? b : (a < b ? a : b));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
typename std::enable_if_t<!std::is_floating_point<T>::value, T> min_value(
|
||||
T a,
|
||||
T b) {
|
||||
return a < b ? a : b;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Value binary_op(const Value& lhs, const Value& rhs, IRNodeType op_type) {
|
||||
std::vector<T> lhs_v = lhs.as_vec<T>();
|
||||
std::vector<T> rhs_v = rhs.as_vec<T>();
|
||||
std::vector<T> result_v(lhs_v.size());
|
||||
|
|
@ -252,30 +276,10 @@ class SimpleIREvaluator : public CodeGen, public IRVisitor {
|
|||
result_v[i] = mod_value(lhs_v[i], rhs_v[i]);
|
||||
break;
|
||||
case IRNodeType::kMax:
|
||||
if (option) {
|
||||
// Propagate NaNs
|
||||
if (is_floating_point(lhs.dtype().scalar_type()) &&
|
||||
is_floating_point(rhs.dtype().scalar_type())) {
|
||||
result_v[i] = lhs_v[i];
|
||||
} else if (std::isnan((float)rhs_v[i])) {
|
||||
result_v[i] = rhs_v[i];
|
||||
}
|
||||
} else {
|
||||
result_v[i] = lhs_v[i] > rhs_v[i] ? lhs_v[i] : rhs_v[i];
|
||||
}
|
||||
result_v[i] = max_value(lhs_v[i], rhs_v[i]);
|
||||
break;
|
||||
case IRNodeType::kMin:
|
||||
if (option) {
|
||||
// Propagate NaNs
|
||||
if (is_floating_point(lhs.dtype().scalar_type()) &&
|
||||
is_floating_point(rhs.dtype().scalar_type())) {
|
||||
result_v[i] = lhs_v[i];
|
||||
} else if (std::isnan((float)rhs_v[i])) {
|
||||
result_v[i] = rhs_v[i];
|
||||
}
|
||||
} else {
|
||||
result_v[i] = lhs_v[i] < rhs_v[i] ? lhs_v[i] : rhs_v[i];
|
||||
}
|
||||
result_v[i] = min_value(lhs_v[i], rhs_v[i]);
|
||||
break;
|
||||
default:
|
||||
// TODO: change to a proper error report
|
||||
|
|
|
|||
|
|
@ -853,7 +853,8 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) {
|
|||
|
||||
case aten::relu: {
|
||||
return computeOneOperand("aten_relu", v, [](const ExprHandle& a) {
|
||||
return Max::make(a, 0, false);
|
||||
auto zero = Cast::make(a.dtype(), 0);
|
||||
return ifThenElse(CompareSelect::make(a, zero, kLT), zero, a);
|
||||
});
|
||||
} break;
|
||||
|
||||
|
|
@ -1081,7 +1082,7 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) {
|
|||
[](const ExprHandle& a,
|
||||
const ExprHandle& threshold,
|
||||
const ExprHandle& value) {
|
||||
return ifThenElse(CompareSelect::make(a, threshold, kGT), a, value);
|
||||
return ifThenElse(CompareSelect::make(a, threshold, kLE), value, a);
|
||||
});
|
||||
} break;
|
||||
|
||||
|
|
|
|||
|
|
@ -623,13 +623,14 @@ void LLVMCodeGenImpl::visit(const Max* v) {
|
|||
return;
|
||||
}
|
||||
|
||||
if (v->propagate_nans()) {
|
||||
value_ = irb_.CreateBinaryIntrinsic(llvm::Intrinsic::maximum, lhs, rhs);
|
||||
return;
|
||||
}
|
||||
|
||||
value_ = irb_.CreateSelect(
|
||||
irb_.CreateFCmp(llvm::FCmpInst::FCMP_OGT, lhs, rhs), lhs, rhs);
|
||||
irb_.CreateFCmp(
|
||||
llvm::FCmpInst::FCMP_UNO,
|
||||
lhs,
|
||||
llvm::ConstantFP::get(lhs->getType(), 0.0)),
|
||||
lhs,
|
||||
irb_.CreateSelect(
|
||||
irb_.CreateFCmp(llvm::FCmpInst::FCMP_OGT, lhs, rhs), lhs, rhs));
|
||||
}
|
||||
|
||||
void LLVMCodeGenImpl::visit(const Min* v) {
|
||||
|
|
@ -644,13 +645,14 @@ void LLVMCodeGenImpl::visit(const Min* v) {
|
|||
return;
|
||||
}
|
||||
|
||||
if (v->propagate_nans()) {
|
||||
value_ = irb_.CreateBinaryIntrinsic(llvm::Intrinsic::minimum, lhs, rhs);
|
||||
return;
|
||||
}
|
||||
|
||||
value_ = irb_.CreateSelect(
|
||||
irb_.CreateFCmp(llvm::FCmpInst::FCMP_OLT, lhs, rhs), lhs, rhs);
|
||||
irb_.CreateFCmp(
|
||||
llvm::FCmpInst::FCMP_UNO,
|
||||
lhs,
|
||||
llvm::ConstantFP::get(lhs->getType(), 0.0)),
|
||||
lhs,
|
||||
irb_.CreateSelect(
|
||||
irb_.CreateFCmp(llvm::FCmpInst::FCMP_OLT, lhs, rhs), lhs, rhs));
|
||||
}
|
||||
|
||||
void LLVMCodeGenImpl::visit(const CompareSelect* v) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user