mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63195 This helps us to later switch from using KernelArena with raw pointers to shared pointers without having to change all our source files at once. The changes are mechanical and should not affect any functionality. With this PR, we're changing the following: * `Add*` --> `AddPtr` * `new Add(...)` --> `alloc<Add>(...)` * `dynamic_cast<Add*>` --> `to<Add>` * `static_cast<Add*>` --> `static_to<Add>` Due to some complications with args forwarding, some places became more verbose, e.g.: * `new Block({})` --> `new Block(std::vector<ExprPtr>())` Test Plan: Imported from OSS Reviewed By: navahgar Differential Revision: D30292779 Pulled By: ZolotukhinM fbshipit-source-id: 150301c7d2df56b608b035827b6a9a87f5e2d9e9
98 lines
3.0 KiB
C++
98 lines
3.0 KiB
C++
#ifdef TORCH_ENABLE_LLVM
|
|
|
|
#include <gtest/gtest.h>
|
|
#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
|
|
#include <torch/csrc/jit/tensorexpr/llvm_codegen.h>
|
|
#include <torch/csrc/jit/tensorexpr/loopnest.h>
|
|
#include <torch/csrc/jit/tensorexpr/tensor.h>
|
|
#include <torch/torch.h>
|
|
#include <cstring>
|
|
|
|
using namespace torch::indexing;
|
|
namespace te = torch::jit::tensorexpr;
|
|
|
|
static void vectorize(te::LoopNest* ln, te::Tensor* target, int width) {
|
|
auto loops = ln->getLoopStmtsFor(target);
|
|
te::ForPtr inner, tail;
|
|
ln->splitWithTail(loops[0], width, &inner, &tail);
|
|
ASSERT_TRUE(te::LoopNest::vectorize(inner));
|
|
}
|
|
|
|
std::string diffs(const at::Tensor& a, const at::Tensor& b) {
|
|
auto diff = torch::abs(a.flatten() - b.flatten());
|
|
auto count_diffs = torch::sum(diff > 0.f);
|
|
auto greatest_diff_index = torch::argmax(diff);
|
|
std::stringstream ss;
|
|
ss << "Found " << count_diffs << " unequal element(s). "
|
|
<< "The greatest difference was " << diff.index({greatest_diff_index})
|
|
<< " at index " << greatest_diff_index;
|
|
return ss.str();
|
|
}
|
|
|
|
TEST(Approx, log_vml) {
|
|
te::KernelScope ks;
|
|
te::VarHandle N("N", te::kInt);
|
|
te::Placeholder A("A", te::kFloat, {N});
|
|
te::Tensor* B = te::Compute(
|
|
"B", {N}, [&](const te::VarHandle& i) { return log_vml(A.load(i)); });
|
|
|
|
te::LoopNest ln({B});
|
|
ln.prepareForCodegen();
|
|
vectorize(&ln, B, 8);
|
|
te::StmtPtr s = ln.root_stmt();
|
|
s = te::IRSimplifier::simplify(s);
|
|
te::LLVMCodeGen cg(s, {A, B, N});
|
|
|
|
auto eps = std::numeric_limits<float>::epsilon();
|
|
auto test = [&](const at::Tensor& A_t) {
|
|
at::Tensor B_ref = at::log(A_t);
|
|
at::Tensor B_t = at::empty_like(A_t);
|
|
auto ap = A_t.data_ptr<float>();
|
|
auto bp = B_t.data_ptr<float>();
|
|
cg.call({ap, bp, A_t.numel()});
|
|
// Results should be bit-identical.
|
|
ASSERT_TRUE(torch::allclose(
|
|
B_t, B_ref, /*rtol=*/eps, /*atol=*/0.0f, /*equal_nan=*/true))
|
|
<< "Input[:8]\n"
|
|
<< A_t.index({Slice(0, 8)}) << "\n"
|
|
<< "Test[:8]\n"
|
|
<< B_t.index({Slice(0, 8)}) << "\n"
|
|
<< "Ref[:8]\n"
|
|
<< B_ref.index({Slice(0, 8)}) << diffs(B_t, B_ref);
|
|
};
|
|
|
|
// Generate every single-precision FP value in [1.0, 2.0).
|
|
at::Tensor A_t = torch::arange(1.0f, 2.0f, eps);
|
|
ASSERT_EQ(A_t.numel(), 1 << 23);
|
|
|
|
test(A_t);
|
|
|
|
test(A_t * 2.0f);
|
|
test(A_t * 0.5f);
|
|
|
|
test(A_t * 4.0f);
|
|
test(A_t * 0.25f);
|
|
|
|
test(A_t * powf(2.0f, 16));
|
|
test(A_t * powf(2.0f, -16));
|
|
|
|
test(A_t * powf(2.0f, 126));
|
|
test(A_t * powf(2.0f, -126));
|
|
|
|
test(torch::full({32}, INFINITY));
|
|
test(torch::full({32}, NAN));
|
|
|
|
auto min = std::numeric_limits<float>::min();
|
|
auto denorm_min = std::numeric_limits<float>::denorm_min();
|
|
|
|
// Denormals aren't bit precise, because sleef isn't bit-precise either.
|
|
A_t = torch::arange(0.0f, min, denorm_min);
|
|
ASSERT_EQ(A_t.numel(), 1 << 23);
|
|
auto B_ref = at::log(A_t);
|
|
auto B_t = at::empty_like(B_ref);
|
|
cg.call({A_t.data_ptr<float>(), B_t.data_ptr<float>(), A_t.numel()});
|
|
ASSERT_TRUE(torch::allclose(B_t, B_ref));
|
|
}
|
|
|
|
#endif // TORCH_ENABLE_LLVM
|