Enable CPU fused kernel on Windows

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/25578

Differential Revision: D17397156

Pulled By: ezyang

fbshipit-source-id: b243528c2bfd5a0d401897833048429e67fe40ef
This commit is contained in:
peter 2019-09-17 07:27:39 -07:00 committed by Facebook Github Bot
parent bebc3d6aad
commit 2ce8c83f67
8 changed files with 408 additions and 36 deletions

View File

@ -479,11 +479,8 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
${TORCH_SRC_DIR}/csrc/jit/export.cpp
${TORCH_SRC_DIR}/csrc/jit/import_legacy.cpp
${TORCH_SRC_DIR}/csrc/jit/netdef_converter.cpp
${TORCH_SRC_DIR}/csrc/jit/fuser/cpu/fused_kernel.cpp
)
if (NOT WIN32)
list(APPEND TORCH_SRCS
${TORCH_SRC_DIR}/csrc/jit/fuser/cpu/fused_kernel.cpp)
endif()
endif()
if (USE_CUDA)

View File

@ -386,6 +386,8 @@ def get_selected_tests(options):
target_arch = os.environ.get('VSCMD_ARG_TGT_ARCH')
if target_arch != 'x64':
WINDOWS_BLACKLIST.append('cpp_extensions')
WINDOWS_BLACKLIST.append('jit')
WINDOWS_BLACKLIST.append('jit_fuser')
selected_tests = exclude_tests(WINDOWS_BLACKLIST, selected_tests, 'on Windows')

View File

@ -86,7 +86,6 @@ if torch.cuda.is_available():
RUN_CUDA_MULTI_GPU = RUN_CUDA and torch.cuda.device_count() > 1
PY35 = sys.version_info >= (3, 5)
WINDOWS = sys.platform == 'win32'
def LSTMCellF(input, hx, cx, *params):
@ -2023,7 +2022,6 @@ graph(%Ra, %Rb):
for node in g.nodes():
self.assertTrue(g2.findNode(node.kind()) is not None)
@unittest.skipIf(IS_WINDOWS, "NYI: JIT tests not yet supported on windows")
@unittest.skipIf(IS_SANDCASTLE, "gtest runs these in sandcastle")
@unittest.skipIf(RUN_CUDA, "covered by test_cpp_cuda")
@skipIfRocm
@ -2033,7 +2031,6 @@ graph(%Ra, %Rb):
torch._C._jit_run_cpp_tests(run_cuda=False)
tests_setup.shutdown()
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
@unittest.skipIf(not RUN_CUDA, "cpp tests require CUDA")
@skipIfRocm
def test_cpp_cuda(self):
@ -2058,7 +2055,6 @@ graph(%Ra, %Rb):
self.assertEqual(outputs, m(*inputs))
@unittest.skipIf(not RUN_CUDA, "test_dropout_cuda require CUDA")
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
def test_dropout_cuda(self):
# Dropout AD is dispatched to _fused_dropout in CUDA case,
# which is not included in TestJitGeneratedFunctional
@ -2226,12 +2222,11 @@ graph(%Ra, %Rb):
def test_ge_unoptimized(self):
self.run_ge_tests(False, False)
@unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle")
@unittest.skipIf(IS_SANDCASTLE, "NYI: fuser support for Sandcastle")
@enable_cpu_fuser
def test_ge_optimized(self):
self.run_ge_tests(True, False)
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
def test_ge_cuda(self):
self.run_ge_tests(True, True)
@ -2322,7 +2317,6 @@ graph(%Ra, %Rb):
self.assertGraphContains(fn.graph, kind='aten::einsum')
self.assertEqual(fn(x, y), outer(x, y))
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
@unittest.skipIf(not RUN_CUDA, "calls .cuda()")
def test_traced_module_cuda(self):
class Model(nn.Module):
@ -5062,7 +5056,7 @@ a")
self.assertEqual(cu.test_integral_shape_inference(*inputs), outputs)
@unittest.skipIf(RUN_CUDA, 'This tests the CPU fuser')
@unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle")
@unittest.skipIf(IS_SANDCASTLE, "NYI: fuser support for Sandcastle")
@enable_cpu_fuser
def test_batchnorm_fuser_cpu(self):
code = '''
@ -5091,7 +5085,7 @@ a")
FileCheck().check('sqrtf').run(code)
@unittest.skipIf(RUN_CUDA, 'This tests the CPU fuser')
@unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle")
@unittest.skipIf(IS_SANDCASTLE, "NYI: fuser support for Sandcastle")
@enable_cpu_fuser
def test_fuser_double_float_codegen(self):
fns = ['log', 'log10', 'log1p', 'log2', 'lgamma', 'exp', 'expm1', 'erf',
@ -5145,7 +5139,7 @@ a")
test_dispatch(fn, lookup_c_equivalent_fn(fn) + 'f(', torch.float, binary=True)
@unittest.skipIf(RUN_CUDA, 'This tests the CPU fuser')
@unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle")
@unittest.skipIf(IS_SANDCASTLE, "NYI: fuser support for Sandcastle")
@enable_cpu_fuser
def test_fuser_double_literal_precision(self):
code = '''
@ -9285,7 +9279,6 @@ a")
m = M()
self.assertEqual(2, m.call_foo(torch.ones((), dtype=torch.int64)))
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
def test_trace_of_script(self):
@torch.jit.script
def foo(a, c):
@ -16975,7 +16968,7 @@ def add_autograd_test(
# We enable the CPU fuser during these checks for more consistent
# behavior. Otherwise, we are going to have to analyze the graph to
# see if producer values are Dimension
@enable_cpu_fuser_if(not (IS_SANDCASTLE or IS_WINDOWS))
@enable_cpu_fuser_if(not IS_SANDCASTLE)
def check(name):
set_rng_seed(2)
is_magic_method = name[:2] == '__' and name[-2:] == '__'
@ -17010,8 +17003,7 @@ def add_autograd_test(
check_against_reference(self, traced_fn,
fn, (self_variable,) + args_variable, kwargs_variable,
check_types=check_types)
# Fuser not supported on windows
if IS_SANDCASTLE or IS_WINDOWS:
if IS_SANDCASTLE:
autodiff_nodes = autodiff_nodes + fusible_nodes
fusible_nodes = []
self.assertAutodiffNode(traced_fn.last_graph, should_autodiff_node, autodiff_nodes, fusible_nodes)
@ -17022,8 +17014,7 @@ def add_autograd_test(
fn, (self_variable,) + args_variable, kwargs_variable,
check_types=check_types)
# Fuser not supported on windows
if IS_SANDCASTLE or IS_WINDOWS:
if IS_SANDCASTLE:
autodiff_nodes = autodiff_nodes + fusible_nodes
fusible_nodes = []
self.assertAutodiffNode(script_fn.last_graph,

View File

@ -9,7 +9,7 @@ import torch.nn as nn
import torch.nn.functional as F
from torch.testing import FileCheck
from common_utils import run_tests, IS_WINDOWS, IS_SANDCASTLE
from common_utils import run_tests, IS_SANDCASTLE
from textwrap import dedent
from itertools import product, permutations
@ -37,7 +37,7 @@ class TestFuser(JitTestCase):
self.assertEqual(func(a), a.abs() * 2)
self.assertAllFused(func.graph_for(a))
@unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser CPU support for Windows or Sandcastle")
@unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle")
@enable_cpu_fuser
def test_abs_cpu(self):
self._test_fused_abs()
@ -180,7 +180,7 @@ class TestFuser(JitTestCase):
for fn in fns:
self.checkScript(fn, [tensor])
@unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser CPU support for Windows or Sandcastle")
@unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle")
@enable_cpu_fuser
def test_chunk_correctness(self):
return self._test_chunk_correctness(self, 'cpu')
@ -502,7 +502,7 @@ class TestFuser(JitTestCase):
self.assertAllFused(scripted.graph_for(x, p), except_for=("aten::size", "prim::BroadcastSizes",
"aten::_size_if_not_equal"))
@unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser CPU support for Windows or Sandcastle")
@unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle")
@enable_cpu_fuser
def test_fuser_deduplication(self):
# See that fusion kernel outputs are deduplicated when removing _grad_sum_to_size in the fuser's compilation
@ -522,7 +522,7 @@ class TestFuser(JitTestCase):
# check that a, b share storage, i.e. were generated as a single output in the fuser
self.assertEqual(ga.data_ptr(), gb.data_ptr())
@unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser CPU support for Windows or Sandcastle")
@unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle")
@enable_cpu_fuser
@unittest.skip("temporarily disabled because fusion was restricted in fixing #22833")
def test_fuser_iou(self):
@ -683,7 +683,7 @@ class TestFuser(JitTestCase):
.check_not("aten::tanh").check("FusionGroup").check_next("TupleConstruct") \
.check_next("return").check_not("FusionGroup_1").run(str(graph))
@unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser CPU support for Windows or Sandcastle")
@unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle")
@unittest.skip("Test is flaky, see https://github.com/pytorch/pytorch/issues/8746")
@enable_cpu_fuser
def test_lstm_traced_cpu(self):
@ -782,7 +782,7 @@ class TestFuser(JitTestCase):
out = script_f(x, y)
self.assertEqual(out[0], out[1])
@unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser CPU support for Windows or Sandcastle")
@unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle")
@enable_cpu_fuser
def test_scalar(self):
def fn(x, y):
@ -828,7 +828,7 @@ class TestFuser(JitTestCase):
self.assertGraphContainsExactly(
ge.graph_for(*inputs), 'prim::FusionGroup', 0, consider_subgraphs=True)
@unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser CPU support for Windows or Sandcastle")
@unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle")
@enable_cpu_fuser
def test_where_and_typing(self):
def f(x, y):

View File

@ -1,10 +1,15 @@
#include <torch/csrc/jit/fuser/cpu/fused_kernel.h>
#include <c10/util/Exception.h>
#include <c10/util/Optional.h>
#include <torch/csrc/jit/code_template.h>
#include <torch/csrc/jit/fuser/compiler.h>
#include <torch/csrc/jit/fuser/cpu/temp_file.h>
#include <torch/csrc/utils/memory.h>
#ifdef _MSC_VER
#include <torch/csrc/jit/fuser/cpu/msvc_arch.h>
#endif
#include <cstdlib>
#include <iostream>
#include <sstream>
@ -16,9 +21,32 @@ namespace jit {
namespace fuser {
namespace cpu {
#ifdef _MSC_VER
static const std::string getTempPath() {
char lpTempPathBuffer[MAX_PATH];
DWORD dwRetVal = GetTempPath(
MAX_PATH, // length of the buffer
lpTempPathBuffer); // buffer for path
TORCH_CHECK(dwRetVal < MAX_PATH && dwRetVal != 0, "GetTempPath failed.");
return std::string(lpTempPathBuffer);
}
static const std::string temp_dir = getTempPath();
static const std::string so_template = temp_dir + "pytorch_fuserXXXXXX.dll";
static const std::string cpp_template = temp_dir + "pytorch_fuserXXXXXX.cpp";
static const std::string check_exists_string = "where \"${program}\" > nul 2> nul";
static std::vector<std::string> env_list;
constexpr int so_suffix_len = 4;
constexpr int cpp_suffix_len = 4;
#else
static const std::string so_template = "/tmp/pytorch_fuserXXXXXX.so";
static const std::string cpp_template = "/tmp/pytorch_fuserXXXXXX.cpp";
static const std::string check_exists_string = "which '${program}' > /dev/null";
constexpr int so_suffix_len = 3;
constexpr int cpp_suffix_len = 4;
#endif
static bool programExists(const std::string& program) {
TemplateEnv env;
@ -27,6 +55,117 @@ static bool programExists(const std::string& program) {
return (system(cmd.c_str()) == 0);
}
#ifdef _MSC_VER
c10::optional<std::string> exec(const std::string& cmd) {
std::array<char, 128> buffer;
std::string result;
std::unique_ptr<FILE, decltype(&_pclose)> pipe(
_popen(cmd.c_str(), "r"), _pclose);
if (!pipe) {
return c10::nullopt;
}
while (fgets(buffer.data(), static_cast<int>(buffer.size()), pipe.get()) != nullptr) {
result += buffer.data();
}
return result;
}
inline std::string& rtrim(std::string& s, const char* t = " \t\n\r\f\v") {
s.erase(s.find_last_not_of(t) + 1);
return s;
}
void activate() {
char* root = nullptr;
std::string cmd;
c10::optional<std::string> exec_out;
std::string path;
std::string vcruntime_plat;
std::string envvars;
// Checking whether the environment is already activated
if (getenv("VSCMD_ARG_TGT_ARCH")) {
return;
}
// Getting `ProgramFiles` through environment variable queries
root = getenv("ProgramFiles(x86)");
if (!root) {
root = getenv("ProgramFiles");
}
if (!root) {
return;
}
// Getting VS 2017 installation path through `vswhere`
cmd = "\"" + std::string(root) +
"\\Microsoft Visual Studio\\Installer\\vswhere.exe\""
" -latest -prerelease -requires Microsoft.VisualStudio.Component.VC.Tools.x86.x64 -property installationPath";
exec_out = exec(cmd);
if (!exec_out) {
return;
}
path = *exec_out;
rtrim(path);
// Checking whether the activation script `vcvarsall.bat` exists
path += "\\VC\\Auxiliary\\Build";
struct stat st;
if (stat(path.c_str(), &st) == -1 || !(st.st_mode & _S_IFDIR)) {
return;
}
path += "\\vcvarsall.bat";
if (_access(path.c_str(), 0) == -1) {
return;
}
// Determining current platform
if (sizeof(void*) == 8) {
vcruntime_plat = "x64";
} else {
vcruntime_plat = "x86";
}
// Getting environment variables after activating VS development shell
cmd = "\"" + path + "\" " + vcruntime_plat + ">NUL && set";
exec_out = exec(cmd);
if (!exec_out) {
return;
}
envvars = *exec_out;
// Setting environment variables to the current environment
std::istringstream f(envvars);
std::string envvar;
while (getline(f, envvar, '\n')) {
env_list.push_back(envvar);
}
}
intptr_t run(const std::string& cmd) {
// Getting the path of `cmd.exe`
char* comspec = getenv("COMSPEC");
if (!comspec) {
comspec = "C:\\Windows\\System32\\cmd.exe";
}
// Constructing the command line
const char* a[] = {"/c", cmd.c_str()};
// Constructing the env array
// If `env_list` is not empty, then add char pointers ending with nullptr.
// Otherwise, it will be nullptr, which implies the default env.
std::vector<const char*> e;
if (!env_list.empty()) {
for (auto& s : env_list) {
e.push_back(s.c_str());
}
e.push_back(nullptr);
}
// Running the command
intptr_t r = _spawnve(_P_WAIT, comspec, a, e.data());
return r;
}
#endif
// A single compiler config is accessed through getConfig() (below)
// Controls compilation options and may be updated based on the result
// of compilation attempts.
@ -37,6 +176,10 @@ struct CompilerConfig {
cxx = cxx_env;
}
#ifdef _MSC_VER
activate();
#endif
if (!programExists(cxx)) {
cxx = "";
}
@ -44,7 +187,13 @@ struct CompilerConfig {
~CompilerConfig() = default;
std::string cxx = "g++"; // compiler location
#ifdef _MSC_VER
std::string cxx = "cl";
const std::string openmp_flags = "/openmp";
#else
std::string cxx = "g++";
const std::string openmp_flags = "-fopenmp";
#endif
bool openmp = true;
};
@ -63,24 +212,46 @@ static CompilerConfig& getConfig() {
// understand for AVX512. When we need better CPU performance this
// optimization can be re-enabled by tracking down the platforms where
// this error occurs and only selectively disabling it.
#ifdef _MSC_VER
static std::string getArchFlags() {
if (InstructionSet::AVX512F()) {
return "/arch:AVX512";
} else if (InstructionSet::AVX2()) {
return "/arch:AVX2";
} else if (InstructionSet::AVX()) {
return "/arch:AVX";
} else {
return "";
}
}
static const std::string arch_flags = getArchFlags();
static const std::string compile_string =
"cd /D \"" + temp_dir + "\" && "
"${cxx} /nologo /MD /Ox " + arch_flags + " /LD /EHsc "
"${fopenmp} \"${cpp_file}\" /link /out:\"${so_file}\"";
#else
static const std::string compile_string =
"\"${cxx}\" -O3 -g "
#ifndef __PPC64__
// "-march=native "
#endif
"-std=c++11 -fPIC ${fopenmp} -shared \"${cpp_file}\" -o \"${so_file}\" -lm";
#endif
static void runCompiler(
const std::string& cpp_file,
const std::string& so_file) {
auto& config = getConfig();
TemplateEnv env;
env.s("cxx", config.cxx);
env.s("fopenmp", config.openmp ? "-fopenmp" : "");
env.s("fopenmp", config.openmp ? config.openmp_flags : "");
env.s("cpp_file", cpp_file);
env.s("so_file", so_file);
std::string result = format(compile_string, env);
#ifdef _MSC_VER
intptr_t r = run(result);
#else
int r = system(result.c_str());
#endif
if (config.openmp && r != 0) {
std::cerr
<< "warning: pytorch jit fuser failed to compile with openmp, trying without it...\n";
@ -90,7 +261,11 @@ static void runCompiler(
TORCH_CHECK(r == 0, "Failed to compile a fused CPU kernel");
}
#ifdef _MSC_VER
static const std::string disas_string = "dumpbin /DISASM:NOBYTES \"${so_file}\"";
#else
static const std::string disas_string = "objdump -M intel -d \"${so_file}\"";
#endif
static void disas(const std::string& so_file) {
TemplateEnv env;
env.s("so_file", so_file);
@ -115,10 +290,14 @@ FusedKernelCPU::FusedKernelCPU(
std::move(chunk_desc),
std::move(concat_desc),
has_random) {
TempFile so_file(so_template, 3);
TempFile cpp_file(cpp_template, 4);
TempFile so_file(so_template, so_suffix_len);
TempFile cpp_file(cpp_template, cpp_suffix_len);
cpp_file.write(code_);
cpp_file.sync();
#ifdef _MSC_VER
so_file.close();
cpp_file.close();
#endif
runCompiler(cpp_file.name(), so_file.name());
if (debugFuser() >= 2)
disas(so_file.name());

View File

@ -0,0 +1,108 @@
// Example code extracted from MSDN page of __cpuidex
#include <intrin.h>
#include <array>
#include <bitset>
#include <string>
#include <vector>
namespace torch {
namespace jit {
namespace fuser {
namespace cpu {
class InstructionSet {
// forward declarations
class InstructionSet_Internal;
public:
// getters
static bool AVX(void) {
return CPU_Rep.f_1_ECX_[28];
}
static bool AVX2(void) {
return CPU_Rep.f_7_EBX_[5];
}
static bool AVX512F(void) {
return CPU_Rep.f_7_EBX_[16];
}
private:
static const InstructionSet_Internal CPU_Rep;
class InstructionSet_Internal {
public:
InstructionSet_Internal()
: nIds_{0},
nExIds_{0},
f_1_ECX_{0},
f_1_EDX_{0},
f_7_EBX_{0},
f_7_ECX_{0},
f_81_ECX_{0},
f_81_EDX_{0},
data_{},
extdata_{} {
// int cpuInfo[4] = {-1};
std::array<int, 4> cpui;
// Calling __cpuid with 0x0 as the function_id argument
// gets the number of the highest valid function ID.
__cpuid(cpui.data(), 0);
nIds_ = cpui[0];
for (int i = 0; i <= nIds_; ++i) {
__cpuidex(cpui.data(), i, 0);
data_.push_back(cpui);
}
// load bitset with flags for function 0x00000001
if (nIds_ >= 1) {
f_1_ECX_ = data_[1][2];
f_1_EDX_ = data_[1][3];
}
// load bitset with flags for function 0x00000007
if (nIds_ >= 7) {
f_7_EBX_ = data_[7][1];
f_7_ECX_ = data_[7][2];
}
// Calling __cpuid with 0x80000000 as the function_id argument
// gets the number of the highest valid extended ID.
__cpuid(cpui.data(), 0x80000000);
nExIds_ = cpui[0];
for (int i = 0x80000000; i <= nExIds_; ++i) {
__cpuidex(cpui.data(), i, 0);
extdata_.push_back(cpui);
}
// load bitset with flags for function 0x80000001
if (nExIds_ >= 0x80000001) {
f_81_ECX_ = extdata_[1][2];
f_81_EDX_ = extdata_[1][3];
}
};
int nIds_;
int nExIds_;
std::bitset<32> f_1_ECX_;
std::bitset<32> f_1_EDX_;
std::bitset<32> f_7_EBX_;
std::bitset<32> f_7_ECX_;
std::bitset<32> f_81_ECX_;
std::bitset<32> f_81_EDX_;
std::vector<std::array<int, 4>> data_;
std::vector<std::array<int, 4>> extdata_;
};
};
// Initialize static member data
const InstructionSet::InstructionSet_Internal InstructionSet::CPU_Rep;
} // namespace cpu
} // namespace fuser
} // namespace jit
} // namespace torch

View File

@ -53,11 +53,34 @@ float fracf(float x) {
${type_declarations}
#ifdef _MSC_VER
template<size_t n> struct int_of_size;
#define DEFINE_INT_OF_SIZE(int_t) \
template<> struct int_of_size<sizeof(int_t)> { using type = int_t; }
DEFINE_INT_OF_SIZE(int64_t);
DEFINE_INT_OF_SIZE(int32_t);
DEFINE_INT_OF_SIZE(int16_t);
DEFINE_INT_OF_SIZE(int8_t);
#undef DEFINE_INT_OF_SIZE
template <typename T>
using int_same_size_t = typename int_of_size<sizeof(T)>::type;
#define IndexTypeLoop int_same_size_t<IndexType>
#define ToIndexTypeLoop(x) static_cast<IndexTypeLoop>(x)
#else
#define IndexTypeLoop IndexType
#define ToIndexTypeLoop(x) x
#endif
#define OMP_THRESHOLD 100000
static void ${kernelName}_kernel(IndexType totalElements, ${formals}) {
#pragma omp parallel for if(totalElements > OMP_THRESHOLD)
for (IndexType linearIndex = 0;
linearIndex < totalElements;
for (IndexTypeLoop linearIndex = 0;
linearIndex < ToIndexTypeLoop(totalElements);
linearIndex += 1) {
// Convert `linearIndex` into an offset of tensor:
${tensorOffsets}
@ -66,8 +89,14 @@ static void ${kernelName}_kernel(IndexType totalElements, ${formals}) {
}
}
#ifdef _WIN32
#define JIT_API __declspec(dllexport)
#else
#define JIT_API
#endif
extern "C"
void ${kernelName}(IndexType totalElements, void ** args) {
JIT_API void ${kernelName}(IndexType totalElements, void ** args) {
${kernelName}_kernel(totalElements ${,argument_loads});
}
)");

View File

@ -5,7 +5,18 @@
#include <c10/util/Exception.h>
#include <torch/csrc/utils/disallow_copy.h>
#ifdef _WIN32
#include <Windows.h>
#include <io.h>
#include <stdio.h>
#include <fcntl.h>
#include <random>
#include <process.h>
#include <WinError.h>
#include <sys/stat.h>
#else
#include <unistd.h>
#endif
#include <string>
#include <vector>
@ -15,6 +26,39 @@ namespace jit {
namespace fuser {
namespace cpu {
#ifdef _MSC_VER
int mkstemps(char* tmpl, int suffix_len) {
int len;
char* name;
int fd = -1;
int save_errno = errno;
len = strlen(tmpl);
if (len < 6 + suffix_len ||
strncmp(&tmpl[len - 6 - suffix_len], "XXXXXX", 6)) {
return -1;
}
name = &tmpl[len - 6 - suffix_len];
std::random_device rd;
do {
for (unsigned i = 0; i < 6; ++i) {
name[i] = "abcdefghijklmnopqrstuvwxyz0123456789"[rd() % 36];
}
fd = _open(tmpl, _O_RDWR | _O_CREAT | _O_EXCL, _S_IWRITE | _S_IREAD);
} while (errno == EEXIST);
if (fd >= 0) {
errno = save_errno;
return fd;
} else {
return -1;
}
}
#endif
struct TempFile {
TH_DISALLOW_COPY_AND_ASSIGN(TempFile);
@ -24,7 +68,11 @@ struct TempFile {
std::vector<char> tt(t.c_str(), t.c_str() + t.size() + 1);
int fd = mkstemps(tt.data(), suffix);
AT_ASSERT(fd != -1);
#ifdef _MSC_VER
file_ = _fdopen(fd, "r+");
#else
file_ = fdopen(fd, "r+");
#endif
// - 1 becuase tt.size() includes the null terminator,
// but std::string does not expect one
@ -44,17 +92,35 @@ struct TempFile {
AT_ASSERT(str.size() == result);
}
#ifdef _MSC_VER
void close() {
if (file_ != nullptr) {
fclose(file_);
}
file_ = nullptr;
}
#endif
FILE* file() {
return file_;
}
~TempFile() {
#ifdef _MSC_VER
if (file_ != nullptr) {
fclose(file_);
}
if (!name_.empty() && _access(name_.c_str(), 0) != -1) {
_unlink(name_.c_str());
}
#else
if (file_ != nullptr) {
// unlink first to ensure another mkstemps doesn't
// race between close and unlink
unlink(name_.c_str());
fclose(file_);
}
#endif
}
private: