mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Enable CPU fused kernel on Windows
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/25578 Differential Revision: D17397156 Pulled By: ezyang fbshipit-source-id: b243528c2bfd5a0d401897833048429e67fe40ef
This commit is contained in:
parent
bebc3d6aad
commit
2ce8c83f67
|
|
@ -479,11 +479,8 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
|
|||
${TORCH_SRC_DIR}/csrc/jit/export.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/import_legacy.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/netdef_converter.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/fuser/cpu/fused_kernel.cpp
|
||||
)
|
||||
if (NOT WIN32)
|
||||
list(APPEND TORCH_SRCS
|
||||
${TORCH_SRC_DIR}/csrc/jit/fuser/cpu/fused_kernel.cpp)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (USE_CUDA)
|
||||
|
|
|
|||
|
|
@ -386,6 +386,8 @@ def get_selected_tests(options):
|
|||
target_arch = os.environ.get('VSCMD_ARG_TGT_ARCH')
|
||||
if target_arch != 'x64':
|
||||
WINDOWS_BLACKLIST.append('cpp_extensions')
|
||||
WINDOWS_BLACKLIST.append('jit')
|
||||
WINDOWS_BLACKLIST.append('jit_fuser')
|
||||
|
||||
selected_tests = exclude_tests(WINDOWS_BLACKLIST, selected_tests, 'on Windows')
|
||||
|
||||
|
|
|
|||
|
|
@ -86,7 +86,6 @@ if torch.cuda.is_available():
|
|||
RUN_CUDA_MULTI_GPU = RUN_CUDA and torch.cuda.device_count() > 1
|
||||
|
||||
PY35 = sys.version_info >= (3, 5)
|
||||
WINDOWS = sys.platform == 'win32'
|
||||
|
||||
|
||||
def LSTMCellF(input, hx, cx, *params):
|
||||
|
|
@ -2023,7 +2022,6 @@ graph(%Ra, %Rb):
|
|||
for node in g.nodes():
|
||||
self.assertTrue(g2.findNode(node.kind()) is not None)
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS, "NYI: JIT tests not yet supported on windows")
|
||||
@unittest.skipIf(IS_SANDCASTLE, "gtest runs these in sandcastle")
|
||||
@unittest.skipIf(RUN_CUDA, "covered by test_cpp_cuda")
|
||||
@skipIfRocm
|
||||
|
|
@ -2033,7 +2031,6 @@ graph(%Ra, %Rb):
|
|||
torch._C._jit_run_cpp_tests(run_cuda=False)
|
||||
tests_setup.shutdown()
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
|
||||
@unittest.skipIf(not RUN_CUDA, "cpp tests require CUDA")
|
||||
@skipIfRocm
|
||||
def test_cpp_cuda(self):
|
||||
|
|
@ -2058,7 +2055,6 @@ graph(%Ra, %Rb):
|
|||
self.assertEqual(outputs, m(*inputs))
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "test_dropout_cuda require CUDA")
|
||||
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
|
||||
def test_dropout_cuda(self):
|
||||
# Dropout AD is dispatched to _fused_dropout in CUDA case,
|
||||
# which is not included in TestJitGeneratedFunctional
|
||||
|
|
@ -2226,12 +2222,11 @@ graph(%Ra, %Rb):
|
|||
def test_ge_unoptimized(self):
|
||||
self.run_ge_tests(False, False)
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle")
|
||||
@unittest.skipIf(IS_SANDCASTLE, "NYI: fuser support for Sandcastle")
|
||||
@enable_cpu_fuser
|
||||
def test_ge_optimized(self):
|
||||
self.run_ge_tests(True, False)
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
|
||||
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
|
||||
def test_ge_cuda(self):
|
||||
self.run_ge_tests(True, True)
|
||||
|
|
@ -2322,7 +2317,6 @@ graph(%Ra, %Rb):
|
|||
self.assertGraphContains(fn.graph, kind='aten::einsum')
|
||||
self.assertEqual(fn(x, y), outer(x, y))
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
|
||||
@unittest.skipIf(not RUN_CUDA, "calls .cuda()")
|
||||
def test_traced_module_cuda(self):
|
||||
class Model(nn.Module):
|
||||
|
|
@ -5062,7 +5056,7 @@ a")
|
|||
self.assertEqual(cu.test_integral_shape_inference(*inputs), outputs)
|
||||
|
||||
@unittest.skipIf(RUN_CUDA, 'This tests the CPU fuser')
|
||||
@unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle")
|
||||
@unittest.skipIf(IS_SANDCASTLE, "NYI: fuser support for Sandcastle")
|
||||
@enable_cpu_fuser
|
||||
def test_batchnorm_fuser_cpu(self):
|
||||
code = '''
|
||||
|
|
@ -5091,7 +5085,7 @@ a")
|
|||
FileCheck().check('sqrtf').run(code)
|
||||
|
||||
@unittest.skipIf(RUN_CUDA, 'This tests the CPU fuser')
|
||||
@unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle")
|
||||
@unittest.skipIf(IS_SANDCASTLE, "NYI: fuser support for Sandcastle")
|
||||
@enable_cpu_fuser
|
||||
def test_fuser_double_float_codegen(self):
|
||||
fns = ['log', 'log10', 'log1p', 'log2', 'lgamma', 'exp', 'expm1', 'erf',
|
||||
|
|
@ -5145,7 +5139,7 @@ a")
|
|||
test_dispatch(fn, lookup_c_equivalent_fn(fn) + 'f(', torch.float, binary=True)
|
||||
|
||||
@unittest.skipIf(RUN_CUDA, 'This tests the CPU fuser')
|
||||
@unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle")
|
||||
@unittest.skipIf(IS_SANDCASTLE, "NYI: fuser support for Sandcastle")
|
||||
@enable_cpu_fuser
|
||||
def test_fuser_double_literal_precision(self):
|
||||
code = '''
|
||||
|
|
@ -9285,7 +9279,6 @@ a")
|
|||
m = M()
|
||||
self.assertEqual(2, m.call_foo(torch.ones((), dtype=torch.int64)))
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
|
||||
def test_trace_of_script(self):
|
||||
@torch.jit.script
|
||||
def foo(a, c):
|
||||
|
|
@ -16975,7 +16968,7 @@ def add_autograd_test(
|
|||
# We enable the CPU fuser during these checks for more consistent
|
||||
# behavior. Otherwise, we are going to have to analyze the graph to
|
||||
# see if producer values are Dimension
|
||||
@enable_cpu_fuser_if(not (IS_SANDCASTLE or IS_WINDOWS))
|
||||
@enable_cpu_fuser_if(not IS_SANDCASTLE)
|
||||
def check(name):
|
||||
set_rng_seed(2)
|
||||
is_magic_method = name[:2] == '__' and name[-2:] == '__'
|
||||
|
|
@ -17010,8 +17003,7 @@ def add_autograd_test(
|
|||
check_against_reference(self, traced_fn,
|
||||
fn, (self_variable,) + args_variable, kwargs_variable,
|
||||
check_types=check_types)
|
||||
# Fuser not supported on windows
|
||||
if IS_SANDCASTLE or IS_WINDOWS:
|
||||
if IS_SANDCASTLE:
|
||||
autodiff_nodes = autodiff_nodes + fusible_nodes
|
||||
fusible_nodes = []
|
||||
self.assertAutodiffNode(traced_fn.last_graph, should_autodiff_node, autodiff_nodes, fusible_nodes)
|
||||
|
|
@ -17022,8 +17014,7 @@ def add_autograd_test(
|
|||
fn, (self_variable,) + args_variable, kwargs_variable,
|
||||
check_types=check_types)
|
||||
|
||||
# Fuser not supported on windows
|
||||
if IS_SANDCASTLE or IS_WINDOWS:
|
||||
if IS_SANDCASTLE:
|
||||
autodiff_nodes = autodiff_nodes + fusible_nodes
|
||||
fusible_nodes = []
|
||||
self.assertAutodiffNode(script_fn.last_graph,
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
from torch.testing import FileCheck
|
||||
|
||||
from common_utils import run_tests, IS_WINDOWS, IS_SANDCASTLE
|
||||
from common_utils import run_tests, IS_SANDCASTLE
|
||||
from textwrap import dedent
|
||||
from itertools import product, permutations
|
||||
|
||||
|
|
@ -37,7 +37,7 @@ class TestFuser(JitTestCase):
|
|||
self.assertEqual(func(a), a.abs() * 2)
|
||||
self.assertAllFused(func.graph_for(a))
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser CPU support for Windows or Sandcastle")
|
||||
@unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle")
|
||||
@enable_cpu_fuser
|
||||
def test_abs_cpu(self):
|
||||
self._test_fused_abs()
|
||||
|
|
@ -180,7 +180,7 @@ class TestFuser(JitTestCase):
|
|||
for fn in fns:
|
||||
self.checkScript(fn, [tensor])
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser CPU support for Windows or Sandcastle")
|
||||
@unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle")
|
||||
@enable_cpu_fuser
|
||||
def test_chunk_correctness(self):
|
||||
return self._test_chunk_correctness(self, 'cpu')
|
||||
|
|
@ -502,7 +502,7 @@ class TestFuser(JitTestCase):
|
|||
self.assertAllFused(scripted.graph_for(x, p), except_for=("aten::size", "prim::BroadcastSizes",
|
||||
"aten::_size_if_not_equal"))
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser CPU support for Windows or Sandcastle")
|
||||
@unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle")
|
||||
@enable_cpu_fuser
|
||||
def test_fuser_deduplication(self):
|
||||
# See that fusion kernel outputs are deduplicated when removing _grad_sum_to_size in the fuser's compilation
|
||||
|
|
@ -522,7 +522,7 @@ class TestFuser(JitTestCase):
|
|||
# check that a, b share storage, i.e. were generated as a single output in the fuser
|
||||
self.assertEqual(ga.data_ptr(), gb.data_ptr())
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser CPU support for Windows or Sandcastle")
|
||||
@unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle")
|
||||
@enable_cpu_fuser
|
||||
@unittest.skip("temporarily disabled because fusion was restricted in fixing #22833")
|
||||
def test_fuser_iou(self):
|
||||
|
|
@ -683,7 +683,7 @@ class TestFuser(JitTestCase):
|
|||
.check_not("aten::tanh").check("FusionGroup").check_next("TupleConstruct") \
|
||||
.check_next("return").check_not("FusionGroup_1").run(str(graph))
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser CPU support for Windows or Sandcastle")
|
||||
@unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle")
|
||||
@unittest.skip("Test is flaky, see https://github.com/pytorch/pytorch/issues/8746")
|
||||
@enable_cpu_fuser
|
||||
def test_lstm_traced_cpu(self):
|
||||
|
|
@ -782,7 +782,7 @@ class TestFuser(JitTestCase):
|
|||
out = script_f(x, y)
|
||||
self.assertEqual(out[0], out[1])
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser CPU support for Windows or Sandcastle")
|
||||
@unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle")
|
||||
@enable_cpu_fuser
|
||||
def test_scalar(self):
|
||||
def fn(x, y):
|
||||
|
|
@ -828,7 +828,7 @@ class TestFuser(JitTestCase):
|
|||
self.assertGraphContainsExactly(
|
||||
ge.graph_for(*inputs), 'prim::FusionGroup', 0, consider_subgraphs=True)
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser CPU support for Windows or Sandcastle")
|
||||
@unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle")
|
||||
@enable_cpu_fuser
|
||||
def test_where_and_typing(self):
|
||||
def f(x, y):
|
||||
|
|
|
|||
|
|
@ -1,10 +1,15 @@
|
|||
#include <torch/csrc/jit/fuser/cpu/fused_kernel.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/Optional.h>
|
||||
#include <torch/csrc/jit/code_template.h>
|
||||
#include <torch/csrc/jit/fuser/compiler.h>
|
||||
#include <torch/csrc/jit/fuser/cpu/temp_file.h>
|
||||
#include <torch/csrc/utils/memory.h>
|
||||
|
||||
#ifdef _MSC_VER
|
||||
#include <torch/csrc/jit/fuser/cpu/msvc_arch.h>
|
||||
#endif
|
||||
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
|
@ -16,9 +21,32 @@ namespace jit {
|
|||
namespace fuser {
|
||||
namespace cpu {
|
||||
|
||||
#ifdef _MSC_VER
|
||||
static const std::string getTempPath() {
|
||||
char lpTempPathBuffer[MAX_PATH];
|
||||
|
||||
DWORD dwRetVal = GetTempPath(
|
||||
MAX_PATH, // length of the buffer
|
||||
lpTempPathBuffer); // buffer for path
|
||||
|
||||
TORCH_CHECK(dwRetVal < MAX_PATH && dwRetVal != 0, "GetTempPath failed.");
|
||||
|
||||
return std::string(lpTempPathBuffer);
|
||||
}
|
||||
static const std::string temp_dir = getTempPath();
|
||||
static const std::string so_template = temp_dir + "pytorch_fuserXXXXXX.dll";
|
||||
static const std::string cpp_template = temp_dir + "pytorch_fuserXXXXXX.cpp";
|
||||
static const std::string check_exists_string = "where \"${program}\" > nul 2> nul";
|
||||
static std::vector<std::string> env_list;
|
||||
constexpr int so_suffix_len = 4;
|
||||
constexpr int cpp_suffix_len = 4;
|
||||
#else
|
||||
static const std::string so_template = "/tmp/pytorch_fuserXXXXXX.so";
|
||||
static const std::string cpp_template = "/tmp/pytorch_fuserXXXXXX.cpp";
|
||||
static const std::string check_exists_string = "which '${program}' > /dev/null";
|
||||
constexpr int so_suffix_len = 3;
|
||||
constexpr int cpp_suffix_len = 4;
|
||||
#endif
|
||||
|
||||
static bool programExists(const std::string& program) {
|
||||
TemplateEnv env;
|
||||
|
|
@ -27,6 +55,117 @@ static bool programExists(const std::string& program) {
|
|||
return (system(cmd.c_str()) == 0);
|
||||
}
|
||||
|
||||
#ifdef _MSC_VER
|
||||
c10::optional<std::string> exec(const std::string& cmd) {
|
||||
std::array<char, 128> buffer;
|
||||
std::string result;
|
||||
std::unique_ptr<FILE, decltype(&_pclose)> pipe(
|
||||
_popen(cmd.c_str(), "r"), _pclose);
|
||||
if (!pipe) {
|
||||
return c10::nullopt;
|
||||
}
|
||||
while (fgets(buffer.data(), static_cast<int>(buffer.size()), pipe.get()) != nullptr) {
|
||||
result += buffer.data();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
inline std::string& rtrim(std::string& s, const char* t = " \t\n\r\f\v") {
|
||||
s.erase(s.find_last_not_of(t) + 1);
|
||||
return s;
|
||||
}
|
||||
|
||||
void activate() {
|
||||
char* root = nullptr;
|
||||
std::string cmd;
|
||||
c10::optional<std::string> exec_out;
|
||||
std::string path;
|
||||
std::string vcruntime_plat;
|
||||
std::string envvars;
|
||||
|
||||
// Checking whether the environment is already activated
|
||||
if (getenv("VSCMD_ARG_TGT_ARCH")) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Getting `ProgramFiles` through environment variable queries
|
||||
root = getenv("ProgramFiles(x86)");
|
||||
if (!root) {
|
||||
root = getenv("ProgramFiles");
|
||||
}
|
||||
if (!root) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Getting VS 2017 installation path through `vswhere`
|
||||
cmd = "\"" + std::string(root) +
|
||||
"\\Microsoft Visual Studio\\Installer\\vswhere.exe\""
|
||||
" -latest -prerelease -requires Microsoft.VisualStudio.Component.VC.Tools.x86.x64 -property installationPath";
|
||||
exec_out = exec(cmd);
|
||||
if (!exec_out) {
|
||||
return;
|
||||
}
|
||||
path = *exec_out;
|
||||
rtrim(path);
|
||||
|
||||
// Checking whether the activation script `vcvarsall.bat` exists
|
||||
path += "\\VC\\Auxiliary\\Build";
|
||||
struct stat st;
|
||||
if (stat(path.c_str(), &st) == -1 || !(st.st_mode & _S_IFDIR)) {
|
||||
return;
|
||||
}
|
||||
path += "\\vcvarsall.bat";
|
||||
if (_access(path.c_str(), 0) == -1) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Determining current platform
|
||||
if (sizeof(void*) == 8) {
|
||||
vcruntime_plat = "x64";
|
||||
} else {
|
||||
vcruntime_plat = "x86";
|
||||
}
|
||||
|
||||
// Getting environment variables after activating VS development shell
|
||||
cmd = "\"" + path + "\" " + vcruntime_plat + ">NUL && set";
|
||||
exec_out = exec(cmd);
|
||||
if (!exec_out) {
|
||||
return;
|
||||
}
|
||||
envvars = *exec_out;
|
||||
|
||||
// Setting environment variables to the current environment
|
||||
std::istringstream f(envvars);
|
||||
std::string envvar;
|
||||
while (getline(f, envvar, '\n')) {
|
||||
env_list.push_back(envvar);
|
||||
}
|
||||
}
|
||||
|
||||
intptr_t run(const std::string& cmd) {
|
||||
// Getting the path of `cmd.exe`
|
||||
char* comspec = getenv("COMSPEC");
|
||||
if (!comspec) {
|
||||
comspec = "C:\\Windows\\System32\\cmd.exe";
|
||||
}
|
||||
// Constructing the command line
|
||||
const char* a[] = {"/c", cmd.c_str()};
|
||||
// Constructing the env array
|
||||
// If `env_list` is not empty, then add char pointers ending with nullptr.
|
||||
// Otherwise, it will be nullptr, which implies the default env.
|
||||
std::vector<const char*> e;
|
||||
if (!env_list.empty()) {
|
||||
for (auto& s : env_list) {
|
||||
e.push_back(s.c_str());
|
||||
}
|
||||
e.push_back(nullptr);
|
||||
}
|
||||
// Running the command
|
||||
intptr_t r = _spawnve(_P_WAIT, comspec, a, e.data());
|
||||
return r;
|
||||
}
|
||||
#endif
|
||||
|
||||
// A single compiler config is accessed through getConfig() (below)
|
||||
// Controls compilation options and may be updated based on the result
|
||||
// of compilation attempts.
|
||||
|
|
@ -37,6 +176,10 @@ struct CompilerConfig {
|
|||
cxx = cxx_env;
|
||||
}
|
||||
|
||||
#ifdef _MSC_VER
|
||||
activate();
|
||||
#endif
|
||||
|
||||
if (!programExists(cxx)) {
|
||||
cxx = "";
|
||||
}
|
||||
|
|
@ -44,7 +187,13 @@ struct CompilerConfig {
|
|||
|
||||
~CompilerConfig() = default;
|
||||
|
||||
std::string cxx = "g++"; // compiler location
|
||||
#ifdef _MSC_VER
|
||||
std::string cxx = "cl";
|
||||
const std::string openmp_flags = "/openmp";
|
||||
#else
|
||||
std::string cxx = "g++";
|
||||
const std::string openmp_flags = "-fopenmp";
|
||||
#endif
|
||||
bool openmp = true;
|
||||
};
|
||||
|
||||
|
|
@ -63,24 +212,46 @@ static CompilerConfig& getConfig() {
|
|||
// understand for AVX512. When we need better CPU performance this
|
||||
// optimization can be re-enabled by tracking down the platforms where
|
||||
// this error occurs and only selectively disabling it.
|
||||
#ifdef _MSC_VER
|
||||
static std::string getArchFlags() {
|
||||
if (InstructionSet::AVX512F()) {
|
||||
return "/arch:AVX512";
|
||||
} else if (InstructionSet::AVX2()) {
|
||||
return "/arch:AVX2";
|
||||
} else if (InstructionSet::AVX()) {
|
||||
return "/arch:AVX";
|
||||
} else {
|
||||
return "";
|
||||
}
|
||||
}
|
||||
static const std::string arch_flags = getArchFlags();
|
||||
static const std::string compile_string =
|
||||
"cd /D \"" + temp_dir + "\" && "
|
||||
"${cxx} /nologo /MD /Ox " + arch_flags + " /LD /EHsc "
|
||||
"${fopenmp} \"${cpp_file}\" /link /out:\"${so_file}\"";
|
||||
#else
|
||||
static const std::string compile_string =
|
||||
"\"${cxx}\" -O3 -g "
|
||||
#ifndef __PPC64__
|
||||
// "-march=native "
|
||||
#endif
|
||||
"-std=c++11 -fPIC ${fopenmp} -shared \"${cpp_file}\" -o \"${so_file}\" -lm";
|
||||
|
||||
#endif
|
||||
static void runCompiler(
|
||||
const std::string& cpp_file,
|
||||
const std::string& so_file) {
|
||||
auto& config = getConfig();
|
||||
TemplateEnv env;
|
||||
env.s("cxx", config.cxx);
|
||||
env.s("fopenmp", config.openmp ? "-fopenmp" : "");
|
||||
env.s("fopenmp", config.openmp ? config.openmp_flags : "");
|
||||
env.s("cpp_file", cpp_file);
|
||||
env.s("so_file", so_file);
|
||||
std::string result = format(compile_string, env);
|
||||
#ifdef _MSC_VER
|
||||
intptr_t r = run(result);
|
||||
#else
|
||||
int r = system(result.c_str());
|
||||
#endif
|
||||
if (config.openmp && r != 0) {
|
||||
std::cerr
|
||||
<< "warning: pytorch jit fuser failed to compile with openmp, trying without it...\n";
|
||||
|
|
@ -90,7 +261,11 @@ static void runCompiler(
|
|||
TORCH_CHECK(r == 0, "Failed to compile a fused CPU kernel");
|
||||
}
|
||||
|
||||
#ifdef _MSC_VER
|
||||
static const std::string disas_string = "dumpbin /DISASM:NOBYTES \"${so_file}\"";
|
||||
#else
|
||||
static const std::string disas_string = "objdump -M intel -d \"${so_file}\"";
|
||||
#endif
|
||||
static void disas(const std::string& so_file) {
|
||||
TemplateEnv env;
|
||||
env.s("so_file", so_file);
|
||||
|
|
@ -115,10 +290,14 @@ FusedKernelCPU::FusedKernelCPU(
|
|||
std::move(chunk_desc),
|
||||
std::move(concat_desc),
|
||||
has_random) {
|
||||
TempFile so_file(so_template, 3);
|
||||
TempFile cpp_file(cpp_template, 4);
|
||||
TempFile so_file(so_template, so_suffix_len);
|
||||
TempFile cpp_file(cpp_template, cpp_suffix_len);
|
||||
cpp_file.write(code_);
|
||||
cpp_file.sync();
|
||||
#ifdef _MSC_VER
|
||||
so_file.close();
|
||||
cpp_file.close();
|
||||
#endif
|
||||
runCompiler(cpp_file.name(), so_file.name());
|
||||
if (debugFuser() >= 2)
|
||||
disas(so_file.name());
|
||||
|
|
|
|||
108
torch/csrc/jit/fuser/cpu/msvc_arch.h
Normal file
108
torch/csrc/jit/fuser/cpu/msvc_arch.h
Normal file
|
|
@ -0,0 +1,108 @@
|
|||
// Example code extracted from MSDN page of __cpuidex
|
||||
|
||||
#include <intrin.h>
|
||||
#include <array>
|
||||
#include <bitset>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace fuser {
|
||||
namespace cpu {
|
||||
|
||||
class InstructionSet {
|
||||
// forward declarations
|
||||
class InstructionSet_Internal;
|
||||
|
||||
public:
|
||||
// getters
|
||||
|
||||
static bool AVX(void) {
|
||||
return CPU_Rep.f_1_ECX_[28];
|
||||
}
|
||||
static bool AVX2(void) {
|
||||
return CPU_Rep.f_7_EBX_[5];
|
||||
}
|
||||
static bool AVX512F(void) {
|
||||
return CPU_Rep.f_7_EBX_[16];
|
||||
}
|
||||
|
||||
private:
|
||||
static const InstructionSet_Internal CPU_Rep;
|
||||
|
||||
class InstructionSet_Internal {
|
||||
public:
|
||||
InstructionSet_Internal()
|
||||
: nIds_{0},
|
||||
nExIds_{0},
|
||||
f_1_ECX_{0},
|
||||
f_1_EDX_{0},
|
||||
f_7_EBX_{0},
|
||||
f_7_ECX_{0},
|
||||
f_81_ECX_{0},
|
||||
f_81_EDX_{0},
|
||||
data_{},
|
||||
extdata_{} {
|
||||
// int cpuInfo[4] = {-1};
|
||||
std::array<int, 4> cpui;
|
||||
|
||||
// Calling __cpuid with 0x0 as the function_id argument
|
||||
// gets the number of the highest valid function ID.
|
||||
__cpuid(cpui.data(), 0);
|
||||
nIds_ = cpui[0];
|
||||
|
||||
for (int i = 0; i <= nIds_; ++i) {
|
||||
__cpuidex(cpui.data(), i, 0);
|
||||
data_.push_back(cpui);
|
||||
}
|
||||
|
||||
// load bitset with flags for function 0x00000001
|
||||
if (nIds_ >= 1) {
|
||||
f_1_ECX_ = data_[1][2];
|
||||
f_1_EDX_ = data_[1][3];
|
||||
}
|
||||
|
||||
// load bitset with flags for function 0x00000007
|
||||
if (nIds_ >= 7) {
|
||||
f_7_EBX_ = data_[7][1];
|
||||
f_7_ECX_ = data_[7][2];
|
||||
}
|
||||
|
||||
// Calling __cpuid with 0x80000000 as the function_id argument
|
||||
// gets the number of the highest valid extended ID.
|
||||
__cpuid(cpui.data(), 0x80000000);
|
||||
nExIds_ = cpui[0];
|
||||
|
||||
for (int i = 0x80000000; i <= nExIds_; ++i) {
|
||||
__cpuidex(cpui.data(), i, 0);
|
||||
extdata_.push_back(cpui);
|
||||
}
|
||||
|
||||
// load bitset with flags for function 0x80000001
|
||||
if (nExIds_ >= 0x80000001) {
|
||||
f_81_ECX_ = extdata_[1][2];
|
||||
f_81_EDX_ = extdata_[1][3];
|
||||
}
|
||||
};
|
||||
|
||||
int nIds_;
|
||||
int nExIds_;
|
||||
std::bitset<32> f_1_ECX_;
|
||||
std::bitset<32> f_1_EDX_;
|
||||
std::bitset<32> f_7_EBX_;
|
||||
std::bitset<32> f_7_ECX_;
|
||||
std::bitset<32> f_81_ECX_;
|
||||
std::bitset<32> f_81_EDX_;
|
||||
std::vector<std::array<int, 4>> data_;
|
||||
std::vector<std::array<int, 4>> extdata_;
|
||||
};
|
||||
};
|
||||
|
||||
// Initialize static member data
|
||||
const InstructionSet::InstructionSet_Internal InstructionSet::CPU_Rep;
|
||||
|
||||
} // namespace cpu
|
||||
} // namespace fuser
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
@ -53,11 +53,34 @@ float fracf(float x) {
|
|||
|
||||
${type_declarations}
|
||||
|
||||
#ifdef _MSC_VER
|
||||
template<size_t n> struct int_of_size;
|
||||
|
||||
#define DEFINE_INT_OF_SIZE(int_t) \
|
||||
template<> struct int_of_size<sizeof(int_t)> { using type = int_t; }
|
||||
|
||||
DEFINE_INT_OF_SIZE(int64_t);
|
||||
DEFINE_INT_OF_SIZE(int32_t);
|
||||
DEFINE_INT_OF_SIZE(int16_t);
|
||||
DEFINE_INT_OF_SIZE(int8_t);
|
||||
|
||||
#undef DEFINE_INT_OF_SIZE
|
||||
|
||||
template <typename T>
|
||||
using int_same_size_t = typename int_of_size<sizeof(T)>::type;
|
||||
|
||||
#define IndexTypeLoop int_same_size_t<IndexType>
|
||||
#define ToIndexTypeLoop(x) static_cast<IndexTypeLoop>(x)
|
||||
#else
|
||||
#define IndexTypeLoop IndexType
|
||||
#define ToIndexTypeLoop(x) x
|
||||
#endif
|
||||
|
||||
#define OMP_THRESHOLD 100000
|
||||
static void ${kernelName}_kernel(IndexType totalElements, ${formals}) {
|
||||
#pragma omp parallel for if(totalElements > OMP_THRESHOLD)
|
||||
for (IndexType linearIndex = 0;
|
||||
linearIndex < totalElements;
|
||||
for (IndexTypeLoop linearIndex = 0;
|
||||
linearIndex < ToIndexTypeLoop(totalElements);
|
||||
linearIndex += 1) {
|
||||
// Convert `linearIndex` into an offset of tensor:
|
||||
${tensorOffsets}
|
||||
|
|
@ -66,8 +89,14 @@ static void ${kernelName}_kernel(IndexType totalElements, ${formals}) {
|
|||
}
|
||||
}
|
||||
|
||||
#ifdef _WIN32
|
||||
#define JIT_API __declspec(dllexport)
|
||||
#else
|
||||
#define JIT_API
|
||||
#endif
|
||||
|
||||
extern "C"
|
||||
void ${kernelName}(IndexType totalElements, void ** args) {
|
||||
JIT_API void ${kernelName}(IndexType totalElements, void ** args) {
|
||||
${kernelName}_kernel(totalElements ${,argument_loads});
|
||||
}
|
||||
)");
|
||||
|
|
|
|||
|
|
@ -5,7 +5,18 @@
|
|||
#include <c10/util/Exception.h>
|
||||
#include <torch/csrc/utils/disallow_copy.h>
|
||||
|
||||
#ifdef _WIN32
|
||||
#include <Windows.h>
|
||||
#include <io.h>
|
||||
#include <stdio.h>
|
||||
#include <fcntl.h>
|
||||
#include <random>
|
||||
#include <process.h>
|
||||
#include <WinError.h>
|
||||
#include <sys/stat.h>
|
||||
#else
|
||||
#include <unistd.h>
|
||||
#endif
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
|
@ -15,6 +26,39 @@ namespace jit {
|
|||
namespace fuser {
|
||||
namespace cpu {
|
||||
|
||||
#ifdef _MSC_VER
|
||||
int mkstemps(char* tmpl, int suffix_len) {
|
||||
int len;
|
||||
char* name;
|
||||
int fd = -1;
|
||||
int save_errno = errno;
|
||||
|
||||
len = strlen(tmpl);
|
||||
if (len < 6 + suffix_len ||
|
||||
strncmp(&tmpl[len - 6 - suffix_len], "XXXXXX", 6)) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
name = &tmpl[len - 6 - suffix_len];
|
||||
|
||||
std::random_device rd;
|
||||
do {
|
||||
for (unsigned i = 0; i < 6; ++i) {
|
||||
name[i] = "abcdefghijklmnopqrstuvwxyz0123456789"[rd() % 36];
|
||||
}
|
||||
|
||||
fd = _open(tmpl, _O_RDWR | _O_CREAT | _O_EXCL, _S_IWRITE | _S_IREAD);
|
||||
} while (errno == EEXIST);
|
||||
|
||||
if (fd >= 0) {
|
||||
errno = save_errno;
|
||||
return fd;
|
||||
} else {
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
struct TempFile {
|
||||
TH_DISALLOW_COPY_AND_ASSIGN(TempFile);
|
||||
|
||||
|
|
@ -24,7 +68,11 @@ struct TempFile {
|
|||
std::vector<char> tt(t.c_str(), t.c_str() + t.size() + 1);
|
||||
int fd = mkstemps(tt.data(), suffix);
|
||||
AT_ASSERT(fd != -1);
|
||||
#ifdef _MSC_VER
|
||||
file_ = _fdopen(fd, "r+");
|
||||
#else
|
||||
file_ = fdopen(fd, "r+");
|
||||
#endif
|
||||
|
||||
// - 1 becuase tt.size() includes the null terminator,
|
||||
// but std::string does not expect one
|
||||
|
|
@ -44,17 +92,35 @@ struct TempFile {
|
|||
AT_ASSERT(str.size() == result);
|
||||
}
|
||||
|
||||
#ifdef _MSC_VER
|
||||
void close() {
|
||||
if (file_ != nullptr) {
|
||||
fclose(file_);
|
||||
}
|
||||
file_ = nullptr;
|
||||
}
|
||||
#endif
|
||||
|
||||
FILE* file() {
|
||||
return file_;
|
||||
}
|
||||
|
||||
~TempFile() {
|
||||
#ifdef _MSC_VER
|
||||
if (file_ != nullptr) {
|
||||
fclose(file_);
|
||||
}
|
||||
if (!name_.empty() && _access(name_.c_str(), 0) != -1) {
|
||||
_unlink(name_.c_str());
|
||||
}
|
||||
#else
|
||||
if (file_ != nullptr) {
|
||||
// unlink first to ensure another mkstemps doesn't
|
||||
// race between close and unlink
|
||||
unlink(name_.c_str());
|
||||
fclose(file_);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user