mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Remove remaining THC code (#69039)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/69039 Test Plan: Imported from OSS Reviewed By: anjali411 Differential Revision: D32872476 Pulled By: ngimel fbshipit-source-id: 7972aacc24aef9450fb59b707ed6396c501bcb31
This commit is contained in:
parent
7407e3d6fd
commit
e279963eef
24
BUILD.bazel
24
BUILD.bazel
|
|
@ -389,27 +389,6 @@ filegroup(
|
|||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "thc_srcs_cu",
|
||||
srcs = [
|
||||
"aten/src/THC/THCReduceApplyUtils.cu.cc",
|
||||
"aten/src/THC/THCSortUtils.cu.cc",
|
||||
"aten/src/THC/THCTensor.cu.cc",
|
||||
"aten/src/THC/THCTensorCopy.cu.cc",
|
||||
"aten/src/THC/THCTensorMathScan.cu.cc",
|
||||
"aten/src/THC/THCTensorScatterGather.cu.cc",
|
||||
"aten/src/THC/THCTensorSort.cu.cc",
|
||||
"aten/src/THC/generated/THCTensorSortByte.cu.cc",
|
||||
"aten/src/THC/generated/THCTensorSortChar.cu.cc",
|
||||
"aten/src/THC/generated/THCTensorSortDouble.cu.cc",
|
||||
"aten/src/THC/generated/THCTensorSortFloat.cu.cc",
|
||||
"aten/src/THC/generated/THCTensorSortHalf.cu.cc",
|
||||
"aten/src/THC/generated/THCTensorSortInt.cu.cc",
|
||||
"aten/src/THC/generated/THCTensorSortLong.cu.cc",
|
||||
"aten/src/THC/generated/THCTensorSortShort.cu.cc",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "aten_srcs_cu",
|
||||
srcs = [
|
||||
|
|
@ -550,9 +529,7 @@ cc_library(
|
|||
"aten/src/**/*.h",
|
||||
"aten/src/**/*.hpp",
|
||||
"aten/src/TH/**/*.cpp",
|
||||
"aten/src/THC/**/*.cpp",
|
||||
"aten/src/THC/*.cuh",
|
||||
"aten/src/THC/generic/*.cu.cc",
|
||||
],
|
||||
exclude = [
|
||||
"aten/src/ATen/Config.h",
|
||||
|
|
@ -692,7 +669,6 @@ cu_library(
|
|||
name = "aten_cuda",
|
||||
srcs = [
|
||||
":aten_srcs_cu",
|
||||
":thc_srcs_cu",
|
||||
],
|
||||
copts = ATEN_COPTS + torch_cuda_half_options,
|
||||
visibility = ["//visibility:public"],
|
||||
|
|
|
|||
|
|
@ -243,9 +243,6 @@ into the repo directory.
|
|||
* [aten](aten) - C++ tensor library for PyTorch (no autograd support)
|
||||
* [src](aten/src) - [README](aten/src/README.md)
|
||||
* [TH](aten/src/TH)
|
||||
[THC](aten/src/THC) - Legacy library code from the original
|
||||
Torch. Try not to add things here; we're slowly porting these to
|
||||
[native](aten/src/ATen/native).
|
||||
* generic - Contains actual implementations of operators,
|
||||
parametrized over `scalar_t`. Files here get compiled N times
|
||||
per supported scalar type in PyTorch.
|
||||
|
|
|
|||
|
|
@ -20,9 +20,7 @@
|
|||
|
||||
namespace at {
|
||||
|
||||
Context::Context()
|
||||
: thc_state(nullptr, [](THCState* p) { /* no-op */ }),
|
||||
thh_state(nullptr, [](THHState* p) { /* no-op */ }) {}
|
||||
Context::Context() = default;
|
||||
|
||||
// TODO: This could be bad juju if someone calls globalContext() in the
|
||||
// destructor of an object with static lifetime.
|
||||
|
|
|
|||
|
|
@ -91,28 +91,19 @@ class TORCH_API Context {
|
|||
}
|
||||
// defined in header so that getNonVariableType has ability to inline
|
||||
// call_once check. getNonVariableType is called fairly frequently
|
||||
THCState* lazyInitCUDA() {
|
||||
void lazyInitCUDA() {
|
||||
std::call_once(thc_init,[&] {
|
||||
thc_state = detail::getCUDAHooks().initCUDA();
|
||||
detail::getCUDAHooks().initCUDA();
|
||||
});
|
||||
return thc_state.get();
|
||||
}
|
||||
THHState* lazyInitHIP() {
|
||||
void lazyInitHIP() {
|
||||
std::call_once(thh_init,[&] {
|
||||
thh_state = detail::getHIPHooks().initHIP();
|
||||
detail::getHIPHooks().initHIP();
|
||||
});
|
||||
return thh_state.get();
|
||||
}
|
||||
static const at::cuda::NVRTC& getNVRTC() {
|
||||
return detail::getCUDAHooks().nvrtc();
|
||||
}
|
||||
THCState* getTHCState() {
|
||||
// AT_ASSERT(thc_state);
|
||||
return thc_state.get();
|
||||
}
|
||||
THHState* getTHHState() {
|
||||
return thh_state.get();
|
||||
}
|
||||
|
||||
static bool setFlushDenormal(bool on);
|
||||
|
||||
|
|
@ -261,8 +252,6 @@ class TORCH_API Context {
|
|||
#endif
|
||||
bool display_vmap_fallback_warnings_ = false;
|
||||
c10::optional<at::QEngine> quantized_engine = c10::nullopt;
|
||||
std::unique_ptr<THCState, void(*)(THCState*)> thc_state;
|
||||
std::unique_ptr<THHState, void(*)(THHState*)> thh_state;
|
||||
|
||||
Allocator* prev_allocator_ptr_{nullptr};
|
||||
};
|
||||
|
|
|
|||
|
|
@ -16,8 +16,6 @@
|
|||
#include <c10/util/Exception.h>
|
||||
#include <c10/cuda/CUDACachingAllocator.h>
|
||||
|
||||
#include <THC/THCGeneral.h> // For THCState
|
||||
|
||||
#if AT_CUDNN_ENABLED()
|
||||
#include <ATen/cudnn/cudnn-wrapper.h>
|
||||
#endif
|
||||
|
|
@ -57,16 +55,8 @@ void set_magma_init_fn(void (*fn)()) {
|
|||
// NB: deleter is dynamic, because we need it to live in a separate
|
||||
// compilation unit (alt is to have another method in hooks, but
|
||||
// let's not if we don't need to!)
|
||||
std::unique_ptr<THCState, void (*)(THCState*)> CUDAHooks::initCUDA() const {
|
||||
void CUDAHooks::initCUDA() const {
|
||||
C10_LOG_API_USAGE_ONCE("aten.init.cuda");
|
||||
// NOTE: THCState is now an empty struct but this pointer is passed
|
||||
// to every THC function. So we can't remove it before the rest of THC.
|
||||
auto thc_state = std::unique_ptr<THCState, void (*)(THCState*)>(
|
||||
new THCState(),
|
||||
[](THCState* p) {
|
||||
delete p;
|
||||
});
|
||||
|
||||
// Force the update to enable unit testing. This code get executed before unit tests
|
||||
// have a chance to enable vitals.
|
||||
at::vitals::VitalsAPI.setVital("CUDA", "used", "true", /* force = */ true);
|
||||
|
|
@ -79,8 +69,6 @@ std::unique_ptr<THCState, void (*)(THCState*)> CUDAHooks::initCUDA() const {
|
|||
TORCH_INTERNAL_ASSERT(magma_init_fn != nullptr, "Cannot initilaize magma, init routine not set");
|
||||
magma_init_fn();
|
||||
#endif
|
||||
|
||||
return thc_state;
|
||||
}
|
||||
|
||||
const Generator& CUDAHooks::getDefaultCUDAGenerator(DeviceIndex device_index) const {
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ TORCH_CUDA_CPP_API c10::optional<int64_t> getDeviceIndexWithPrimaryContext();
|
|||
// The real implementation of CUDAHooksInterface
|
||||
struct CUDAHooks : public at::CUDAHooksInterface {
|
||||
CUDAHooks(at::CUDAHooksArgs) {}
|
||||
std::unique_ptr<THCState, void(*)(THCState*)> initCUDA() const override;
|
||||
void initCUDA() const override;
|
||||
Device getDeviceFromPtr(void* data) const override;
|
||||
bool isPinnedPtr(void* data) const override;
|
||||
const Generator& getDefaultCUDAGenerator(DeviceIndex device_index = -1) const override;
|
||||
|
|
|
|||
|
|
@ -10,9 +10,6 @@
|
|||
#include <functional>
|
||||
#include <memory>
|
||||
|
||||
// Forward-declares THCState
|
||||
struct THCState;
|
||||
|
||||
// Forward-declares at::cuda::NVRTC
|
||||
namespace at { namespace cuda {
|
||||
struct NVRTC;
|
||||
|
|
@ -73,7 +70,7 @@ struct TORCH_API CUDAHooksInterface {
|
|||
virtual ~CUDAHooksInterface() {}
|
||||
|
||||
// Initialize THCState and, transitively, the CUDA state
|
||||
virtual std::unique_ptr<THCState, void (*)(THCState*)> initCUDA() const {
|
||||
virtual void initCUDA() const {
|
||||
TORCH_CHECK(false, "Cannot initialize CUDA without ATen_cuda library. ", CUDA_HELP);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -10,9 +10,6 @@
|
|||
#include <functional>
|
||||
#include <memory>
|
||||
|
||||
// Forward-declares THHState
|
||||
struct THHState;
|
||||
|
||||
namespace at {
|
||||
class Context;
|
||||
}
|
||||
|
|
@ -29,8 +26,8 @@ struct TORCH_API HIPHooksInterface {
|
|||
// squelch -Werror=non-virtual-dtor
|
||||
virtual ~HIPHooksInterface() {}
|
||||
|
||||
// Initialize THHState and, transitively, the HIP state
|
||||
virtual std::unique_ptr<THHState, void (*)(THHState*)> initHIP() const {
|
||||
// Initialize the HIP library state
|
||||
virtual void initHIP() const {
|
||||
AT_ERROR("Cannot initialize HIP without ATen_hip library.");
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -5,8 +5,6 @@
|
|||
#include <stdexcept>
|
||||
#include <sstream>
|
||||
|
||||
struct THCState;
|
||||
|
||||
namespace at { namespace native {
|
||||
|
||||
class miopen_exception : public std::runtime_error {
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
#pragma once
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <aten/src/THH/THH.h>
|
||||
#include <ATen/miopen/miopen-wrapper.h>
|
||||
#include <ATen/miopen/Handle.h>
|
||||
|
||||
|
|
|
|||
|
|
@ -114,7 +114,7 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> miopen_depthwise_convolution_backwa
|
|||
#include <ATen/native/ConvUtils.h>
|
||||
#include <c10/util/irange.h>
|
||||
|
||||
#include <c10/cuda/CUDACachingAllocator.h>
|
||||
#include <c10/hip/HIPCachingAllocator.h>
|
||||
|
||||
#include <functional>
|
||||
#include <iterator>
|
||||
|
|
|
|||
|
|
@ -37,8 +37,6 @@ namespace at { namespace native {
|
|||
|
||||
#else // AT_ROCM_ENABLED()
|
||||
|
||||
#include <aten/src/THH/THH.h>
|
||||
|
||||
#include <ATen/miopen/miopen-wrapper.h>
|
||||
#include <ATen/miopen/Descriptors.h>
|
||||
#include <ATen/miopen/Types.h>
|
||||
|
|
|
|||
|
|
@ -2,22 +2,8 @@ set(ATen_CUDA_INCLUDE ${ATen_CUDA_INCLUDE}
|
|||
"${CMAKE_CURRENT_SOURCE_DIR}"
|
||||
PARENT_SCOPE)
|
||||
|
||||
set(ATen_CUDA_SRCS ${ATen_CUDA_SRCS}
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/THCTensor.cpp
|
||||
PARENT_SCOPE)
|
||||
|
||||
install(FILES
|
||||
THC.h
|
||||
THCGeneral.h
|
||||
THCGeneral.hpp
|
||||
THCTensor.h
|
||||
THCAtomics.cuh
|
||||
THCDeviceUtils.cuh
|
||||
THCGenerateByteType.h
|
||||
# See Note [TH abstraction violation]
|
||||
THCTensor.hpp
|
||||
DESTINATION "${ATEN_INSTALL_INCLUDE_SUBDIR}/THC")
|
||||
|
||||
install(FILES
|
||||
generic/THCTensor.cpp
|
||||
generic/THCTensor.h
|
||||
DESTINATION "${ATEN_INSTALL_INCLUDE_SUBDIR}/THC/generic")
|
||||
|
|
|
|||
|
|
@ -1,8 +0,0 @@
|
|||
#ifndef THC_INC
|
||||
#define THC_INC
|
||||
|
||||
#include <THC/THCGeneral.h>
|
||||
#include <c10/cuda/CUDACachingAllocator.h>
|
||||
#include <THC/THCTensor.h>
|
||||
|
||||
#endif
|
||||
|
|
@ -1,18 +0,0 @@
|
|||
#ifndef THC_GENERAL_INC
|
||||
#define THC_GENERAL_INC
|
||||
|
||||
#include <TH/THGeneral.h>
|
||||
|
||||
#include <c10/core/Allocator.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cublas_v2.h>
|
||||
|
||||
#include <cusparse.h>
|
||||
|
||||
/* Global state of THC. */
|
||||
struct THCState {};
|
||||
|
||||
#endif
|
||||
|
|
@ -1,3 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
#include <THC/THCGeneral.h>
|
||||
|
|
@ -1,14 +0,0 @@
|
|||
#include <THC/THCGeneral.h>
|
||||
#include <THC/THCTensor.hpp>
|
||||
|
||||
#include <new>
|
||||
|
||||
#include <THC/generic/THCTensor.cpp>
|
||||
#include <THC/THCGenerateByteType.h>
|
||||
|
||||
void THCTensor_setStorage(THCState *state, THCTensor *self, c10::StorageImpl *storage_, ptrdiff_t storageOffset_, at::IntArrayRef size_, at::IntArrayRef stride_)
|
||||
{
|
||||
c10::raw::intrusive_ptr::incref(storage_);
|
||||
THTensor_wrap(self).set_(at::Storage(c10::intrusive_ptr<at::StorageImpl>::reclaim(storage_)),
|
||||
storageOffset_, size_, stride_);
|
||||
}
|
||||
|
|
@ -1,18 +0,0 @@
|
|||
#ifndef THC_TENSOR_INC
|
||||
#define THC_TENSOR_INC
|
||||
|
||||
#include <TH/THTensor.h>
|
||||
#include <THC/THCGeneral.h>
|
||||
|
||||
#define THCTensor_(NAME) TH_CONCAT_4(TH,CReal,Tensor_,NAME)
|
||||
|
||||
#define THC_DESC_BUFF_LEN 64
|
||||
|
||||
typedef struct TORCH_CUDA_CU_API THCDescBuff {
|
||||
char str[THC_DESC_BUFF_LEN];
|
||||
} THCDescBuff;
|
||||
|
||||
#include <THC/generic/THCTensor.h>
|
||||
#include <THC/THCGenerateByteType.h>
|
||||
|
||||
#endif
|
||||
|
|
@ -1,18 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
// STOP!!! Thinking of including this header directly? Please
|
||||
// read Note [TH abstraction violation]
|
||||
|
||||
#include <THC/THCTensor.h>
|
||||
#include <TH/THTensor.hpp>
|
||||
#include <THC/THCGeneral.hpp>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
TORCH_CUDA_CU_API void THCTensor_setStorage(
|
||||
THCState* state,
|
||||
THCTensor* self,
|
||||
c10::StorageImpl* storage_,
|
||||
ptrdiff_t storageOffset_,
|
||||
at::IntArrayRef size_,
|
||||
at::IntArrayRef stride_);
|
||||
|
|
@ -1,24 +0,0 @@
|
|||
#ifndef THC_GENERIC_FILE
|
||||
#define THC_GENERIC_FILE "THC/generic/THCTensor.cpp"
|
||||
#else
|
||||
|
||||
#include <ATen/InferSize.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
|
||||
/**** creation methods ****/
|
||||
|
||||
THCTensor *THCTensor_(newWithStorage1d)(THCState *state, c10::StorageImpl *storage, ptrdiff_t storageOffset,
|
||||
int64_t size0, int64_t stride0)
|
||||
{
|
||||
c10::raw::intrusive_ptr::incref(storage);
|
||||
THTensor* self = c10::make_intrusive<at::TensorImpl, at::UndefinedTensorImpl>(
|
||||
c10::intrusive_ptr<at::StorageImpl>::reclaim(storage),
|
||||
at::DispatchKey::CUDA,
|
||||
caffe2::TypeMeta::Make<scalar_t>())
|
||||
.release();
|
||||
THCTensor_setStorage(state, self, storage, storageOffset, {size0}, {stride0});
|
||||
|
||||
return self;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
|
@ -1,33 +0,0 @@
|
|||
#ifndef THC_GENERIC_FILE
|
||||
#define THC_GENERIC_FILE "THC/generic/THCTensor.h"
|
||||
#else
|
||||
|
||||
#define THCTensor THTensor
|
||||
|
||||
// These used to be distinct types; for some measure of backwards compatibility and documentation
|
||||
// alias these to the single THCTensor type.
|
||||
#define THCudaTensor THCTensor
|
||||
#define THCudaDoubleTensor THCTensor
|
||||
#define THCudaHalfTensor THCTensor
|
||||
#define THCudaByteTensor THCTensor
|
||||
#define THCudaCharTensor THCTensor
|
||||
#define THCudaShortTensor THCTensor
|
||||
#define THCudaIntTensor THCTensor
|
||||
#define THCudaLongTensor THCTensor
|
||||
#define THCudaBoolTensor THCTensor
|
||||
#define THCudaBFloat16Tensor THCTensor
|
||||
#define THCudaComplexFloatTensor THCTensor
|
||||
#define THCudaComplexDoubleTensor THCTensor
|
||||
|
||||
/**** access methods ****/
|
||||
TORCH_CUDA_CU_API c10::StorageImpl* THCTensor_(
|
||||
storage)(THCState* state, const THCTensor* self);
|
||||
/**** creation methods ****/
|
||||
TORCH_CUDA_CU_API THCTensor* THCTensor_(newWithStorage1d)(
|
||||
THCState* state,
|
||||
c10::StorageImpl* storage_,
|
||||
ptrdiff_t storageOffset_,
|
||||
int64_t size0_,
|
||||
int64_t stride0_);
|
||||
|
||||
#endif
|
||||
|
|
@ -63,11 +63,9 @@ Extension
|
|||
CFFI Extension
|
||||
^^^^^^^^^^^^^^
|
||||
|
||||
The support for CFFI Extension is very experimental. There're
|
||||
generally two steps to enable it under Windows.
|
||||
|
||||
First, specify additional ``libraries`` in ``Extension``
|
||||
object to make it build on Windows.
|
||||
The support for CFFI Extension is very experimental. You must specify
|
||||
additional ``libraries`` in ``Extension`` object to make it build on
|
||||
Windows.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
|
@ -82,35 +80,6 @@ object to make it build on Windows.
|
|||
libraries=['ATen', '_C'] # Append cuda libraries when necessary, like cudart
|
||||
)
|
||||
|
||||
Second, here is a workground for "unresolved external symbol
|
||||
state caused by ``extern THCState *state;``"
|
||||
|
||||
Change the source code from C to C++. An example is listed below.
|
||||
|
||||
.. code-block:: cpp
|
||||
|
||||
#include <THC/THC.h>
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
THCState *state = at::globalContext().thc_state;
|
||||
|
||||
extern "C" int my_lib_add_forward_cuda(THCudaTensor *input1, THCudaTensor *input2,
|
||||
THCudaTensor *output)
|
||||
{
|
||||
if (!THCudaTensor_isSameSizeAs(state, input1, input2))
|
||||
return 0;
|
||||
THCudaTensor_resizeAs(state, output, input1);
|
||||
THCudaTensor_cadd(state, output, input1, 1.0, input2);
|
||||
return 1;
|
||||
}
|
||||
|
||||
extern "C" int my_lib_add_backward_cuda(THCudaTensor *grad_output, THCudaTensor *grad_input)
|
||||
{
|
||||
THCudaTensor_resizeAs(state, grad_input, grad_output);
|
||||
THCudaTensor_fill(state, grad_input, 1);
|
||||
return 1;
|
||||
}
|
||||
|
||||
Cpp Extension
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
|
|
|
|||
|
|
@ -68,9 +68,9 @@ class AsyncInputIsOutputTest : public AsyncTest {
|
|||
AsyncInputIsOutputTest(const std::string& path, int numTensors)
|
||||
: AsyncTest(path),
|
||||
numTensors_(numTensors),
|
||||
numDevices_(cudaNumDevices()),
|
||||
state_(::at::globalContext().lazyInitCUDA()) {
|
||||
numDevices_(cudaNumDevices()) {
|
||||
// Allocate inputs on available devices in a round robin fashion.
|
||||
::at::globalContext().lazyInitCUDA();
|
||||
inputs_.resize(numTensors_);
|
||||
for (const auto i : c10::irange(numTensors_)) {
|
||||
inputs_[i] = at::empty(
|
||||
|
|
@ -121,7 +121,6 @@ class AsyncInputIsOutputTest : public AsyncTest {
|
|||
protected:
|
||||
const int numTensors_;
|
||||
const int numDevices_;
|
||||
THCState* state_;
|
||||
std::vector<at::Tensor> inputs_;
|
||||
std::vector<CUDAStream> streams_;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -57,9 +57,9 @@ class NCCLTest : public NCCLTestBase {
|
|||
NCCLTest(const std::string& path, int worldSize, std::chrono::milliseconds pgTimeout = kProcessGroupDefaultTimeout)
|
||||
: NCCLTestBase(path, pgTimeout),
|
||||
numDevices_(cudaNumDevices()),
|
||||
state_(::at::globalContext().lazyInitCUDA()),
|
||||
worldSize_(worldSize) {
|
||||
// Each device has a single tensor to perf the NCCL op
|
||||
::at::globalContext().lazyInitCUDA();
|
||||
tensors_.resize(numDevices_);
|
||||
inputs_.resize(numDevices_);
|
||||
outputs_.resize(numDevices_);
|
||||
|
|
@ -163,7 +163,6 @@ class NCCLTest : public NCCLTestBase {
|
|||
}
|
||||
|
||||
const int numDevices_;
|
||||
THCState* state_;
|
||||
int worldSize_;
|
||||
std::vector<at::Tensor> tensors_;
|
||||
std::vector<std::vector<at::Tensor>> inputs_;
|
||||
|
|
|
|||
|
|
@ -422,8 +422,6 @@ class TestCppExtensionJIT(common.TestCase):
|
|||
for the corresponding issue.
|
||||
"""
|
||||
cuda_source = """
|
||||
#include <THC/THCGeneral.h>
|
||||
|
||||
template<typename T, typename U>
|
||||
__global__ void half_test_kernel(const T* input, U* output) {
|
||||
if (input[0] < input[1] || input[0] >= input[1]) {
|
||||
|
|
|
|||
|
|
@ -20,10 +20,6 @@
|
|||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#ifdef USE_CUDA
|
||||
#include <THC/THC.h>
|
||||
#endif
|
||||
|
||||
namespace torch {
|
||||
namespace {
|
||||
std::unordered_map<at::DeprecatedTypeProperties*, PyTypeObject*> attype_to_py_storage_type;
|
||||
|
|
|
|||
|
|
@ -34,7 +34,6 @@
|
|||
|
||||
using namespace torch;
|
||||
|
||||
THCState *state = nullptr;
|
||||
static bool in_bad_fork = false; // True for children forked after cuda init
|
||||
|
||||
#ifndef WIN32
|
||||
|
|
@ -42,7 +41,6 @@ static bool in_bad_fork = false; // True for children forked after cuda init
|
|||
static void forked_child() {
|
||||
in_bad_fork = true;
|
||||
torch::utils::set_run_yet_variable_to_false();
|
||||
state = nullptr;
|
||||
}
|
||||
#endif
|
||||
|
||||
|
|
@ -522,7 +520,7 @@ static PyObject * THCPModule_initExtension(PyObject *self, PyObject *noargs)
|
|||
HANDLE_TH_ERRORS
|
||||
TORCH_INTERNAL_ASSERT(!in_bad_fork); // Handled at python level
|
||||
poison_fork();
|
||||
state = at::globalContext().lazyInitCUDA();
|
||||
at::globalContext().lazyInitCUDA();
|
||||
|
||||
auto m = THPObjectPtr(PyImport_ImportModule("torch.cuda"));
|
||||
if (!m) throw python_error();
|
||||
|
|
@ -542,10 +540,6 @@ static PyObject * THCPModule_initExtension(PyObject *self, PyObject *noargs)
|
|||
set_module_attr("has_magma", at::hasMAGMA() ? Py_True : Py_False);
|
||||
set_module_attr("has_half", has_half ? Py_True : Py_False);
|
||||
|
||||
auto _state_cdata = THPObjectPtr(PyLong_FromVoidPtr(state));
|
||||
if (!_state_cdata) throw python_error();
|
||||
set_module_attr("_state_cdata", _state_cdata.get());
|
||||
|
||||
auto num_gpus = c10::cuda::device_count();
|
||||
auto default_cuda_generators = PyTuple_New(static_cast<Py_ssize_t>(num_gpus));
|
||||
for(const auto i : c10::irange(num_gpus)) {
|
||||
|
|
|
|||
|
|
@ -1,10 +1,6 @@
|
|||
#ifndef THCP_CUDA_MODULE_INC
|
||||
#define THCP_CUDA_MODULE_INC
|
||||
|
||||
#include <ATen/Context.h>
|
||||
|
||||
extern THCState *state;
|
||||
|
||||
void THCPModule_setDevice(int idx);
|
||||
PyObject * THCPModule_getDevice_wrap(PyObject *self);
|
||||
PyObject * THCPModule_setDevice_wrap(PyObject *self, PyObject *arg);
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@
|
|||
// See Note [TH abstraction violation]
|
||||
// - Used to get at allocator from storage
|
||||
#include <TH/THTensor.hpp>
|
||||
#include <THC/THCTensor.hpp>
|
||||
#include <torch/csrc/cuda/THCP.h>
|
||||
|
||||
#include <torch/csrc/cuda/override_macros.h>
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
#define THCP_H
|
||||
|
||||
#include <torch/csrc/python_headers.h>
|
||||
|
||||
#include <torch/csrc/THP.h>
|
||||
#include <torch/csrc/cuda/serialization.h>
|
||||
#include <torch/csrc/cuda/Module.h>
|
||||
|
|
|
|||
|
|
@ -9,8 +9,6 @@
|
|||
#include <tuple>
|
||||
#include <unordered_set>
|
||||
|
||||
#include <THC/THC.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/core/DeviceType.h>
|
||||
#include <c10/cuda/CUDAGraphsC10Utils.h>
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@
|
|||
struct THPStorage;
|
||||
struct THSPTensor;
|
||||
|
||||
typedef class THPPointer<THWTensor> THWTensorPtr;
|
||||
typedef class THPPointer<THPStorage> THPStoragePtr;
|
||||
|
||||
#if (!defined(THC_GENERIC_FILE)) && \
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@
|
|||
#include <ATen/CUDAGeneratorImpl.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/nvrtc_stub/ATenNVRTC.h>
|
||||
#include <THC/THC.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/csrc/jit/resource_guard.h>
|
||||
|
||||
|
|
|
|||
|
|
@ -179,7 +179,7 @@ bool maybeThrowBackCompatKeepdimWarn(char *func) {
|
|||
template<>
|
||||
void THPPointer<THTensor>::free() {
|
||||
if (ptr) {
|
||||
THTensor_free(LIBRARY_STATE ptr);
|
||||
c10::raw::intrusive_ptr::decref(ptr);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@
|
|||
#include <torch/csrc/utils/python_compat.h>
|
||||
|
||||
#ifdef USE_CUDA
|
||||
#include <THC/THC.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#endif
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user