mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[PyTorch] Autoformat c10 (#56830)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/56830 Opt into formatting on GitHub and format everything. This is a trial run before turning on formatting for more and eventually all of the codebase. Test Plan: CI Reviewed By: zertosh Differential Revision: D27979080 fbshipit-source-id: a80f0c48691c08ae8ca0af06377b87e6a2351151
This commit is contained in:
parent
3c4d57c18b
commit
44cc873fba
|
|
@ -18,7 +18,6 @@ class Foo : public intrusive_ptr_target {
|
|||
int param;
|
||||
};
|
||||
|
||||
|
||||
class Bar : public std::enable_shared_from_this<Bar> {
|
||||
public:
|
||||
Bar(int param_) : param(param_) {}
|
||||
|
|
@ -48,7 +47,7 @@ BENCHMARK(BM_SharedPtrCtorDtor);
|
|||
static void BM_IntrusivePtrArray(benchmark::State& state) {
|
||||
intrusive_ptr<Foo> var = make_intrusive<Foo>(0);
|
||||
const size_t kLength = state.range(0);
|
||||
std::vector<intrusive_ptr<Foo> > vararray(kLength);
|
||||
std::vector<intrusive_ptr<Foo>> vararray(kLength);
|
||||
while (state.KeepRunning()) {
|
||||
for (const auto i : c10::irange(kLength)) {
|
||||
vararray[i] = var;
|
||||
|
|
@ -64,7 +63,7 @@ BENCHMARK(BM_IntrusivePtrArray)->RangeMultiplier(2)->Range(16, 4096);
|
|||
static void BM_SharedPtrArray(benchmark::State& state) {
|
||||
std::shared_ptr<Bar> var = std::make_shared<Bar>(0);
|
||||
const size_t kLength = state.range(0);
|
||||
std::vector<std::shared_ptr<Bar> > vararray(kLength);
|
||||
std::vector<std::shared_ptr<Bar>> vararray(kLength);
|
||||
while (state.KeepRunning()) {
|
||||
for (const auto i : c10::irange(kLength)) {
|
||||
vararray[i] = var;
|
||||
|
|
@ -78,5 +77,4 @@ static void BM_SharedPtrArray(benchmark::State& state) {
|
|||
BENCHMARK(BM_SharedPtrArray)->RangeMultiplier(2)->Range(16, 4096);
|
||||
} // namespace
|
||||
|
||||
|
||||
BENCHMARK_MAIN();
|
||||
|
|
|
|||
|
|
@ -12,10 +12,11 @@ at::DataPtr InefficientStdFunctionContext::makeDataPtr(
|
|||
void* ptr,
|
||||
const std::function<void(void*)>& deleter,
|
||||
Device device) {
|
||||
return {ptr,
|
||||
new InefficientStdFunctionContext({ptr, deleter}),
|
||||
&deleteInefficientStdFunctionContext,
|
||||
device};
|
||||
return {
|
||||
ptr,
|
||||
new InefficientStdFunctionContext({ptr, deleter}),
|
||||
&deleteInefficientStdFunctionContext,
|
||||
device};
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
|
||||
|
|
|
|||
|
|
@ -94,7 +94,9 @@ class C10_API DataPtr {
|
|||
* be; be sure to read the source code of the Allocator
|
||||
* in question to confirm this.
|
||||
*/
|
||||
C10_NODISCARD bool compare_exchange_deleter(DeleterFnPtr expected_deleter, DeleterFnPtr new_deleter) {
|
||||
C10_NODISCARD bool compare_exchange_deleter(
|
||||
DeleterFnPtr expected_deleter,
|
||||
DeleterFnPtr new_deleter) {
|
||||
return ptr_.compare_exchange_deleter(expected_deleter, new_deleter);
|
||||
}
|
||||
Device device() const {
|
||||
|
|
@ -215,8 +217,8 @@ struct AllocatorRegisterer {
|
|||
}
|
||||
};
|
||||
|
||||
#define REGISTER_ALLOCATOR(t, f) \
|
||||
namespace { \
|
||||
#define REGISTER_ALLOCATOR(t, f) \
|
||||
namespace { \
|
||||
static AllocatorRegisterer<t> g_allocator_d(f); \
|
||||
}
|
||||
|
||||
|
|
@ -227,12 +229,18 @@ struct C10_API MemoryReportingInfoBase : public c10::DebugInfoBase {
|
|||
virtual ~MemoryReportingInfoBase() {}
|
||||
|
||||
// Negative alloc_size corresponds to freeing of the memory
|
||||
virtual void reportMemoryUsage(void* ptr, int64_t alloc_size, Device device) = 0;
|
||||
virtual void reportMemoryUsage(
|
||||
void* ptr,
|
||||
int64_t alloc_size,
|
||||
Device device) = 0;
|
||||
|
||||
virtual bool memoryProfilingEnabled() const = 0;
|
||||
};
|
||||
|
||||
C10_API bool memoryProfilingEnabled();
|
||||
C10_API void reportMemoryUsageToProfiler(void* ptr, int64_t alloc_size, Device device);
|
||||
C10_API void reportMemoryUsageToProfiler(
|
||||
void* ptr,
|
||||
int64_t alloc_size,
|
||||
Device device);
|
||||
|
||||
} // namespace c10
|
||||
|
|
|
|||
|
|
@ -253,7 +253,7 @@ static inline bool isSparse(Backend b) {
|
|||
}
|
||||
|
||||
static inline bool isSparseCsr(Backend b) {
|
||||
switch(b) {
|
||||
switch (b) {
|
||||
case Backend::SparseCsrCPU:
|
||||
case Backend::SparseCsrCUDA:
|
||||
return true;
|
||||
|
|
|
|||
|
|
@ -29,9 +29,11 @@ namespace c10 {
|
|||
* }
|
||||
* EXPECT_EQ(3, execute(TORCH_FN(add), 1, 2));
|
||||
*/
|
||||
template<class FuncType_, FuncType_* func_ptr_>
|
||||
template <class FuncType_, FuncType_* func_ptr_>
|
||||
struct CompileTimeFunctionPointer final {
|
||||
static_assert(guts::is_function_type<FuncType_>::value, "TORCH_FN can only wrap function types.");
|
||||
static_assert(
|
||||
guts::is_function_type<FuncType_>::value,
|
||||
"TORCH_FN can only wrap function types.");
|
||||
using FuncType = FuncType_;
|
||||
|
||||
static constexpr FuncType* func_ptr() {
|
||||
|
|
@ -39,11 +41,16 @@ struct CompileTimeFunctionPointer final {
|
|||
}
|
||||
};
|
||||
|
||||
template<class T> struct is_compile_time_function_pointer : std::false_type {};
|
||||
template<class FuncType, FuncType* func_ptr>
|
||||
struct is_compile_time_function_pointer<CompileTimeFunctionPointer<FuncType, func_ptr>> : std::true_type {};
|
||||
template <class T>
|
||||
struct is_compile_time_function_pointer : std::false_type {};
|
||||
template <class FuncType, FuncType* func_ptr>
|
||||
struct is_compile_time_function_pointer<
|
||||
CompileTimeFunctionPointer<FuncType, func_ptr>> : std::true_type {};
|
||||
|
||||
}
|
||||
} // namespace c10
|
||||
|
||||
#define TORCH_FN_TYPE(func) ::c10::CompileTimeFunctionPointer<std::remove_pointer_t<std::remove_reference_t<decltype(func)>>, func>
|
||||
#define TORCH_FN_TYPE(func) \
|
||||
::c10::CompileTimeFunctionPointer< \
|
||||
std::remove_pointer_t<std::remove_reference_t<decltype(func)>>, \
|
||||
func>
|
||||
#define TORCH_FN(func) TORCH_FN_TYPE(func)()
|
||||
|
|
|
|||
|
|
@ -47,4 +47,4 @@ void CopyBytes(
|
|||
ptr(nbytes, src, src_device, dst, dst_device);
|
||||
}
|
||||
|
||||
}
|
||||
} // namespace c10
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
#include <c10/util/typeid.h>
|
||||
#include <c10/core/DefaultDtype.h>
|
||||
#include <c10/util/typeid.h>
|
||||
|
||||
namespace c10 {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
|
|
@ -7,7 +7,8 @@ static auto default_dtype = caffe2::TypeMeta::Make<float>();
|
|||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
static auto default_dtype_as_scalartype = default_dtype.toScalarType();
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
static auto default_complex_dtype = caffe2::TypeMeta::Make<c10::complex<float>>();
|
||||
static auto default_complex_dtype =
|
||||
caffe2::TypeMeta::Make<c10::complex<float>>();
|
||||
|
||||
void set_default_dtype(caffe2::TypeMeta dtype) {
|
||||
default_dtype = dtype;
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
#pragma once
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
|
||||
namespace caffe2 {
|
||||
class TypeMeta;
|
||||
|
|
|
|||
|
|
@ -13,19 +13,27 @@ struct TensorOptions;
|
|||
struct DefaultTensorOptions {
|
||||
DefaultTensorOptions() = default;
|
||||
|
||||
caffe2::TypeMeta dtype() const noexcept { return dtype_; }
|
||||
Device device() const noexcept { return device_; }
|
||||
Layout layout() const noexcept { return layout_; }
|
||||
bool requires_grad() const noexcept { return requires_grad_; }
|
||||
caffe2::TypeMeta dtype() const noexcept {
|
||||
return dtype_;
|
||||
}
|
||||
Device device() const noexcept {
|
||||
return device_;
|
||||
}
|
||||
Layout layout() const noexcept {
|
||||
return layout_;
|
||||
}
|
||||
bool requires_grad() const noexcept {
|
||||
return requires_grad_;
|
||||
}
|
||||
|
||||
// Defined in TensorOptions.h
|
||||
inline DefaultTensorOptions& merge(const TensorOptions& options);
|
||||
|
||||
private:
|
||||
caffe2::TypeMeta dtype_ = caffe2::TypeMeta::Make<float>(); // 64-bit
|
||||
Device device_ = at::kCPU; // 32-bit
|
||||
Layout layout_ = at::kStrided; // 8-bit
|
||||
bool requires_grad_ = false; // 8-bit
|
||||
Device device_ = at::kCPU; // 32-bit
|
||||
Layout layout_ = at::kStrided; // 8-bit
|
||||
bool requires_grad_ = false; // 8-bit
|
||||
};
|
||||
|
||||
inline const DefaultTensorOptions& getDefaultTensorOptions() {
|
||||
|
|
|
|||
|
|
@ -6,25 +6,24 @@
|
|||
#include <array>
|
||||
#include <exception>
|
||||
#include <ostream>
|
||||
#include <regex>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
#include <regex>
|
||||
|
||||
// Check if compiler has working std::regex implementation
|
||||
//
|
||||
// Test below is adapted from https://stackoverflow.com/a/41186162
|
||||
#if defined(_MSVC_LANG) && _MSVC_LANG >= 201103L
|
||||
// Compiler has working regex. MSVC has erroneous __cplusplus.
|
||||
// Compiler has working regex. MSVC has erroneous __cplusplus.
|
||||
#elif __cplusplus >= 201103L && \
|
||||
(!defined(__GLIBCXX__) || (__cplusplus >= 201402L) || \
|
||||
(defined(_GLIBCXX_REGEX_DFS_QUANTIFIERS_LIMIT) || \
|
||||
defined(_GLIBCXX_REGEX_STATE_LIMIT) || \
|
||||
(defined(_GLIBCXX_RELEASE) && \
|
||||
_GLIBCXX_RELEASE > 4)))
|
||||
// Compiler has working regex.
|
||||
(defined(_GLIBCXX_REGEX_DFS_QUANTIFIERS_LIMIT) || \
|
||||
defined(_GLIBCXX_REGEX_STATE_LIMIT) || \
|
||||
(defined(_GLIBCXX_RELEASE) && _GLIBCXX_RELEASE > 4)))
|
||||
// Compiler has working regex.
|
||||
#else
|
||||
static_assert(false, "Compiler does not have proper regex support.");
|
||||
static_assert(false, "Compiler does not have proper regex support.");
|
||||
#endif
|
||||
|
||||
namespace c10 {
|
||||
|
|
@ -58,7 +57,8 @@ DeviceType parse_type(const std::string& device_string) {
|
|||
if (device != types.end()) {
|
||||
return device->second;
|
||||
}
|
||||
TORCH_CHECK(false,
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Expected one of cpu, cuda, xpu, mkldnn, opengl, opencl, ideep, hip, msnpu, mlc, xla, vulkan, meta device type at start of device string: ",
|
||||
device_string);
|
||||
}
|
||||
|
|
@ -71,16 +71,22 @@ Device::Device(const std::string& device_string) : Device(Type::CPU) {
|
|||
static const std::regex regex("([a-zA-Z_]+)(?::([1-9]\\d*|0))?");
|
||||
std::smatch match;
|
||||
TORCH_CHECK(
|
||||
std::regex_match(device_string, match, regex),
|
||||
"Invalid device string: '", device_string, "'");
|
||||
std::regex_match(device_string, match, regex),
|
||||
"Invalid device string: '",
|
||||
device_string,
|
||||
"'");
|
||||
type_ = parse_type(match[1].str());
|
||||
if (match[2].matched) {
|
||||
try {
|
||||
index_ = c10::stoi(match[2].str());
|
||||
} catch (const std::exception &) {
|
||||
TORCH_CHECK(false,
|
||||
"Could not parse device index '", match[2].str(),
|
||||
"' in device string '", device_string, "'");
|
||||
} catch (const std::exception&) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Could not parse device index '",
|
||||
match[2].str(),
|
||||
"' in device string '",
|
||||
device_string,
|
||||
"'");
|
||||
}
|
||||
}
|
||||
validate();
|
||||
|
|
|
|||
|
|
@ -107,16 +107,18 @@ struct C10_API Device final {
|
|||
// performance in micro-benchmarks.
|
||||
// This is safe to do, because backends that use the DeviceIndex
|
||||
// have a later check when we actually try to switch to that device.
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(index_ == -1 || index_ >= 0,
|
||||
"Device index must be -1 or non-negative, got ", (int)index_);
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!is_cpu() || index_ <= 0,
|
||||
"CPU device index must be -1 or zero, got ", (int)index_);
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
||||
index_ == -1 || index_ >= 0,
|
||||
"Device index must be -1 or non-negative, got ",
|
||||
(int)index_);
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
||||
!is_cpu() || index_ <= 0,
|
||||
"CPU device index must be -1 or zero, got ",
|
||||
(int)index_);
|
||||
}
|
||||
};
|
||||
|
||||
C10_API std::ostream& operator<<(
|
||||
std::ostream& stream,
|
||||
const Device& device);
|
||||
C10_API std::ostream& operator<<(std::ostream& stream, const Device& device);
|
||||
|
||||
} // namespace c10
|
||||
|
||||
|
|
@ -136,10 +138,11 @@ struct hash<c10::Device> {
|
|||
// half of the resulting integer.
|
||||
//
|
||||
// Technically, by C/C++ integer promotion rules, we only need one of the
|
||||
// uint32_t casts to the result type, but we put in both for explicitness's sake.
|
||||
uint32_t bits =
|
||||
static_cast<uint32_t>(static_cast<uint8_t>(d.type())) << 16
|
||||
| static_cast<uint32_t>(static_cast<uint8_t>(d.index()));
|
||||
// uint32_t casts to the result type, but we put in both for explicitness's
|
||||
// sake.
|
||||
uint32_t bits = static_cast<uint32_t>(static_cast<uint8_t>(d.type()))
|
||||
<< 16 |
|
||||
static_cast<uint32_t>(static_cast<uint8_t>(d.index()));
|
||||
return std::hash<uint32_t>{}(bits);
|
||||
}
|
||||
};
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ namespace c10 {
|
|||
/// want to setup a guard (i.e., are looking for the moral equivalent
|
||||
/// of optional<DeviceGuard>), see OptionalDeviceGuard.
|
||||
class DeviceGuard {
|
||||
public:
|
||||
public:
|
||||
/// No default constructor; see Note [Omitted default constructor from RAII]
|
||||
explicit DeviceGuard() = delete;
|
||||
|
||||
|
|
@ -25,7 +25,10 @@ public:
|
|||
explicit DeviceGuard(Device device) : guard_(device) {}
|
||||
|
||||
/// This constructor is for testing only.
|
||||
explicit DeviceGuard(Device device, const impl::DeviceGuardImplInterface* impl) : guard_(device, impl) {}
|
||||
explicit DeviceGuard(
|
||||
Device device,
|
||||
const impl::DeviceGuardImplInterface* impl)
|
||||
: guard_(device, impl) {}
|
||||
|
||||
/// Copy is disallowed
|
||||
DeviceGuard(const DeviceGuard&) = delete;
|
||||
|
|
@ -48,7 +51,9 @@ public:
|
|||
}
|
||||
|
||||
/// This method is for testing only.
|
||||
void reset_device(at::Device device, const impl::DeviceGuardImplInterface* impl) {
|
||||
void reset_device(
|
||||
at::Device device,
|
||||
const impl::DeviceGuardImplInterface* impl) {
|
||||
guard_.reset_device(device, impl);
|
||||
}
|
||||
|
||||
|
|
@ -69,7 +74,7 @@ public:
|
|||
return guard_.current_device();
|
||||
}
|
||||
|
||||
private:
|
||||
private:
|
||||
impl::InlineDeviceGuard<impl::VirtualGuardImpl> guard_;
|
||||
};
|
||||
|
||||
|
|
@ -79,8 +84,8 @@ private:
|
|||
* Morally, a OptionalDeviceGuard is equivalent to optional<DeviceGuard>, but
|
||||
* with extra constructors and methods as appropriate.
|
||||
*
|
||||
* Besides its obvious use (optionally applying a DeviceGuard), OptionalDeviceGuard
|
||||
* is often also used for the following idiom:
|
||||
* Besides its obvious use (optionally applying a DeviceGuard),
|
||||
* OptionalDeviceGuard is often also used for the following idiom:
|
||||
*
|
||||
* OptionalDeviceGuard g;
|
||||
* for (const auto& t : tensors) {
|
||||
|
|
@ -117,7 +122,7 @@ private:
|
|||
* DeviceGuard will still reset the device to original_device_.
|
||||
*/
|
||||
class OptionalDeviceGuard {
|
||||
public:
|
||||
public:
|
||||
/// Create an uninitialized guard. Set the guard later using reset_device.
|
||||
explicit OptionalDeviceGuard() : guard_() {}
|
||||
|
||||
|
|
@ -129,7 +134,10 @@ public:
|
|||
explicit OptionalDeviceGuard(optional<Device> device) : guard_(device) {}
|
||||
|
||||
/// Constructor for testing only.
|
||||
explicit OptionalDeviceGuard(Device device, const impl::DeviceGuardImplInterface* impl) : guard_(device, impl) {}
|
||||
explicit OptionalDeviceGuard(
|
||||
Device device,
|
||||
const impl::DeviceGuardImplInterface* impl)
|
||||
: guard_(device, impl) {}
|
||||
|
||||
/// Copy is disallowed
|
||||
OptionalDeviceGuard(const OptionalDeviceGuard&) = delete;
|
||||
|
|
@ -149,7 +157,9 @@ public:
|
|||
}
|
||||
|
||||
/// For testing only
|
||||
void reset_device(at::Device device, const impl::DeviceGuardImplInterface* impl) {
|
||||
void reset_device(
|
||||
at::Device device,
|
||||
const impl::DeviceGuardImplInterface* impl) {
|
||||
guard_.reset_device(device, impl);
|
||||
}
|
||||
|
||||
|
|
@ -164,7 +174,7 @@ public:
|
|||
return guard_.current_device();
|
||||
}
|
||||
|
||||
private:
|
||||
private:
|
||||
impl::InlineOptionalDeviceGuard<impl::VirtualGuardImpl> guard_;
|
||||
};
|
||||
|
||||
|
|
@ -173,7 +183,8 @@ private:
|
|||
// Design note: in principle, we could avoid these wrappers using:
|
||||
//
|
||||
// using DeviceGuard = impl::InlineDeviceGuard<impl::VirtualGuardImpl>;
|
||||
// using OptionalDeviceGuard = impl::InlineOptionalDeviceGuard<impl::VirtualGuardImpl>;
|
||||
// using OptionalDeviceGuard =
|
||||
// impl::InlineOptionalDeviceGuard<impl::VirtualGuardImpl>;
|
||||
//
|
||||
// But the error messages are worse, and our users can't just look at the
|
||||
// header file to find out what's going on. Furthermore, for specializations
|
||||
|
|
|
|||
|
|
@ -38,7 +38,8 @@ std::string DeviceTypeName(DeviceType d, bool lower_case) {
|
|||
case DeviceType::Meta:
|
||||
return lower_case ? "meta" : "META";
|
||||
default:
|
||||
TORCH_CHECK(false,
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Unknown device: ",
|
||||
static_cast<int16_t>(d),
|
||||
". If you have recently updated the caffe2.proto file to add a new "
|
||||
|
|
|
|||
|
|
@ -7,8 +7,8 @@
|
|||
|
||||
#include <c10/macros/Macros.h>
|
||||
|
||||
#include <ostream>
|
||||
#include <functional>
|
||||
#include <ostream>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
|
|
@ -51,7 +51,8 @@ constexpr DeviceType kXPU = DeviceType::XPU;
|
|||
constexpr int COMPILE_TIME_MAX_DEVICE_TYPES =
|
||||
static_cast<int>(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES);
|
||||
|
||||
static_assert(COMPILE_TIME_MAX_DEVICE_TYPES <= 16,
|
||||
static_assert(
|
||||
COMPILE_TIME_MAX_DEVICE_TYPES <= 16,
|
||||
"Hey! You seem to be adding a lot of new DeviceTypes. The intent was "
|
||||
"for this constant to reflect the actual number of DeviceTypes we support "
|
||||
"in PyTorch; it's important that this number is not too large as we "
|
||||
|
|
@ -61,9 +62,7 @@ static_assert(COMPILE_TIME_MAX_DEVICE_TYPES <= 16,
|
|||
"types registration, please be aware that you are affecting code that "
|
||||
"this number is small. Try auditing uses of this constant.");
|
||||
|
||||
C10_API std::string DeviceTypeName(
|
||||
DeviceType d,
|
||||
bool lower_case = false);
|
||||
C10_API std::string DeviceTypeName(DeviceType d, bool lower_case = false);
|
||||
|
||||
C10_API bool isValidDeviceType(DeviceType d);
|
||||
|
||||
|
|
@ -72,7 +71,8 @@ C10_API std::ostream& operator<<(std::ostream& stream, DeviceType type);
|
|||
} // namespace c10
|
||||
|
||||
namespace std {
|
||||
template <> struct hash<c10::DeviceType> {
|
||||
template <>
|
||||
struct hash<c10::DeviceType> {
|
||||
std::size_t operator()(c10::DeviceType k) const {
|
||||
return std::hash<int>()(static_cast<int>(k));
|
||||
}
|
||||
|
|
@ -80,5 +80,5 @@ template <> struct hash<c10::DeviceType> {
|
|||
} // namespace std
|
||||
|
||||
namespace torch {
|
||||
using c10::DeviceType;
|
||||
using c10::DeviceType;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/ArrayRef.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
|
|
@ -58,7 +58,7 @@ enum class DispatchKey : uint8_t {
|
|||
HIP, // NB: I think this is not actually used, due to Note [Masquerading as
|
||||
// CUDA]
|
||||
FPGA, // Xilinx support lives out of tree at
|
||||
// https://gitlab.com/pytorch-complex/vitis_kernels
|
||||
// https://gitlab.com/pytorch-complex/vitis_kernels
|
||||
MSNPU, // unused externally, but tested at
|
||||
// test/cpp_extensions/msnpu_extension.cpp
|
||||
XLA, // lives out of tree at https://github.com/pytorch/xla
|
||||
|
|
@ -177,7 +177,8 @@ enum class DispatchKey : uint8_t {
|
|||
// But this work is currently blocked since it adds an extra dispatch
|
||||
// for all ops and it's non-trivial overhead at model level(a few percents).
|
||||
// Thus our current approach takes advantage of the fact every kernel go
|
||||
// through VariableType kernel first and pulls the `at::AutoDispatchBelowInplaceOrView` guard of functional ops
|
||||
// through VariableType kernel first and pulls the
|
||||
// `at::AutoDispatchBelowInplaceOrView` guard of functional ops
|
||||
// up to the `VariableType` kernel. Thus we only add the extra dispatch
|
||||
// to view/inplace ops to minimize its perf impact to real models.
|
||||
InplaceOrView,
|
||||
|
|
@ -213,7 +214,8 @@ enum class DispatchKey : uint8_t {
|
|||
AutogradXLA,
|
||||
AutogradXPU,
|
||||
AutogradMLC,
|
||||
AutogradNestedTensor, // lives out of tree at https://github.com/pytorch/nestedtensor
|
||||
AutogradNestedTensor, // lives out of tree at
|
||||
// https://github.com/pytorch/nestedtensor
|
||||
// Here are some reserved pre-autograd keys for user-defined backends, see
|
||||
// Note [Private use DispatchKey]
|
||||
AutogradPrivateUse1,
|
||||
|
|
@ -224,7 +226,7 @@ enum class DispatchKey : uint8_t {
|
|||
|
||||
// Autocasting precedes VariableTypeId, to ensure casts are autograd-exposed
|
||||
// and inputs are saved for backward in the post-autocast type.
|
||||
//AutocastCPU,
|
||||
// AutocastCPU,
|
||||
AutocastCUDA,
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~ WRAPPERS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ //
|
||||
|
|
@ -278,9 +280,10 @@ enum class DispatchKey : uint8_t {
|
|||
|
||||
// See Note [Alias Dispatch Key : Autograd]
|
||||
Autograd,
|
||||
CompositeImplicitAutograd, // registered at build/aten/src/ATen/RegisterCompositeImplicitAutograd.cpp
|
||||
CompositeImplicitAutograd, // registered at
|
||||
// build/aten/src/ATen/RegisterCompositeImplicitAutograd.cpp
|
||||
CompositeExplicitAutograd, // registered at
|
||||
// build/aten/src/ATen/RegisterCompositeExplicitAutograd.cpp
|
||||
// build/aten/src/ATen/RegisterCompositeExplicitAutograd.cpp
|
||||
|
||||
// Define an alias key to represent end of alias dispatch keys.
|
||||
// If you add new alias keys after Autograd, please also update it here.
|
||||
|
|
@ -316,15 +319,15 @@ enum class DispatchKey : uint8_t {
|
|||
// We provide two classes of private user tensor id: regular DispatchKeys
|
||||
// and Autograd DispatchKeys. DispatchKeys serve the role of ordinary "backend"
|
||||
// DispatchKeys; if you were adding support for a new type of accelerator, you
|
||||
// would use a backend DispatchKey, and ideally automatically reuse AutogradOther
|
||||
// definitions already defined in PyTorch. AutogradPrivateUse DispatchKeys serve
|
||||
// as "wrapper" DispatchKeys: they are only necessary for tensors that compose
|
||||
// multiple internal tensors, and for cases when the built-in autograd formulas
|
||||
// for operators are not appropriate.
|
||||
// would use a backend DispatchKey, and ideally automatically reuse
|
||||
// AutogradOther definitions already defined in PyTorch. AutogradPrivateUse
|
||||
// DispatchKeys serve as "wrapper" DispatchKeys: they are only necessary for
|
||||
// tensors that compose multiple internal tensors, and for cases when the
|
||||
// built-in autograd formulas for operators are not appropriate.
|
||||
|
||||
static_assert(
|
||||
static_cast<uint8_t>(DispatchKey::NumDispatchKeys) < 64,
|
||||
"DispatchKey is used as index into 64-bit bitmask; you must have less than 64 entries");
|
||||
static_cast<uint8_t>(DispatchKey::NumDispatchKeys) < 64,
|
||||
"DispatchKey is used as index into 64-bit bitmask; you must have less than 64 entries");
|
||||
|
||||
C10_API const char* toString(DispatchKey);
|
||||
C10_API std::ostream& operator<<(std::ostream&, DispatchKey);
|
||||
|
|
@ -345,10 +348,10 @@ inline bool isAliasDispatchKey(DispatchKey k) {
|
|||
} // namespace c10
|
||||
|
||||
namespace torch {
|
||||
// Expose the constant, but not the TYPE (DispatchKey is an implementation
|
||||
// detail!)
|
||||
using c10::kAutograd;
|
||||
}
|
||||
// Expose the constant, but not the TYPE (DispatchKey is an implementation
|
||||
// detail!)
|
||||
using c10::kAutograd;
|
||||
} // namespace torch
|
||||
|
||||
// NB: You really shouldn't use this instance; this enum is guaranteed
|
||||
// to be pretty small so a regular array should be acceptable.
|
||||
|
|
@ -362,4 +365,4 @@ struct hash<c10::DispatchKey> {
|
|||
return static_cast<size_t>(x);
|
||||
}
|
||||
};
|
||||
}
|
||||
} // namespace std
|
||||
|
|
|
|||
|
|
@ -3,9 +3,10 @@
|
|||
namespace c10 {
|
||||
|
||||
// backend_dispatch_keyset should include all runtime backend keys.
|
||||
// Alias key DispatchKey::CompositeExplicitAutograd maps to backend_dispatch_keyset
|
||||
// NestedTensor has been explicitly removed due to incompatibility with some
|
||||
// kernels, such as structured kernels, that use the DefaultBackend key.
|
||||
// Alias key DispatchKey::CompositeExplicitAutograd maps to
|
||||
// backend_dispatch_keyset NestedTensor has been explicitly removed due to
|
||||
// incompatibility with some kernels, such as structured kernels, that use the
|
||||
// DefaultBackend key.
|
||||
constexpr DispatchKeySet backend_dispatch_keyset = autogradother_backends |
|
||||
DispatchKeySet({
|
||||
DispatchKey::CPU,
|
||||
|
|
@ -23,9 +24,11 @@ bool isBackendDispatchKey(DispatchKey t) {
|
|||
return t != DispatchKey::Undefined && backend_dispatch_keyset.has(t);
|
||||
}
|
||||
|
||||
// math_dispatch_keyset contains all keys in backend_dispatch_keyset and autograd_dispatch_keyset
|
||||
// Alias key DispatchKey::CompositeImplicitAutograd maps to math_dispatch_keyset.
|
||||
constexpr DispatchKeySet math_dispatch_keyset = backend_dispatch_keyset | autograd_dispatch_keyset;
|
||||
// math_dispatch_keyset contains all keys in backend_dispatch_keyset and
|
||||
// autograd_dispatch_keyset Alias key DispatchKey::CompositeImplicitAutograd
|
||||
// maps to math_dispatch_keyset.
|
||||
constexpr DispatchKeySet math_dispatch_keyset =
|
||||
backend_dispatch_keyset | autograd_dispatch_keyset;
|
||||
|
||||
DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t) {
|
||||
TORCH_INTERNAL_ASSERT(t != DispatchKey::Undefined);
|
||||
|
|
@ -41,8 +44,8 @@ DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t) {
|
|||
}
|
||||
}
|
||||
|
||||
// for a given autograd key, return the (guaranteed nonempty) set of associated backend keys.
|
||||
// for a non-autograd key, return the empty keyset.
|
||||
// for a given autograd key, return the (guaranteed nonempty) set of associated
|
||||
// backend keys. for a non-autograd key, return the empty keyset.
|
||||
DispatchKeySet getBackendKeySetFromAutograd(DispatchKey t) {
|
||||
switch (t) {
|
||||
case DispatchKey::AutogradCPU:
|
||||
|
|
@ -72,7 +75,7 @@ DispatchKeySet getBackendKeySetFromAutograd(DispatchKey t) {
|
|||
|
||||
DispatchKeySet getAutocastRelatedKeySetFromBackend(DispatchKey t) {
|
||||
switch (t) {
|
||||
//case DispatchKey::CPU:
|
||||
// case DispatchKey::CPU:
|
||||
// return DispatchKeySet(DispatchKey::AutocastCPU);
|
||||
case DispatchKey::CUDA:
|
||||
return DispatchKeySet(DispatchKey::AutocastCUDA);
|
||||
|
|
@ -82,8 +85,8 @@ DispatchKeySet getAutocastRelatedKeySetFromBackend(DispatchKey t) {
|
|||
}
|
||||
|
||||
DispatchKeySet getAutogradRelatedKeySetFromBackend(DispatchKey t) {
|
||||
return DispatchKeySet({
|
||||
DispatchKey::InplaceOrView, getAutogradKeyFromBackend(t)});
|
||||
return DispatchKeySet(
|
||||
{DispatchKey::InplaceOrView, getAutogradKeyFromBackend(t)});
|
||||
}
|
||||
|
||||
bool isIncludedInAlias(DispatchKey k, DispatchKey alias) {
|
||||
|
|
@ -116,4 +119,4 @@ std::ostream& operator<<(std::ostream& os, DispatchKeySet ts) {
|
|||
return os;
|
||||
}
|
||||
|
||||
}
|
||||
} // namespace c10
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
#pragma once
|
||||
|
||||
#include <c10/core/DispatchKey.h>
|
||||
#include <c10/util/llvmMathExtras.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/Metaprogramming.h>
|
||||
#include <c10/util/llvmMathExtras.h>
|
||||
#include <ostream>
|
||||
|
||||
namespace c10 {
|
||||
|
|
@ -33,30 +33,29 @@ namespace c10 {
|
|||
//
|
||||
// An undefined tensor is one with an empty tensor type set.
|
||||
class DispatchKeySet final {
|
||||
public:
|
||||
public:
|
||||
enum Full { FULL };
|
||||
enum FullAfter { FULL_AFTER };
|
||||
enum Raw { RAW };
|
||||
|
||||
// NB: default constructor representation as zero is MANDATORY as
|
||||
// use of DispatchKeySet in TLS requires this.
|
||||
constexpr DispatchKeySet()
|
||||
: repr_(0) {}
|
||||
constexpr DispatchKeySet() : repr_(0) {}
|
||||
constexpr DispatchKeySet(Full)
|
||||
: repr_(std::numeric_limits<decltype(repr_)>::max()) {}
|
||||
: repr_(std::numeric_limits<decltype(repr_)>::max()) {}
|
||||
constexpr DispatchKeySet(FullAfter, DispatchKey t)
|
||||
// LSB after t are OK, but not t itself.
|
||||
: repr_((1ULL << (static_cast<uint8_t>(t) - 1)) - 1) {}
|
||||
// LSB after t are OK, but not t itself.
|
||||
: repr_((1ULL << (static_cast<uint8_t>(t) - 1)) - 1) {}
|
||||
// Public version of DispatchKeySet(uint64_t) API; external users
|
||||
// must be explicit when they do this!
|
||||
constexpr DispatchKeySet(Raw, uint64_t x)
|
||||
: repr_(x) {}
|
||||
constexpr DispatchKeySet(Raw, uint64_t x) : repr_(x) {}
|
||||
explicit constexpr DispatchKeySet(DispatchKey t)
|
||||
: repr_(t == DispatchKey::Undefined
|
||||
? 0
|
||||
: 1ULL << (static_cast<uint8_t>(t) - 1)) {}
|
||||
: repr_(
|
||||
t == DispatchKey::Undefined
|
||||
? 0
|
||||
: 1ULL << (static_cast<uint8_t>(t) - 1)) {}
|
||||
explicit constexpr DispatchKeySet(std::initializer_list<DispatchKey> ks)
|
||||
: repr_(0) {
|
||||
: repr_(0) {
|
||||
for (auto k : ks) {
|
||||
repr_ |= DispatchKeySet(k).repr_;
|
||||
}
|
||||
|
|
@ -105,7 +104,9 @@ public:
|
|||
bool empty() const {
|
||||
return repr_ == 0;
|
||||
}
|
||||
uint64_t raw_repr() { return repr_; }
|
||||
uint64_t raw_repr() {
|
||||
return repr_;
|
||||
}
|
||||
// Return the type id in this set with the highest priority (i.e.,
|
||||
// is the largest in the DispatchKey enum). Intuitively, this
|
||||
// type id is the one that should handle dispatch (assuming there
|
||||
|
|
@ -119,18 +120,20 @@ public:
|
|||
}
|
||||
|
||||
DispatchKey highestPriorityBackendTypeId() const {
|
||||
return (*this & ((1ULL << static_cast<uint8_t>(DispatchKey::EndOfBackendKeys)) - 1))
|
||||
.highestPriorityTypeId();
|
||||
return (*this &
|
||||
((1ULL << static_cast<uint8_t>(DispatchKey::EndOfBackendKeys)) - 1))
|
||||
.highestPriorityTypeId();
|
||||
}
|
||||
private:
|
||||
|
||||
private:
|
||||
constexpr DispatchKeySet(uint64_t repr) : repr_(repr) {}
|
||||
uint64_t repr_ = 0;
|
||||
|
||||
public:
|
||||
public:
|
||||
// STL iterator for DispatchKeySet. Iterates through all DispatchKeys in the
|
||||
// set. The iterator is only invalidated by the destruction of the underlying
|
||||
// DispatchKeySet as the iterator stores a pointer to the raw representation of
|
||||
// the DispatchKeySet.
|
||||
// DispatchKeySet as the iterator stores a pointer to the raw representation
|
||||
// of the DispatchKeySet.
|
||||
class iterator {
|
||||
public:
|
||||
using self_type = iterator;
|
||||
|
|
@ -138,13 +141,15 @@ public:
|
|||
using value_type = DispatchKey;
|
||||
using difference_type = ptrdiff_t;
|
||||
|
||||
explicit iterator(const uint64_t *data_ptr, uint8_t i=0) : data_ptr_(data_ptr), i_(i) {
|
||||
explicit iterator(const uint64_t* data_ptr, uint8_t i = 0)
|
||||
: data_ptr_(data_ptr), i_(i) {
|
||||
// Go to the first key in the set
|
||||
++(*this);
|
||||
}
|
||||
|
||||
self_type& operator++() {
|
||||
TORCH_INTERNAL_ASSERT(i_ <= static_cast<uint8_t>(DispatchKey::NumDispatchKeys));
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
i_ <= static_cast<uint8_t>(DispatchKey::NumDispatchKeys));
|
||||
|
||||
// Create a masked version of the set representation to ignore previous
|
||||
// keys that we've iterated through.
|
||||
|
|
@ -153,7 +158,7 @@ public:
|
|||
|
||||
// If there are no keys, set to end iterator value
|
||||
if (firstKeyIndex == std::numeric_limits<uint64_t>::max() ||
|
||||
i_ == static_cast<uint8_t>(DispatchKey::NumDispatchKeys)) {
|
||||
i_ == static_cast<uint8_t>(DispatchKey::NumDispatchKeys)) {
|
||||
i_ = static_cast<uint8_t>(DispatchKey::NumDispatchKeys);
|
||||
return *this;
|
||||
}
|
||||
|
|
@ -163,29 +168,38 @@ public:
|
|||
}
|
||||
|
||||
self_type operator++(int) {
|
||||
self_type previous_iterator = *this;
|
||||
self_type previous_iterator = *this;
|
||||
++(*this);
|
||||
return previous_iterator;
|
||||
}
|
||||
|
||||
bool operator==(const self_type& rhs) const { return i_ == rhs.i_; }
|
||||
bool operator!=(const self_type& rhs) const { return i_ != rhs.i_; }
|
||||
DispatchKey operator*() const { return static_cast<DispatchKey> (i_); }
|
||||
bool operator==(const self_type& rhs) const {
|
||||
return i_ == rhs.i_;
|
||||
}
|
||||
bool operator!=(const self_type& rhs) const {
|
||||
return i_ != rhs.i_;
|
||||
}
|
||||
DispatchKey operator*() const {
|
||||
return static_cast<DispatchKey>(i_);
|
||||
}
|
||||
|
||||
private:
|
||||
const uint64_t *data_ptr_;
|
||||
const uint64_t* data_ptr_;
|
||||
uint8_t i_;
|
||||
};
|
||||
|
||||
public:
|
||||
// Returns iterator to the first key in the set. If no keys are in the
|
||||
// set, then will return the end iterator.
|
||||
iterator begin() const { return iterator(&repr_); }
|
||||
iterator begin() const {
|
||||
return iterator(&repr_);
|
||||
}
|
||||
|
||||
// We do not need to iterate beyond NumDispatchKeys so we will treat this as
|
||||
// the end iterator. NumDispatchKeys will always be strictly less than 64.
|
||||
iterator end() const { return iterator(&repr_, static_cast<uint8_t>(DispatchKey::NumDispatchKeys)); }
|
||||
|
||||
iterator end() const {
|
||||
return iterator(&repr_, static_cast<uint8_t>(DispatchKey::NumDispatchKeys));
|
||||
}
|
||||
};
|
||||
|
||||
C10_API std::string toString(DispatchKeySet);
|
||||
|
|
@ -208,7 +222,7 @@ constexpr DispatchKeySet autograd_dispatch_keyset = DispatchKeySet({
|
|||
});
|
||||
|
||||
constexpr DispatchKeySet autocast_dispatch_keyset = DispatchKeySet({
|
||||
//DispatchKey::AutocastCPU,
|
||||
// DispatchKey::AutocastCPU,
|
||||
DispatchKey::AutocastCUDA,
|
||||
});
|
||||
|
||||
|
|
@ -224,40 +238,36 @@ constexpr DispatchKeySet default_excluded_set = DispatchKeySet({
|
|||
});
|
||||
|
||||
constexpr DispatchKeySet autograd_dispatch_keyset_with_InplaceOrView =
|
||||
autograd_dispatch_keyset | DispatchKeySet(DispatchKey::InplaceOrView);
|
||||
autograd_dispatch_keyset | DispatchKeySet(DispatchKey::InplaceOrView);
|
||||
|
||||
// backend dispatch keys that map to DispatchKey::AutogradOther
|
||||
// NB: keys in this set also get associated with CompositeImplicitAutograd
|
||||
constexpr DispatchKeySet autogradother_backends = DispatchKeySet({
|
||||
DispatchKey::HIP,
|
||||
DispatchKey::FPGA,
|
||||
DispatchKey::MSNPU,
|
||||
DispatchKey::Vulkan,
|
||||
DispatchKey::Metal,
|
||||
DispatchKey::QuantizedCPU,
|
||||
DispatchKey::QuantizedCUDA,
|
||||
DispatchKey::CustomRNGKeyId,
|
||||
DispatchKey::MkldnnCPU,
|
||||
DispatchKey::SparseCPU,
|
||||
DispatchKey::SparseCUDA,
|
||||
DispatchKey::SparseHIP,
|
||||
DispatchKey::SparseCsrCPU,
|
||||
DispatchKey::SparseCsrCUDA,
|
||||
DispatchKey::Meta
|
||||
});
|
||||
constexpr DispatchKeySet autogradother_backends = DispatchKeySet(
|
||||
{DispatchKey::HIP,
|
||||
DispatchKey::FPGA,
|
||||
DispatchKey::MSNPU,
|
||||
DispatchKey::Vulkan,
|
||||
DispatchKey::Metal,
|
||||
DispatchKey::QuantizedCPU,
|
||||
DispatchKey::QuantizedCUDA,
|
||||
DispatchKey::CustomRNGKeyId,
|
||||
DispatchKey::MkldnnCPU,
|
||||
DispatchKey::SparseCPU,
|
||||
DispatchKey::SparseCUDA,
|
||||
DispatchKey::SparseHIP,
|
||||
DispatchKey::SparseCsrCPU,
|
||||
DispatchKey::SparseCsrCUDA,
|
||||
DispatchKey::Meta});
|
||||
|
||||
// The set of dispatch keys that come after autograd
|
||||
// n.b. this relies on the fact that AutogradOther is currently the lowest Autograd key
|
||||
constexpr DispatchKeySet after_autograd_keyset = DispatchKeySet(
|
||||
DispatchKeySet::FULL_AFTER,
|
||||
c10::DispatchKey::AutogradOther
|
||||
);
|
||||
// n.b. this relies on the fact that AutogradOther is currently the lowest
|
||||
// Autograd key
|
||||
constexpr DispatchKeySet after_autograd_keyset =
|
||||
DispatchKeySet(DispatchKeySet::FULL_AFTER, c10::DispatchKey::AutogradOther);
|
||||
|
||||
// The set of dispatch keys that come after InplaceOrView
|
||||
constexpr DispatchKeySet after_InplaceOrView_keyset = DispatchKeySet(
|
||||
DispatchKeySet::FULL_AFTER,
|
||||
c10::DispatchKey::InplaceOrView
|
||||
);
|
||||
constexpr DispatchKeySet after_InplaceOrView_keyset =
|
||||
DispatchKeySet(DispatchKeySet::FULL_AFTER, c10::DispatchKey::InplaceOrView);
|
||||
|
||||
// true if t is a backend dispatch key
|
||||
C10_API bool isBackendDispatchKey(DispatchKey t);
|
||||
|
|
@ -265,8 +275,8 @@ C10_API bool isBackendDispatchKey(DispatchKey t);
|
|||
// Resolve alias dispatch key to DispatchKeySet if applicable
|
||||
C10_API DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t);
|
||||
|
||||
// Returns a DispatchKeySet of all backend keys mapped to Autograd dispatch key t,
|
||||
// DispatchKeySet is empty if t is not alias of DispatchKey::Autograd.
|
||||
// Returns a DispatchKeySet of all backend keys mapped to Autograd dispatch key
|
||||
// t, DispatchKeySet is empty if t is not alias of DispatchKey::Autograd.
|
||||
C10_API DispatchKeySet getBackendKeySetFromAutograd(DispatchKey t);
|
||||
|
||||
// Returns a DispatchKeySet of autograd related keys mapped to backend.
|
||||
|
|
@ -291,27 +301,32 @@ static inline DispatchKey legacyExtractDispatchKey(DispatchKeySet s) {
|
|||
// top of existing "backend" keys like CPU/CUDA, you need to add it
|
||||
// here. At the moment, autograd keys and InplaceOrView key need this
|
||||
// treatment;
|
||||
return (s - autograd_dispatch_keyset_with_InplaceOrView - autocast_dispatch_keyset).highestPriorityTypeId();
|
||||
return (s - autograd_dispatch_keyset_with_InplaceOrView -
|
||||
autocast_dispatch_keyset)
|
||||
.highestPriorityTypeId();
|
||||
}
|
||||
|
||||
template<class T>
|
||||
template <class T>
|
||||
using is_not_DispatchKeySet = guts::negation<std::is_same<DispatchKeySet, T>>;
|
||||
|
||||
// Given a function type, constructs a function_traits type that drops the first parameter
|
||||
// type if the first parameter is of type DispatchKeySet.
|
||||
// NB: DispatchKeySet is currently explicitly hidden from JIT (mainly to avoid pushing unnecessary
|
||||
// arguments on the stack - see Note [ Plumbing Keys Through the Dispatcher] for details).
|
||||
// If at any point in the future we need to expose this type to JIT, revisit the usage of this type alias.
|
||||
// Given a function type, constructs a function_traits type that drops the first
|
||||
// parameter type if the first parameter is of type DispatchKeySet. NB:
|
||||
// DispatchKeySet is currently explicitly hidden from JIT (mainly to avoid
|
||||
// pushing unnecessary arguments on the stack - see Note [ Plumbing Keys Through
|
||||
// the Dispatcher] for details). If at any point in the future we need to expose
|
||||
// this type to JIT, revisit the usage of this type alias.
|
||||
template <class FuncType>
|
||||
using remove_DispatchKeySet_arg_from_func = guts::make_function_traits_t<
|
||||
typename guts::infer_function_traits_t<FuncType>::return_type,
|
||||
typename std::conditional_t<
|
||||
std::is_same<
|
||||
DispatchKeySet,
|
||||
typename guts::typelist::head_with_default_t<void, typename guts::infer_function_traits_t<FuncType>::parameter_types>
|
||||
>::value,
|
||||
guts::typelist::drop_if_nonempty_t<typename guts::infer_function_traits_t<FuncType>::parameter_types, 1>,
|
||||
typename guts::infer_function_traits_t<FuncType>::parameter_types
|
||||
>
|
||||
>;
|
||||
}
|
||||
typename guts::infer_function_traits_t<FuncType>::return_type,
|
||||
typename std::conditional_t<
|
||||
std::is_same<
|
||||
DispatchKeySet,
|
||||
typename guts::typelist::head_with_default_t<
|
||||
void,
|
||||
typename guts::infer_function_traits_t<
|
||||
FuncType>::parameter_types>>::value,
|
||||
guts::typelist::drop_if_nonempty_t<
|
||||
typename guts::infer_function_traits_t<FuncType>::parameter_types,
|
||||
1>,
|
||||
typename guts::infer_function_traits_t<FuncType>::parameter_types>>;
|
||||
} // namespace c10
|
||||
|
|
|
|||
|
|
@ -41,18 +41,18 @@ struct Event final {
|
|||
// Constructors
|
||||
Event() = delete;
|
||||
Event(
|
||||
const DeviceType _device_type,
|
||||
const EventFlag _flag = EventFlag::PYTORCH_DEFAULT)
|
||||
: impl_{_device_type, _flag} { }
|
||||
const DeviceType _device_type,
|
||||
const EventFlag _flag = EventFlag::PYTORCH_DEFAULT)
|
||||
: impl_{_device_type, _flag} {}
|
||||
|
||||
// Copy constructor and copy assignment operator (deleted)
|
||||
Event(const Event&) = delete;
|
||||
Event& operator=(const Event&) = delete;
|
||||
|
||||
// Move constructor and move assignment operator
|
||||
Event(Event&& other) : impl_{std::move(other.impl_)} { }
|
||||
Event(Event&& other) : impl_{std::move(other.impl_)} {}
|
||||
Event& operator=(Event&& other) {
|
||||
impl_.swap(std::move(other.impl_));
|
||||
impl_.swap(std::move(other.impl_));
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
|
@ -60,54 +60,62 @@ struct Event final {
|
|||
~Event() = default;
|
||||
|
||||
// Getters
|
||||
DeviceType device_type() const noexcept { return impl_.device_type(); }
|
||||
DeviceIndex device_index() const noexcept { return impl_.device_index(); }
|
||||
EventFlag flag() const noexcept { return impl_.flag(); }
|
||||
bool was_marked_for_recording() const noexcept { return impl_.was_marked_for_recording(); }
|
||||
DeviceType device_type() const noexcept {
|
||||
return impl_.device_type();
|
||||
}
|
||||
DeviceIndex device_index() const noexcept {
|
||||
return impl_.device_index();
|
||||
}
|
||||
EventFlag flag() const noexcept {
|
||||
return impl_.flag();
|
||||
}
|
||||
bool was_marked_for_recording() const noexcept {
|
||||
return impl_.was_marked_for_recording();
|
||||
}
|
||||
|
||||
/**
|
||||
* Calls record() if and only if record() has never been called for this event.
|
||||
* Note: because Event is not thread-safe recordOnce() may call record()
|
||||
* multiple times if called from multiple threads.
|
||||
*/
|
||||
/**
|
||||
* Calls record() if and only if record() has never been called for this
|
||||
* event. Note: because Event is not thread-safe recordOnce() may call
|
||||
* record() multiple times if called from multiple threads.
|
||||
*/
|
||||
void recordOnce(const Stream& stream) {
|
||||
impl_.recordOnce(stream);
|
||||
}
|
||||
|
||||
/**
|
||||
* Increments the event's version and enqueues a job with this version
|
||||
* in the stream's work queue. When the stream process that job
|
||||
* it nofifies all streams waiting on / blocked by that version of the
|
||||
* event to continue and marks that version as recorded.
|
||||
* */
|
||||
/**
|
||||
* Increments the event's version and enqueues a job with this version
|
||||
* in the stream's work queue. When the stream process that job
|
||||
* it nofifies all streams waiting on / blocked by that version of the
|
||||
* event to continue and marks that version as recorded.
|
||||
* */
|
||||
void record(const Stream& stream) {
|
||||
impl_.record(stream);
|
||||
}
|
||||
|
||||
/**
|
||||
* Does nothing if the event has not been scheduled to be recorded.
|
||||
* If the event was previously enqueued to be recorded, a command
|
||||
* to wait for the version of the event that exists at the time of this call
|
||||
* is inserted in the stream's work queue.
|
||||
* When the stream reaches this command it will stop processing
|
||||
* additional commands until that version of the event is marked as recorded.
|
||||
*/
|
||||
/**
|
||||
* Does nothing if the event has not been scheduled to be recorded.
|
||||
* If the event was previously enqueued to be recorded, a command
|
||||
* to wait for the version of the event that exists at the time of this call
|
||||
* is inserted in the stream's work queue.
|
||||
* When the stream reaches this command it will stop processing
|
||||
* additional commands until that version of the event is marked as recorded.
|
||||
*/
|
||||
void block(const Stream& stream) const {
|
||||
impl_.block(stream);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns true if (and only if)
|
||||
* (1) the event has never been scheduled to be recorded
|
||||
* (2) the current version is marked as recorded.
|
||||
* Returns false otherwise.
|
||||
*/
|
||||
/**
|
||||
* Returns true if (and only if)
|
||||
* (1) the event has never been scheduled to be recorded
|
||||
* (2) the current version is marked as recorded.
|
||||
* Returns false otherwise.
|
||||
*/
|
||||
bool query() const {
|
||||
return impl_.query();
|
||||
}
|
||||
|
||||
private:
|
||||
private:
|
||||
impl::InlineEvent<impl::VirtualGuardImpl> impl_;
|
||||
};
|
||||
|
||||
} // c10
|
||||
} // namespace c10
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ namespace c10 {
|
|||
* GeneratorImpl class implementation
|
||||
*/
|
||||
GeneratorImpl::GeneratorImpl(Device device_in, DispatchKeySet key_set)
|
||||
: device_{device_in}, key_set_(key_set) {}
|
||||
: device_{device_in}, key_set_(key_set) {}
|
||||
|
||||
/**
|
||||
* Clone this generator. Note that clone() is the only
|
||||
|
|
@ -40,14 +40,15 @@ namespace detail {
|
|||
* FIXME: use std::random_device with entropy information
|
||||
*/
|
||||
#ifndef _WIN32
|
||||
static uint64_t readURandomLong()
|
||||
{
|
||||
static uint64_t readURandomLong() {
|
||||
int randDev = open("/dev/urandom", O_RDONLY);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
uint64_t randValue;
|
||||
TORCH_CHECK(randDev >= 0, "Unable to open /dev/urandom");
|
||||
ssize_t readBytes = read(randDev, &randValue, sizeof(randValue));
|
||||
TORCH_CHECK(readBytes >= (ssize_t) sizeof(randValue), "Unable to read from /dev/urandom");
|
||||
TORCH_CHECK(
|
||||
readBytes >= (ssize_t)sizeof(randValue),
|
||||
"Unable to read from /dev/urandom");
|
||||
close(randDev);
|
||||
return randValue;
|
||||
}
|
||||
|
|
@ -58,11 +59,11 @@ static uint64_t readURandomLong()
|
|||
* /dev/urandom or the current time. For CUDA, gets random from
|
||||
* std::random_device and adds a transformation on it.
|
||||
*
|
||||
* FIXME: The behavior in this function is from legacy code (THRandom_seed/THCRandom_seed)
|
||||
* and is probably not the right thing to do, even though our tests pass.
|
||||
* Figure out if tests get perturbed
|
||||
* - when the same algorithm is used for all backends. Note that the current behavior is
|
||||
* different for CPU, CUDA and Windows CPU.
|
||||
* FIXME: The behavior in this function is from legacy code
|
||||
* (THRandom_seed/THCRandom_seed) and is probably not the right thing to do,
|
||||
* even though our tests pass. Figure out if tests get perturbed
|
||||
* - when the same algorithm is used for all backends. Note that the current
|
||||
* behavior is different for CPU, CUDA and Windows CPU.
|
||||
* - when using C++11 std objects, such as std::random_device
|
||||
* - when constructing a 64 bit seed properly, rather than static casting
|
||||
* a 32 bit number to 64 bit.
|
||||
|
|
@ -71,11 +72,13 @@ uint64_t getNonDeterministicRandom(bool is_cuda) {
|
|||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
uint64_t s;
|
||||
if (!is_cuda) {
|
||||
#ifdef _WIN32
|
||||
s = (uint64_t)std::chrono::high_resolution_clock::now().time_since_epoch().count();
|
||||
#else
|
||||
s = readURandomLong();
|
||||
#endif
|
||||
#ifdef _WIN32
|
||||
s = (uint64_t)std::chrono::high_resolution_clock::now()
|
||||
.time_since_epoch()
|
||||
.count();
|
||||
#else
|
||||
s = readURandomLong();
|
||||
#endif
|
||||
} else {
|
||||
std::random_device rd;
|
||||
// limit to 53 bits to ensure unique representation in double
|
||||
|
|
|
|||
|
|
@ -1,52 +1,54 @@
|
|||
#pragma once
|
||||
|
||||
#include <stdint.h>
|
||||
#include <mutex>
|
||||
#include <deque>
|
||||
#include <atomic>
|
||||
#include <deque>
|
||||
#include <mutex>
|
||||
#include <typeinfo>
|
||||
#include <utility>
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/C++17.h>
|
||||
#include <c10/util/intrusive_ptr.h>
|
||||
#include <c10/core/Device.h>
|
||||
#include <c10/core/DispatchKeySet.h>
|
||||
#include <c10/util/python_stub.h>
|
||||
#include <c10/core/TensorImpl.h>
|
||||
#include <c10/util/C++17.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/intrusive_ptr.h>
|
||||
#include <c10/util/python_stub.h>
|
||||
|
||||
/**
|
||||
* Note [Generator]
|
||||
* ~~~~~~~~~~~~~~~~
|
||||
* A Pseudo Random Number Generator (PRNG) is an engine that uses an algorithm to
|
||||
* generate a seemingly random sequence of numbers, that may be later be used in creating
|
||||
* a random distribution. Such an engine almost always maintains a state and requires a
|
||||
* seed to start off the creation of random numbers. Often times, users have
|
||||
* found it beneficial to be able to explicitly create, retain, and destroy
|
||||
* PRNG states and also be able to have control over the seed value.
|
||||
* A Pseudo Random Number Generator (PRNG) is an engine that uses an algorithm
|
||||
* to generate a seemingly random sequence of numbers, that may be later be used
|
||||
* in creating a random distribution. Such an engine almost always maintains a
|
||||
* state and requires a seed to start off the creation of random numbers. Often
|
||||
* times, users have found it beneficial to be able to explicitly create,
|
||||
* retain, and destroy PRNG states and also be able to have control over the
|
||||
* seed value.
|
||||
*
|
||||
* A Generator in ATen gives users the ability to read, write and modify a PRNG engine.
|
||||
* For instance, it does so by letting users seed a PRNG engine, fork the state of the
|
||||
* engine, etc.
|
||||
* A Generator in ATen gives users the ability to read, write and modify a PRNG
|
||||
* engine. For instance, it does so by letting users seed a PRNG engine, fork
|
||||
* the state of the engine, etc.
|
||||
*
|
||||
* By default, there is one generator per device, and a device's generator is
|
||||
* lazily created. A user can use the torch.Generator() api to create their own generator.
|
||||
* Currently torch.Generator() can only create a CPUGeneratorImpl.
|
||||
* lazily created. A user can use the torch.Generator() api to create their own
|
||||
* generator. Currently torch.Generator() can only create a CPUGeneratorImpl.
|
||||
*/
|
||||
|
||||
/**
|
||||
* Note [Acquire lock when using random generators]
|
||||
* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
* Generator and its derived classes are NOT thread-safe. Please note that most of the
|
||||
* places where we have inserted locking for generators are historically based, and we
|
||||
* haven't actually checked that everything is truly thread safe (and it probably isn't).
|
||||
* Please use the public mutex_ when using any methods from these classes, except for the
|
||||
* read-only methods. You can learn about the usage by looking into the unittests
|
||||
* (aten/src/ATen/cpu_generator_test.cpp) and other places where we have used lock_guard.
|
||||
* Generator and its derived classes are NOT thread-safe. Please note that most
|
||||
* of the places where we have inserted locking for generators are historically
|
||||
* based, and we haven't actually checked that everything is truly thread safe
|
||||
* (and it probably isn't). Please use the public mutex_ when using any methods
|
||||
* from these classes, except for the read-only methods. You can learn about the
|
||||
* usage by looking into the unittests (aten/src/ATen/cpu_generator_test.cpp)
|
||||
* and other places where we have used lock_guard.
|
||||
*
|
||||
* TODO: Look into changing the threading semantics of Generators in ATen (e.g., making
|
||||
* them non-thread safe and instead making the generator state splittable, to accommodate
|
||||
* forks into other threads).
|
||||
* TODO: Look into changing the threading semantics of Generators in ATen (e.g.,
|
||||
* making them non-thread safe and instead making the generator state
|
||||
* splittable, to accommodate forks into other threads).
|
||||
*/
|
||||
|
||||
namespace c10 {
|
||||
|
|
@ -79,7 +81,9 @@ struct C10_API GeneratorImpl : public c10::intrusive_ptr_target {
|
|||
// See Note [Acquire lock when using random generators]
|
||||
std::mutex mutex_;
|
||||
|
||||
DispatchKeySet key_set() const { return key_set_; }
|
||||
DispatchKeySet key_set() const {
|
||||
return key_set_;
|
||||
}
|
||||
|
||||
inline void set_pyobj(PyObject* pyobj) noexcept {
|
||||
pyobj_ = pyobj;
|
||||
|
|
@ -89,12 +93,12 @@ struct C10_API GeneratorImpl : public c10::intrusive_ptr_target {
|
|||
return pyobj_;
|
||||
}
|
||||
|
||||
protected:
|
||||
Device device_;
|
||||
DispatchKeySet key_set_;
|
||||
PyObject* pyobj_ = nullptr;
|
||||
protected:
|
||||
Device device_;
|
||||
DispatchKeySet key_set_;
|
||||
PyObject* pyobj_ = nullptr;
|
||||
|
||||
virtual GeneratorImpl* clone_impl() const = 0;
|
||||
virtual GeneratorImpl* clone_impl() const = 0;
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
|
|
|
|||
|
|
@ -27,4 +27,4 @@ struct TORCH_API NoGradGuard : public AutoGradMode {
|
|||
NoGradGuard() : AutoGradMode(/*enabled=*/false) {}
|
||||
};
|
||||
|
||||
}
|
||||
} // namespace c10
|
||||
|
|
|
|||
|
|
@ -6,7 +6,8 @@ namespace c10 {
|
|||
thread_local bool InferenceMode_enabled = false;
|
||||
|
||||
// Invariant:
|
||||
// is_enabled() == !c10::impl::tls_is_dispatch_key_included(DispatchKey::InplaceOrView);
|
||||
// is_enabled() ==
|
||||
// !c10::impl::tls_is_dispatch_key_included(DispatchKey::InplaceOrView);
|
||||
// InferenceMode::is_enabled() is in perf critical path (TensorImpl constructor)
|
||||
// so it worths a separate TLS to skip the DispatchKeySet check.
|
||||
bool InferenceMode::is_enabled() {
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
#pragma once
|
||||
|
||||
#include <c10/core/GradMode.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/core/impl/LocalDispatchKeySet.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
|
|
@ -10,13 +10,14 @@ namespace c10 {
|
|||
// construction, and sets it back to the original value upon destruction.
|
||||
struct TORCH_API InferenceMode {
|
||||
// Note [Expected TLS state in InferenceMode]:
|
||||
// InferenceMode: InplaceOrView not in raw_local_dispatch_key_set.included(),
|
||||
// InferenceMode: InplaceOrView not in
|
||||
// raw_local_dispatch_key_set.included(),
|
||||
// Autograd in raw_local_dispatch_key_set.excluded()
|
||||
// GradMode is disabled.
|
||||
// NormalMode: InplaceOrView in raw_local_dispatch_key_set.included(),
|
||||
// Autograd not in raw_local_dispatch_key_set.excluded()
|
||||
// GradMode is enabled by default unless toggled manually through
|
||||
// other APIs, e.g. NoGradGuard.
|
||||
// GradMode is enabled by default unless toggled manually
|
||||
// through other APIs, e.g. NoGradGuard.
|
||||
//
|
||||
// Invariant:
|
||||
// - InplaceOrView is never in the excluded set
|
||||
|
|
@ -25,8 +26,8 @@ struct TORCH_API InferenceMode {
|
|||
//
|
||||
// 1. Why do we put InplaceOrView in included set outside InferenceMode?
|
||||
//
|
||||
// Inplace update to inference tensor outside InferenceMode is not allowed.
|
||||
// See Note [Inplace update inference tensor] for more details.
|
||||
// Inplace update to inference tensor outside InferenceMode is not
|
||||
// allowed. See Note [Inplace update inference tensor] for more details.
|
||||
// Without going through InplaceOrView kernel, we cannot throw error
|
||||
// for `inference_tensor.add_(1)` case.
|
||||
//
|
||||
|
|
@ -48,14 +49,17 @@ struct TORCH_API InferenceMode {
|
|||
// version of NoGradGuard. All runtime checks using GradMode::is_enabled()
|
||||
// are applicable to InferenceMode as well, e.g.
|
||||
// `tensorTypeInCurrentExecutionContext` in interpreter.cpp.
|
||||
InferenceMode(bool enabled=true): prev_mode(InferenceMode::is_enabled()),
|
||||
prev_keyset(c10::impl::tls_local_dispatch_key_set()),
|
||||
grad_mode(at::AutoGradMode(!enabled)) {
|
||||
InferenceMode(bool enabled = true)
|
||||
: prev_mode(InferenceMode::is_enabled()),
|
||||
prev_keyset(c10::impl::tls_local_dispatch_key_set()),
|
||||
grad_mode(at::AutoGradMode(!enabled)) {
|
||||
set_enabled(enabled);
|
||||
DispatchKeySet included = enabled ? prev_keyset.included_.remove(c10::DispatchKey::InplaceOrView)
|
||||
: prev_keyset.included_.add(c10::DispatchKey::InplaceOrView);
|
||||
DispatchKeySet excluded = enabled ? (prev_keyset.excluded_ | c10::autograd_dispatch_keyset)
|
||||
: (prev_keyset.excluded_ - c10::autograd_dispatch_keyset);
|
||||
DispatchKeySet included = enabled
|
||||
? prev_keyset.included_.remove(c10::DispatchKey::InplaceOrView)
|
||||
: prev_keyset.included_.add(c10::DispatchKey::InplaceOrView);
|
||||
DispatchKeySet excluded = enabled
|
||||
? (prev_keyset.excluded_ | c10::autograd_dispatch_keyset)
|
||||
: (prev_keyset.excluded_ - c10::autograd_dispatch_keyset);
|
||||
c10::impl::PODLocalDispatchKeySet cur_keyset;
|
||||
cur_keyset.set_included(included);
|
||||
cur_keyset.set_excluded(excluded);
|
||||
|
|
@ -71,9 +75,9 @@ struct TORCH_API InferenceMode {
|
|||
// ThreadLocalState.cpp.
|
||||
static void set_enabled(bool enabled);
|
||||
|
||||
private:
|
||||
bool prev_mode;
|
||||
c10::impl::LocalDispatchKeySet prev_keyset;
|
||||
at::AutoGradMode grad_mode;
|
||||
private:
|
||||
bool prev_mode;
|
||||
c10::impl::LocalDispatchKeySet prev_keyset;
|
||||
at::AutoGradMode grad_mode;
|
||||
};
|
||||
} // namespace c10
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
#pragma once
|
||||
|
||||
#include <c10/core/Backend.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/ArrayRef.h>
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
#include <iostream>
|
||||
|
||||
|
|
@ -17,14 +17,20 @@
|
|||
// should be in channels_last format
|
||||
//
|
||||
// Contiguous:
|
||||
// Regardless of input tensors format, the output should be contiguous Tensor.
|
||||
// Regardless of input tensors format, the output should be contiguous
|
||||
// Tensor.
|
||||
//
|
||||
// ChannelsLast:
|
||||
// Regardless of input tensors format, the output should be in channels_last format.
|
||||
|
||||
// Regardless of input tensors format, the output should be in channels_last
|
||||
// format.
|
||||
|
||||
namespace c10 {
|
||||
enum class MemoryFormat : int8_t { Contiguous, Preserve, ChannelsLast, ChannelsLast3d };
|
||||
enum class MemoryFormat : int8_t {
|
||||
Contiguous,
|
||||
Preserve,
|
||||
ChannelsLast,
|
||||
ChannelsLast3d
|
||||
};
|
||||
|
||||
// If you are seeing this, it means that this call site was not checked if
|
||||
// the memory format could be preserved, and it was switched to old default
|
||||
|
|
@ -52,7 +58,8 @@ inline std::ostream& operator<<(
|
|||
}
|
||||
}
|
||||
|
||||
// Note: Hardcoded the channel last stride indices here to get better performance
|
||||
// Note: Hardcoded the channel last stride indices here to get better
|
||||
// performance
|
||||
inline std::vector<int64_t> get_channels_last_strides_2d(IntArrayRef sizes) {
|
||||
std::vector<int64_t> strides(sizes.size());
|
||||
switch (sizes.size()) {
|
||||
|
|
@ -68,7 +75,8 @@ inline std::vector<int64_t> get_channels_last_strides_2d(IntArrayRef sizes) {
|
|||
strides[1] = strides[2] * sizes[2];
|
||||
return strides;
|
||||
default:
|
||||
TORCH_INTERNAL_ASSERT(false, "ChannelsLast2d doesn't support size ", sizes.size());
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
false, "ChannelsLast2d doesn't support size ", sizes.size());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -89,7 +97,8 @@ inline std::vector<int64_t> get_channels_last_strides_3d(IntArrayRef sizes) {
|
|||
strides[1] = strides[2] * sizes[2];
|
||||
return strides;
|
||||
default:
|
||||
TORCH_INTERNAL_ASSERT(false, "ChannelsLast3d doesn't support size ", sizes.size());
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
false, "ChannelsLast3d doesn't support size ", sizes.size());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -100,12 +109,16 @@ inline std::vector<int64_t> get_channels_last_strides_3d(IntArrayRef sizes) {
|
|||
// will be a constant array and we can access it using constant index number,
|
||||
// the compiler will fully unroll the loop on strides indices to gain a better
|
||||
// performance.
|
||||
// 2. No error check in helper function, caller ensures the correctness of the input
|
||||
// 3. All helper functions have similar comments, only 1st helper function is commented here.
|
||||
inline bool is_channels_last_strides_2d_s4(const IntArrayRef sizes, const IntArrayRef strides) {
|
||||
// 2. No error check in helper function, caller ensures the correctness of the
|
||||
// input
|
||||
// 3. All helper functions have similar comments, only 1st helper function is
|
||||
// commented here.
|
||||
inline bool is_channels_last_strides_2d_s4(
|
||||
const IntArrayRef sizes,
|
||||
const IntArrayRef strides) {
|
||||
int64_t min = 0;
|
||||
// special case for trivial C dimension. default to NCHW
|
||||
if (strides[1]==0) {
|
||||
if (strides[1] == 0) {
|
||||
return false;
|
||||
}
|
||||
// loop strides indices
|
||||
|
|
@ -121,8 +134,9 @@ inline bool is_channels_last_strides_2d_s4(const IntArrayRef sizes, const IntArr
|
|||
// N111 tensor with identical strides for size 1 dimension;
|
||||
// Two cases could lead us here:
|
||||
// a. N111 contiguous Tensor ([N,1,1,1]@[1,1,1,1])
|
||||
// b. N11W contiguous Tensor sliced on the W-dimension. ([N,1,1,1]@[W,W,W,W])
|
||||
if (d==0 && min==strides[1]) {
|
||||
// b. N11W contiguous Tensor sliced on the W-dimension.
|
||||
// ([N,1,1,1]@[W,W,W,W])
|
||||
if (d == 0 && min == strides[1]) {
|
||||
return false;
|
||||
}
|
||||
// This is necessary to:
|
||||
|
|
@ -140,7 +154,9 @@ inline bool is_channels_last_strides_2d_s4(const IntArrayRef sizes, const IntArr
|
|||
return true;
|
||||
}
|
||||
|
||||
inline bool is_channels_last_strides_3d_s5(const IntArrayRef sizes, const IntArrayRef strides) {
|
||||
inline bool is_channels_last_strides_3d_s5(
|
||||
const IntArrayRef sizes,
|
||||
const IntArrayRef strides) {
|
||||
int64_t min = 0;
|
||||
if (strides[1] == 0) {
|
||||
return false;
|
||||
|
|
@ -209,10 +225,13 @@ inline bool is_channels_last_strides_3d_s5(const IntArrayRef sizes, const IntArr
|
|||
// issues in our tests.
|
||||
//
|
||||
// We use Channels Last 2d as an example above.
|
||||
// This is a general problem for all the is_channels_last_strides_xd implementation.
|
||||
// Please check the helper functions (is_channels_last_strides_*d_s*) for more details.
|
||||
// This is a general problem for all the is_channels_last_strides_xd
|
||||
// implementation. Please check the helper functions
|
||||
// (is_channels_last_strides_*d_s*) for more details.
|
||||
|
||||
inline bool is_channels_last_strides_2d(const IntArrayRef sizes, const IntArrayRef strides) {
|
||||
inline bool is_channels_last_strides_2d(
|
||||
const IntArrayRef sizes,
|
||||
const IntArrayRef strides) {
|
||||
switch (sizes.size()) {
|
||||
case 4:
|
||||
return is_channels_last_strides_2d_s4(sizes, strides);
|
||||
|
|
@ -224,7 +243,9 @@ inline bool is_channels_last_strides_2d(const IntArrayRef sizes, const IntArrayR
|
|||
}
|
||||
}
|
||||
|
||||
inline bool is_channels_last_strides_3d(const IntArrayRef sizes, const IntArrayRef strides) {
|
||||
inline bool is_channels_last_strides_3d(
|
||||
const IntArrayRef sizes,
|
||||
const IntArrayRef strides) {
|
||||
switch (sizes.size()) {
|
||||
case 5:
|
||||
return is_channels_last_strides_3d_s5(sizes, strides);
|
||||
|
|
|
|||
|
|
@ -31,9 +31,7 @@ inline std::string toString(QEngine qengine) {
|
|||
return "QNNPACK";
|
||||
default:
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Unrecognized Quantized Engine: ",
|
||||
static_cast<int>(qengine));
|
||||
false, "Unrecognized Quantized Engine: ", static_cast<int>(qengine));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -24,12 +24,13 @@ constexpr auto kPerTensorAffine = QScheme::PER_TENSOR_AFFINE;
|
|||
constexpr auto kPerChannelAffine = QScheme::PER_CHANNEL_AFFINE;
|
||||
constexpr auto kPerTensorSymmetric = QScheme::PER_TENSOR_SYMMETRIC;
|
||||
constexpr auto kPerChannelSymmetric = QScheme::PER_CHANNEL_SYMMETRIC;
|
||||
constexpr auto kPerChannelAffineFloatQParams = QScheme::PER_CHANNEL_AFFINE_FLOAT_QPARAMS;
|
||||
constexpr auto kPerChannelAffineFloatQParams =
|
||||
QScheme::PER_CHANNEL_AFFINE_FLOAT_QPARAMS;
|
||||
constexpr int COMPILE_TIME_NUM_QSCHEMES =
|
||||
static_cast<int>(QScheme::COMPILE_TIME_NUM_QSCHEMES);
|
||||
static_cast<int>(QScheme::COMPILE_TIME_NUM_QSCHEMES);
|
||||
|
||||
inline std::string toString(QScheme qscheme) {
|
||||
switch(qscheme) {
|
||||
switch (qscheme) {
|
||||
case kPerTensorAffine:
|
||||
return "per_tensor_affine";
|
||||
case kPerChannelAffine:
|
||||
|
|
|
|||
|
|
@ -3,7 +3,9 @@
|
|||
namespace c10 {
|
||||
|
||||
Scalar Scalar::operator-() const {
|
||||
TORCH_CHECK(!isBoolean(), "torch boolean negative, the `-` operator, is not supported.");
|
||||
TORCH_CHECK(
|
||||
!isBoolean(),
|
||||
"torch boolean negative, the `-` operator, is not supported.");
|
||||
if (isFloatingPoint()) {
|
||||
return Scalar(-v.d);
|
||||
} else if (isComplex()) {
|
||||
|
|
@ -31,4 +33,4 @@ Scalar Scalar::log() const {
|
|||
}
|
||||
}
|
||||
|
||||
} // namespace c10
|
||||
} // namespace c10
|
||||
|
|
|
|||
|
|
@ -4,8 +4,8 @@
|
|||
#include <stdint.h>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
|
|
@ -26,8 +26,8 @@ class C10_API Scalar {
|
|||
public:
|
||||
Scalar() : Scalar(int64_t(0)) {}
|
||||
|
||||
#define DEFINE_IMPLICIT_CTOR(type, name) \
|
||||
Scalar(type vv) : Scalar(vv, true) { }
|
||||
#define DEFINE_IMPLICIT_CTOR(type, name) \
|
||||
Scalar(type vv) : Scalar(vv, true) {}
|
||||
|
||||
AT_FORALL_SCALAR_TYPES_AND2(Half, BFloat16, DEFINE_IMPLICIT_CTOR)
|
||||
AT_FORALL_COMPLEX_TYPES(DEFINE_IMPLICIT_CTOR)
|
||||
|
|
@ -45,18 +45,18 @@ class C10_API Scalar {
|
|||
v.i = convert<int64_t, bool>(vv);
|
||||
}
|
||||
|
||||
#define DEFINE_ACCESSOR(type, name) \
|
||||
type to##name() const { \
|
||||
if (Tag::HAS_d == tag) { \
|
||||
return checked_convert<type, double>(v.d, #type); \
|
||||
} else if (Tag::HAS_z == tag) { \
|
||||
return checked_convert<type, c10::complex<double>>( \
|
||||
v.z, #type); \
|
||||
} if (Tag::HAS_b == tag) { \
|
||||
return checked_convert<type, bool>(v.i, #type); \
|
||||
} else { \
|
||||
return checked_convert<type, int64_t>(v.i, #type); \
|
||||
} \
|
||||
#define DEFINE_ACCESSOR(type, name) \
|
||||
type to##name() const { \
|
||||
if (Tag::HAS_d == tag) { \
|
||||
return checked_convert<type, double>(v.d, #type); \
|
||||
} else if (Tag::HAS_z == tag) { \
|
||||
return checked_convert<type, c10::complex<double>>(v.z, #type); \
|
||||
} \
|
||||
if (Tag::HAS_b == tag) { \
|
||||
return checked_convert<type, bool>(v.i, #type); \
|
||||
} else { \
|
||||
return checked_convert<type, int64_t>(v.i, #type); \
|
||||
} \
|
||||
}
|
||||
|
||||
// TODO: Support ComplexHalf accessor
|
||||
|
|
@ -71,7 +71,8 @@ class C10_API Scalar {
|
|||
return Tag::HAS_d == tag;
|
||||
}
|
||||
|
||||
C10_DEPRECATED_MESSAGE("isIntegral is deprecated. Please use the overload with 'includeBool' parameter instead.")
|
||||
C10_DEPRECATED_MESSAGE(
|
||||
"isIntegral is deprecated. Please use the overload with 'includeBool' parameter instead.")
|
||||
bool isIntegral() const {
|
||||
return Tag::HAS_i == tag;
|
||||
}
|
||||
|
|
@ -90,7 +91,9 @@ class C10_API Scalar {
|
|||
Scalar conj() const;
|
||||
Scalar log() const;
|
||||
|
||||
template<typename T, typename std::enable_if<!c10::is_complex<T>::value, int>::type = 0>
|
||||
template <
|
||||
typename T,
|
||||
typename std::enable_if<!c10::is_complex<T>::value, int>::type = 0>
|
||||
bool equal(T num) const {
|
||||
if (isComplex()) {
|
||||
auto val = v.z;
|
||||
|
|
@ -105,7 +108,9 @@ class C10_API Scalar {
|
|||
}
|
||||
}
|
||||
|
||||
template<typename T, typename std::enable_if<c10::is_complex<T>::value, int>::type = 0>
|
||||
template <
|
||||
typename T,
|
||||
typename std::enable_if<c10::is_complex<T>::value, int>::type = 0>
|
||||
bool equal(T num) const {
|
||||
if (isComplex()) {
|
||||
return v.z == num;
|
||||
|
|
@ -142,26 +147,30 @@ class C10_API Scalar {
|
|||
}
|
||||
|
||||
private:
|
||||
template<typename T,
|
||||
typename std::enable_if<std::is_integral<T>::value && ! std::is_same<T, bool>::value, bool>::type* =
|
||||
nullptr>
|
||||
Scalar(T vv, bool) : tag(Tag::HAS_i) {
|
||||
v.i = convert<decltype(v.i), T>(vv);
|
||||
}
|
||||
template <
|
||||
typename T,
|
||||
typename std::enable_if<
|
||||
std::is_integral<T>::value && !std::is_same<T, bool>::value,
|
||||
bool>::type* = nullptr>
|
||||
Scalar(T vv, bool) : tag(Tag::HAS_i) {
|
||||
v.i = convert<decltype(v.i), T>(vv);
|
||||
}
|
||||
|
||||
template<typename T,
|
||||
typename std::enable_if<!std::is_integral<T>::value && !c10::is_complex<T>::value, bool>::type* =
|
||||
nullptr>
|
||||
Scalar(T vv, bool) : tag(Tag::HAS_d) {
|
||||
v.d = convert<decltype(v.d), T>(vv);
|
||||
}
|
||||
template <
|
||||
typename T,
|
||||
typename std::enable_if<
|
||||
!std::is_integral<T>::value && !c10::is_complex<T>::value,
|
||||
bool>::type* = nullptr>
|
||||
Scalar(T vv, bool) : tag(Tag::HAS_d) {
|
||||
v.d = convert<decltype(v.d), T>(vv);
|
||||
}
|
||||
|
||||
template<typename T,
|
||||
typename std::enable_if<c10::is_complex<T>::value, bool>::type* =
|
||||
nullptr>
|
||||
Scalar(T vv, bool) : tag(Tag::HAS_z) {
|
||||
v.z = convert<decltype(v.z), T>(vv);
|
||||
}
|
||||
template <
|
||||
typename T,
|
||||
typename std::enable_if<c10::is_complex<T>::value, bool>::type* = nullptr>
|
||||
Scalar(T vv, bool) : tag(Tag::HAS_z) {
|
||||
v.z = convert<decltype(v.z), T>(vv);
|
||||
}
|
||||
|
||||
// We can't set v in the initializer list using the
|
||||
// syntax v{ .member = ... } because it doesn't work on MSVC
|
||||
|
|
@ -172,7 +181,7 @@ class C10_API Scalar {
|
|||
double d;
|
||||
int64_t i;
|
||||
c10::complex<double> z;
|
||||
v_t(){} // default constructor
|
||||
v_t() {} // default constructor
|
||||
} v;
|
||||
};
|
||||
|
||||
|
|
@ -182,10 +191,10 @@ inline T Scalar::to() const {
|
|||
throw std::runtime_error("to() cast to unexpected type.");
|
||||
}
|
||||
|
||||
#define DEFINE_TO(T, name) \
|
||||
template <> \
|
||||
inline T Scalar::to<T>() const { \
|
||||
return to##name(); \
|
||||
#define DEFINE_TO(T, name) \
|
||||
template <> \
|
||||
inline T Scalar::to<T>() const { \
|
||||
return to##name(); \
|
||||
}
|
||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF(DEFINE_TO)
|
||||
#undef DEFINE_TO
|
||||
|
|
|
|||
|
|
@ -1,14 +1,14 @@
|
|||
#pragma once
|
||||
|
||||
#include <c10/util/ArrayRef.h>
|
||||
#include <c10/util/complex.h>
|
||||
#include <c10/util/BFloat16.h>
|
||||
#include <c10/util/Half.h>
|
||||
#include <c10/util/Optional.h>
|
||||
#include <c10/util/complex.h>
|
||||
#include <c10/util/qint32.h>
|
||||
#include <c10/util/qint8.h>
|
||||
#include <c10/util/quint8.h>
|
||||
#include <c10/util/BFloat16.h>
|
||||
#include <c10/util/quint4x2.h>
|
||||
#include <c10/util/Optional.h>
|
||||
#include <c10/util/quint8.h>
|
||||
|
||||
#include <complex>
|
||||
#include <cstdint>
|
||||
|
|
@ -44,7 +44,6 @@ namespace c10 {
|
|||
_(at::BFloat16, BFloat16) /* 15 */ \
|
||||
_(c10::quint4x2, QUInt4x2) /* 16 */
|
||||
|
||||
|
||||
// If you want to support ComplexHalf for real, add ComplexHalf
|
||||
// into this macro (and change the name). But beware: convert()
|
||||
// doesn't work for all the conversions you need...
|
||||
|
|
@ -62,7 +61,6 @@ namespace c10 {
|
|||
_(bool, Bool) \
|
||||
_(at::BFloat16, BFloat16)
|
||||
|
||||
|
||||
enum class ScalarType : int8_t {
|
||||
#define DEFINE_ENUM(_1, n) n,
|
||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_ENUM)
|
||||
|
|
@ -71,7 +69,8 @@ enum class ScalarType : int8_t {
|
|||
NumOptions
|
||||
};
|
||||
|
||||
constexpr uint16_t NumScalarTypes = static_cast<uint16_t>(ScalarType::NumOptions);
|
||||
constexpr uint16_t NumScalarTypes =
|
||||
static_cast<uint16_t>(ScalarType::NumOptions);
|
||||
|
||||
namespace impl {
|
||||
|
||||
|
|
@ -80,20 +79,20 @@ namespace impl {
|
|||
template <c10::ScalarType N>
|
||||
struct ScalarTypeToCPPType;
|
||||
|
||||
#define SPECIALIZE_ScalarTypeToCPPType(cpp_type, scalar_type) \
|
||||
template<> \
|
||||
struct ScalarTypeToCPPType<c10::ScalarType::scalar_type> { \
|
||||
using type = cpp_type; \
|
||||
\
|
||||
/* This is a workaround for the CUDA bug which prevents */ \
|
||||
/* ::detail::ScalarTypeToCType<T>::type being used directly due to */ \
|
||||
/* ambiguous reference which can't to be resolved. For some reason it */ \
|
||||
/* cant pick between at::detail and at::cuda::detail. */ \
|
||||
/* For repro example, please see: */ \
|
||||
/* https://gist.github.com/izdeby/952ae7cf256ddb740a73776d39a7e7ba */ \
|
||||
/* TODO: remove once the bug is fixed. */ \
|
||||
static type t; \
|
||||
};
|
||||
#define SPECIALIZE_ScalarTypeToCPPType(cpp_type, scalar_type) \
|
||||
template <> \
|
||||
struct ScalarTypeToCPPType<c10::ScalarType::scalar_type> { \
|
||||
using type = cpp_type; \
|
||||
\
|
||||
/* This is a workaround for the CUDA bug which prevents */ \
|
||||
/* ::detail::ScalarTypeToCType<T>::type being used directly due to */ \
|
||||
/* ambiguous reference which can't to be resolved. For some reason it */ \
|
||||
/* cant pick between at::detail and at::cuda::detail. */ \
|
||||
/* For repro example, please see: */ \
|
||||
/* https://gist.github.com/izdeby/952ae7cf256ddb740a73776d39a7e7ba */ \
|
||||
/* TODO: remove once the bug is fixed. */ \
|
||||
static type t; \
|
||||
};
|
||||
|
||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_ScalarTypeToCPPType)
|
||||
|
||||
|
|
@ -104,12 +103,12 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_ScalarTypeToCPPType)
|
|||
template <typename T>
|
||||
struct CppTypeToScalarType;
|
||||
|
||||
#define SPECIALIZE_CppTypeToScalarType(cpp_type, scalar_type) \
|
||||
template<> \
|
||||
struct CppTypeToScalarType<cpp_type>: \
|
||||
std::integral_constant<c10::ScalarType, \
|
||||
c10::ScalarType::scalar_type> \
|
||||
{};
|
||||
#define SPECIALIZE_CppTypeToScalarType(cpp_type, scalar_type) \
|
||||
template <> \
|
||||
struct CppTypeToScalarType<cpp_type> \
|
||||
: std:: \
|
||||
integral_constant<c10::ScalarType, c10::ScalarType::scalar_type> { \
|
||||
};
|
||||
|
||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType)
|
||||
|
||||
|
|
@ -131,47 +130,59 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType)
|
|||
_(float, Float) \
|
||||
_(double, Double)
|
||||
|
||||
#define AT_FORALL_SCALAR_TYPES_AND(SCALARTYPE, _) \
|
||||
_(uint8_t, Byte) \
|
||||
_(int8_t, Char) \
|
||||
_(int16_t, Short) \
|
||||
_(int, Int) \
|
||||
_(int64_t, Long) \
|
||||
_(float, Float) \
|
||||
_(double, Double) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType<::c10::ScalarType::SCALARTYPE>::t), SCALARTYPE)
|
||||
#define AT_FORALL_SCALAR_TYPES_AND(SCALARTYPE, _) \
|
||||
_(uint8_t, Byte) \
|
||||
_(int8_t, Char) \
|
||||
_(int16_t, Short) \
|
||||
_(int, Int) \
|
||||
_(int64_t, Long) \
|
||||
_(float, Float) \
|
||||
_(double, Double) \
|
||||
_(decltype( \
|
||||
::c10::impl::ScalarTypeToCPPType<::c10::ScalarType::SCALARTYPE>::t), \
|
||||
SCALARTYPE)
|
||||
|
||||
#define AT_FORALL_SCALAR_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, _) \
|
||||
_(uint8_t, Byte) \
|
||||
_(int8_t, Char) \
|
||||
_(int16_t, Short) \
|
||||
_(int, Int) \
|
||||
_(int64_t, Long) \
|
||||
_(float, Float) \
|
||||
_(double, Double) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType<::c10::ScalarType::SCALARTYPE1>::t), SCALARTYPE1) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType<::c10::ScalarType::SCALARTYPE2>::t), SCALARTYPE2)
|
||||
#define AT_FORALL_SCALAR_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, _) \
|
||||
_(uint8_t, Byte) \
|
||||
_(int8_t, Char) \
|
||||
_(int16_t, Short) \
|
||||
_(int, Int) \
|
||||
_(int64_t, Long) \
|
||||
_(float, Float) \
|
||||
_(double, Double) \
|
||||
_(decltype( \
|
||||
::c10::impl::ScalarTypeToCPPType<::c10::ScalarType::SCALARTYPE1>::t), \
|
||||
SCALARTYPE1) \
|
||||
_(decltype( \
|
||||
::c10::impl::ScalarTypeToCPPType<::c10::ScalarType::SCALARTYPE2>::t), \
|
||||
SCALARTYPE2)
|
||||
|
||||
#define AT_FORALL_SCALAR_TYPES_AND3(SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, _) \
|
||||
_(uint8_t, Byte) \
|
||||
_(int8_t, Char) \
|
||||
_(int16_t, Short) \
|
||||
_(int, Int) \
|
||||
_(int64_t, Long) \
|
||||
_(float, Float) \
|
||||
_(double, Double) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType<::c10::ScalarType::SCALARTYPE1>::t), SCALARTYPE1) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType<::c10::ScalarType::SCALARTYPE2>::t), SCALARTYPE2) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType<::c10::ScalarType::SCALARTYPE3>::t), SCALARTYPE3)
|
||||
#define AT_FORALL_SCALAR_TYPES_AND3(SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, _) \
|
||||
_(uint8_t, Byte) \
|
||||
_(int8_t, Char) \
|
||||
_(int16_t, Short) \
|
||||
_(int, Int) \
|
||||
_(int64_t, Long) \
|
||||
_(float, Float) \
|
||||
_(double, Double) \
|
||||
_(decltype( \
|
||||
::c10::impl::ScalarTypeToCPPType<::c10::ScalarType::SCALARTYPE1>::t), \
|
||||
SCALARTYPE1) \
|
||||
_(decltype( \
|
||||
::c10::impl::ScalarTypeToCPPType<::c10::ScalarType::SCALARTYPE2>::t), \
|
||||
SCALARTYPE2) \
|
||||
_(decltype( \
|
||||
::c10::impl::ScalarTypeToCPPType<::c10::ScalarType::SCALARTYPE3>::t), \
|
||||
SCALARTYPE3)
|
||||
|
||||
#define AT_FORALL_QINT_TYPES(_) \
|
||||
_(c10::qint8, QInt8) \
|
||||
_(c10::quint8, QUInt8) \
|
||||
_(c10::qint32, QInt32) \
|
||||
#define AT_FORALL_QINT_TYPES(_) \
|
||||
_(c10::qint8, QInt8) \
|
||||
_(c10::quint8, QUInt8) \
|
||||
_(c10::qint32, QInt32) \
|
||||
_(c10::quint4x2, QUInt4x2)
|
||||
|
||||
#define AT_FORALL_COMPLEX_TYPES(_) \
|
||||
_(c10::complex<float>, ComplexFloat) \
|
||||
#define AT_FORALL_COMPLEX_TYPES(_) \
|
||||
_(c10::complex<float>, ComplexFloat) \
|
||||
_(c10::complex<double>, ComplexDouble)
|
||||
|
||||
#define DEFINE_CONSTANT(_, name) \
|
||||
|
|
@ -206,7 +217,8 @@ static inline size_t elementSize(ScalarType t) {
|
|||
#undef CASE_ELEMENTSIZE_CASE
|
||||
}
|
||||
|
||||
C10_DEPRECATED_MESSAGE("isIntegralType is deprecated. Please use the overload with 'includeBool' parameter instead.")
|
||||
C10_DEPRECATED_MESSAGE(
|
||||
"isIntegralType is deprecated. Please use the overload with 'includeBool' parameter instead.")
|
||||
static inline bool isIntegralType(ScalarType t) {
|
||||
return (
|
||||
t == ScalarType::Byte || t == ScalarType::Char || t == ScalarType::Int ||
|
||||
|
|
@ -214,9 +226,9 @@ static inline bool isIntegralType(ScalarType t) {
|
|||
}
|
||||
|
||||
static inline bool isIntegralType(ScalarType t, bool includeBool) {
|
||||
bool isIntegral = (
|
||||
t == ScalarType::Byte || t == ScalarType::Char || t == ScalarType::Int ||
|
||||
t == ScalarType::Long || t == ScalarType::Short);
|
||||
bool isIntegral =
|
||||
(t == ScalarType::Byte || t == ScalarType::Char || t == ScalarType::Int ||
|
||||
t == ScalarType::Long || t == ScalarType::Short);
|
||||
|
||||
return includeBool ? isIntegral || (t == ScalarType::Bool) : isIntegral;
|
||||
}
|
||||
|
|
@ -235,7 +247,8 @@ static inline bool isComplexType(ScalarType t) {
|
|||
|
||||
static inline bool isQIntType(ScalarType t) {
|
||||
// Don't forget to extend this when adding new QInt types
|
||||
return t == ScalarType:: QInt8 || t == ScalarType::QUInt8 || t == ScalarType::QInt32 || t == ScalarType::QUInt4x2;
|
||||
return t == ScalarType::QInt8 || t == ScalarType::QUInt8 ||
|
||||
t == ScalarType::QInt32 || t == ScalarType::QUInt4x2;
|
||||
}
|
||||
|
||||
static inline ScalarType toQIntType(ScalarType t) {
|
||||
|
|
@ -268,20 +281,20 @@ static inline ScalarType toUnderlying(ScalarType t) {
|
|||
|
||||
static inline bool isSignedType(ScalarType t) {
|
||||
TORCH_CHECK(!isQIntType(t), "isSignedType not supported for quantized types");
|
||||
#define CASE_SIGNED(ctype, name) \
|
||||
case ScalarType::name: \
|
||||
return std::numeric_limits<ctype>::is_signed;
|
||||
#define CASE_SIGNED(ctype, name) \
|
||||
case ScalarType::name: \
|
||||
return std::numeric_limits<ctype>::is_signed;
|
||||
|
||||
switch (t) {
|
||||
case ScalarType::ComplexHalf:
|
||||
case ScalarType::ComplexFloat:
|
||||
case ScalarType::ComplexDouble:
|
||||
return true;
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Half, Bool, BFloat16, CASE_SIGNED)
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Half, Bool, BFloat16, CASE_SIGNED)
|
||||
default:
|
||||
TORCH_CHECK(false, "Unknown ScalarType");
|
||||
}
|
||||
#undef CASE_SIGNED
|
||||
#undef CASE_SIGNED
|
||||
}
|
||||
|
||||
static inline bool isUnderlying(ScalarType type, ScalarType qtype) {
|
||||
|
|
@ -323,7 +336,8 @@ static inline ScalarType toComplexType(ScalarType t) {
|
|||
// see tensor_attributes.rst for detailed explanation and examples
|
||||
// of casting rules.
|
||||
static inline bool canCast(const ScalarType from, const ScalarType to) {
|
||||
// We disallow complex -> non complex, e.g., float_tensor *= complex is disallowed.
|
||||
// We disallow complex -> non complex, e.g., float_tensor *= complex is
|
||||
// disallowed.
|
||||
if (isComplexType(from) && !isComplexType(to)) {
|
||||
return false;
|
||||
}
|
||||
|
|
@ -333,13 +347,14 @@ static inline bool canCast(const ScalarType from, const ScalarType to) {
|
|||
}
|
||||
|
||||
// Treat bool as a distinct "category," to be consistent with type promotion
|
||||
// rules (e.g. `bool_tensor + 5 -> int64_tensor`). If `5` was in the same category
|
||||
// as `bool_tensor`, we would not promote.
|
||||
// Differing categories implies `bool_tensor += 5` is disallowed.
|
||||
// rules (e.g. `bool_tensor + 5 -> int64_tensor`). If `5` was in the same
|
||||
// category as `bool_tensor`, we would not promote. Differing categories
|
||||
// implies `bool_tensor += 5` is disallowed.
|
||||
//
|
||||
// NB: numpy distinguishes "unsigned" as a category to get the desired
|
||||
// `bool_tensor + 5 -> int64_tensor` behavior. We don't, because:
|
||||
// * We don't want the performance hit of checking the runtime sign of Scalars.
|
||||
// * We don't want the performance hit of checking the runtime sign of
|
||||
// Scalars.
|
||||
// * `uint8_tensor + 5 -> int64_tensor` would be undesirable.
|
||||
if (from != ScalarType::Bool && to == ScalarType::Bool) {
|
||||
return false;
|
||||
|
|
@ -373,7 +388,8 @@ static inline ScalarType promoteTypes(ScalarType a, ScalarType b) {
|
|||
}
|
||||
|
||||
if (isQIntType(a) || isQIntType(b)) {
|
||||
TORCH_CHECK(false,
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"promoteTypes with quantized numbers is not handled yet; figure out what the correct rules should be, offending types: ",
|
||||
toString(a),
|
||||
" ",
|
||||
|
|
@ -385,23 +401,23 @@ static inline ScalarType promoteTypes(ScalarType a, ScalarType b) {
|
|||
// corrent values for the type promotions in complex type cases.
|
||||
static constexpr ScalarType _promoteTypesLookup[static_cast<int>(
|
||||
ScalarType::NumOptions)][static_cast<int>(ScalarType::NumOptions)] = {
|
||||
/* u1 i1 i2 i4 i8 f2 f4 f8 c2 c4 c8 b1 q1 q2 q3 bf*/
|
||||
/* u1 */ {u1, i2, i2, i4, i8, f2, f4, f8, ud, c4, c8, u1, ud, ud, ud, bf},
|
||||
/* i1 */ {i2, i1, i2, i4, i8, f2, f4, f8, ud, c4, c8, i1, ud, ud, ud, bf},
|
||||
/* i2 */ {i2, i2, i2, i4, i8, f2, f4, f8, ud, c4, c8, i2, ud, ud, ud, bf},
|
||||
/* i4 */ {i4, i4, i4, i4, i8, f2, f4, f8, ud, c4, c8, i4, ud, ud, ud, bf},
|
||||
/* i8 */ {i8, i8, i8, i8, i8, f2, f4, f8, ud, c4, c8, i8, ud, ud, ud, bf},
|
||||
/* f2 */ {f2, f2, f2, f2, f2, f2, f4, f8, ud, c4, c8, f2, ud, ud, ud, f4},
|
||||
/* f4 */ {f4, f4, f4, f4, f4, f4, f4, f8, ud, c4, c8, f4, ud, ud, ud, f4},
|
||||
/* f8 */ {f8, f8, f8, f8, f8, f8, f8, f8, ud, c8, c8, f8, ud, ud, ud, f8},
|
||||
/* c2 */ {ud, ud, ud, ud, ud, ud, ud, ud, c2, c4, c8, ud, ud, ud, ud, ud},
|
||||
/* c4 */ {c4, c4, c4, c4, c4, c4, c4, c8, c4, c4, c8, c4, ud, ud, ud, c4},
|
||||
/* c8 */ {c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, ud, ud, ud, c8},
|
||||
/* b1 */ {u1, i1, i2, i4, i8, f2, f4, f8, ud, c4, c8, b1, ud, ud, ud, bf},
|
||||
/* q1 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
|
||||
/* q2 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
|
||||
/* q3 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
|
||||
/* bf */ {bf, bf, bf, bf, bf, f4, f4, f8, ud, c4, c8, bf, ud, ud, ud, bf},
|
||||
/* u1 i1 i2 i4 i8 f2 f4 f8 c2 c4 c8 b1 q1 q2 q3 bf*/
|
||||
/* u1 */ {u1, i2, i2, i4, i8, f2, f4, f8, ud, c4, c8, u1, ud, ud, ud, bf},
|
||||
/* i1 */ {i2, i1, i2, i4, i8, f2, f4, f8, ud, c4, c8, i1, ud, ud, ud, bf},
|
||||
/* i2 */ {i2, i2, i2, i4, i8, f2, f4, f8, ud, c4, c8, i2, ud, ud, ud, bf},
|
||||
/* i4 */ {i4, i4, i4, i4, i8, f2, f4, f8, ud, c4, c8, i4, ud, ud, ud, bf},
|
||||
/* i8 */ {i8, i8, i8, i8, i8, f2, f4, f8, ud, c4, c8, i8, ud, ud, ud, bf},
|
||||
/* f2 */ {f2, f2, f2, f2, f2, f2, f4, f8, ud, c4, c8, f2, ud, ud, ud, f4},
|
||||
/* f4 */ {f4, f4, f4, f4, f4, f4, f4, f8, ud, c4, c8, f4, ud, ud, ud, f4},
|
||||
/* f8 */ {f8, f8, f8, f8, f8, f8, f8, f8, ud, c8, c8, f8, ud, ud, ud, f8},
|
||||
/* c2 */ {ud, ud, ud, ud, ud, ud, ud, ud, c2, c4, c8, ud, ud, ud, ud, ud},
|
||||
/* c4 */ {c4, c4, c4, c4, c4, c4, c4, c8, c4, c4, c8, c4, ud, ud, ud, c4},
|
||||
/* c8 */ {c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, ud, ud, ud, c8},
|
||||
/* b1 */ {u1, i1, i2, i4, i8, f2, f4, f8, ud, c4, c8, b1, ud, ud, ud, bf},
|
||||
/* q1 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
|
||||
/* q2 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
|
||||
/* q3 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
|
||||
/* bf */ {bf, bf, bf, bf, bf, f4, f4, f8, ud, c4, c8, bf, ud, ud, ud, bf},
|
||||
};
|
||||
return _promoteTypesLookup[static_cast<int>(a)][static_cast<int>(b)];
|
||||
}
|
||||
|
|
|
|||
|
|
@ -26,7 +26,8 @@ static inline ScalarType typeMetaToScalarType(caffe2::TypeMeta dtype) {
|
|||
/**
|
||||
* typeMetaToScalarType(), lifted to optional
|
||||
*/
|
||||
static inline optional<at::ScalarType> optTypeMetaToScalarType(optional<caffe2::TypeMeta> type_meta) {
|
||||
static inline optional<at::ScalarType> optTypeMetaToScalarType(
|
||||
optional<caffe2::TypeMeta> type_meta) {
|
||||
if (!type_meta.has_value()) {
|
||||
return c10::nullopt;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
#include <c10/core/Storage.h>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
} // namespace c10
|
||||
namespace c10 {} // namespace c10
|
||||
|
|
|
|||
|
|
@ -9,7 +9,8 @@ struct C10_API Storage {
|
|||
struct use_byte_size_t {};
|
||||
|
||||
Storage() {}
|
||||
Storage(c10::intrusive_ptr<StorageImpl> ptr) : storage_impl_(std::move(ptr)) {}
|
||||
Storage(c10::intrusive_ptr<StorageImpl> ptr)
|
||||
: storage_impl_(std::move(ptr)) {}
|
||||
|
||||
// Allocates memory buffer using given allocator and creates a storage with it
|
||||
Storage(
|
||||
|
|
@ -53,10 +54,14 @@ struct C10_API Storage {
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
T* data() const { return storage_impl_->data<T>(); }
|
||||
T* data() const {
|
||||
return storage_impl_->data<T>();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T* unsafe_data() const { return storage_impl_->unsafe_data<T>(); }
|
||||
T* unsafe_data() const {
|
||||
return storage_impl_->unsafe_data<T>();
|
||||
}
|
||||
|
||||
// TODO: remove later
|
||||
void set_nbytes(size_t size_bytes) const {
|
||||
|
|
@ -134,7 +139,8 @@ struct C10_API Storage {
|
|||
size_t capacity,
|
||||
DeleterFnPtr d = nullptr) {
|
||||
if (!storage_impl_.unique()) {
|
||||
TORCH_CHECK(false,
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"UniqueStorageShareExternalPointer can only be called when use_count == 1");
|
||||
}
|
||||
storage_impl_->UniqueStorageShareExternalPointer(src, capacity, d);
|
||||
|
|
@ -144,7 +150,8 @@ struct C10_API Storage {
|
|||
at::DataPtr&& data_ptr,
|
||||
size_t capacity) {
|
||||
if (!storage_impl_.unique()) {
|
||||
TORCH_CHECK(false,
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"UniqueStorageShareExternalPointer can only be called when use_count == 1");
|
||||
}
|
||||
storage_impl_->UniqueStorageShareExternalPointer(
|
||||
|
|
|
|||
|
|
@ -20,13 +20,13 @@ namespace c10 {
|
|||
// but a lot of things won't work correctly, including:
|
||||
//
|
||||
// - An ordinary deleter on such a storage is wrong, because normal deleters
|
||||
// assume unique ownership, but if you have two storages at the same data, that
|
||||
// implies there is some sort of shared ownership. So your deleter would have to
|
||||
// actually be internally doing some sort of refcount thing
|
||||
// assume unique ownership, but if you have two storages at the same data,
|
||||
// that implies there is some sort of shared ownership. So your deleter would
|
||||
// have to actually be internally doing some sort of refcount thing
|
||||
// - Deepcopy in Python side relies on storage equality and not data pointer
|
||||
// equality; so if there are two separate storages pointing to the same data,
|
||||
// the data will actually get duplicated in that case (one data ptr before, two
|
||||
// data ptrs after)
|
||||
// the data will actually get duplicated in that case (one data ptr before,
|
||||
// two data ptrs after)
|
||||
// - Version counts won't work correctly, because we do all VC tracking at the
|
||||
// level of storages (unless you explicitly disconnect the VC with detach);
|
||||
// mutation because data pointers are the same are totally untracked
|
||||
|
|
|
|||
|
|
@ -55,10 +55,11 @@ using StreamId = int32_t;
|
|||
* wrapper classes which provide this functionality, e.g., CUDAStream.
|
||||
*/
|
||||
class Stream final {
|
||||
private:
|
||||
private:
|
||||
Device device_;
|
||||
StreamId id_;
|
||||
public:
|
||||
|
||||
public:
|
||||
enum Unsafe { UNSAFE };
|
||||
enum Default { DEFAULT };
|
||||
|
||||
|
|
@ -69,16 +70,13 @@ public:
|
|||
/// we don't require backends to give any guarantees about non-zero
|
||||
/// StreamIds; they are welcome to allocate in whatever way they like.
|
||||
explicit Stream(Unsafe, Device device, StreamId id)
|
||||
: device_(device)
|
||||
, id_(id) {}
|
||||
: device_(device), id_(id) {}
|
||||
|
||||
/// Construct the default stream of a Device. The default stream is
|
||||
/// NOT the same as the current stream; default stream is a fixed stream
|
||||
/// that never changes, whereas the current stream may be changed by
|
||||
/// StreamGuard.
|
||||
explicit Stream(Default, Device device)
|
||||
: device_(device)
|
||||
, id_(0) {}
|
||||
explicit Stream(Default, Device device) : device_(device), id_(0) {}
|
||||
|
||||
bool operator==(const Stream& other) const noexcept {
|
||||
return this->device_ == other.device_ && this->id_ == other.id_;
|
||||
|
|
@ -87,10 +85,18 @@ public:
|
|||
return !(*this == other);
|
||||
}
|
||||
|
||||
Device device() const noexcept { return device_; }
|
||||
DeviceType device_type() const noexcept { return device_.type(); }
|
||||
DeviceIndex device_index() const noexcept { return device_.index(); }
|
||||
StreamId id() const noexcept { return id_; }
|
||||
Device device() const noexcept {
|
||||
return device_;
|
||||
}
|
||||
DeviceType device_type() const noexcept {
|
||||
return device_.type();
|
||||
}
|
||||
DeviceIndex device_index() const noexcept {
|
||||
return device_.index();
|
||||
}
|
||||
StreamId id() const noexcept {
|
||||
return id_;
|
||||
}
|
||||
|
||||
// Enqueues a wait instruction in the stream's work queue.
|
||||
// This instruction is a no-op unless the event is marked
|
||||
|
|
@ -116,10 +122,10 @@ public:
|
|||
static_assert(sizeof(StreamId) == 4, "DeviceIndex is not 32-bit");
|
||||
// Concat these together into a 64-bit integer
|
||||
// See Note [Hazard when concatenating signed integers]
|
||||
uint64_t bits =
|
||||
static_cast<uint64_t>(static_cast<uint8_t>(device_type())) << 48
|
||||
| static_cast<uint64_t>(static_cast<uint8_t>(device_index())) << 32
|
||||
| static_cast<uint64_t>(static_cast<uint32_t>(id()));
|
||||
uint64_t bits = static_cast<uint64_t>(static_cast<uint8_t>(device_type()))
|
||||
<< 48 |
|
||||
static_cast<uint64_t>(static_cast<uint8_t>(device_index())) << 32 |
|
||||
static_cast<uint64_t>(static_cast<uint32_t>(id()));
|
||||
return bits;
|
||||
}
|
||||
|
||||
|
|
@ -145,10 +151,10 @@ C10_API std::ostream& operator<<(std::ostream& stream, const Stream& s);
|
|||
} // namespace c10
|
||||
|
||||
namespace std {
|
||||
template <>
|
||||
struct hash<c10::Stream> {
|
||||
size_t operator()(c10::Stream s) const noexcept {
|
||||
return std::hash<uint64_t>{}(s.pack());
|
||||
}
|
||||
};
|
||||
template <>
|
||||
struct hash<c10::Stream> {
|
||||
size_t operator()(c10::Stream s) const noexcept {
|
||||
return std::hash<uint64_t>{}(s.pack());
|
||||
}
|
||||
};
|
||||
} // namespace std
|
||||
|
|
|
|||
|
|
@ -47,7 +47,9 @@ struct StreamGuard {
|
|||
/// WARNING: reset_stream does NOT preserve previously set streams on
|
||||
/// different devices. If you need to set streams on multiple devices
|
||||
/// on , use MultiStreamGuard instead.
|
||||
void reset_stream(Stream stream) { guard_.reset_stream(stream); }
|
||||
void reset_stream(Stream stream) {
|
||||
guard_.reset_stream(stream);
|
||||
}
|
||||
|
||||
/// Returns the stream that was set at the time the guard was constructed.
|
||||
Stream original_stream() const {
|
||||
|
|
@ -62,13 +64,17 @@ struct StreamGuard {
|
|||
|
||||
/// Returns the most recent device that was set using this device guard,
|
||||
/// either from construction, or via set_device/reset_device/set_index.
|
||||
Device current_device() const { return guard_.current_device(); }
|
||||
Device current_device() const {
|
||||
return guard_.current_device();
|
||||
}
|
||||
|
||||
/// Returns the device that was set at the most recent reset_stream(),
|
||||
/// or otherwise the device at construction time.
|
||||
Device original_device() const { return guard_.original_device(); }
|
||||
Device original_device() const {
|
||||
return guard_.original_device();
|
||||
}
|
||||
|
||||
private:
|
||||
private:
|
||||
c10::impl::InlineStreamGuard<impl::VirtualGuardImpl> guard_;
|
||||
};
|
||||
|
||||
|
|
@ -88,7 +94,8 @@ struct OptionalStreamGuard {
|
|||
/// Set the current device to the device associated with the passed stream,
|
||||
/// and set the current stream on that device to the passed stream,
|
||||
/// if the passed stream is not nullopt.
|
||||
explicit OptionalStreamGuard(optional<Stream> stream_opt) : guard_(stream_opt) {}
|
||||
explicit OptionalStreamGuard(optional<Stream> stream_opt)
|
||||
: guard_(stream_opt) {}
|
||||
|
||||
/// Copy is disallowed
|
||||
OptionalStreamGuard(const OptionalStreamGuard&) = delete;
|
||||
|
|
@ -105,21 +112,30 @@ struct OptionalStreamGuard {
|
|||
/// set the current device to the device associated with the passed stream,
|
||||
/// and set the current stream on that device to the passed stream.
|
||||
/// Initializes the guard if it was not previously initialized.
|
||||
void reset_stream(Stream stream) { guard_.reset_stream(stream); }
|
||||
void reset_stream(Stream stream) {
|
||||
guard_.reset_stream(stream);
|
||||
}
|
||||
|
||||
/// Returns the stream that was set at the time the guard was most recently
|
||||
/// initialized, or nullopt if the guard is uninitialized.
|
||||
optional<Stream> original_stream() const { return guard_.original_stream(); }
|
||||
optional<Stream> original_stream() const {
|
||||
return guard_.original_stream();
|
||||
}
|
||||
|
||||
/// Returns the most recent stream that was set using this stream guard,
|
||||
/// either from construction, or via reset_stream, if the guard is initialized,
|
||||
/// or nullopt if the guard is uninitialized.
|
||||
optional<Stream> current_stream() const { return guard_.current_stream(); }
|
||||
/// either from construction, or via reset_stream, if the guard is
|
||||
/// initialized, or nullopt if the guard is uninitialized.
|
||||
optional<Stream> current_stream() const {
|
||||
return guard_.current_stream();
|
||||
}
|
||||
|
||||
/// Restore the original device and stream, resetting this guard to uninitialized state.
|
||||
void reset() { guard_.reset(); }
|
||||
/// Restore the original device and stream, resetting this guard to
|
||||
/// uninitialized state.
|
||||
void reset() {
|
||||
guard_.reset();
|
||||
}
|
||||
|
||||
private:
|
||||
private:
|
||||
c10::impl::InlineOptionalStreamGuard<impl::VirtualGuardImpl> guard_;
|
||||
};
|
||||
|
||||
|
|
@ -142,7 +158,7 @@ struct MultiStreamGuard {
|
|||
// See Note [Move assignment for RAII guards is tricky]
|
||||
MultiStreamGuard& operator=(MultiStreamGuard&& other) = delete;
|
||||
|
||||
private:
|
||||
private:
|
||||
c10::impl::InlineMultiStreamGuard<impl::VirtualGuardImpl> guard_;
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
#include <c10/core/TensorImpl.h>
|
||||
|
||||
#include <c10/core/Backend.h>
|
||||
#include <c10/core/InferenceMode.h>
|
||||
#include <c10/core/WrapDimMinimal.h>
|
||||
#include <c10/core/impl/LocalDispatchKeySet.h>
|
||||
#include <c10/util/Optional.h>
|
||||
#include <c10/core/InferenceMode.h>
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
C10_DEFINE_bool(
|
||||
|
|
@ -21,7 +21,7 @@ C10_DEFINE_int64(
|
|||
|
||||
namespace c10 {
|
||||
|
||||
const char * const TensorImpl::err_msg_tensor_metadata_change_not_allowed =
|
||||
const char* const TensorImpl::err_msg_tensor_metadata_change_not_allowed =
|
||||
"is not allowed on a Tensor created from .data or .detach().\n"
|
||||
"If your intent is to change the metadata of a Tensor (such as sizes / strides / storage / storage_offset)\n"
|
||||
"without autograd tracking the change, remove the .data / .detach() call and wrap the change in a `with torch.no_grad():` block.\n"
|
||||
|
|
@ -32,7 +32,8 @@ const char * const TensorImpl::err_msg_tensor_metadata_change_not_allowed =
|
|||
" x.set_(y)";
|
||||
|
||||
at::Tensor& TensorImpl::mutable_grad() {
|
||||
if (!autograd_meta_) autograd_meta_ = impl::GetAutogradMetaFactory()->make();
|
||||
if (!autograd_meta_)
|
||||
autograd_meta_ = impl::GetAutogradMetaFactory()->make();
|
||||
return autograd_meta_->mutable_grad();
|
||||
}
|
||||
|
||||
|
|
@ -43,18 +44,26 @@ const at::Tensor& TensorImpl::grad() const {
|
|||
// is not so easy to fix right now because the mutable counterpart of
|
||||
// this function must keep working so that "x.grad() = ..." keeps working
|
||||
// (part of public API).
|
||||
if (!autograd_meta_) return impl::GetAutogradMetaFactory()->undefined_tensor();
|
||||
if (!autograd_meta_)
|
||||
return impl::GetAutogradMetaFactory()->undefined_tensor();
|
||||
return autograd_meta_->grad();
|
||||
}
|
||||
|
||||
const at::Tensor& TensorImpl::_fw_grad(uint64_t level, const at::Tensor& self) const {
|
||||
const at::Tensor& TensorImpl::_fw_grad(uint64_t level, const at::Tensor& self)
|
||||
const {
|
||||
// See TensorImpl::grad() above for explanation about the line below
|
||||
if (!autograd_meta_) return impl::GetAutogradMetaFactory()->undefined_tensor();
|
||||
if (!autograd_meta_)
|
||||
return impl::GetAutogradMetaFactory()->undefined_tensor();
|
||||
return autograd_meta_->fw_grad(level, self);
|
||||
}
|
||||
|
||||
void TensorImpl::_set_fw_grad(const at::Tensor& new_grad, const at::Tensor& self, uint64_t level, bool is_inplace_op) {
|
||||
if (!autograd_meta_) autograd_meta_ = impl::GetAutogradMetaFactory()->make();
|
||||
void TensorImpl::_set_fw_grad(
|
||||
const at::Tensor& new_grad,
|
||||
const at::Tensor& self,
|
||||
uint64_t level,
|
||||
bool is_inplace_op) {
|
||||
if (!autograd_meta_)
|
||||
autograd_meta_ = impl::GetAutogradMetaFactory()->make();
|
||||
autograd_meta_->set_fw_grad(new_grad, self, level, is_inplace_op);
|
||||
}
|
||||
|
||||
|
|
@ -63,7 +72,11 @@ TensorImpl::TensorImpl(
|
|||
DispatchKeySet key_set,
|
||||
const caffe2::TypeMeta data_type)
|
||||
// Use std::forward to suppress static analyzer false positive.
|
||||
: TensorImpl(std::forward<Storage>(storage), key_set, data_type, storage.device()) {}
|
||||
: TensorImpl(
|
||||
std::forward<Storage>(storage),
|
||||
key_set,
|
||||
data_type,
|
||||
storage.device()) {}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
TensorImpl::TensorImpl(
|
||||
|
|
@ -84,23 +97,29 @@ TensorImpl::TensorImpl(
|
|||
}
|
||||
}
|
||||
|
||||
TensorImpl::TensorImpl(DispatchKeySet key_set, const caffe2::TypeMeta data_type, c10::optional<c10::Device> device_opt)
|
||||
TensorImpl::TensorImpl(
|
||||
DispatchKeySet key_set,
|
||||
const caffe2::TypeMeta data_type,
|
||||
c10::optional<c10::Device> device_opt)
|
||||
// NOLINTNEXTLINE(performance-move-const-arg)
|
||||
: TensorImpl({}, key_set, data_type, std::move(device_opt)) {}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
TensorImpl::TensorImpl(Storage&& storage, DispatchKeySet key_set, const caffe2::TypeMeta data_type,
|
||||
c10::optional<c10::Device> device_opt)
|
||||
TensorImpl::TensorImpl(
|
||||
Storage&& storage,
|
||||
DispatchKeySet key_set,
|
||||
const caffe2::TypeMeta data_type,
|
||||
c10::optional<c10::Device> device_opt)
|
||||
: storage_(std::move(storage)),
|
||||
storage_offset_(0),
|
||||
numel_(0),
|
||||
data_type_(data_type),
|
||||
device_opt_(device_opt) {
|
||||
|
||||
init_bitfields();
|
||||
|
||||
if (!key_set.empty()) {
|
||||
TORCH_INTERNAL_ASSERT(data_type == ScalarType::Undefined || device_opt_.has_value());
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
data_type == ScalarType::Undefined || device_opt_.has_value());
|
||||
// UndefinedTensorImpl is a singleton, so we skip logging it
|
||||
C10_LOG_API_USAGE_ONCE("tensor.create");
|
||||
}
|
||||
|
|
@ -115,12 +134,13 @@ TensorImpl::TensorImpl(Storage&& storage, DispatchKeySet key_set, const caffe2::
|
|||
|
||||
// Inference tensor doesn't have autograd related keys.
|
||||
if (inference_mode) {
|
||||
// See Note [Expected TLS state in InferenceMode] for why we exclude Autograd & InplaceOrView keys.
|
||||
// Normally key_set only contains backend keys but we do the substraction
|
||||
// here to make sure.
|
||||
// See Note [Expected TLS state in InferenceMode] for why we exclude
|
||||
// Autograd & InplaceOrView keys. Normally key_set only contains backend
|
||||
// keys but we do the substraction here to make sure.
|
||||
key_set_ = key_set - c10::autograd_dispatch_keyset_with_InplaceOrView;
|
||||
} else {
|
||||
// TODO: Ideally we only add AutogradBackend key when the tensor requires grad.
|
||||
// TODO: Ideally we only add AutogradBackend key when the tensor requires
|
||||
// grad.
|
||||
// See Note [Dream: skip VariableType kernel when requires_grad=false]
|
||||
key_set_ = key_set | getAutogradRelatedKeySetFromBackend(k);
|
||||
}
|
||||
|
|
@ -130,8 +150,8 @@ TensorImpl::TensorImpl(Storage&& storage, DispatchKeySet key_set, const caffe2::
|
|||
version_counter_ = VariableVersion(/*version=*/0);
|
||||
}
|
||||
|
||||
// we would also like to check that non-cpu devices have an index, but some Caffe2 operators create
|
||||
// Storages with default devices.
|
||||
// we would also like to check that non-cpu devices have an index, but some
|
||||
// Caffe2 operators create Storages with default devices.
|
||||
}
|
||||
|
||||
#ifndef C10_DISABLE_TENSORIMPL_EXTENSIBILITY
|
||||
|
|
@ -151,15 +171,14 @@ void TensorImpl::HandleResize() {
|
|||
if (reserved_) {
|
||||
// If tensor is reserved then don't claim its memeory unless nbytes()
|
||||
// is smaller than new size
|
||||
reset_tensor = storage_.nbytes() <
|
||||
(storage_offset_ + numel_) * data_type_.itemsize();
|
||||
reset_tensor =
|
||||
storage_.nbytes() < (storage_offset_ + numel_) * data_type_.itemsize();
|
||||
} else {
|
||||
reset_tensor = storage_.nbytes() <
|
||||
(storage_offset_ + numel_) * data_type_.itemsize() ||
|
||||
!FLAGS_caffe2_keep_on_shrink ||
|
||||
storage_.nbytes() -
|
||||
(storage_offset_ + numel_) * data_type_.itemsize() >
|
||||
static_cast<size_t>(FLAGS_caffe2_max_keep_on_shrink_memory);
|
||||
(storage_offset_ + numel_) * data_type_.itemsize() ||
|
||||
!FLAGS_caffe2_keep_on_shrink ||
|
||||
storage_.nbytes() - (storage_offset_ + numel_) * data_type_.itemsize() >
|
||||
static_cast<size_t>(FLAGS_caffe2_max_keep_on_shrink_memory);
|
||||
}
|
||||
|
||||
if (reset_tensor && storage_initialized()) {
|
||||
|
|
@ -190,20 +209,19 @@ bool TensorImpl::compute_channels_last_contiguous_2d() const {
|
|||
// Please don't combine these code, constant array is used here to let
|
||||
// compiler fully unroll the loop to get better performance
|
||||
switch (sizes_and_strides_.size()) {
|
||||
case 4:
|
||||
{
|
||||
int64_t expected = 1;
|
||||
for (auto& d : {1, 3, 2, 0}) {
|
||||
const auto size_d = sizes_and_strides_.size_at_unchecked(d);
|
||||
if (size_d != 1) {
|
||||
if (sizes_and_strides_.stride_at_unchecked(d) != expected) {
|
||||
return false;
|
||||
}
|
||||
expected *= size_d;
|
||||
case 4: {
|
||||
int64_t expected = 1;
|
||||
for (auto& d : {1, 3, 2, 0}) {
|
||||
const auto size_d = sizes_and_strides_.size_at_unchecked(d);
|
||||
if (size_d != 1) {
|
||||
if (sizes_and_strides_.stride_at_unchecked(d) != expected) {
|
||||
return false;
|
||||
}
|
||||
expected *= size_d;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
// NOLINTNEXTLINE(bugprone-branch-clone)
|
||||
case 3:
|
||||
// TODO dim == 3 case will be enabled once it is fully tested
|
||||
|
|
@ -218,20 +236,19 @@ bool TensorImpl::compute_channels_last_contiguous_3d() const {
|
|||
// compiler fully unroll the loop to get better performance
|
||||
switch (sizes_and_strides_.size()) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
case 5:
|
||||
{
|
||||
int64_t expected = 1;
|
||||
for (auto& d : {1, 4, 3, 2, 0}) {
|
||||
const auto size_d = sizes_and_strides_.size_at_unchecked(d);
|
||||
if (size_d != 1) {
|
||||
if (sizes_and_strides_.stride_at_unchecked(d) != expected) {
|
||||
return false;
|
||||
}
|
||||
expected *= size_d;
|
||||
case 5: {
|
||||
int64_t expected = 1;
|
||||
for (auto& d : {1, 4, 3, 2, 0}) {
|
||||
const auto size_d = sizes_and_strides_.size_at_unchecked(d);
|
||||
if (size_d != 1) {
|
||||
if (sizes_and_strides_.stride_at_unchecked(d) != expected) {
|
||||
return false;
|
||||
}
|
||||
expected *= size_d;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
// NOLINTNEXTLINE(bugprone-branch-clone)
|
||||
case 4:
|
||||
// TODO dim == 4 case will be enabled once it is fully tested
|
||||
|
|
@ -242,34 +259,38 @@ bool TensorImpl::compute_channels_last_contiguous_3d() const {
|
|||
}
|
||||
|
||||
bool TensorImpl::compute_strides_like_channels_last_2d() const {
|
||||
return is_channels_last_strides_2d(TensorImpl::sizes(), TensorImpl::strides());
|
||||
return is_channels_last_strides_2d(
|
||||
TensorImpl::sizes(), TensorImpl::strides());
|
||||
}
|
||||
|
||||
bool TensorImpl::compute_strides_like_channels_last_3d() const {
|
||||
return is_channels_last_strides_3d(TensorImpl::sizes(), TensorImpl::strides());
|
||||
return is_channels_last_strides_3d(
|
||||
TensorImpl::sizes(), TensorImpl::strides());
|
||||
}
|
||||
|
||||
bool TensorImpl::compute_non_overlapping_and_dense() const {
|
||||
if (dim() == 1) {
|
||||
return sizes_and_strides_.size_at_unchecked(0) < 2 || sizes_and_strides_.stride_at_unchecked(0) == 1;
|
||||
return sizes_and_strides_.size_at_unchecked(0) < 2 ||
|
||||
sizes_and_strides_.stride_at_unchecked(0) == 1;
|
||||
}
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
SmallVector<int64_t,5> perm;
|
||||
SmallVector<int64_t, 5> perm;
|
||||
perm.resize(dim());
|
||||
for (int64_t i = 0; i < dim(); i ++) {
|
||||
for (int64_t i = 0; i < dim(); i++) {
|
||||
perm[i] = i;
|
||||
}
|
||||
// Sort by strides, leaving 0 and 1 sized dims at the end of the array
|
||||
std::sort(perm.begin(), perm.end(), [&](int64_t a, int64_t b) {
|
||||
if (sizes_and_strides_.size_at_unchecked(a) < 2) {
|
||||
return false;
|
||||
} else if (sizes_and_strides_.size_at_unchecked(b) < 2) {
|
||||
return true;
|
||||
}
|
||||
return sizes_and_strides_.stride_at_unchecked(a) < sizes_and_strides_.stride_at_unchecked(b);
|
||||
if (sizes_and_strides_.size_at_unchecked(a) < 2) {
|
||||
return false;
|
||||
} else if (sizes_and_strides_.size_at_unchecked(b) < 2) {
|
||||
return true;
|
||||
}
|
||||
return sizes_and_strides_.stride_at_unchecked(a) <
|
||||
sizes_and_strides_.stride_at_unchecked(b);
|
||||
});
|
||||
auto require_stride = 1;
|
||||
for (int64_t i = 0; i < dim(); i ++) {
|
||||
for (int64_t i = 0; i < dim(); i++) {
|
||||
const auto size_perm_i = sizes_and_strides_.size_at_unchecked(perm[i]);
|
||||
if (size_perm_i < 2) {
|
||||
return true;
|
||||
|
|
@ -312,16 +333,23 @@ bool TensorImpl::has_storage() const {
|
|||
#endif
|
||||
|
||||
void TensorImpl::throw_storage_access_error() const {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "Cannot access storage of ", tensorimpl_type_name());
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false, "Cannot access storage of ", tensorimpl_type_name());
|
||||
}
|
||||
|
||||
bool TensorImpl::is_contiguous_nondefault_policy_impl(at::MemoryFormat memory_format) const {
|
||||
if (has_contiguity_ == static_cast<uint8_t>(HasContiguityPolicy::ContiguityNotSupported)) {
|
||||
bool TensorImpl::is_contiguous_nondefault_policy_impl(
|
||||
at::MemoryFormat memory_format) const {
|
||||
if (has_contiguity_ ==
|
||||
static_cast<uint8_t>(HasContiguityPolicy::ContiguityNotSupported)) {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false, "Tensors of type ", tensorimpl_type_name(),
|
||||
false,
|
||||
"Tensors of type ",
|
||||
tensorimpl_type_name(),
|
||||
" do not have is_contiguous");
|
||||
} else {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(has_contiguity_ == static_cast<uint8_t>(HasContiguityPolicy::CustomBehavior));
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
||||
has_contiguity_ ==
|
||||
static_cast<uint8_t>(HasContiguityPolicy::CustomBehavior));
|
||||
return is_contiguous_custom(memory_format);
|
||||
}
|
||||
}
|
||||
|
|
@ -343,10 +371,11 @@ at::DataPtr PlacementDeleteContext::makeDataPtr(
|
|||
size_t size,
|
||||
at::Device device) {
|
||||
auto* ptr = data_ptr.get();
|
||||
return {ptr,
|
||||
new PlacementDeleteContext(std::move(data_ptr), placement_dtor, size),
|
||||
&deletePlacementDeleteContext,
|
||||
device};
|
||||
return {
|
||||
ptr,
|
||||
new PlacementDeleteContext(std::move(data_ptr), placement_dtor, size),
|
||||
&deletePlacementDeleteContext,
|
||||
device};
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(modernize-use-equals-default)
|
||||
|
|
@ -359,10 +388,14 @@ AutogradMetaInterface::~AutogradMetaInterface() {}
|
|||
// used in C++ frontend. Forbidding it inside InferenceMode will force users
|
||||
// to delete these setter code in their code which is not ideal.
|
||||
void TensorImpl::set_requires_grad(bool requires_grad) {
|
||||
TORCH_CHECK(!(requires_grad && is_inference_tensor() && !c10::InferenceMode::is_enabled()),
|
||||
"Setting requires_grad=True on inference tensor outside InferenceMode is not allowed.");
|
||||
if (!requires_grad && !autograd_meta_) return;
|
||||
if (!autograd_meta_) autograd_meta_ = impl::GetAutogradMetaFactory()->make();
|
||||
TORCH_CHECK(
|
||||
!(requires_grad && is_inference_tensor() &&
|
||||
!c10::InferenceMode::is_enabled()),
|
||||
"Setting requires_grad=True on inference tensor outside InferenceMode is not allowed.");
|
||||
if (!requires_grad && !autograd_meta_)
|
||||
return;
|
||||
if (!autograd_meta_)
|
||||
autograd_meta_ = impl::GetAutogradMetaFactory()->make();
|
||||
// NB: In principle, setting requires_grad to false could result in
|
||||
// the AutogradMeta becoming equal to a default constructed state,
|
||||
// in which case we could apply the nullptr AutogradMeta optimization
|
||||
|
|
@ -376,11 +409,13 @@ void TensorImpl::set_requires_grad(bool requires_grad) {
|
|||
}
|
||||
|
||||
bool TensorImpl::requires_grad() const {
|
||||
if (!autograd_meta_) return false;
|
||||
if (!autograd_meta_)
|
||||
return false;
|
||||
return autograd_meta_->requires_grad();
|
||||
}
|
||||
|
||||
void TensorImpl::set_autograd_meta(std::unique_ptr<c10::AutogradMetaInterface> autograd_meta) {
|
||||
void TensorImpl::set_autograd_meta(
|
||||
std::unique_ptr<c10::AutogradMetaInterface> autograd_meta) {
|
||||
// NB: autograd_meta may be null! That just means it's the default
|
||||
// constructor
|
||||
autograd_meta_ = std::move(autograd_meta);
|
||||
|
|
@ -396,7 +431,9 @@ c10::intrusive_ptr<TensorImpl> TensorImpl::shallow_copy_and_detach(
|
|||
bool allow_tensor_metadata_change) const {
|
||||
auto impl = c10::make_intrusive<TensorImpl>(
|
||||
// No need to populate Storage; copy_tensor_metadata will do it for us.
|
||||
key_set_, data_type_, device_opt_);
|
||||
key_set_,
|
||||
data_type_,
|
||||
device_opt_);
|
||||
copy_tensor_metadata(
|
||||
/*src_impl=*/this,
|
||||
/*dest_impl=*/impl.get(),
|
||||
|
|
@ -412,7 +449,9 @@ c10::intrusive_ptr<TensorImpl> TensorImpl::shallow_copy_and_detach(
|
|||
bool allow_tensor_metadata_change) const {
|
||||
auto impl = c10::make_intrusive<TensorImpl>(
|
||||
// No need to populate Storage; copy_tensor_metadata will do it for us.
|
||||
key_set_, data_type_, device_opt_);
|
||||
key_set_,
|
||||
data_type_,
|
||||
device_opt_);
|
||||
copy_tensor_metadata(
|
||||
/*src_impl=*/this,
|
||||
/*dest_impl=*/impl.get(),
|
||||
|
|
@ -435,15 +474,19 @@ void TensorImpl::copy_tensor_metadata_except_version_counter(
|
|||
dest_impl->key_set_ = src_impl->key_set_;
|
||||
dest_impl->is_contiguous_ = src_impl->is_contiguous_;
|
||||
dest_impl->has_contiguity_ = src_impl->has_contiguity_;
|
||||
dest_impl->is_channels_last_contiguous_ = src_impl->is_channels_last_contiguous_;
|
||||
dest_impl->is_channels_last_3d_contiguous_ = src_impl->is_channels_last_3d_contiguous_;
|
||||
dest_impl->is_channels_last_contiguous_ =
|
||||
src_impl->is_channels_last_contiguous_;
|
||||
dest_impl->is_channels_last_3d_contiguous_ =
|
||||
src_impl->is_channels_last_3d_contiguous_;
|
||||
dest_impl->is_channels_last_ = src_impl->is_channels_last_;
|
||||
dest_impl->is_channels_last_3d_ = src_impl->is_channels_last_3d_;
|
||||
dest_impl->is_non_overlapping_and_dense_ = src_impl->is_non_overlapping_and_dense_;
|
||||
dest_impl->is_non_overlapping_and_dense_ =
|
||||
src_impl->is_non_overlapping_and_dense_;
|
||||
dest_impl->is_wrapped_number_ = src_impl->is_wrapped_number_;
|
||||
dest_impl->reserved_ = src_impl->reserved_;
|
||||
dest_impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
|
||||
dest_impl->storage_access_should_throw_ = src_impl->storage_access_should_throw_;
|
||||
dest_impl->storage_access_should_throw_ =
|
||||
src_impl->storage_access_should_throw_;
|
||||
if (src_impl->named_tensor_meta_ != nullptr) {
|
||||
dest_impl->named_tensor_meta_ = src_impl->named_tensor_meta_->clone();
|
||||
}
|
||||
|
|
@ -454,9 +497,11 @@ void TensorImpl::copy_tensor_metadata(
|
|||
TensorImpl* dest_impl,
|
||||
const c10::VariableVersion& version_counter,
|
||||
bool allow_tensor_metadata_change) {
|
||||
copy_tensor_metadata_except_version_counter(src_impl, dest_impl, allow_tensor_metadata_change);
|
||||
copy_tensor_metadata_except_version_counter(
|
||||
src_impl, dest_impl, allow_tensor_metadata_change);
|
||||
// TODO: In the ideal end state, it's okay to set disabled version_counter
|
||||
// on inference tensor since it's a no-op. This requires refactor on call sites.
|
||||
// on inference tensor since it's a no-op. This requires refactor on call
|
||||
// sites.
|
||||
if (!dest_impl->is_inference_tensor()) {
|
||||
dest_impl->set_version_counter(version_counter);
|
||||
}
|
||||
|
|
@ -467,7 +512,8 @@ void TensorImpl::copy_tensor_metadata(
|
|||
TensorImpl* dest_impl,
|
||||
c10::VariableVersion&& version_counter,
|
||||
bool allow_tensor_metadata_change) {
|
||||
copy_tensor_metadata_except_version_counter(src_impl, dest_impl, allow_tensor_metadata_change);
|
||||
copy_tensor_metadata_except_version_counter(
|
||||
src_impl, dest_impl, allow_tensor_metadata_change);
|
||||
if (!dest_impl->is_inference_tensor()) {
|
||||
dest_impl->set_version_counter(std::move(version_counter));
|
||||
}
|
||||
|
|
@ -478,13 +524,15 @@ namespace impl {
|
|||
namespace {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
AutogradMetaFactory* meta_factory = nullptr;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void SetAutogradMetaFactory(AutogradMetaFactory* factory) {
|
||||
meta_factory = factory;
|
||||
}
|
||||
AutogradMetaFactory* GetAutogradMetaFactory() {
|
||||
TORCH_CHECK(meta_factory, "Support for autograd has not been loaded; have you linked against libtorch.so?")
|
||||
TORCH_CHECK(
|
||||
meta_factory,
|
||||
"Support for autograd has not been loaded; have you linked against libtorch.so?")
|
||||
return meta_factory;
|
||||
}
|
||||
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -17,18 +17,20 @@ namespace c10 {
|
|||
// internal state and what its getters will return.
|
||||
|
||||
std::ostream& operator<<(std::ostream& stream, const TensorOptions& options) {
|
||||
|
||||
auto print = [&](const char *label, auto prop, bool has_prop) {
|
||||
auto print = [&](const char* label, auto prop, bool has_prop) {
|
||||
stream << label << std::boolalpha << prop << (has_prop ? "" : " (default)");
|
||||
};
|
||||
|
||||
print("TensorOptions(dtype=", options.dtype(), options.has_dtype());
|
||||
print(", device=", options.device(), options.has_device());
|
||||
print(", layout=", options.layout(), options.has_layout());
|
||||
print(", requires_grad=", options.requires_grad(), options.has_requires_grad());
|
||||
print(", pinned_memory=", options.pinned_memory(), options.has_pinned_memory());
|
||||
print(
|
||||
", requires_grad=", options.requires_grad(), options.has_requires_grad());
|
||||
print(
|
||||
", pinned_memory=", options.pinned_memory(), options.has_pinned_memory());
|
||||
|
||||
// note: default-supplying memory_format() getter not provided; no canonical default
|
||||
// note: default-supplying memory_format() getter not provided; no canonical
|
||||
// default
|
||||
stream << ", memory_format=";
|
||||
if (options.has_memory_format()) {
|
||||
stream << *options.memory_format_opt();
|
||||
|
|
|
|||
|
|
@ -1,17 +1,17 @@
|
|||
#pragma once
|
||||
|
||||
#include <c10/core/DefaultDtype.h>
|
||||
#include <c10/core/Backend.h>
|
||||
#include <c10/core/DefaultDtype.h>
|
||||
#include <c10/core/Device.h>
|
||||
#include <c10/core/DispatchKeySet.h>
|
||||
#include <c10/core/Layout.h>
|
||||
#include <c10/core/MemoryFormat.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <c10/core/ScalarTypeToTypeMeta.h>
|
||||
#include <c10/core/Device.h>
|
||||
#include <c10/core/MemoryFormat.h>
|
||||
#include <c10/core/DispatchKeySet.h>
|
||||
|
||||
#include <c10/util/Optional.h>
|
||||
#include <c10/util/C++17.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/C++17.h>
|
||||
#include <c10/util/Optional.h>
|
||||
|
||||
#include <cstddef>
|
||||
#include <iosfwd>
|
||||
|
|
@ -19,14 +19,18 @@
|
|||
|
||||
namespace c10 {
|
||||
|
||||
DispatchKey computeDispatchKey(c10::optional<ScalarType> dtype, c10::optional<Layout> layout, c10::optional<Device> device);
|
||||
DispatchKey computeDispatchKey(
|
||||
c10::optional<ScalarType> dtype,
|
||||
c10::optional<Layout> layout,
|
||||
c10::optional<Device> device);
|
||||
|
||||
inline ScalarType dtype_or_default(c10::optional<ScalarType> dtype) {
|
||||
return value_or_else(dtype, [] {return get_default_dtype_as_scalartype();});
|
||||
return value_or_else(dtype, [] { return get_default_dtype_as_scalartype(); });
|
||||
}
|
||||
|
||||
inline caffe2::TypeMeta dtype_or_default(c10::optional<caffe2::TypeMeta> dtype) {
|
||||
return value_or_else(dtype, [] {return get_default_dtype();});
|
||||
inline caffe2::TypeMeta dtype_or_default(
|
||||
c10::optional<caffe2::TypeMeta> dtype) {
|
||||
return value_or_else(dtype, [] { return get_default_dtype(); });
|
||||
}
|
||||
|
||||
inline Layout layout_or_default(c10::optional<Layout> layout) {
|
||||
|
|
@ -34,7 +38,7 @@ inline Layout layout_or_default(c10::optional<Layout> layout) {
|
|||
}
|
||||
|
||||
inline Device device_or_default(c10::optional<Device> device) {
|
||||
return value_or_else(device, [] {return Device(kCPU);});
|
||||
return value_or_else(device, [] { return Device(kCPU); });
|
||||
}
|
||||
|
||||
inline bool pinned_memory_or_default(c10::optional<bool> pinned_memory) {
|
||||
|
|
@ -65,7 +69,8 @@ inline bool pinned_memory_or_default(c10::optional<bool> pinned_memory) {
|
|||
/// at::dtype(at::kInt)
|
||||
///
|
||||
/// Additionally, anywhere a TensorOptions is expected, you can directly
|
||||
/// pass at::kCUDA / at::kInt, and it will implicitly convert to a TensorOptions.
|
||||
/// pass at::kCUDA / at::kInt, and it will implicitly convert to a
|
||||
/// TensorOptions.
|
||||
///
|
||||
/// Here are some recommended ways to create a 2x2 tensor of zeros
|
||||
/// with certain properties. These all *implicitly* make use of
|
||||
|
|
@ -108,7 +113,8 @@ inline bool pinned_memory_or_default(c10::optional<bool> pinned_memory) {
|
|||
/// }
|
||||
///
|
||||
/// template <typename... Args,
|
||||
/// typename = std::enable_if_t<std::is_constructible<Device, Args&&...>::value>>
|
||||
/// typename = std::enable_if_t<std::is_constructible<Device,
|
||||
/// Args&&...>::value>>
|
||||
/// /* implicit */ TensorOptions(Args&&... args)
|
||||
/// : TensorOptions(Device(std::forward<Args>(args)...)) {}
|
||||
///
|
||||
|
|
@ -121,20 +127,21 @@ inline bool pinned_memory_or_default(c10::optional<bool> pinned_memory) {
|
|||
/// To get around this, we templatize the `Device` constructor. Since overload
|
||||
/// resolution is done before template resolution, our problem is solved.
|
||||
|
||||
DispatchKey computeDispatchKey(optional<ScalarType> dtype, optional<Layout> layout, optional<Device> device);
|
||||
|
||||
DispatchKey computeDispatchKey(
|
||||
optional<ScalarType> dtype,
|
||||
optional<Layout> layout,
|
||||
optional<Device> device);
|
||||
|
||||
struct C10_API TensorOptions {
|
||||
TensorOptions()
|
||||
: requires_grad_(false)
|
||||
, pinned_memory_(false)
|
||||
, has_device_(false)
|
||||
, has_dtype_(false)
|
||||
, has_layout_(false)
|
||||
, has_requires_grad_(false)
|
||||
, has_pinned_memory_(false)
|
||||
, has_memory_format_(false)
|
||||
{}
|
||||
: requires_grad_(false),
|
||||
pinned_memory_(false),
|
||||
has_device_(false),
|
||||
has_dtype_(false),
|
||||
has_layout_(false),
|
||||
has_requires_grad_(false),
|
||||
has_pinned_memory_(false),
|
||||
has_memory_format_(false) {}
|
||||
|
||||
/// Constructs a `TensorOptions` object with the given layout.
|
||||
/* implicit */ TensorOptions(Layout layout) : TensorOptions() {
|
||||
|
|
@ -143,8 +150,9 @@ struct C10_API TensorOptions {
|
|||
|
||||
/// Constructs a `TensorOptions` object with the given device.
|
||||
/// See NOTE [ TensorOptions Constructors ] on why this is templatized.
|
||||
template<typename T,
|
||||
typename = std::enable_if_t<std::is_same<std::decay_t<T>, Device>::value>>
|
||||
template <
|
||||
typename T,
|
||||
typename = std::enable_if_t<std::is_same<std::decay_t<T>, Device>::value>>
|
||||
/* implicit */ TensorOptions(T&& device) : TensorOptions() {
|
||||
this->set_device(std::forward<T>(device));
|
||||
}
|
||||
|
|
@ -157,10 +165,12 @@ struct C10_API TensorOptions {
|
|||
/// NB: Ideally we only allow implicit constructors here. But there is no easy
|
||||
/// way to detect them. So we have this one that allows explicit
|
||||
/// constructors too.
|
||||
template <typename... Args,
|
||||
typename = std::enable_if_t<std::is_constructible<Device, Args&&...>::value>>
|
||||
/* implicit */ TensorOptions(Args&&... args)
|
||||
: TensorOptions(Device(std::forward<Args>(args)...)) {}
|
||||
template <
|
||||
typename... Args,
|
||||
typename =
|
||||
std::enable_if_t<std::is_constructible<Device, Args&&...>::value>>
|
||||
/* implicit */ TensorOptions(Args&&... args)
|
||||
: TensorOptions(Device(std::forward<Args>(args)...)) {}
|
||||
|
||||
/// Constructs a `TensorOptions` object with the given dtype.
|
||||
/* implicit */ TensorOptions(caffe2::TypeMeta dtype) : TensorOptions() {
|
||||
|
|
@ -179,7 +189,8 @@ struct C10_API TensorOptions {
|
|||
|
||||
/// Return a copy of `TensorOptions` with `device` set to the given one, or
|
||||
/// cleared if `device` is `nullopt`.
|
||||
C10_NODISCARD TensorOptions device(c10::optional<Device> device) const noexcept {
|
||||
C10_NODISCARD TensorOptions
|
||||
device(c10::optional<Device> device) const noexcept {
|
||||
TensorOptions r = *this;
|
||||
r.set_device(device);
|
||||
return r;
|
||||
|
|
@ -188,9 +199,10 @@ struct C10_API TensorOptions {
|
|||
/// Return a copy of `TensorOptions` with `device` set to the given one.
|
||||
/// (This overload ensures that variadic template c10::optional constructor
|
||||
/// for Device work correctly.)
|
||||
template<typename ... Args>
|
||||
template <typename... Args>
|
||||
C10_NODISCARD TensorOptions device(Args&&... args) const noexcept {
|
||||
return device(c10::optional<Device>(c10::in_place, std::forward<Args>(args)...));
|
||||
return device(
|
||||
c10::optional<Device>(c10::in_place, std::forward<Args>(args)...));
|
||||
}
|
||||
|
||||
/// Return a copy of `TensorOptions`, but with device set to CUDA, and the
|
||||
|
|
@ -198,19 +210,22 @@ struct C10_API TensorOptions {
|
|||
///
|
||||
/// TODO: This function encourages bad behavior (assuming CUDA is
|
||||
/// the only device that matters). Get rid of it / rename it.
|
||||
C10_NODISCARD TensorOptions device_index(int16_t device_index) const noexcept {
|
||||
C10_NODISCARD TensorOptions
|
||||
device_index(int16_t device_index) const noexcept {
|
||||
return device(Device::Type::CUDA, device_index);
|
||||
}
|
||||
|
||||
/// Return a copy of `TensorOptions` with `dtype` set to the given one.
|
||||
C10_NODISCARD TensorOptions dtype(c10::optional<caffe2::TypeMeta> dtype) const noexcept {
|
||||
C10_NODISCARD TensorOptions
|
||||
dtype(c10::optional<caffe2::TypeMeta> dtype) const noexcept {
|
||||
TensorOptions r = *this;
|
||||
r.set_dtype(dtype);
|
||||
return r;
|
||||
}
|
||||
|
||||
// legacy function to support ScalarType
|
||||
C10_NODISCARD TensorOptions dtype(c10::optional<ScalarType> dtype) const noexcept {
|
||||
C10_NODISCARD TensorOptions
|
||||
dtype(c10::optional<ScalarType> dtype) const noexcept {
|
||||
TensorOptions r = *this;
|
||||
r.set_dtype(dtype);
|
||||
return r;
|
||||
|
|
@ -225,28 +240,32 @@ struct C10_API TensorOptions {
|
|||
}
|
||||
|
||||
/// Sets the layout of the `TensorOptions`.
|
||||
C10_NODISCARD TensorOptions layout(c10::optional<Layout> layout) const noexcept {
|
||||
C10_NODISCARD TensorOptions
|
||||
layout(c10::optional<Layout> layout) const noexcept {
|
||||
TensorOptions r = *this;
|
||||
r.set_layout(layout);
|
||||
return r;
|
||||
}
|
||||
|
||||
/// Sets the `requires_grad` property of the `TensorOptions`.
|
||||
C10_NODISCARD TensorOptions requires_grad(c10::optional<bool> requires_grad) const noexcept {
|
||||
C10_NODISCARD TensorOptions
|
||||
requires_grad(c10::optional<bool> requires_grad) const noexcept {
|
||||
TensorOptions r = *this;
|
||||
r.set_requires_grad(requires_grad);
|
||||
return r;
|
||||
}
|
||||
|
||||
/// Sets the `pinned_memory` property on the `TensorOptions`.
|
||||
C10_NODISCARD TensorOptions pinned_memory(c10::optional<bool> pinned_memory) const noexcept {
|
||||
C10_NODISCARD TensorOptions
|
||||
pinned_memory(c10::optional<bool> pinned_memory) const noexcept {
|
||||
TensorOptions r = *this;
|
||||
r.set_pinned_memory(pinned_memory);
|
||||
return r;
|
||||
}
|
||||
|
||||
/// Sets the `memory_format` property on `TensorOptions`.
|
||||
C10_NODISCARD TensorOptions memory_format(c10::optional<MemoryFormat> memory_format) const noexcept {
|
||||
C10_NODISCARD TensorOptions
|
||||
memory_format(c10::optional<MemoryFormat> memory_format) const noexcept {
|
||||
TensorOptions r = *this;
|
||||
r.set_memory_format(memory_format);
|
||||
return r;
|
||||
|
|
@ -343,13 +362,15 @@ struct C10_API TensorOptions {
|
|||
|
||||
// For compatibility with legacy tensor.type() comparisons
|
||||
bool type_equal(const TensorOptions& other) const {
|
||||
return computeDispatchKey() == other.computeDispatchKey() && typeMetaToScalarType(dtype_) == typeMetaToScalarType(other.dtype());
|
||||
return computeDispatchKey() == other.computeDispatchKey() &&
|
||||
typeMetaToScalarType(dtype_) == typeMetaToScalarType(other.dtype());
|
||||
}
|
||||
|
||||
/// Returns the `pinned_memory` property of the `TensorOptions`, or
|
||||
/// `c10::nullopt` if `pinned_memory` is not specified.
|
||||
c10::optional<bool> pinned_memory_opt() const noexcept {
|
||||
return has_pinned_memory_ ? c10::make_optional(pinned_memory_) : c10::nullopt;
|
||||
return has_pinned_memory_ ? c10::make_optional(pinned_memory_)
|
||||
: c10::nullopt;
|
||||
}
|
||||
|
||||
/// Returns whether the `memory_layout` is specified
|
||||
|
|
@ -363,7 +384,8 @@ struct C10_API TensorOptions {
|
|||
/// Returns the `memory_layout` property of `TensorOptions, or
|
||||
/// `c10::nullopt` if `memory_format` is not specified.
|
||||
c10::optional<MemoryFormat> memory_format_opt() const noexcept {
|
||||
return has_memory_format_ ? c10::make_optional(memory_format_) : c10::nullopt;
|
||||
return has_memory_format_ ? c10::make_optional(memory_format_)
|
||||
: c10::nullopt;
|
||||
}
|
||||
|
||||
// Resolves the ATen backend specified by the current construction axes.
|
||||
|
|
@ -384,18 +406,25 @@ struct C10_API TensorOptions {
|
|||
///
|
||||
TensorOptions merge_in(TensorOptions options) const noexcept {
|
||||
TensorOptions merged = *this;
|
||||
if (options.has_device()) merged.set_device(options.device_opt());
|
||||
if (options.has_dtype()) merged.set_dtype(options.dtype_opt());
|
||||
if (options.has_layout()) merged.set_layout(options.layout_opt());
|
||||
if (options.has_device())
|
||||
merged.set_device(options.device_opt());
|
||||
if (options.has_dtype())
|
||||
merged.set_dtype(options.dtype_opt());
|
||||
if (options.has_layout())
|
||||
merged.set_layout(options.layout_opt());
|
||||
// NB: requires grad is right biased; not a logical AND/OR!
|
||||
if (options.has_requires_grad()) merged.set_requires_grad(options.requires_grad_opt());
|
||||
if (options.has_pinned_memory()) merged.set_pinned_memory(options.pinned_memory_opt());
|
||||
if (options.has_memory_format()) merged.set_memory_format(options.memory_format_opt());
|
||||
if (options.has_requires_grad())
|
||||
merged.set_requires_grad(options.requires_grad_opt());
|
||||
if (options.has_pinned_memory())
|
||||
merged.set_pinned_memory(options.pinned_memory_opt());
|
||||
if (options.has_memory_format())
|
||||
merged.set_memory_format(options.memory_format_opt());
|
||||
return merged;
|
||||
}
|
||||
|
||||
// TODO remove after TensorOptions rationalization
|
||||
TensorOptions merge_memory_format(c10::optional<MemoryFormat> optional_memory_format) const noexcept {
|
||||
TensorOptions merge_memory_format(
|
||||
c10::optional<MemoryFormat> optional_memory_format) const noexcept {
|
||||
TensorOptions merged = *this;
|
||||
if (optional_memory_format.has_value()) {
|
||||
merged.set_memory_format(*optional_memory_format);
|
||||
|
|
@ -408,11 +437,11 @@ struct C10_API TensorOptions {
|
|||
// the most part, this just means that this function never returns an
|
||||
// Autograd key)
|
||||
DispatchKey computeDispatchKey() const {
|
||||
return c10::computeDispatchKey(optTypeMetaToScalarType(dtype_opt()), layout_opt(), device_opt());
|
||||
return c10::computeDispatchKey(
|
||||
optTypeMetaToScalarType(dtype_opt()), layout_opt(), device_opt());
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
// These methods are currently private because I'm not sure if it's wise
|
||||
// to actually publish them. They are methods because I need them in
|
||||
// the constructor and the functional API implementation.
|
||||
|
|
@ -515,13 +544,12 @@ struct C10_API TensorOptions {
|
|||
// Bitmask required here to get this to fit inside 32 bits (or even 64 bits,
|
||||
// for that matter)
|
||||
|
||||
bool requires_grad_ : 1;
|
||||
bool pinned_memory_ : 1;
|
||||
bool requires_grad_ : 1;
|
||||
bool pinned_memory_ : 1;
|
||||
|
||||
|
||||
bool has_device_ : 1;
|
||||
bool has_dtype_ : 1;
|
||||
bool has_layout_ : 1;
|
||||
bool has_device_ : 1;
|
||||
bool has_dtype_ : 1;
|
||||
bool has_layout_ : 1;
|
||||
bool has_requires_grad_ : 1;
|
||||
bool has_pinned_memory_ : 1;
|
||||
bool has_memory_format_ : 1;
|
||||
|
|
@ -530,8 +558,9 @@ struct C10_API TensorOptions {
|
|||
// We should aspire to fit in one machine-size word; but a size greater than two
|
||||
// words is too much. (We are doing terribly on 32-bit archs, where we require
|
||||
// three machine size words to store tensor options. Eek!)
|
||||
static_assert( sizeof(TensorOptions) <= sizeof(int64_t) * 2,
|
||||
"TensorOptions must fit in 128-bits" );
|
||||
static_assert(
|
||||
sizeof(TensorOptions) <= sizeof(int64_t) * 2,
|
||||
"TensorOptions must fit in 128-bits");
|
||||
|
||||
/// Convenience function that returns a `TensorOptions` object with the `dtype`
|
||||
/// set to the given one.
|
||||
|
|
@ -591,88 +620,106 @@ inline std::string toString(const TensorOptions options) {
|
|||
|
||||
// This is intended to be a centralized location by which we can determine
|
||||
// what an appropriate DispatchKey for a tensor is.
|
||||
inline DispatchKey computeDispatchKey(c10::optional<ScalarType> dtype, c10::optional<Layout> layout, c10::optional<Device> device) {
|
||||
inline DispatchKey computeDispatchKey(
|
||||
c10::optional<ScalarType> dtype,
|
||||
c10::optional<Layout> layout,
|
||||
c10::optional<Device> device) {
|
||||
const auto layout_ = layout_or_default(layout);
|
||||
const auto device_ = device_or_default(device);
|
||||
switch (layout_) {
|
||||
case Layout::Strided: {
|
||||
const auto dtype_ = dtype_or_default(dtype);
|
||||
switch (device_.type()) {
|
||||
case DeviceType::CPU: {
|
||||
if (isQIntType(dtype_)) {
|
||||
return DispatchKey::QuantizedCPU;
|
||||
}
|
||||
return DispatchKey::CPU;
|
||||
case Layout::Strided: {
|
||||
const auto dtype_ = dtype_or_default(dtype);
|
||||
switch (device_.type()) {
|
||||
case DeviceType::CPU: {
|
||||
if (isQIntType(dtype_)) {
|
||||
return DispatchKey::QuantizedCPU;
|
||||
}
|
||||
case DeviceType::CUDA: {
|
||||
if (isQIntType(dtype_)) {
|
||||
return DispatchKey::QuantizedCUDA;
|
||||
}
|
||||
return DispatchKey::CUDA;
|
||||
}
|
||||
case DeviceType::XPU: {
|
||||
if (isQIntType(dtype_)) {
|
||||
return DispatchKey::QuantizedXPU;
|
||||
}
|
||||
return DispatchKey::XPU;
|
||||
}
|
||||
case DeviceType::MKLDNN:
|
||||
case DeviceType::OPENGL:
|
||||
case DeviceType::OPENCL:
|
||||
case DeviceType::IDEEP:
|
||||
TORCH_INTERNAL_ASSERT(0, "This is a grandfathered Caffe2 device type ", device_.type(), ", it shouldn't ever convert to a DispatchKey. File a bug describing what you were doing if you think this is in error.");
|
||||
case DeviceType::HIP:
|
||||
return DispatchKey::HIP;
|
||||
case DeviceType::FPGA:
|
||||
return DispatchKey::FPGA;
|
||||
case DeviceType::MSNPU:
|
||||
return DispatchKey::MSNPU;
|
||||
case DeviceType::XLA:
|
||||
return DispatchKey::XLA;
|
||||
case DeviceType::MLC:
|
||||
return DispatchKey::MLC;
|
||||
case DeviceType::Vulkan:
|
||||
return DispatchKey::Vulkan;
|
||||
case DeviceType::Metal:
|
||||
return DispatchKey::Metal;
|
||||
case DeviceType::Meta:
|
||||
return DispatchKey::Meta;
|
||||
default:
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "Unsupported device type for dense layout: ", device_.type());
|
||||
return DispatchKey::CPU;
|
||||
}
|
||||
case DeviceType::CUDA: {
|
||||
if (isQIntType(dtype_)) {
|
||||
return DispatchKey::QuantizedCUDA;
|
||||
}
|
||||
return DispatchKey::CUDA;
|
||||
}
|
||||
case DeviceType::XPU: {
|
||||
if (isQIntType(dtype_)) {
|
||||
return DispatchKey::QuantizedXPU;
|
||||
}
|
||||
return DispatchKey::XPU;
|
||||
}
|
||||
case DeviceType::MKLDNN:
|
||||
case DeviceType::OPENGL:
|
||||
case DeviceType::OPENCL:
|
||||
case DeviceType::IDEEP:
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
0,
|
||||
"This is a grandfathered Caffe2 device type ",
|
||||
device_.type(),
|
||||
", it shouldn't ever convert to a DispatchKey. File a bug describing what you were doing if you think this is in error.");
|
||||
case DeviceType::HIP:
|
||||
return DispatchKey::HIP;
|
||||
case DeviceType::FPGA:
|
||||
return DispatchKey::FPGA;
|
||||
case DeviceType::MSNPU:
|
||||
return DispatchKey::MSNPU;
|
||||
case DeviceType::XLA:
|
||||
return DispatchKey::XLA;
|
||||
case DeviceType::MLC:
|
||||
return DispatchKey::MLC;
|
||||
case DeviceType::Vulkan:
|
||||
return DispatchKey::Vulkan;
|
||||
case DeviceType::Metal:
|
||||
return DispatchKey::Metal;
|
||||
case DeviceType::Meta:
|
||||
return DispatchKey::Meta;
|
||||
default:
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"Unsupported device type for dense layout: ",
|
||||
device_.type());
|
||||
}
|
||||
case Layout::Sparse:
|
||||
switch (device_.type()) {
|
||||
case DeviceType::CPU:
|
||||
return DispatchKey::SparseCPU;
|
||||
case DeviceType::CUDA:
|
||||
return DispatchKey::SparseCUDA;
|
||||
case DeviceType::HIP:
|
||||
return DispatchKey::SparseHIP;
|
||||
case DeviceType::XPU:
|
||||
return DispatchKey::SparseXPU;
|
||||
default:
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "Unsupported device type for sparse layout: ", device_.type());
|
||||
}
|
||||
case Layout::Mkldnn:
|
||||
switch (device_.type()) {
|
||||
case DeviceType::CPU:
|
||||
return DispatchKey::MkldnnCPU;
|
||||
default:
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "Unsupported device type for mkldnn layout: ", device_.type());
|
||||
}
|
||||
case Layout::SparseCsr:
|
||||
switch(device_.type()) {
|
||||
case DeviceType::CPU:
|
||||
return DispatchKey::SparseCsrCPU;
|
||||
case DeviceType::CUDA:
|
||||
return DispatchKey::SparseCsrCUDA;
|
||||
default:
|
||||
AT_ERROR("Unsupported device type for sparse CSR layout: ", device_.type());
|
||||
}
|
||||
default:
|
||||
TORCH_CHECK(false, "Unsupported layout: ", layout_);
|
||||
}
|
||||
case Layout::Sparse:
|
||||
switch (device_.type()) {
|
||||
case DeviceType::CPU:
|
||||
return DispatchKey::SparseCPU;
|
||||
case DeviceType::CUDA:
|
||||
return DispatchKey::SparseCUDA;
|
||||
case DeviceType::HIP:
|
||||
return DispatchKey::SparseHIP;
|
||||
case DeviceType::XPU:
|
||||
return DispatchKey::SparseXPU;
|
||||
default:
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"Unsupported device type for sparse layout: ",
|
||||
device_.type());
|
||||
}
|
||||
case Layout::Mkldnn:
|
||||
switch (device_.type()) {
|
||||
case DeviceType::CPU:
|
||||
return DispatchKey::MkldnnCPU;
|
||||
default:
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"Unsupported device type for mkldnn layout: ",
|
||||
device_.type());
|
||||
}
|
||||
case Layout::SparseCsr:
|
||||
switch (device_.type()) {
|
||||
case DeviceType::CPU:
|
||||
return DispatchKey::SparseCsrCPU;
|
||||
case DeviceType::CUDA:
|
||||
return DispatchKey::SparseCsrCUDA;
|
||||
default:
|
||||
AT_ERROR(
|
||||
"Unsupported device type for sparse CSR layout: ",
|
||||
device_.type());
|
||||
}
|
||||
default:
|
||||
TORCH_CHECK(false, "Unsupported layout: ", layout_);
|
||||
}
|
||||
}
|
||||
|
||||
inline Layout dispatchKeyToLayout(DispatchKey dispatch_key) {
|
||||
|
|
@ -692,7 +739,7 @@ inline Layout dispatchKeyToLayout(DispatchKey dispatch_key) {
|
|||
}
|
||||
|
||||
inline DeviceType dispatchKeyToDeviceType(DispatchKey dispatch_key) {
|
||||
switch(dispatch_key) {
|
||||
switch (dispatch_key) {
|
||||
// stuff that's real
|
||||
case DispatchKey::CPU:
|
||||
case DispatchKey::SparseCPU:
|
||||
|
|
@ -730,14 +777,18 @@ inline DeviceType dispatchKeyToDeviceType(DispatchKey dispatch_key) {
|
|||
case DispatchKey::MSNPU:
|
||||
return DeviceType::MSNPU;
|
||||
default:
|
||||
TORCH_CHECK(false, "DispatchKey ", dispatch_key, " doesn't correspond to a device");
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"DispatchKey ",
|
||||
dispatch_key,
|
||||
" doesn't correspond to a device");
|
||||
}
|
||||
}
|
||||
|
||||
inline TensorOptions dispatchKeyToTensorOptions(DispatchKey dispatch_key) {
|
||||
return TensorOptions()
|
||||
.layout(dispatchKeyToLayout(dispatch_key))
|
||||
.device(dispatchKeyToDeviceType(dispatch_key));
|
||||
.layout(dispatchKeyToLayout(dispatch_key))
|
||||
.device(dispatchKeyToDeviceType(dispatch_key));
|
||||
}
|
||||
|
||||
} // namespace c10
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ namespace c10 {
|
|||
|
||||
// should this use the globalContext? Can it get a context passed in somehow?
|
||||
UndefinedTensorImpl::UndefinedTensorImpl()
|
||||
: TensorImpl(DispatchKey::Undefined, caffe2::TypeMeta(), c10::nullopt) {
|
||||
: TensorImpl(DispatchKey::Undefined, caffe2::TypeMeta(), c10::nullopt) {
|
||||
set_storage_access_should_throw();
|
||||
}
|
||||
|
||||
|
|
@ -19,7 +19,8 @@ int64_t UndefinedTensorImpl::stride(int64_t d) const {
|
|||
|
||||
#ifdef DEBUG
|
||||
bool UndefinedTensorImpl::has_storage() const {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!storage_, "UndefinedTensorImpl assumes that storage_ is never set");
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
||||
!storage_, "UndefinedTensorImpl assumes that storage_ is never set");
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
|
|
@ -39,4 +40,4 @@ const char* UndefinedTensorImpl::tensorimpl_type_name() const {
|
|||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
UndefinedTensorImpl UndefinedTensorImpl::_singleton;
|
||||
|
||||
}
|
||||
} // namespace c10
|
||||
|
|
|
|||
|
|
@ -7,13 +7,14 @@ namespace c10 {
|
|||
struct C10_API UndefinedTensorImpl final : public TensorImpl {
|
||||
public:
|
||||
// Without this, we get:
|
||||
// error: identifier "at::UndefinedTensorImpl::_singleton" is undefined in device code
|
||||
// error: identifier "at::UndefinedTensorImpl::_singleton" is undefined in
|
||||
// device code
|
||||
// (ostensibly because the constexpr tricks MSVC into trying to compile this
|
||||
// function for device as well).
|
||||
#ifdef _WIN32
|
||||
static inline TensorImpl * singleton() {
|
||||
static inline TensorImpl* singleton() {
|
||||
#else
|
||||
static constexpr inline TensorImpl * singleton() {
|
||||
static constexpr inline TensorImpl* singleton() {
|
||||
#endif
|
||||
return &_singleton;
|
||||
}
|
||||
|
|
@ -24,7 +25,8 @@ struct C10_API UndefinedTensorImpl final : public TensorImpl {
|
|||
bool has_storage() const override;
|
||||
#endif
|
||||
void set_storage_offset(int64_t offset) override;
|
||||
private:
|
||||
|
||||
private:
|
||||
UndefinedTensorImpl();
|
||||
static UndefinedTensorImpl _singleton;
|
||||
const char* tensorimpl_type_name() const override;
|
||||
|
|
|
|||
|
|
@ -4,10 +4,17 @@
|
|||
|
||||
namespace c10 {
|
||||
|
||||
static inline int64_t maybe_wrap_dim(int64_t dim, int64_t dim_post_expr, bool wrap_scalar=true) {
|
||||
static inline int64_t maybe_wrap_dim(
|
||||
int64_t dim,
|
||||
int64_t dim_post_expr,
|
||||
bool wrap_scalar = true) {
|
||||
if (dim_post_expr <= 0) {
|
||||
if (!wrap_scalar) {
|
||||
TORCH_CHECK_INDEX(false, "dimension specified as ", dim, " but tensor has no dimensions");
|
||||
TORCH_CHECK_INDEX(
|
||||
false,
|
||||
"dimension specified as ",
|
||||
dim,
|
||||
" but tensor has no dimensions");
|
||||
}
|
||||
dim_post_expr = 1; // this will make range [-1, 0]
|
||||
}
|
||||
|
|
@ -15,12 +22,19 @@ static inline int64_t maybe_wrap_dim(int64_t dim, int64_t dim_post_expr, bool wr
|
|||
int64_t min = -dim_post_expr;
|
||||
int64_t max = dim_post_expr - 1;
|
||||
if (dim < min || dim > max) {
|
||||
TORCH_CHECK_INDEX(false,
|
||||
"Dimension out of range (expected to be in range of [",
|
||||
min, ", ", max, "], but got ", dim, ")");
|
||||
TORCH_CHECK_INDEX(
|
||||
false,
|
||||
"Dimension out of range (expected to be in range of [",
|
||||
min,
|
||||
", ",
|
||||
max,
|
||||
"], but got ",
|
||||
dim,
|
||||
")");
|
||||
}
|
||||
if (dim < 0) dim += dim_post_expr;
|
||||
if (dim < 0)
|
||||
dim += dim_post_expr;
|
||||
return dim;
|
||||
}
|
||||
|
||||
}
|
||||
} // namespace c10
|
||||
|
|
|
|||
|
|
@ -5,11 +5,15 @@ namespace impl {
|
|||
|
||||
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
|
||||
std::atomic<const DeviceGuardImplInterface*>
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
device_guard_impl_registry[static_cast<size_t>(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES)];
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
device_guard_impl_registry[static_cast<size_t>(
|
||||
DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES)];
|
||||
|
||||
DeviceGuardImplRegistrar::DeviceGuardImplRegistrar(DeviceType type, const DeviceGuardImplInterface* impl) {
|
||||
DeviceGuardImplRegistrar::DeviceGuardImplRegistrar(
|
||||
DeviceType type,
|
||||
const DeviceGuardImplInterface* impl) {
|
||||
device_guard_impl_registry[static_cast<size_t>(type)].store(impl);
|
||||
}
|
||||
|
||||
}} // namespace c10::impl
|
||||
} // namespace impl
|
||||
} // namespace c10
|
||||
|
|
|
|||
|
|
@ -30,16 +30,16 @@ class DataPtr;
|
|||
* should map one-to-one with actual event flags for those backends.
|
||||
*/
|
||||
enum class EventFlag {
|
||||
PYTORCH_DEFAULT,
|
||||
BACKEND_DEFAULT,
|
||||
// CUDA flags
|
||||
CUDA_EVENT_DEFAULT,
|
||||
CUDA_EVENT_DISABLE_TIMING, // PyTorch-default for CUDA
|
||||
// HIP flags
|
||||
HIP_EVENT_DEFAULT,
|
||||
HIP_EVENT_DISABLE_TIMING, // PyTorch-default for HIP
|
||||
// FOR TESTING ONLY
|
||||
INVALID
|
||||
PYTORCH_DEFAULT,
|
||||
BACKEND_DEFAULT,
|
||||
// CUDA flags
|
||||
CUDA_EVENT_DEFAULT,
|
||||
CUDA_EVENT_DISABLE_TIMING, // PyTorch-default for CUDA
|
||||
// HIP flags
|
||||
HIP_EVENT_DEFAULT,
|
||||
HIP_EVENT_DISABLE_TIMING, // PyTorch-default for HIP
|
||||
// FOR TESTING ONLY
|
||||
INVALID
|
||||
};
|
||||
|
||||
namespace impl {
|
||||
|
|
@ -126,47 +126,44 @@ struct C10_API DeviceGuardImplInterface {
|
|||
*/
|
||||
virtual Stream exchangeStream(Stream) const noexcept = 0;
|
||||
|
||||
/**
|
||||
* Destroys the given event.
|
||||
*/
|
||||
virtual void destroyEvent (
|
||||
void* event,
|
||||
const DeviceIndex device_index) const noexcept { }
|
||||
/**
|
||||
* Destroys the given event.
|
||||
*/
|
||||
virtual void destroyEvent(void* event, const DeviceIndex device_index)
|
||||
const noexcept {}
|
||||
|
||||
/**
|
||||
* Increments the event's version and enqueues a job with this version
|
||||
* in the stream's work queue. When the stream process that job
|
||||
* it notifies all streams waiting on / blocked by that version of the
|
||||
* event to continue and marks that version as recorded.
|
||||
* */
|
||||
/**
|
||||
* Increments the event's version and enqueues a job with this version
|
||||
* in the stream's work queue. When the stream process that job
|
||||
* it notifies all streams waiting on / blocked by that version of the
|
||||
* event to continue and marks that version as recorded.
|
||||
* */
|
||||
virtual void record(
|
||||
void** event,
|
||||
const Stream& stream,
|
||||
const DeviceIndex device_index,
|
||||
const c10::EventFlag flag) const {
|
||||
void** event,
|
||||
const Stream& stream,
|
||||
const DeviceIndex device_index,
|
||||
const c10::EventFlag flag) const {
|
||||
TORCH_CHECK(false, "Backend doesn't support events.");
|
||||
}
|
||||
|
||||
/**
|
||||
* Does nothing if the event has not been scheduled to be recorded.
|
||||
* If the event was previously enqueued to be recorded, a command
|
||||
* to wait for the version of the event that exists at the time of this call
|
||||
* is inserted in the stream's work queue.
|
||||
* When the stream reaches this command it will stop processing
|
||||
* additional commands until that version of the event is marked as recorded.
|
||||
*/
|
||||
virtual void block(
|
||||
void* event,
|
||||
const Stream& stream) const {
|
||||
/**
|
||||
* Does nothing if the event has not been scheduled to be recorded.
|
||||
* If the event was previously enqueued to be recorded, a command
|
||||
* to wait for the version of the event that exists at the time of this call
|
||||
* is inserted in the stream's work queue.
|
||||
* When the stream reaches this command it will stop processing
|
||||
* additional commands until that version of the event is marked as recorded.
|
||||
*/
|
||||
virtual void block(void* event, const Stream& stream) const {
|
||||
TORCH_CHECK(false, "Backend doesn't support events.");
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns true if (and only if)
|
||||
* (1) the event has never been scheduled to be recorded
|
||||
* (2) the current version is marked as recorded.
|
||||
* Returns false otherwise.
|
||||
*/
|
||||
/**
|
||||
* Returns true if (and only if)
|
||||
* (1) the event has never been scheduled to be recorded
|
||||
* (2) the current version is marked as recorded.
|
||||
* Returns false otherwise.
|
||||
*/
|
||||
virtual bool queryEvent(void* event) const {
|
||||
TORCH_CHECK(false, "Backend doesn't support events.");
|
||||
}
|
||||
|
|
@ -178,13 +175,13 @@ struct C10_API DeviceGuardImplInterface {
|
|||
*/
|
||||
virtual DeviceIndex deviceCount() const noexcept = 0;
|
||||
|
||||
|
||||
/**
|
||||
* Ensure the caching allocator (if any) is aware that the given DataPtr is
|
||||
* being used on the given stream, and that it should thus avoid recycling the
|
||||
* DataPtr until all work on that stream is done.
|
||||
*/
|
||||
virtual void recordDataPtrOnStream(const c10::DataPtr&, const Stream&) const { }
|
||||
virtual void recordDataPtrOnStream(const c10::DataPtr&, const Stream&) const {
|
||||
}
|
||||
|
||||
/**
|
||||
* Intended use of this class is to leak the DeviceGuardImpl at program end.
|
||||
|
|
@ -228,23 +225,21 @@ struct NoOpDeviceGuardImpl final : public DeviceGuardImplInterface {
|
|||
}
|
||||
|
||||
// Event-related functions
|
||||
void record(void** event,
|
||||
const Stream& stream,
|
||||
const DeviceIndex device_index,
|
||||
const EventFlag flag) const override {
|
||||
void record(
|
||||
void** event,
|
||||
const Stream& stream,
|
||||
const DeviceIndex device_index,
|
||||
const EventFlag flag) const override {
|
||||
TORCH_CHECK(false, D, " backend doesn't support events.");
|
||||
}
|
||||
void block(
|
||||
void* event,
|
||||
const Stream& stream) const override {
|
||||
void block(void* event, const Stream& stream) const override {
|
||||
TORCH_CHECK(false, D, " backend doesn't support events.")
|
||||
}
|
||||
bool queryEvent(void* event) const override {
|
||||
TORCH_CHECK(false, D, " backend doesn't support events.")
|
||||
}
|
||||
void destroyEvent(
|
||||
void* event,
|
||||
const DeviceIndex device_index) const noexcept override { }
|
||||
void destroyEvent(void* event, const DeviceIndex device_index)
|
||||
const noexcept override {}
|
||||
};
|
||||
|
||||
// The registry is NON-owning. Each stored pointer is std::atomic so
|
||||
|
|
@ -262,7 +257,8 @@ struct NoOpDeviceGuardImpl final : public DeviceGuardImplInterface {
|
|||
// putting them in the registry. This is done by deleting the destructor
|
||||
// on DeviceGuardImplInterface.
|
||||
extern C10_API std::atomic<const DeviceGuardImplInterface*>
|
||||
device_guard_impl_registry[static_cast<size_t>(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES)];
|
||||
device_guard_impl_registry[static_cast<size_t>(
|
||||
DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES)];
|
||||
|
||||
// I can't conveniently use c10/util/Registry.h for the following reason:
|
||||
// c10/util/Registry.h gives me a slow way of Create'ing a object of some
|
||||
|
|
@ -273,12 +269,13 @@ device_guard_impl_registry[static_cast<size_t>(DeviceType::COMPILE_TIME_MAX_DEVI
|
|||
// into device_guard_impl_registry.
|
||||
|
||||
class C10_API DeviceGuardImplRegistrar {
|
||||
public:
|
||||
public:
|
||||
DeviceGuardImplRegistrar(DeviceType, const DeviceGuardImplInterface*);
|
||||
};
|
||||
|
||||
#define C10_REGISTER_GUARD_IMPL(DevType, DeviceGuardImpl) \
|
||||
static ::c10::impl::DeviceGuardImplRegistrar C10_ANONYMOUS_VARIABLE(g_##DeviceType)(::c10::DeviceType::DevType, new DeviceGuardImpl());
|
||||
#define C10_REGISTER_GUARD_IMPL(DevType, DeviceGuardImpl) \
|
||||
static ::c10::impl::DeviceGuardImplRegistrar C10_ANONYMOUS_VARIABLE( \
|
||||
g_##DeviceType)(::c10::DeviceType::DevType, new DeviceGuardImpl());
|
||||
|
||||
inline const DeviceGuardImplInterface* getDeviceGuardImpl(DeviceType type) {
|
||||
// Two adjacent int16_t fields DeviceType and DeviceIndex has field access
|
||||
|
|
@ -301,4 +298,5 @@ inline bool hasDeviceGuardImpl(DeviceType type) {
|
|||
return device_guard_impl_registry[static_cast<size_t>(type)].load();
|
||||
}
|
||||
|
||||
}} // namespace c10::impl
|
||||
} // namespace impl
|
||||
} // namespace c10
|
||||
|
|
|
|||
|
|
@ -60,17 +60,16 @@ struct FakeGuardImpl final : public DeviceGuardImplInterface {
|
|||
|
||||
// Event-related functions
|
||||
void record(
|
||||
void** event,
|
||||
const Stream& stream,
|
||||
const DeviceIndex device_index,
|
||||
const EventFlag flag) const override { }
|
||||
void block(
|
||||
void* event,
|
||||
const Stream& stream) const override { }
|
||||
bool queryEvent(void* event) const override { return true; }
|
||||
void destroyEvent(
|
||||
void* event,
|
||||
const DeviceIndex device_index) const noexcept override { }
|
||||
void** event,
|
||||
const Stream& stream,
|
||||
const DeviceIndex device_index,
|
||||
const EventFlag flag) const override {}
|
||||
void block(void* event, const Stream& stream) const override {}
|
||||
bool queryEvent(void* event) const override {
|
||||
return true;
|
||||
}
|
||||
void destroyEvent(void* event, const DeviceIndex device_index)
|
||||
const noexcept override {}
|
||||
|
||||
// Convenience methods for testing
|
||||
static DeviceIndex getDeviceIndex() {
|
||||
|
|
@ -87,9 +86,11 @@ struct FakeGuardImpl final : public DeviceGuardImplInterface {
|
|||
static void resetStreams() {
|
||||
current_streams_.fill(0);
|
||||
}
|
||||
private:
|
||||
|
||||
private:
|
||||
thread_local static DeviceIndex current_device_;
|
||||
thread_local static std::array<StreamId, kFakeGuardImplMaxDevices> current_streams_;
|
||||
thread_local static std::array<StreamId, kFakeGuardImplMaxDevices>
|
||||
current_streams_;
|
||||
};
|
||||
|
||||
template <DeviceType T>
|
||||
|
|
@ -99,7 +100,8 @@ template <DeviceType T>
|
|||
constexpr DeviceType FakeGuardImpl<T>::static_type;
|
||||
|
||||
template <DeviceType T>
|
||||
thread_local std::array<StreamId, kFakeGuardImplMaxDevices> FakeGuardImpl<T>::current_streams_ = {0,0,0,0,0,0,0,0};
|
||||
thread_local std::array<StreamId, kFakeGuardImplMaxDevices>
|
||||
FakeGuardImpl<T>::current_streams_ = {0, 0, 0, 0, 0, 0, 0, 0};
|
||||
|
||||
|
||||
}} // namespace c10::impl
|
||||
} // namespace impl
|
||||
} // namespace c10
|
||||
|
|
|
|||
|
|
@ -1,18 +1,17 @@
|
|||
#pragma once
|
||||
|
||||
// This file provides implementations of InlineDeviceGuard and InlineOptionalDeviceGuard.
|
||||
// This file provides implementations of InlineDeviceGuard and
|
||||
// InlineOptionalDeviceGuard.
|
||||
|
||||
#include <c10/core/Device.h>
|
||||
#include <c10/core/impl/DeviceGuardImplInterface.h>
|
||||
#include <c10/core/impl/VirtualGuardImpl.h>
|
||||
#include <c10/util/Optional.h>
|
||||
#include <c10/util/C++17.h>
|
||||
#include <c10/util/Optional.h>
|
||||
|
||||
namespace c10 {
|
||||
namespace impl {
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* A DeviceGuard is an RAII class that sets a device to some value
|
||||
* on construction, and resets the device to its original value on
|
||||
|
|
@ -54,7 +53,7 @@ namespace impl {
|
|||
*/
|
||||
template <typename T>
|
||||
class InlineDeviceGuard {
|
||||
public:
|
||||
public:
|
||||
// Note [Omitted default constructor from RAII]
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
// In principle, we could add a default constructor to
|
||||
|
|
@ -69,25 +68,36 @@ public:
|
|||
|
||||
/// Set the current device to the passed Device.
|
||||
explicit InlineDeviceGuard(Device device)
|
||||
: impl_(device.type())
|
||||
, original_device_(device.index() == -1 ? impl_.getDevice() : impl_.exchangeDevice(device))
|
||||
, current_device_(device.index() == -1 ? original_device_ : device)
|
||||
{}
|
||||
: impl_(device.type()),
|
||||
original_device_(
|
||||
device.index() == -1 ? impl_.getDevice()
|
||||
: impl_.exchangeDevice(device)),
|
||||
current_device_(device.index() == -1 ? original_device_ : device) {}
|
||||
|
||||
/// Set the current device index to the passed DeviceIndex. (The
|
||||
/// device type is inferred from the template parameter T).
|
||||
template <typename U=T, typename=typename std::enable_if<!std::is_same<U, VirtualGuardImpl>::value>::type>
|
||||
template <
|
||||
typename U = T,
|
||||
typename = typename std::enable_if<
|
||||
!std::is_same<U, VirtualGuardImpl>::value>::type>
|
||||
explicit InlineDeviceGuard(DeviceIndex device_index)
|
||||
: InlineDeviceGuard(Device(U::static_type, device_index)) {}
|
||||
: InlineDeviceGuard(Device(U::static_type, device_index)) {}
|
||||
|
||||
/// Construct an InlineDeviceGuard using VirtualGuardImpl with an explicit
|
||||
/// DeviceGuardImplInterface pointer.
|
||||
template <typename U=T, typename=typename std::enable_if<std::is_same<U, VirtualGuardImpl>::value>::type>
|
||||
explicit InlineDeviceGuard(Device device, const DeviceGuardImplInterface* impl)
|
||||
: impl_(VirtualGuardImpl(impl ? impl : getDeviceGuardImpl(device.type())))
|
||||
, original_device_(device.index() == -1 ? impl_.getDevice() : impl_.exchangeDevice(device))
|
||||
, current_device_(device.index() == -1 ? original_device_ : device)
|
||||
{}
|
||||
template <
|
||||
typename U = T,
|
||||
typename = typename std::enable_if<
|
||||
std::is_same<U, VirtualGuardImpl>::value>::type>
|
||||
explicit InlineDeviceGuard(
|
||||
Device device,
|
||||
const DeviceGuardImplInterface* impl)
|
||||
: impl_(
|
||||
VirtualGuardImpl(impl ? impl : getDeviceGuardImpl(device.type()))),
|
||||
original_device_(
|
||||
device.index() == -1 ? impl_.getDevice()
|
||||
: impl_.exchangeDevice(device)),
|
||||
current_device_(device.index() == -1 ? original_device_ : device) {}
|
||||
|
||||
/// Copy is disallowed
|
||||
InlineDeviceGuard(const InlineDeviceGuard<T>&) = delete;
|
||||
|
|
@ -103,12 +113,17 @@ public:
|
|||
}
|
||||
|
||||
/// Sets the device to the given one.
|
||||
template <typename U=T, typename std::enable_if<!std::is_same<U, VirtualGuardImpl>::value, int>::type = 0>
|
||||
template <
|
||||
typename U = T,
|
||||
typename std::enable_if<!std::is_same<U, VirtualGuardImpl>::value, int>::
|
||||
type = 0>
|
||||
void set_device(at::Device device) {
|
||||
AT_ASSERT((U::static_type == DeviceType::HIP && device.is_cuda()) ||
|
||||
device.type() == U::static_type);
|
||||
AT_ASSERT(
|
||||
(U::static_type == DeviceType::HIP && device.is_cuda()) ||
|
||||
device.type() == U::static_type);
|
||||
auto index = device.index();
|
||||
if (index == -1) return;
|
||||
if (index == -1)
|
||||
return;
|
||||
impl_.setDevice(device);
|
||||
current_device_ = device;
|
||||
}
|
||||
|
|
@ -116,8 +131,8 @@ public:
|
|||
/// Resets the currently set device to its original device, and then sets the
|
||||
/// current device to the passed device. This is effectively equivalent to
|
||||
/// set_device when a guard supports only a single device type.
|
||||
template <typename U=T>
|
||||
typename std::enable_if<!std::is_same<U, VirtualGuardImpl>::value >::type
|
||||
template <typename U = T>
|
||||
typename std::enable_if<!std::is_same<U, VirtualGuardImpl>::value>::type
|
||||
reset_device(at::Device device) {
|
||||
set_device(device);
|
||||
}
|
||||
|
|
@ -139,11 +154,14 @@ public:
|
|||
/// that it is unnecessary.
|
||||
///
|
||||
/// Optional argument is for testing only.
|
||||
template <typename U=T>
|
||||
typename std::enable_if<std::is_same<U, VirtualGuardImpl>::value >::type
|
||||
reset_device(at::Device device, const impl::DeviceGuardImplInterface* impl = nullptr) {
|
||||
template <typename U = T>
|
||||
typename std::enable_if<std::is_same<U, VirtualGuardImpl>::value>::type
|
||||
reset_device(
|
||||
at::Device device,
|
||||
const impl::DeviceGuardImplInterface* impl = nullptr) {
|
||||
auto index = device.index();
|
||||
if (index == -1) return;
|
||||
if (index == -1)
|
||||
return;
|
||||
if (device.type() == original_device_.type()) {
|
||||
AT_ASSERT(impl == nullptr || impl->type() == device.type());
|
||||
impl_.setDevice(device);
|
||||
|
|
@ -175,10 +193,10 @@ public:
|
|||
return current_device_;
|
||||
}
|
||||
|
||||
protected:
|
||||
protected:
|
||||
T impl_;
|
||||
|
||||
private:
|
||||
private:
|
||||
Device original_device_;
|
||||
Device current_device_;
|
||||
};
|
||||
|
|
@ -193,29 +211,33 @@ private:
|
|||
*/
|
||||
template <typename T>
|
||||
class InlineOptionalDeviceGuard {
|
||||
public:
|
||||
public:
|
||||
// Note [Explicit initialization of optional fields]
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
// Explicit initialization of optional fields
|
||||
// required to workaround an nvcc bug; see https://github.com/pytorch/pytorch/issues/12117
|
||||
// required to workaround an nvcc bug; see
|
||||
// https://github.com/pytorch/pytorch/issues/12117
|
||||
|
||||
/// Creates an uninitialized OptionalDeviceGuard.
|
||||
explicit InlineOptionalDeviceGuard()
|
||||
: guard_() // See Note [Explicit initialization of optional fields]
|
||||
{}
|
||||
: guard_() // See Note [Explicit initialization of optional fields]
|
||||
{}
|
||||
|
||||
/// Set the current device to the passed Device, if it is not nullopt.
|
||||
explicit InlineOptionalDeviceGuard(optional<Device> device_opt)
|
||||
: guard_() { // See Note [Explicit initialization of optional fields]
|
||||
: guard_() { // See Note [Explicit initialization of optional fields]
|
||||
if (device_opt.has_value()) {
|
||||
guard_.emplace(device_opt.value());
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the current device to the passed DeviceIndex, if it is not nullopt.
|
||||
template <typename U=T, typename=typename std::enable_if<!std::is_same<U, VirtualGuardImpl>::value>::type>
|
||||
template <
|
||||
typename U = T,
|
||||
typename = typename std::enable_if<
|
||||
!std::is_same<U, VirtualGuardImpl>::value>::type>
|
||||
explicit InlineOptionalDeviceGuard(optional<DeviceIndex> device_index_opt)
|
||||
: guard_() { // See Note [Explicit initialization of optional fields]
|
||||
: guard_() { // See Note [Explicit initialization of optional fields]
|
||||
if (device_index_opt.has_value()) {
|
||||
guard_.emplace(device_index_opt.value());
|
||||
}
|
||||
|
|
@ -225,7 +247,7 @@ public:
|
|||
/// and result in initialized OptionalDeviceGuard.
|
||||
template <typename... Args>
|
||||
explicit InlineOptionalDeviceGuard(Args&&... args)
|
||||
: guard_(in_place, std::forward<Args>(args)...) {}
|
||||
: guard_(in_place, std::forward<Args>(args)...) {}
|
||||
|
||||
// TODO: Consider readding Tensor and TensorList constructors here, when
|
||||
// Tensor moves to c10. (These are only valid on OptionalDeviceGuard,
|
||||
|
|
@ -313,11 +335,15 @@ public:
|
|||
//
|
||||
// We could solve this with an extra thread-local variable. But no one is
|
||||
// actually using move-assignment. So just get rid of it.
|
||||
InlineOptionalDeviceGuard& operator=(InlineOptionalDeviceGuard&& other) = delete;
|
||||
InlineOptionalDeviceGuard& operator=(InlineOptionalDeviceGuard&& other) =
|
||||
delete;
|
||||
|
||||
/// Sets the device to the given one. Initializes OptionalDeviceGuard if it
|
||||
/// is not already initialized.
|
||||
template <typename U=T, typename=typename std::enable_if<!std::is_same<U, VirtualGuardImpl>::value>::type>
|
||||
template <
|
||||
typename U = T,
|
||||
typename = typename std::enable_if<
|
||||
!std::is_same<U, VirtualGuardImpl>::value>::type>
|
||||
void set_device(at::Device device) {
|
||||
if (!guard_.has_value()) {
|
||||
guard_.emplace(device);
|
||||
|
|
@ -333,8 +359,13 @@ public:
|
|||
/// See notes on why this is called reset_device on InlineDeviceGuard.
|
||||
///
|
||||
/// Optional argument is for testing only.
|
||||
template <typename U=T, typename=typename std::enable_if<std::is_same<U, VirtualGuardImpl>::value>::type>
|
||||
void reset_device(at::Device device, const DeviceGuardImplInterface* impl = nullptr) {
|
||||
template <
|
||||
typename U = T,
|
||||
typename = typename std::enable_if<
|
||||
std::is_same<U, VirtualGuardImpl>::value>::type>
|
||||
void reset_device(
|
||||
at::Device device,
|
||||
const DeviceGuardImplInterface* impl = nullptr) {
|
||||
if (!guard_.has_value()) {
|
||||
guard_.emplace(device, impl);
|
||||
} else {
|
||||
|
|
@ -346,7 +377,10 @@ public:
|
|||
/// current device to the passed device. Initializes the guard if it is
|
||||
/// not already initialized. This is effectively equivalent to set_device
|
||||
/// when a guard supports only a single device type.
|
||||
template <typename U=T, typename=typename std::enable_if<!std::is_same<U, VirtualGuardImpl>::value>::type>
|
||||
template <
|
||||
typename U = T,
|
||||
typename = typename std::enable_if<
|
||||
!std::is_same<U, VirtualGuardImpl>::value>::type>
|
||||
void reset_device(at::Device device) {
|
||||
if (!guard_.has_value()) {
|
||||
guard_.emplace(device);
|
||||
|
|
@ -357,7 +391,10 @@ public:
|
|||
|
||||
/// Sets the device index to the given one. The device type is statically
|
||||
/// known.
|
||||
template <typename U=T, typename=typename std::enable_if<!std::is_same<U, VirtualGuardImpl>::value >::type>
|
||||
template <
|
||||
typename U = T,
|
||||
typename = typename std::enable_if<
|
||||
!std::is_same<U, VirtualGuardImpl>::value>::type>
|
||||
void set_index(DeviceIndex index) {
|
||||
if (!guard_.has_value()) {
|
||||
guard_.emplace(index);
|
||||
|
|
@ -366,17 +403,19 @@ public:
|
|||
}
|
||||
}
|
||||
|
||||
/// Returns the device that was set immediately prior to initialization of the,
|
||||
/// guard, or nullopt if the guard is uninitialized.
|
||||
/// Returns the device that was set immediately prior to initialization of
|
||||
/// the, guard, or nullopt if the guard is uninitialized.
|
||||
optional<Device> original_device() const {
|
||||
return guard_.has_value() ? make_optional(guard_->original_device()) : nullopt;
|
||||
return guard_.has_value() ? make_optional(guard_->original_device())
|
||||
: nullopt;
|
||||
}
|
||||
|
||||
/// Returns the most recent device that was set using this device guard,
|
||||
/// either from construction, or via set_device, if the guard is initialized,
|
||||
/// or nullopt if the guard is uninitialized.
|
||||
optional<Device> current_device() const {
|
||||
return guard_.has_value() ? make_optional(guard_->current_device()) : nullopt;
|
||||
return guard_.has_value() ? make_optional(guard_->current_device())
|
||||
: nullopt;
|
||||
}
|
||||
|
||||
/// Restore the original device, resetting this guard to uninitialized state.
|
||||
|
|
@ -384,8 +423,9 @@ public:
|
|||
guard_.reset();
|
||||
}
|
||||
|
||||
private:
|
||||
private:
|
||||
optional<InlineDeviceGuard<T>> guard_;
|
||||
};
|
||||
|
||||
}} // namespace c10::impl
|
||||
} // namespace impl
|
||||
} // namespace c10
|
||||
|
|
|
|||
|
|
@ -2,22 +2,19 @@
|
|||
|
||||
#include <c10/core/DeviceType.h>
|
||||
#include <c10/core/Stream.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/core/impl/DeviceGuardImplInterface.h>
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
namespace c10 {
|
||||
namespace impl {
|
||||
|
||||
template <typename T>
|
||||
struct InlineEvent final {
|
||||
|
||||
InlineEvent() = delete;
|
||||
InlineEvent(
|
||||
const DeviceType _device_type,
|
||||
const EventFlag _flag = EventFlag::PYTORCH_DEFAULT)
|
||||
: backend_{_device_type},
|
||||
device_type_{_device_type},
|
||||
flag_{_flag} { }
|
||||
const DeviceType _device_type,
|
||||
const EventFlag _flag = EventFlag::PYTORCH_DEFAULT)
|
||||
: backend_{_device_type}, device_type_{_device_type}, flag_{_flag} {}
|
||||
|
||||
// Copy constructor and copy assignment operator (deleted)
|
||||
InlineEvent(const InlineEvent&) = delete;
|
||||
|
|
@ -25,7 +22,7 @@ struct InlineEvent final {
|
|||
|
||||
// Move constructor and move assignment operator
|
||||
InlineEvent(InlineEvent&& other)
|
||||
: InlineEvent(other.device_type_, other.flag_) {
|
||||
: InlineEvent(other.device_type_, other.flag_) {
|
||||
swap(std::move(other));
|
||||
}
|
||||
InlineEvent& operator=(InlineEvent&& other) {
|
||||
|
|
@ -43,27 +40,36 @@ struct InlineEvent final {
|
|||
}
|
||||
|
||||
~InlineEvent() noexcept {
|
||||
if (event_) backend_.destroyEvent(event_, device_index_);
|
||||
if (event_)
|
||||
backend_.destroyEvent(event_, device_index_);
|
||||
}
|
||||
|
||||
DeviceType device_type() const noexcept { return device_type_; }
|
||||
DeviceIndex device_index() const noexcept { return device_index_; }
|
||||
EventFlag flag() const noexcept { return flag_; }
|
||||
bool was_marked_for_recording() const noexcept { return was_marked_for_recording_; }
|
||||
|
||||
DeviceType device_type() const noexcept {
|
||||
return device_type_;
|
||||
}
|
||||
DeviceIndex device_index() const noexcept {
|
||||
return device_index_;
|
||||
}
|
||||
EventFlag flag() const noexcept {
|
||||
return flag_;
|
||||
}
|
||||
bool was_marked_for_recording() const noexcept {
|
||||
return was_marked_for_recording_;
|
||||
}
|
||||
|
||||
void recordOnce(const Stream& stream) {
|
||||
if (!was_marked_for_recording_) record(stream);
|
||||
if (!was_marked_for_recording_)
|
||||
record(stream);
|
||||
}
|
||||
|
||||
void record(const Stream& stream) {
|
||||
TORCH_CHECK(
|
||||
stream.device_type() == device_type_,
|
||||
"Event device type ",
|
||||
DeviceTypeName(device_type_),
|
||||
" does not match recording stream's device type ",
|
||||
DeviceTypeName(stream.device_type()),
|
||||
".");
|
||||
stream.device_type() == device_type_,
|
||||
"Event device type ",
|
||||
DeviceTypeName(device_type_),
|
||||
" does not match recording stream's device type ",
|
||||
DeviceTypeName(stream.device_type()),
|
||||
".");
|
||||
|
||||
backend_.record(&event_, stream, device_index_, flag_);
|
||||
was_marked_for_recording_ = true;
|
||||
|
|
@ -71,25 +77,27 @@ struct InlineEvent final {
|
|||
}
|
||||
|
||||
void block(const Stream& stream) const {
|
||||
if (!was_marked_for_recording_) return;
|
||||
if (!was_marked_for_recording_)
|
||||
return;
|
||||
|
||||
TORCH_CHECK(
|
||||
stream.device_type() == device_type_,
|
||||
"Event device type ",
|
||||
DeviceTypeName(device_type_),
|
||||
" does not match blocking stream's device type ",
|
||||
DeviceTypeName(stream.device_type()),
|
||||
".");
|
||||
stream.device_type() == device_type_,
|
||||
"Event device type ",
|
||||
DeviceTypeName(device_type_),
|
||||
" does not match blocking stream's device type ",
|
||||
DeviceTypeName(stream.device_type()),
|
||||
".");
|
||||
|
||||
backend_.block(event_, stream);
|
||||
}
|
||||
|
||||
bool query() const {
|
||||
if (!was_marked_for_recording_) return true;
|
||||
if (!was_marked_for_recording_)
|
||||
return true;
|
||||
return backend_.queryEvent(event_);
|
||||
}
|
||||
|
||||
private:
|
||||
private:
|
||||
void* event_ = nullptr;
|
||||
T backend_;
|
||||
DeviceType device_type_;
|
||||
|
|
@ -98,5 +106,5 @@ private:
|
|||
bool was_marked_for_recording_ = false;
|
||||
};
|
||||
|
||||
} // impl
|
||||
} // c10
|
||||
} // namespace impl
|
||||
} // namespace c10
|
||||
|
|
|
|||
|
|
@ -16,27 +16,34 @@ namespace impl {
|
|||
*/
|
||||
template <typename T>
|
||||
class InlineStreamGuard : private InlineDeviceGuard<T> {
|
||||
public:
|
||||
public:
|
||||
/// No default constructor, see Note [Omitted default constructor from RAII]
|
||||
explicit InlineStreamGuard() = delete;
|
||||
|
||||
/// Set the current device to the device associated with the passed stream,
|
||||
/// and set the current stream on that device to the passed stream.
|
||||
explicit InlineStreamGuard(Stream stream)
|
||||
: InlineDeviceGuard<T>(stream.device())
|
||||
, original_stream_of_original_device_(this->impl_.getStream(original_device()))
|
||||
, original_stream_of_current_device_(this->impl_.exchangeStream(stream))
|
||||
, current_stream_(stream)
|
||||
{}
|
||||
: InlineDeviceGuard<T>(stream.device()),
|
||||
original_stream_of_original_device_(
|
||||
this->impl_.getStream(original_device())),
|
||||
original_stream_of_current_device_(this->impl_.exchangeStream(stream)),
|
||||
current_stream_(stream) {}
|
||||
|
||||
/// This constructor exists purely for testing
|
||||
template <typename U=T, typename=typename std::enable_if<std::is_same<U, VirtualGuardImpl>::value>::type>
|
||||
explicit InlineStreamGuard(Stream stream, const DeviceGuardImplInterface* impl)
|
||||
: InlineDeviceGuard<T>(stream.device(), impl ? impl : getDeviceGuardImpl(stream.device_type()))
|
||||
, original_stream_of_original_device_(this->impl_.getStream(original_device()))
|
||||
, original_stream_of_current_device_(this->impl_.exchangeStream(stream))
|
||||
, current_stream_(stream)
|
||||
{}
|
||||
template <
|
||||
typename U = T,
|
||||
typename = typename std::enable_if<
|
||||
std::is_same<U, VirtualGuardImpl>::value>::type>
|
||||
explicit InlineStreamGuard(
|
||||
Stream stream,
|
||||
const DeviceGuardImplInterface* impl)
|
||||
: InlineDeviceGuard<T>(
|
||||
stream.device(),
|
||||
impl ? impl : getDeviceGuardImpl(stream.device_type())),
|
||||
original_stream_of_original_device_(
|
||||
this->impl_.getStream(original_device())),
|
||||
original_stream_of_current_device_(this->impl_.exchangeStream(stream)),
|
||||
current_stream_(stream) {}
|
||||
|
||||
/// Copy is disallowed
|
||||
InlineStreamGuard(const InlineStreamGuard<T>&) = delete;
|
||||
|
|
@ -110,8 +117,9 @@ public:
|
|||
return InlineDeviceGuard<T>::original_device();
|
||||
}
|
||||
|
||||
private:
|
||||
Stream original_stream_of_original_device_; // what the user probably cares about
|
||||
private:
|
||||
Stream
|
||||
original_stream_of_original_device_; // what the user probably cares about
|
||||
Stream original_stream_of_current_device_; // what we need to restore
|
||||
Stream current_stream_;
|
||||
};
|
||||
|
|
@ -123,17 +131,16 @@ private:
|
|||
*/
|
||||
template <typename T>
|
||||
class InlineOptionalStreamGuard {
|
||||
public:
|
||||
public:
|
||||
/// Creates an uninitialized stream guard.
|
||||
explicit InlineOptionalStreamGuard()
|
||||
: guard_() // See Note [Explicit initialization of optional fields]
|
||||
{}
|
||||
: guard_() // See Note [Explicit initialization of optional fields]
|
||||
{}
|
||||
|
||||
/// Set the current device to the device associated with the passed stream,
|
||||
/// and set the current stream on that device to the passed stream,
|
||||
/// if the passed stream is not nullopt.
|
||||
explicit InlineOptionalStreamGuard(optional<Stream> stream_opt)
|
||||
: guard_() {
|
||||
explicit InlineOptionalStreamGuard(optional<Stream> stream_opt) : guard_() {
|
||||
if (stream_opt.has_value()) {
|
||||
guard_.emplace(stream_opt.value());
|
||||
}
|
||||
|
|
@ -142,13 +149,14 @@ public:
|
|||
/// All constructors of StreamGuard are valid for OptionalStreamGuard
|
||||
template <typename... Args>
|
||||
explicit InlineOptionalStreamGuard(Args&&... args)
|
||||
: guard_(in_place, std::forward<Args>(args)...) {}
|
||||
: guard_(in_place, std::forward<Args>(args)...) {}
|
||||
|
||||
// See Note [Move construction for RAII guards is tricky]
|
||||
InlineOptionalStreamGuard(InlineOptionalStreamGuard<T>&& other) = delete;
|
||||
|
||||
// See Note [Move assignment for RAII guards is tricky]
|
||||
InlineOptionalStreamGuard& operator=(InlineOptionalStreamGuard&& other) = delete;
|
||||
InlineOptionalStreamGuard& operator=(InlineOptionalStreamGuard&& other) =
|
||||
delete;
|
||||
|
||||
/// Resets the currently set stream to the original stream and
|
||||
/// the currently set device to the original device. Then,
|
||||
|
|
@ -166,28 +174,31 @@ public:
|
|||
/// Returns the stream that was set at the time the guard was most recently
|
||||
/// initialized, or nullopt if the guard is uninitialized.
|
||||
optional<Stream> original_stream() const {
|
||||
return guard_.has_value() ? make_optional(guard_->original_stream()) : nullopt;
|
||||
return guard_.has_value() ? make_optional(guard_->original_stream())
|
||||
: nullopt;
|
||||
}
|
||||
|
||||
/// Returns the most recent stream that was set using this stream guard,
|
||||
/// either from construction, or via reset_stream, if the guard is initialized,
|
||||
/// or nullopt if the guard is uninitialized.
|
||||
/// either from construction, or via reset_stream, if the guard is
|
||||
/// initialized, or nullopt if the guard is uninitialized.
|
||||
optional<Stream> current_stream() const {
|
||||
return guard_.has_value() ? make_optional(guard_->current_stream()) : nullopt;
|
||||
return guard_.has_value() ? make_optional(guard_->current_stream())
|
||||
: nullopt;
|
||||
}
|
||||
|
||||
/// Restore the original device and stream, resetting this guard to uninitialized state.
|
||||
/// Restore the original device and stream, resetting this guard to
|
||||
/// uninitialized state.
|
||||
void reset() {
|
||||
guard_.reset();
|
||||
}
|
||||
|
||||
private:
|
||||
private:
|
||||
optional<InlineStreamGuard<T>> guard_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class InlineMultiStreamGuard {
|
||||
public:
|
||||
public:
|
||||
/// Calls `set_stream` on each of the streams in the list.
|
||||
/// This may be useful if you need to set different streams
|
||||
/// for different devices.
|
||||
|
|
@ -216,10 +227,10 @@ public:
|
|||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
protected:
|
||||
optional<T> impl_;
|
||||
|
||||
private:
|
||||
private:
|
||||
/// The original streams that were active on all devices.
|
||||
std::vector<Stream> original_streams_;
|
||||
|
||||
|
|
@ -228,13 +239,17 @@ private:
|
|||
DeviceType type = streams[0].device_type();
|
||||
for (size_t idx = 1; idx < streams.size(); idx++) {
|
||||
TORCH_CHECK_VALUE(
|
||||
streams[idx].device_type() == type,
|
||||
"Streams have a mix of device types: stream 0 is on ",
|
||||
streams[0].device(), " while stream ", idx, " is on device ",
|
||||
streams[idx].device());
|
||||
streams[idx].device_type() == type,
|
||||
"Streams have a mix of device types: stream 0 is on ",
|
||||
streams[0].device(),
|
||||
" while stream ",
|
||||
idx,
|
||||
" is on device ",
|
||||
streams[idx].device());
|
||||
}
|
||||
return type;
|
||||
}
|
||||
};
|
||||
|
||||
}} // namespace c10::impl
|
||||
} // namespace impl
|
||||
} // namespace c10
|
||||
|
|
|
|||
|
|
@ -8,12 +8,12 @@ namespace impl {
|
|||
// NB: POD, must be zero initialized!
|
||||
// Note [TLS Initialization]
|
||||
// We wanted raw_local_dispatch_key_set to be initialized with non-zero state
|
||||
// e.g. BackendSelect and InplaceOrView in included set. But certain Windows compiler (e.g the one
|
||||
// used in ARVR tests) only allow TLS to be zero-initialized.
|
||||
// To preserve the invariant that raw TLS storage of the default state is zero,
|
||||
// we obtain the actual include keyset by XORing raw_local_dispatch_key_set.included_
|
||||
// with c10::default_included_set. This logic is encapsulated in struct
|
||||
// PODLocalDispatchKeySet.
|
||||
// e.g. BackendSelect and InplaceOrView in included set. But certain Windows
|
||||
// compiler (e.g the one used in ARVR tests) only allow TLS to be
|
||||
// zero-initialized. To preserve the invariant that raw TLS storage of the
|
||||
// default state is zero, we obtain the actual include keyset by XORing
|
||||
// raw_local_dispatch_key_set.included_ with c10::default_included_set. This
|
||||
// logic is encapsulated in struct PODLocalDispatchKeySet.
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
thread_local PODLocalDispatchKeySet raw_local_dispatch_key_set;
|
||||
|
||||
|
|
@ -28,26 +28,29 @@ void _force_tls_local_dispatch_key_set(LocalDispatchKeySet key_set) {
|
|||
raw_local_dispatch_key_set.set_excluded(key_set.excluded_);
|
||||
}
|
||||
|
||||
// An RAII guard could snapshot and restore the entire state (entire DispatchKeySet) as
|
||||
// opposed to only snapshotting and restoring the state of its assigned DispatchKeySet.
|
||||
// I'm not sure which is better. If only the RAII API is used, the two choices are
|
||||
// not distinguishable.
|
||||
// An RAII guard could snapshot and restore the entire state (entire
|
||||
// DispatchKeySet) as opposed to only snapshotting and restoring the state of
|
||||
// its assigned DispatchKeySet. I'm not sure which is better. If only the RAII
|
||||
// API is used, the two choices are not distinguishable.
|
||||
//
|
||||
// However, if the guard chooses to snapshot and restore the entire DispatchKeySet,
|
||||
// the interaction with the non-RAII API changes. Consider this sequence of events:
|
||||
// - An RAII guard is declared for a particular DispatchKeySet, but snapshots the entire
|
||||
// However, if the guard chooses to snapshot and restore the entire
|
||||
// DispatchKeySet, the interaction with the non-RAII API changes. Consider this
|
||||
// sequence of events:
|
||||
// - An RAII guard is declared for a particular DispatchKeySet, but snapshots
|
||||
// the entire
|
||||
// current DispatchKeySet.
|
||||
// - A call to the non-RAII API changes the state for DispatchKeys outside the assigned
|
||||
// - A call to the non-RAII API changes the state for DispatchKeys outside the
|
||||
// assigned
|
||||
// set.
|
||||
// - The RAII guard goes out of scope, restoring the entire DispatchKeySet it snapshotted
|
||||
// (which restores the state for its own assigned DispatchKey and wipes out the state
|
||||
// for the other DispatchKeys set by the non-RAII API).
|
||||
// - The RAII guard goes out of scope, restoring the entire DispatchKeySet it
|
||||
// snapshotted
|
||||
// (which restores the state for its own assigned DispatchKey and wipes out
|
||||
// the state for the other DispatchKeys set by the non-RAII API).
|
||||
|
||||
// RAII API
|
||||
|
||||
IncludeDispatchKeyGuard::IncludeDispatchKeyGuard(DispatchKeySet include)
|
||||
: tls_(&raw_local_dispatch_key_set)
|
||||
, include_(include - tls_->included()) {
|
||||
: tls_(&raw_local_dispatch_key_set), include_(include - tls_->included()) {
|
||||
if (!include_.empty()) {
|
||||
tls_->set_included(tls_->included() | include_);
|
||||
}
|
||||
|
|
@ -60,8 +63,7 @@ IncludeDispatchKeyGuard::~IncludeDispatchKeyGuard() {
|
|||
}
|
||||
|
||||
ExcludeDispatchKeyGuard::ExcludeDispatchKeyGuard(DispatchKeySet exclude)
|
||||
: tls_(&raw_local_dispatch_key_set)
|
||||
, exclude_(exclude - tls_->excluded()) {
|
||||
: tls_(&raw_local_dispatch_key_set), exclude_(exclude - tls_->excluded()) {
|
||||
if (!exclude_.empty()) {
|
||||
tls_->set_excluded(tls_->excluded() | exclude_);
|
||||
}
|
||||
|
|
@ -74,7 +76,8 @@ ExcludeDispatchKeyGuard::~ExcludeDispatchKeyGuard() {
|
|||
}
|
||||
|
||||
// Non-RAII API
|
||||
// Please prefer using the RAII API. See declarations in LocalDispatchKeySet.h for details.
|
||||
// Please prefer using the RAII API. See declarations in LocalDispatchKeySet.h
|
||||
// for details.
|
||||
|
||||
bool tls_is_dispatch_key_excluded(DispatchKey x) {
|
||||
return raw_local_dispatch_key_set.excluded().has(x);
|
||||
|
|
@ -115,4 +118,5 @@ bool tls_is_dispatch_keyset_excluded(DispatchKeySet ks) {
|
|||
bool tls_is_dispatch_keyset_included(DispatchKeySet ks) {
|
||||
return raw_local_dispatch_key_set.included().isSupersetOf(ks);
|
||||
}
|
||||
}} // namespace c10::impl
|
||||
} // namespace impl
|
||||
} // namespace c10
|
||||
|
|
|
|||
|
|
@ -36,10 +36,12 @@ struct C10_API PODLocalDispatchKeySet {
|
|||
|
||||
// See Note [TLS Initialization]
|
||||
DispatchKeySet included() const {
|
||||
return DispatchKeySet(DispatchKeySet::RAW, included_) ^ c10::default_included_set;
|
||||
return DispatchKeySet(DispatchKeySet::RAW, included_) ^
|
||||
c10::default_included_set;
|
||||
}
|
||||
DispatchKeySet excluded() const {
|
||||
return DispatchKeySet(DispatchKeySet::RAW, excluded_) ^ c10::default_excluded_set;
|
||||
return DispatchKeySet(DispatchKeySet::RAW, excluded_) ^
|
||||
c10::default_excluded_set;
|
||||
}
|
||||
|
||||
void set_included(DispatchKeySet x) {
|
||||
|
|
@ -49,11 +51,13 @@ struct C10_API PODLocalDispatchKeySet {
|
|||
excluded_ = (x ^ c10::default_excluded_set).raw_repr();
|
||||
}
|
||||
};
|
||||
static_assert(std::is_pod<PODLocalDispatchKeySet>::value, "PODLocalDispatchKeySet must be a POD type.");
|
||||
static_assert(
|
||||
std::is_pod<PODLocalDispatchKeySet>::value,
|
||||
"PODLocalDispatchKeySet must be a POD type.");
|
||||
|
||||
struct C10_API LocalDispatchKeySet {
|
||||
/* implicit */ LocalDispatchKeySet(PODLocalDispatchKeySet x)
|
||||
: included_(x.included()), excluded_(x.excluded()) {}
|
||||
: included_(x.included()), excluded_(x.excluded()) {}
|
||||
DispatchKeySet included_;
|
||||
DispatchKeySet excluded_;
|
||||
};
|
||||
|
|
@ -63,7 +67,7 @@ struct C10_API LocalDispatchKeySet {
|
|||
#if defined(_MSC_VER) || defined(C10_ANDROID)
|
||||
C10_API LocalDispatchKeySet tls_local_dispatch_key_set();
|
||||
#else // defined(_MSC_VER) || defined(C10_ANDROID)
|
||||
extern C10_API thread_local PODLocalDispatchKeySet raw_local_dispatch_key_set;
|
||||
extern C10_API thread_local PODLocalDispatchKeySet raw_local_dispatch_key_set;
|
||||
|
||||
inline C10_API LocalDispatchKeySet tls_local_dispatch_key_set() {
|
||||
// Don't let people fiddle with the thread_local directly just
|
||||
|
|
@ -78,15 +82,17 @@ C10_API void _force_tls_local_dispatch_key_set(LocalDispatchKeySet key_set);
|
|||
// RAII API for manipulating the thread-local dispatch state.
|
||||
|
||||
class C10_API IncludeDispatchKeyGuard {
|
||||
public:
|
||||
public:
|
||||
IncludeDispatchKeyGuard(DispatchKeySet);
|
||||
IncludeDispatchKeyGuard(DispatchKey k) : IncludeDispatchKeyGuard(DispatchKeySet(k)) {}
|
||||
IncludeDispatchKeyGuard(DispatchKey k)
|
||||
: IncludeDispatchKeyGuard(DispatchKeySet(k)) {}
|
||||
IncludeDispatchKeyGuard(const IncludeDispatchKeyGuard&) = delete;
|
||||
IncludeDispatchKeyGuard operator=(const IncludeDispatchKeyGuard&) = delete;
|
||||
IncludeDispatchKeyGuard(IncludeDispatchKeyGuard&&) = delete;
|
||||
IncludeDispatchKeyGuard operator=(IncludeDispatchKeyGuard&&) = delete;
|
||||
~IncludeDispatchKeyGuard();
|
||||
private:
|
||||
|
||||
private:
|
||||
// A little micro-optimization to save us from tls_get_addr call
|
||||
// on destruction
|
||||
PODLocalDispatchKeySet* tls_;
|
||||
|
|
@ -94,15 +100,17 @@ private:
|
|||
};
|
||||
|
||||
class C10_API ExcludeDispatchKeyGuard {
|
||||
public:
|
||||
public:
|
||||
ExcludeDispatchKeyGuard(DispatchKeySet);
|
||||
ExcludeDispatchKeyGuard(DispatchKey k) : ExcludeDispatchKeyGuard(DispatchKeySet(k)) {}
|
||||
ExcludeDispatchKeyGuard(DispatchKey k)
|
||||
: ExcludeDispatchKeyGuard(DispatchKeySet(k)) {}
|
||||
ExcludeDispatchKeyGuard(const ExcludeDispatchKeyGuard&) = delete;
|
||||
ExcludeDispatchKeyGuard operator=(const ExcludeDispatchKeyGuard&) = delete;
|
||||
ExcludeDispatchKeyGuard(ExcludeDispatchKeyGuard&&) = delete;
|
||||
ExcludeDispatchKeyGuard operator=(ExcludeDispatchKeyGuard&&) = delete;
|
||||
~ExcludeDispatchKeyGuard();
|
||||
private:
|
||||
|
||||
private:
|
||||
// A little micro-optimization to save us from tls_get_addr call
|
||||
// on destruction
|
||||
PODLocalDispatchKeySet* tls_;
|
||||
|
|
@ -120,7 +128,8 @@ private:
|
|||
// through that DispatchKey's registered overrides.
|
||||
//
|
||||
// The non-RAII API is less efficient than the RAII guards because both the
|
||||
// getter and setter will do a tls_getaddr lookup (the RAII struct only needs one!)
|
||||
// getter and setter will do a tls_getaddr lookup (the RAII struct only needs
|
||||
// one!)
|
||||
|
||||
C10_API bool tls_is_dispatch_key_excluded(DispatchKey x);
|
||||
C10_API void tls_set_dispatch_key_excluded(DispatchKey x, bool desired_state);
|
||||
|
|
@ -129,4 +138,5 @@ C10_API void tls_set_dispatch_key_included(DispatchKey x, bool desired_state);
|
|||
C10_API bool tls_is_dispatch_keyset_excluded(DispatchKeySet ks);
|
||||
C10_API bool tls_is_dispatch_keyset_included(DispatchKeySet ks);
|
||||
|
||||
}} // namespace c10::impl
|
||||
} // namespace impl
|
||||
} // namespace c10
|
||||
|
|
|
|||
|
|
@ -3,9 +3,13 @@
|
|||
namespace c10 {
|
||||
namespace impl {
|
||||
|
||||
void SizesAndStrides::resizeSlowPath(const size_t newSize, const size_t oldSize) {
|
||||
void SizesAndStrides::resizeSlowPath(
|
||||
const size_t newSize,
|
||||
const size_t oldSize) {
|
||||
if (newSize <= C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE) {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!isInline(), "resizeSlowPath called when fast path should have been hit!");
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
||||
!isInline(),
|
||||
"resizeSlowPath called when fast path should have been hit!");
|
||||
int64_t* tempStorage = outOfLineStorage_;
|
||||
memcpy(
|
||||
&inlineStorage_[0],
|
||||
|
|
@ -24,15 +28,23 @@ void SizesAndStrides::resizeSlowPath(const size_t newSize, const size_t oldSize)
|
|||
// CANNOT USE allocateOutOfLineStorage(newSize) HERE! WOULD
|
||||
// OVERWRITE inlineStorage_!
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-no-malloc)
|
||||
int64_t* tempStorage = static_cast<int64_t *>(malloc(storageBytes(newSize)));
|
||||
TORCH_CHECK(tempStorage, "Could not allocate memory to change Tensor SizesAndStrides!");
|
||||
int64_t* tempStorage =
|
||||
static_cast<int64_t*>(malloc(storageBytes(newSize)));
|
||||
TORCH_CHECK(
|
||||
tempStorage,
|
||||
"Could not allocate memory to change Tensor SizesAndStrides!");
|
||||
const auto bytesToCopy = oldSize * sizeof(inlineStorage_[0]);
|
||||
const auto bytesToZero = (newSize > oldSize) ? (newSize - oldSize) * sizeof(tempStorage[0]) : 0;
|
||||
const auto bytesToZero = (newSize > oldSize)
|
||||
? (newSize - oldSize) * sizeof(tempStorage[0])
|
||||
: 0;
|
||||
memcpy(&tempStorage[0], &inlineStorage_[0], bytesToCopy);
|
||||
if (bytesToZero) {
|
||||
memset(&tempStorage[oldSize], 0, bytesToZero);
|
||||
}
|
||||
memcpy(&tempStorage[newSize], &inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE], bytesToCopy);
|
||||
memcpy(
|
||||
&tempStorage[newSize],
|
||||
&inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE],
|
||||
bytesToCopy);
|
||||
if (bytesToZero) {
|
||||
memset(&tempStorage[newSize + oldSize], 0, bytesToZero);
|
||||
}
|
||||
|
|
@ -55,7 +67,8 @@ void SizesAndStrides::resizeSlowPath(const size_t newSize, const size_t oldSize)
|
|||
resizeOutOfLineStorage(newSize);
|
||||
} else {
|
||||
// Zero the end of the sizes portion.
|
||||
const auto bytesToZero = (newSize - oldSize) * sizeof(outOfLineStorage_[0]);
|
||||
const auto bytesToZero =
|
||||
(newSize - oldSize) * sizeof(outOfLineStorage_[0]);
|
||||
memset(&outOfLineStorage_[oldSize], 0, bytesToZero);
|
||||
memset(&outOfLineStorage_[newSize + oldSize], 0, bytesToZero);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,7 +19,8 @@ namespace impl {
|
|||
// the number of strides. The memory layout is as follows:
|
||||
//
|
||||
// 1 size_t for the size
|
||||
// 5 eightbytes of inline sizes and 5 eightbytes of inline strides, OR pointer to out-of-line array
|
||||
// 5 eightbytes of inline sizes and 5 eightbytes of inline strides, OR pointer
|
||||
// to out-of-line array
|
||||
class C10_API SizesAndStrides {
|
||||
public:
|
||||
// TODO: different iterator types for sizes & strides to prevent
|
||||
|
|
@ -238,11 +239,16 @@ class C10_API SizesAndStrides {
|
|||
if (newSize == oldSize) {
|
||||
return;
|
||||
}
|
||||
if (C10_LIKELY(newSize <= C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE && isInline())) {
|
||||
if (C10_LIKELY(
|
||||
newSize <= C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE && isInline())) {
|
||||
if (oldSize < newSize) {
|
||||
const auto bytesToZero = (newSize - oldSize) * sizeof(inlineStorage_[0]);
|
||||
const auto bytesToZero =
|
||||
(newSize - oldSize) * sizeof(inlineStorage_[0]);
|
||||
memset(&inlineStorage_[oldSize], 0, bytesToZero);
|
||||
memset(&inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE + oldSize], 0, bytesToZero);
|
||||
memset(
|
||||
&inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE + oldSize],
|
||||
0,
|
||||
bytesToZero);
|
||||
}
|
||||
size_ = newSize;
|
||||
} else {
|
||||
|
|
@ -267,14 +273,19 @@ class C10_API SizesAndStrides {
|
|||
}
|
||||
|
||||
void allocateOutOfLineStorage(size_t size) {
|
||||
outOfLineStorage_ = static_cast<int64_t *>(malloc(storageBytes(size)));
|
||||
TORCH_CHECK(outOfLineStorage_, "Could not allocate memory for Tensor SizesAndStrides!");
|
||||
outOfLineStorage_ = static_cast<int64_t*>(malloc(storageBytes(size)));
|
||||
TORCH_CHECK(
|
||||
outOfLineStorage_,
|
||||
"Could not allocate memory for Tensor SizesAndStrides!");
|
||||
}
|
||||
|
||||
void resizeOutOfLineStorage(size_t newSize) {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!isInline());
|
||||
outOfLineStorage_ = static_cast<int64_t *>(realloc(outOfLineStorage_, storageBytes(newSize)));
|
||||
TORCH_CHECK(outOfLineStorage_, "Could not allocate memory for Tensor SizesAndStrides!");
|
||||
outOfLineStorage_ = static_cast<int64_t*>(
|
||||
realloc(outOfLineStorage_, storageBytes(newSize)));
|
||||
TORCH_CHECK(
|
||||
outOfLineStorage_,
|
||||
"Could not allocate memory for Tensor SizesAndStrides!");
|
||||
}
|
||||
|
||||
void copyDataOutline(const SizesAndStrides& rhs) noexcept {
|
||||
|
|
@ -283,10 +294,9 @@ class C10_API SizesAndStrides {
|
|||
|
||||
size_t size_;
|
||||
union {
|
||||
int64_t *outOfLineStorage_;
|
||||
int64_t* outOfLineStorage_;
|
||||
int64_t inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE * 2]{};
|
||||
};
|
||||
|
||||
};
|
||||
|
||||
} // namespace impl
|
||||
|
|
|
|||
|
|
@ -10,12 +10,11 @@ namespace impl {
|
|||
* to virtual dispatch on the DeviceGuardImpl registry.
|
||||
*/
|
||||
class VirtualGuardImpl final : public DeviceGuardImplInterface {
|
||||
public:
|
||||
public:
|
||||
VirtualGuardImpl(DeviceType device_type)
|
||||
: impl_(getDeviceGuardImpl(device_type)) {}
|
||||
: impl_(getDeviceGuardImpl(device_type)) {}
|
||||
// This constructor exists purely for testing
|
||||
VirtualGuardImpl(const DeviceGuardImplInterface* impl)
|
||||
: impl_(impl) {}
|
||||
VirtualGuardImpl(const DeviceGuardImplInterface* impl) : impl_(impl) {}
|
||||
|
||||
// Copying and moving is OK!
|
||||
|
||||
|
|
@ -51,34 +50,32 @@ public:
|
|||
}
|
||||
|
||||
// Event functions
|
||||
void record(void** event,
|
||||
const Stream& stream,
|
||||
const DeviceIndex device_index,
|
||||
const EventFlag flag) const override {
|
||||
void record(
|
||||
void** event,
|
||||
const Stream& stream,
|
||||
const DeviceIndex device_index,
|
||||
const EventFlag flag) const override {
|
||||
impl_->record(event, stream, device_index, flag);
|
||||
}
|
||||
void block(
|
||||
void* event,
|
||||
const Stream& stream) const override {
|
||||
void block(void* event, const Stream& stream) const override {
|
||||
impl_->block(event, stream);
|
||||
}
|
||||
bool queryEvent(void* event) const override {
|
||||
return impl_->queryEvent(event);
|
||||
}
|
||||
void destroyEvent(
|
||||
void* event,
|
||||
const DeviceIndex device_index) const noexcept override {
|
||||
void destroyEvent(void* event, const DeviceIndex device_index)
|
||||
const noexcept override {
|
||||
impl_->destroyEvent(event, device_index);
|
||||
}
|
||||
|
||||
void recordDataPtrOnStream(
|
||||
const c10::DataPtr& data_ptr,
|
||||
const Stream& stream) const override {
|
||||
void recordDataPtrOnStream(const c10::DataPtr& data_ptr, const Stream& stream)
|
||||
const override {
|
||||
impl_->recordDataPtrOnStream(data_ptr, stream);
|
||||
}
|
||||
|
||||
private:
|
||||
private:
|
||||
const DeviceGuardImplInterface* impl_ = nullptr;
|
||||
};
|
||||
|
||||
}} // namespace c10::impl
|
||||
} // namespace impl
|
||||
} // namespace c10
|
||||
|
|
|
|||
|
|
@ -3,9 +3,9 @@
|
|||
namespace c10 {
|
||||
|
||||
ThreadPool::ThreadPool(
|
||||
int pool_size,
|
||||
int numa_node_id,
|
||||
std::function<void()> init_thread)
|
||||
int pool_size,
|
||||
int numa_node_id,
|
||||
std::function<void()> init_thread)
|
||||
: threads_(pool_size < 0 ? defaultNumThreads() : pool_size),
|
||||
running_(true),
|
||||
complete_(true),
|
||||
|
|
@ -13,7 +13,7 @@ ThreadPool::ThreadPool(
|
|||
total_(threads_.size()),
|
||||
numa_node_id_(numa_node_id) {
|
||||
for (std::size_t i = 0; i < threads_.size(); ++i) {
|
||||
threads_[i] = std::thread([this, i, init_thread](){
|
||||
threads_[i] = std::thread([this, i, init_thread]() {
|
||||
if (init_thread) {
|
||||
init_thread();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -50,9 +50,9 @@ class C10_API ThreadPool : public c10::TaskThreadPoolBase {
|
|||
const std::function<void(std::size_t)> with_id;
|
||||
|
||||
explicit task_element_t(std::function<void()> f)
|
||||
: run_with_id(false), no_id(std::move(f)), with_id(nullptr) {}
|
||||
: run_with_id(false), no_id(std::move(f)), with_id(nullptr) {}
|
||||
explicit task_element_t(std::function<void(std::size_t)> f)
|
||||
: run_with_id(true), no_id(nullptr), with_id(std::move(f)) {}
|
||||
: run_with_id(true), no_id(nullptr), with_id(std::move(f)) {}
|
||||
};
|
||||
|
||||
std::queue<task_element_t> tasks_;
|
||||
|
|
@ -105,13 +105,11 @@ class C10_API ThreadPool : public c10::TaskThreadPoolBase {
|
|||
|
||||
class C10_API TaskThreadPool : public c10::ThreadPool {
|
||||
public:
|
||||
explicit TaskThreadPool(
|
||||
std::size_t pool_size,
|
||||
int numa_node_id = -1)
|
||||
: ThreadPool(pool_size, numa_node_id, [numa_node_id](){
|
||||
setThreadName("CaffeTaskThread");
|
||||
NUMABind(numa_node_id);
|
||||
}) {}
|
||||
explicit TaskThreadPool(std::size_t pool_size, int numa_node_id = -1)
|
||||
: ThreadPool(pool_size, numa_node_id, [numa_node_id]() {
|
||||
setThreadName("CaffeTaskThread");
|
||||
NUMABind(numa_node_id);
|
||||
}) {}
|
||||
};
|
||||
|
||||
C10_DECLARE_SHARED_REGISTRY(
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,9 +1,9 @@
|
|||
#ifndef THC_DEVICE_ALLOCATOR_INC
|
||||
#define THC_DEVICE_ALLOCATOR_INC
|
||||
#include <c10/cuda/CUDAGraphsC10Utils.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <c10/core/Allocator.h>
|
||||
#include <c10/cuda/CUDAGraphsC10Utils.h>
|
||||
#include <c10/cuda/CUDAMacros.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <c10/util/Registry.h>
|
||||
|
||||
#include <array>
|
||||
|
|
@ -19,7 +19,7 @@ class C10_CUDA_API CUDAOutOfMemoryError : public c10::Error {
|
|||
// block inside of already allocated area.
|
||||
class C10_CUDA_API FreeMemoryCallback {
|
||||
public:
|
||||
virtual ~FreeMemoryCallback() {};
|
||||
virtual ~FreeMemoryCallback(){};
|
||||
virtual bool Execute() = 0;
|
||||
};
|
||||
|
||||
|
|
@ -55,7 +55,7 @@ enum struct StatType : uint64_t {
|
|||
AGGREGATE = 0,
|
||||
SMALL_POOL = 1,
|
||||
LARGE_POOL = 2,
|
||||
NUM_TYPES = 3 // remember to update this whenever a new stat type is added
|
||||
NUM_TYPES = 3 // remember to update this whenever a new stat type is added
|
||||
};
|
||||
|
||||
typedef std::array<Stat, static_cast<size_t>(StatType::NUM_TYPES)> StatArray;
|
||||
|
|
@ -68,7 +68,8 @@ struct DeviceStats {
|
|||
StatArray segment;
|
||||
// COUNT: number of active memory blocks (allocated or used by stream)
|
||||
StatArray active;
|
||||
// COUNT: number of inactive, split memory blocks (unallocated but can't be released via cudaFree)
|
||||
// COUNT: number of inactive, split memory blocks (unallocated but can't be
|
||||
// released via cudaFree)
|
||||
StatArray inactive_split;
|
||||
|
||||
// SUM: bytes requested by client code
|
||||
|
|
@ -80,14 +81,16 @@ struct DeviceStats {
|
|||
// SUM: bytes within inactive, split memory blocks
|
||||
StatArray inactive_split_bytes;
|
||||
|
||||
// COUNT: total number of failed calls to CUDA malloc necessitating cache flushes.
|
||||
// COUNT: total number of failed calls to CUDA malloc necessitating cache
|
||||
// flushes.
|
||||
int64_t num_alloc_retries = 0;
|
||||
|
||||
// COUNT: total number of OOMs (i.e. failed calls to CUDA after cache flush)
|
||||
int64_t num_ooms = 0;
|
||||
};
|
||||
|
||||
// Struct containing info of an allocation block (i.e. a fractional part of a cudaMalloc)..
|
||||
// Struct containing info of an allocation block (i.e. a fractional part of a
|
||||
// cudaMalloc)..
|
||||
struct BlockInfo {
|
||||
int64_t size = 0;
|
||||
bool allocated = false;
|
||||
|
|
@ -113,8 +116,11 @@ C10_CUDA_API Allocator* get();
|
|||
C10_CUDA_API void init(int device_count);
|
||||
C10_CUDA_API void setMemoryFraction(double fraction, int device);
|
||||
C10_CUDA_API void emptyCache();
|
||||
C10_CUDA_API void cacheInfo(int dev_id, size_t* cachedAndFree, size_t* largestBlock);
|
||||
C10_CUDA_API void* getBaseAllocation(void *ptr, size_t *size);
|
||||
C10_CUDA_API void cacheInfo(
|
||||
int dev_id,
|
||||
size_t* cachedAndFree,
|
||||
size_t* largestBlock);
|
||||
C10_CUDA_API void* getBaseAllocation(void* ptr, size_t* size);
|
||||
C10_CUDA_API void recordStream(const DataPtr&, CUDAStream stream);
|
||||
C10_CUDA_API DeviceStats getDeviceStats(int device);
|
||||
C10_CUDA_API void resetAccumulatedStats(int device);
|
||||
|
|
@ -122,9 +128,10 @@ C10_CUDA_API void resetPeakStats(int device);
|
|||
C10_CUDA_API std::vector<SegmentInfo> snapshot();
|
||||
|
||||
// CUDAGraph interactions
|
||||
C10_CUDA_API void notifyCaptureBegin(int device,
|
||||
CaptureId_t graph_id,
|
||||
MempoolId_t mempool_id);
|
||||
C10_CUDA_API void notifyCaptureBegin(
|
||||
int device,
|
||||
CaptureId_t graph_id,
|
||||
MempoolId_t mempool_id);
|
||||
C10_CUDA_API void notifyCaptureEnd(int device, CaptureId_t graph_id);
|
||||
C10_CUDA_API void notifyCaptureDestroy(int device, MempoolId_t mempool_id);
|
||||
|
||||
|
|
@ -133,6 +140,7 @@ C10_CUDA_API std::mutex* getFreeMutex();
|
|||
C10_CUDA_API std::shared_ptr<void> getIpcDevPtr(std::string handle);
|
||||
} // namespace CUDACachingAllocator
|
||||
|
||||
}} // namespace c10::cuda
|
||||
} // namespace cuda
|
||||
} // namespace c10
|
||||
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
#pragma once
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <cuda.h>
|
||||
|
||||
// Note [CHECK macro]
|
||||
|
|
@ -21,7 +21,7 @@
|
|||
} \
|
||||
} while (0)
|
||||
|
||||
#define C10_CUDA_CHECK_WARN(EXPR) \
|
||||
#define C10_CUDA_CHECK_WARN(EXPR) \
|
||||
do { \
|
||||
cudaError_t __err = EXPR; \
|
||||
if (__err != cudaSuccess) { \
|
||||
|
|
|
|||
|
|
@ -101,7 +101,9 @@ DeviceIndex device_count() noexcept {
|
|||
static int count = []() {
|
||||
try {
|
||||
auto result = device_count_impl(/*fail_if_no_driver=*/false);
|
||||
TORCH_INTERNAL_ASSERT(result <= std::numeric_limits<DeviceIndex>::max(), "Too many CUDA devices, DeviceIndex overflowed");
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
result <= std::numeric_limits<DeviceIndex>::max(),
|
||||
"Too many CUDA devices, DeviceIndex overflowed");
|
||||
return result;
|
||||
} catch (const c10::Error& ex) {
|
||||
// We don't want to fail, but still log the warning
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ struct C10_CUDA_API CUDAStreamCaptureModeGuard {
|
|||
C10_CUDA_CHECK_WARN(cudaThreadExchangeStreamCaptureMode(&strictness_));
|
||||
}
|
||||
|
||||
private:
|
||||
private:
|
||||
cudaStreamCaptureMode strictness_;
|
||||
};
|
||||
#endif
|
||||
|
|
@ -35,15 +35,18 @@ struct C10_CUDA_API CUDAStreamCaptureModeGuard {
|
|||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
|
||||
// Protects against enum cudaStreamCaptureStatus implementation changes.
|
||||
// Some compilers seem not to like static_assert without the messages.
|
||||
static_assert(int(cudaStreamCaptureStatus::cudaStreamCaptureStatusNone) == 0,
|
||||
"unexpected int(cudaStreamCaptureStatusNone) value");
|
||||
static_assert(int(cudaStreamCaptureStatus::cudaStreamCaptureStatusActive) == 1,
|
||||
"unexpected int(cudaStreamCaptureStatusActive) value");
|
||||
static_assert(int(cudaStreamCaptureStatus::cudaStreamCaptureStatusInvalidated) == 2,
|
||||
"unexpected int(cudaStreamCaptureStatusInvalidated) value");
|
||||
static_assert(
|
||||
int(cudaStreamCaptureStatus::cudaStreamCaptureStatusNone) == 0,
|
||||
"unexpected int(cudaStreamCaptureStatusNone) value");
|
||||
static_assert(
|
||||
int(cudaStreamCaptureStatus::cudaStreamCaptureStatusActive) == 1,
|
||||
"unexpected int(cudaStreamCaptureStatusActive) value");
|
||||
static_assert(
|
||||
int(cudaStreamCaptureStatus::cudaStreamCaptureStatusInvalidated) == 2,
|
||||
"unexpected int(cudaStreamCaptureStatusInvalidated) value");
|
||||
#endif
|
||||
|
||||
enum class CaptureStatus: int {
|
||||
enum class CaptureStatus : int {
|
||||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
|
||||
None = int(cudaStreamCaptureStatus::cudaStreamCaptureStatusNone),
|
||||
Active = int(cudaStreamCaptureStatus::cudaStreamCaptureStatusActive),
|
||||
|
|
@ -54,7 +57,7 @@ enum class CaptureStatus: int {
|
|||
};
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, CaptureStatus status) {
|
||||
switch(status) {
|
||||
switch (status) {
|
||||
case CaptureStatus::None:
|
||||
os << "cudaStreamCaptureStatusNone";
|
||||
break;
|
||||
|
|
@ -67,9 +70,8 @@ inline std::ostream& operator<<(std::ostream& os, CaptureStatus status) {
|
|||
break;
|
||||
#endif
|
||||
default:
|
||||
TORCH_INTERNAL_ASSERT(false,
|
||||
"Unknown CUDA graph CaptureStatus",
|
||||
int(status));
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
false, "Unknown CUDA graph CaptureStatus", int(status));
|
||||
}
|
||||
return os;
|
||||
}
|
||||
|
|
@ -78,13 +80,13 @@ inline std::ostream& operator<<(std::ostream& os, CaptureStatus status) {
|
|||
inline CaptureStatus currentStreamCaptureStatusMayInitCtx() {
|
||||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
|
||||
cudaStreamCaptureStatus is_capturing;
|
||||
C10_CUDA_CHECK(cudaStreamIsCapturing(c10::cuda::getCurrentCUDAStream(),
|
||||
&is_capturing));
|
||||
C10_CUDA_CHECK(
|
||||
cudaStreamIsCapturing(c10::cuda::getCurrentCUDAStream(), &is_capturing));
|
||||
return CaptureStatus(is_capturing);
|
||||
#else
|
||||
return CaptureStatus::None;
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace c10
|
||||
} // namespace cuda
|
||||
} // namespace c10
|
||||
|
|
|
|||
|
|
@ -1,16 +1,18 @@
|
|||
#pragma once
|
||||
|
||||
#include <c10/cuda/impl/CUDAGuardImpl.h>
|
||||
#include <c10/cuda/CUDAMacros.h>
|
||||
#include <c10/core/DeviceType.h>
|
||||
#include <c10/core/impl/InlineDeviceGuard.h>
|
||||
#include <c10/core/impl/InlineStreamGuard.h>
|
||||
#include <c10/cuda/CUDAMacros.h>
|
||||
#include <c10/cuda/impl/CUDAGuardImpl.h>
|
||||
|
||||
#include <cstddef>
|
||||
|
||||
namespace c10 { namespace cuda {
|
||||
namespace c10 {
|
||||
namespace cuda {
|
||||
|
||||
// This code is kind of boilerplatey. See Note [Whither the DeviceGuard boilerplate]
|
||||
// This code is kind of boilerplatey. See Note [Whither the DeviceGuard
|
||||
// boilerplate]
|
||||
|
||||
/// A variant of DeviceGuard that is specialized for CUDA. It accepts
|
||||
/// integer indices (interpreting them as CUDA devices) and is a little
|
||||
|
|
@ -38,22 +40,32 @@ struct CUDAGuard {
|
|||
|
||||
/// Sets the CUDA device to the given device. Errors if the given device
|
||||
/// is not a CUDA device.
|
||||
void set_device(Device device) { guard_.set_device(device); }
|
||||
void set_device(Device device) {
|
||||
guard_.set_device(device);
|
||||
}
|
||||
|
||||
/// Sets the CUDA device to the given device. Errors if the given device
|
||||
/// is not a CUDA device. (This method is provided for uniformity with
|
||||
/// DeviceGuard).
|
||||
void reset_device(Device device) { guard_.reset_device(device); }
|
||||
void reset_device(Device device) {
|
||||
guard_.reset_device(device);
|
||||
}
|
||||
|
||||
/// Sets the CUDA device to the given device index.
|
||||
void set_index(DeviceIndex device_index) { guard_.set_index(device_index); }
|
||||
void set_index(DeviceIndex device_index) {
|
||||
guard_.set_index(device_index);
|
||||
}
|
||||
|
||||
/// Returns the device that was set upon construction of the guard
|
||||
Device original_device() const { return guard_.original_device(); }
|
||||
Device original_device() const {
|
||||
return guard_.original_device();
|
||||
}
|
||||
|
||||
/// Returns the last device that was set via `set_device`, if any, otherwise the
|
||||
/// device passed during construction.
|
||||
Device current_device() const { return guard_.current_device(); }
|
||||
/// Returns the last device that was set via `set_device`, if any, otherwise
|
||||
/// the device passed during construction.
|
||||
Device current_device() const {
|
||||
return guard_.current_device();
|
||||
}
|
||||
|
||||
private:
|
||||
/// The guard for the current device.
|
||||
|
|
@ -67,11 +79,13 @@ struct OptionalCUDAGuard {
|
|||
explicit OptionalCUDAGuard() : guard_() {}
|
||||
|
||||
/// Set the current CUDA device to the passed Device, if it is not nullopt.
|
||||
explicit OptionalCUDAGuard(optional<Device> device_opt) : guard_(device_opt) {}
|
||||
explicit OptionalCUDAGuard(optional<Device> device_opt)
|
||||
: guard_(device_opt) {}
|
||||
|
||||
/// Set the current CUDA device to the passed device index, if it is not
|
||||
/// nullopt
|
||||
explicit OptionalCUDAGuard(optional<DeviceIndex> device_index_opt) : guard_(device_index_opt) {}
|
||||
explicit OptionalCUDAGuard(optional<DeviceIndex> device_index_opt)
|
||||
: guard_(device_index_opt) {}
|
||||
|
||||
// Copy is not allowed
|
||||
OptionalCUDAGuard(const OptionalCUDAGuard&) = delete;
|
||||
|
|
@ -84,31 +98,45 @@ struct OptionalCUDAGuard {
|
|||
OptionalCUDAGuard& operator=(OptionalCUDAGuard&& other) = delete;
|
||||
|
||||
/// Sets the CUDA device to the given device, initializing the guard if it
|
||||
/// is not already initialized. Errors if the given device is not a CUDA device.
|
||||
void set_device(Device device) { guard_.set_device(device); }
|
||||
/// is not already initialized. Errors if the given device is not a CUDA
|
||||
/// device.
|
||||
void set_device(Device device) {
|
||||
guard_.set_device(device);
|
||||
}
|
||||
|
||||
/// Sets the CUDA device to the given device, initializing the guard if it is
|
||||
/// not already initialized. Errors if the given device is not a CUDA device.
|
||||
/// (This method is provided for uniformity with OptionalDeviceGuard).
|
||||
void reset_device(Device device) { guard_.reset_device(device); }
|
||||
void reset_device(Device device) {
|
||||
guard_.reset_device(device);
|
||||
}
|
||||
|
||||
/// Sets the CUDA device to the given device index, initializing the guard if
|
||||
/// it is not already initialized.
|
||||
void set_index(DeviceIndex device_index) { guard_.set_index(device_index); }
|
||||
void set_index(DeviceIndex device_index) {
|
||||
guard_.set_index(device_index);
|
||||
}
|
||||
|
||||
/// Returns the device that was set immediately prior to initialization of the
|
||||
/// guard, or nullopt if the guard is uninitialized.
|
||||
optional<Device> original_device() const { return guard_.original_device(); }
|
||||
optional<Device> original_device() const {
|
||||
return guard_.original_device();
|
||||
}
|
||||
|
||||
/// Returns the most recent device that was set using this device guard,
|
||||
/// either from construction, or via set_device, if the guard is initialized,
|
||||
/// or nullopt if the guard is uninitialized.
|
||||
optional<Device> current_device() const { return guard_.current_device(); }
|
||||
optional<Device> current_device() const {
|
||||
return guard_.current_device();
|
||||
}
|
||||
|
||||
/// Restore the original CUDA device, resetting this guard to uninitialized state.
|
||||
void reset() { guard_.reset(); }
|
||||
/// Restore the original CUDA device, resetting this guard to uninitialized
|
||||
/// state.
|
||||
void reset() {
|
||||
guard_.reset();
|
||||
}
|
||||
|
||||
private:
|
||||
private:
|
||||
c10::impl::InlineOptionalDeviceGuard<impl::CUDAGuardImpl> guard_;
|
||||
};
|
||||
|
||||
|
|
@ -118,17 +146,17 @@ struct CUDAStreamGuard {
|
|||
/// No default constructor, see Note [Omitted default constructor from RAII]
|
||||
explicit CUDAStreamGuard() = delete;
|
||||
|
||||
/// Set the current CUDA device to the device associated with the passed stream,
|
||||
/// and set the current CUDA stream on that device to the passed stream.
|
||||
/// Errors if the Stream is not a CUDA stream.
|
||||
/// Set the current CUDA device to the device associated with the passed
|
||||
/// stream, and set the current CUDA stream on that device to the passed
|
||||
/// stream. Errors if the Stream is not a CUDA stream.
|
||||
explicit CUDAStreamGuard(Stream stream) : guard_(stream) {}
|
||||
|
||||
/// Copy is disallowed
|
||||
CUDAStreamGuard(const CUDAStreamGuard&) = delete;
|
||||
CUDAStreamGuard& operator=(const CUDAStreamGuard&) = delete;
|
||||
|
||||
/// Move is disallowed, as CUDAStreamGuard does not have an uninitialized state,
|
||||
/// which is required for moves on types with nontrivial destructors.
|
||||
/// Move is disallowed, as CUDAStreamGuard does not have an uninitialized
|
||||
/// state, which is required for moves on types with nontrivial destructors.
|
||||
CUDAStreamGuard(CUDAStreamGuard&& other) = delete;
|
||||
CUDAStreamGuard& operator=(CUDAStreamGuard&& other) = delete;
|
||||
|
||||
|
|
@ -144,9 +172,12 @@ struct CUDAStreamGuard {
|
|||
/// WARNING: reset_stream does NOT preserve previously set streams on
|
||||
/// different devices. If you need to set streams on multiple devices
|
||||
/// on CUDA, use CUDAMultiStreamGuard instead.
|
||||
void reset_stream(Stream stream) { guard_.reset_stream(stream); }
|
||||
void reset_stream(Stream stream) {
|
||||
guard_.reset_stream(stream);
|
||||
}
|
||||
|
||||
/// Returns the CUDA stream that was set at the time the guard was constructed.
|
||||
/// Returns the CUDA stream that was set at the time the guard was
|
||||
/// constructed.
|
||||
CUDAStream original_stream() const {
|
||||
return CUDAStream(CUDAStream::UNCHECKED, guard_.original_stream());
|
||||
}
|
||||
|
|
@ -159,31 +190,36 @@ struct CUDAStreamGuard {
|
|||
|
||||
/// Returns the most recent CUDA device that was set using this device guard,
|
||||
/// either from construction, or via set_device/reset_device/set_index.
|
||||
Device current_device() const { return guard_.current_device(); }
|
||||
Device current_device() const {
|
||||
return guard_.current_device();
|
||||
}
|
||||
|
||||
/// Returns the CUDA device that was set at the most recent reset_stream(),
|
||||
/// or otherwise the device at construction time.
|
||||
Device original_device() const { return guard_.original_device(); }
|
||||
Device original_device() const {
|
||||
return guard_.original_device();
|
||||
}
|
||||
|
||||
private:
|
||||
private:
|
||||
c10::impl::InlineStreamGuard<impl::CUDAGuardImpl> guard_;
|
||||
};
|
||||
|
||||
/// A variant of OptionalStreamGuard that is specialized for CUDA. See CUDAGuard
|
||||
/// for when you can use this.
|
||||
/// A variant of OptionalStreamGuard that is specialized for CUDA. See
|
||||
/// CUDAGuard for when you can use this.
|
||||
struct OptionalCUDAStreamGuard {
|
||||
/// Create an uninitialized guard.
|
||||
explicit OptionalCUDAStreamGuard() : guard_() {}
|
||||
|
||||
/// Set the current CUDA device to the device associated with the passed stream,
|
||||
/// and set the current CUDA stream on that device to the passed stream.
|
||||
/// Errors if the Stream is not a CUDA stream.
|
||||
/// Set the current CUDA device to the device associated with the passed
|
||||
/// stream, and set the current CUDA stream on that device to the passed
|
||||
/// stream. Errors if the Stream is not a CUDA stream.
|
||||
explicit OptionalCUDAStreamGuard(Stream stream) : guard_(stream) {}
|
||||
|
||||
/// Set the current device to the device associated with the passed stream,
|
||||
/// and set the current stream on that device to the passed stream,
|
||||
/// if the passed stream is not nullopt.
|
||||
explicit OptionalCUDAStreamGuard(optional<Stream> stream_opt) : guard_(stream_opt) {}
|
||||
explicit OptionalCUDAStreamGuard(optional<Stream> stream_opt)
|
||||
: guard_(stream_opt) {}
|
||||
|
||||
/// Copy is disallowed
|
||||
OptionalCUDAStreamGuard(const OptionalCUDAStreamGuard&) = delete;
|
||||
|
|
@ -200,10 +236,12 @@ struct OptionalCUDAStreamGuard {
|
|||
/// set the current device to the device associated with the passed stream,
|
||||
/// and set the current stream on that device to the passed stream.
|
||||
/// Initializes the guard if it was not previously initialized.
|
||||
void reset_stream(Stream stream) { guard_.reset_stream(stream); }
|
||||
void reset_stream(Stream stream) {
|
||||
guard_.reset_stream(stream);
|
||||
}
|
||||
|
||||
/// Returns the CUDA stream that was set at the time the guard was most recently
|
||||
/// initialized, or nullopt if the guard is uninitialized.
|
||||
/// Returns the CUDA stream that was set at the time the guard was most
|
||||
/// recently initialized, or nullopt if the guard is uninitialized.
|
||||
optional<CUDAStream> original_stream() const {
|
||||
auto r = guard_.original_stream();
|
||||
if (r.has_value()) {
|
||||
|
|
@ -214,8 +252,8 @@ struct OptionalCUDAStreamGuard {
|
|||
}
|
||||
|
||||
/// Returns the most recent CUDA stream that was set using this stream guard,
|
||||
/// either from construction, or via reset_stream, if the guard is initialized,
|
||||
/// or nullopt if the guard is uninitialized.
|
||||
/// either from construction, or via reset_stream, if the guard is
|
||||
/// initialized, or nullopt if the guard is uninitialized.
|
||||
optional<CUDAStream> current_stream() const {
|
||||
auto r = guard_.current_stream();
|
||||
if (r.has_value()) {
|
||||
|
|
@ -225,17 +263,20 @@ struct OptionalCUDAStreamGuard {
|
|||
}
|
||||
}
|
||||
|
||||
/// Restore the original CUDA device and stream, resetting this guard to uninitialized state.
|
||||
void reset() { guard_.reset(); }
|
||||
/// Restore the original CUDA device and stream, resetting this guard to
|
||||
/// uninitialized state.
|
||||
void reset() {
|
||||
guard_.reset();
|
||||
}
|
||||
|
||||
private:
|
||||
private:
|
||||
c10::impl::InlineOptionalStreamGuard<impl::CUDAGuardImpl> guard_;
|
||||
};
|
||||
|
||||
/// A variant of MultiStreamGuard that is specialized for CUDA.
|
||||
struct CUDAMultiStreamGuard {
|
||||
explicit CUDAMultiStreamGuard(ArrayRef<CUDAStream> streams)
|
||||
: guard_(unwrapStreams(streams)) {}
|
||||
: guard_(unwrapStreams(streams)) {}
|
||||
|
||||
/// Copy is disallowed
|
||||
CUDAMultiStreamGuard(const CUDAMultiStreamGuard&) = delete;
|
||||
|
|
@ -247,7 +288,7 @@ struct CUDAMultiStreamGuard {
|
|||
// See Note [Move assignment for RAII guards is tricky]
|
||||
CUDAMultiStreamGuard& operator=(CUDAMultiStreamGuard&& other) = delete;
|
||||
|
||||
private:
|
||||
private:
|
||||
c10::impl::InlineMultiStreamGuard<impl::CUDAGuardImpl> guard_;
|
||||
|
||||
static std::vector<Stream> unwrapStreams(ArrayRef<CUDAStream> cudaStreams) {
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <c10/cuda/CUDAFunctions.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/irange.h>
|
||||
|
||||
|
|
@ -204,7 +204,7 @@ static void initGlobalStreamState() {
|
|||
"). Increase that and recompile.");
|
||||
|
||||
// Initializes default streams
|
||||
for (const auto i: c10::irange(num_gpus)) {
|
||||
for (const auto i : c10::irange(num_gpus)) {
|
||||
default_streams[i].device_index = i;
|
||||
low_priority_counters[i] = 0;
|
||||
high_priority_counters[i] = 0;
|
||||
|
|
@ -218,7 +218,7 @@ static void initDeviceStreamState(DeviceIndex device_index) {
|
|||
// with it.
|
||||
CUDAGuard device_guard{device_index};
|
||||
|
||||
for (const auto i: c10::irange(kStreamsPerPool)) {
|
||||
for (const auto i : c10::irange(kStreamsPerPool)) {
|
||||
auto& lowpri_stream = low_priority_streams[device_index][i];
|
||||
auto& hipri_stream = high_priority_streams[device_index][i];
|
||||
|
||||
|
|
@ -244,7 +244,7 @@ static void initCUDAStreamsOnce() {
|
|||
// Inits current streams (thread local) to default streams
|
||||
current_streams =
|
||||
(LeakyStreamInternals**)malloc(num_gpus * sizeof(LeakyStreamInternals*));
|
||||
for (const auto i: c10::irange(num_gpus)) {
|
||||
for (const auto i : c10::irange(num_gpus)) {
|
||||
current_streams[i] = &default_streams[i];
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,53 +5,53 @@
|
|||
|
||||
#include <cuda_runtime_api.h>
|
||||
|
||||
#include <c10/core/DeviceGuard.h>
|
||||
#include <c10/core/Stream.h>
|
||||
#include <c10/cuda/CUDAException.h>
|
||||
#include <c10/cuda/CUDAMacros.h>
|
||||
#include <c10/core/DeviceGuard.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/core/Stream.h>
|
||||
|
||||
/*
|
||||
* Stream pool note.
|
||||
*
|
||||
* A CUDAStream is an abstraction of an actual cuStream on the GPU. CUDAStreams
|
||||
* are backed by cuStreams, but they use several pools to minimize the costs
|
||||
* associated with creating, retaining, and destroying cuStreams.
|
||||
*
|
||||
* There are three pools per device, and a device's pools are lazily created.
|
||||
*
|
||||
* The first pool contains only the default stream. When the default stream
|
||||
* is requested it's returned.
|
||||
*
|
||||
* The second pool is the "low priority" or "default priority" streams. In
|
||||
* HIP builds there is no distinction between streams in this pool and streams
|
||||
* in the third pool (below). There are 32 of these streams per device, and
|
||||
* when a stream is requested one of these streams is returned round-robin.
|
||||
* That is, the first stream requested is at index 0, the second at index 1...
|
||||
* to index 31, then index 0 again.
|
||||
*
|
||||
* This means that if 33 low priority streams are requested, the first and
|
||||
* last streams requested are actually the same stream (under the covers)
|
||||
* and kernels enqueued on them cannot run concurrently.
|
||||
*
|
||||
* The third pool is the "high priority" streams. The third pool acts like
|
||||
* the second pool except the streams are created with a higher priority.
|
||||
*
|
||||
* These pools suggest that stream users should prefer many short-lived streams,
|
||||
* as the cost of acquiring and releasing streams is effectively zero. If
|
||||
* many longer-lived streams are required in performance critical scenarios
|
||||
* then the functionality here may need to be extended to allow, for example,
|
||||
* "reserving" a subset of the pool so that other streams do not accidentally
|
||||
* overlap the performance critical streams.
|
||||
*
|
||||
* Note: although the notion of "current stream for device" is thread local
|
||||
* (every OS thread has a separate current stream, as one might expect),
|
||||
* the stream pool is global across all threads; stream 0 is always stream 0
|
||||
* no matter which thread you use it on. Multiple threads can synchronize
|
||||
* on the same stream. Although the CUDA documentation is not very clear
|
||||
* on the matter, streams are thread safe; e.g., it is safe to enqueue
|
||||
* a kernel on the same stream from two different threads.
|
||||
*/
|
||||
* Stream pool note.
|
||||
*
|
||||
* A CUDAStream is an abstraction of an actual cuStream on the GPU. CUDAStreams
|
||||
* are backed by cuStreams, but they use several pools to minimize the costs
|
||||
* associated with creating, retaining, and destroying cuStreams.
|
||||
*
|
||||
* There are three pools per device, and a device's pools are lazily created.
|
||||
*
|
||||
* The first pool contains only the default stream. When the default stream
|
||||
* is requested it's returned.
|
||||
*
|
||||
* The second pool is the "low priority" or "default priority" streams. In
|
||||
* HIP builds there is no distinction between streams in this pool and streams
|
||||
* in the third pool (below). There are 32 of these streams per device, and
|
||||
* when a stream is requested one of these streams is returned round-robin.
|
||||
* That is, the first stream requested is at index 0, the second at index 1...
|
||||
* to index 31, then index 0 again.
|
||||
*
|
||||
* This means that if 33 low priority streams are requested, the first and
|
||||
* last streams requested are actually the same stream (under the covers)
|
||||
* and kernels enqueued on them cannot run concurrently.
|
||||
*
|
||||
* The third pool is the "high priority" streams. The third pool acts like
|
||||
* the second pool except the streams are created with a higher priority.
|
||||
*
|
||||
* These pools suggest that stream users should prefer many short-lived streams,
|
||||
* as the cost of acquiring and releasing streams is effectively zero. If
|
||||
* many longer-lived streams are required in performance critical scenarios
|
||||
* then the functionality here may need to be extended to allow, for example,
|
||||
* "reserving" a subset of the pool so that other streams do not accidentally
|
||||
* overlap the performance critical streams.
|
||||
*
|
||||
* Note: although the notion of "current stream for device" is thread local
|
||||
* (every OS thread has a separate current stream, as one might expect),
|
||||
* the stream pool is global across all threads; stream 0 is always stream 0
|
||||
* no matter which thread you use it on. Multiple threads can synchronize
|
||||
* on the same stream. Although the CUDA documentation is not very clear
|
||||
* on the matter, streams are thread safe; e.g., it is safe to enqueue
|
||||
* a kernel on the same stream from two different threads.
|
||||
*/
|
||||
|
||||
namespace c10 {
|
||||
namespace cuda {
|
||||
|
|
@ -61,8 +61,7 @@ namespace cuda {
|
|||
// functionality (conversion to cudaStream_t), and a guarantee that
|
||||
// the wrapped c10::Stream really is a CUDA stream.
|
||||
class C10_CUDA_API CUDAStream {
|
||||
public:
|
||||
|
||||
public:
|
||||
enum Unchecked { UNCHECKED };
|
||||
|
||||
/// Construct a CUDAStream from a Stream. This construction is checked,
|
||||
|
|
@ -85,21 +84,31 @@ public:
|
|||
}
|
||||
|
||||
/// Implicit conversion to cudaStream_t.
|
||||
operator cudaStream_t() const { return stream(); }
|
||||
operator cudaStream_t() const {
|
||||
return stream();
|
||||
}
|
||||
|
||||
/// Implicit conversion to Stream (a.k.a., forget that the stream is a
|
||||
/// CUDA stream).
|
||||
operator Stream() const { return unwrap(); }
|
||||
operator Stream() const {
|
||||
return unwrap();
|
||||
}
|
||||
|
||||
/// Get the CUDA device index that this stream is associated with.
|
||||
DeviceIndex device_index() const { return stream_.device_index(); }
|
||||
DeviceIndex device_index() const {
|
||||
return stream_.device_index();
|
||||
}
|
||||
|
||||
/// Get the full Device that this stream is associated with. The Device
|
||||
/// is guaranteed to be a CUDA device.
|
||||
Device device() const { return Device(DeviceType::CUDA, device_index()); }
|
||||
Device device() const {
|
||||
return Device(DeviceType::CUDA, device_index());
|
||||
}
|
||||
|
||||
/// Return the stream ID corresponding to this particular stream.
|
||||
StreamId id() const { return stream_.id(); }
|
||||
StreamId id() const {
|
||||
return stream_.id();
|
||||
}
|
||||
|
||||
bool query() const {
|
||||
DeviceGuard guard{stream_.device()};
|
||||
|
|
@ -120,17 +129,19 @@ public:
|
|||
}
|
||||
|
||||
int priority() const {
|
||||
DeviceGuard guard{stream_.device()};
|
||||
int priority = 0;
|
||||
C10_CUDA_CHECK(cudaStreamGetPriority(stream(), &priority));
|
||||
return priority;
|
||||
DeviceGuard guard{stream_.device()};
|
||||
int priority = 0;
|
||||
C10_CUDA_CHECK(cudaStreamGetPriority(stream(), &priority));
|
||||
return priority;
|
||||
}
|
||||
|
||||
/// Explicit conversion to cudaStream_t.
|
||||
cudaStream_t stream() const;
|
||||
|
||||
/// Explicit conversion to Stream.
|
||||
Stream unwrap() const { return stream_; }
|
||||
Stream unwrap() const {
|
||||
return stream_;
|
||||
}
|
||||
|
||||
/// Reversibly pack a CUDAStream into a uint64_t representation. This may
|
||||
/// be helpful when storing a CUDAStream in a C struct, where you cannot
|
||||
|
|
@ -150,22 +161,24 @@ public:
|
|||
}
|
||||
|
||||
static std::tuple<int, int> priority_range() {
|
||||
// Note: this returns the range of priority **supported by PyTorch**, not
|
||||
// the range of priority **supported by CUDA**. The former is a subset of
|
||||
// the latter. Currently PyTorch only supports 0 and -1, which are "low" and
|
||||
// "high" priority.
|
||||
int least_priority, greatest_priority;
|
||||
C10_CUDA_CHECK(
|
||||
// Note: this returns the range of priority **supported by PyTorch**, not
|
||||
// the range of priority **supported by CUDA**. The former is a subset of
|
||||
// the latter. Currently PyTorch only supports 0 and -1, which are "low" and
|
||||
// "high" priority.
|
||||
int least_priority, greatest_priority;
|
||||
C10_CUDA_CHECK(
|
||||
cudaDeviceGetStreamPriorityRange(&least_priority, &greatest_priority));
|
||||
TORCH_INTERNAL_ASSERT(least_priority >= 0, "Unexpected CUDA stream priority range");
|
||||
TORCH_INTERNAL_ASSERT(greatest_priority <= -1, "Unexpected CUDA stream priority range");
|
||||
return std::make_tuple(0, -1);
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
least_priority >= 0, "Unexpected CUDA stream priority range");
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
greatest_priority <= -1, "Unexpected CUDA stream priority range");
|
||||
return std::make_tuple(0, -1);
|
||||
}
|
||||
|
||||
// Deleted for now; use CUDAEvent::block instead
|
||||
// void synchronize_with(const CUDAEvent& event) const;
|
||||
|
||||
private:
|
||||
private:
|
||||
Stream stream_;
|
||||
};
|
||||
|
||||
|
|
@ -214,13 +227,13 @@ TORCH_API void setCurrentCUDAStream(CUDAStream stream);
|
|||
C10_API std::ostream& operator<<(std::ostream& stream, const CUDAStream& s);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace at
|
||||
} // namespace c10
|
||||
|
||||
namespace std {
|
||||
template <>
|
||||
struct hash<c10::cuda::CUDAStream> {
|
||||
size_t operator()(c10::cuda::CUDAStream s) const noexcept {
|
||||
return std::hash<c10::Stream>{}(s.unwrap());
|
||||
}
|
||||
};
|
||||
template <>
|
||||
struct hash<c10::cuda::CUDAStream> {
|
||||
size_t operator()(c10::cuda::CUDAStream s) const noexcept {
|
||||
return std::hash<c10::Stream>{}(s.unwrap());
|
||||
}
|
||||
};
|
||||
} // namespace std
|
||||
|
|
|
|||
|
|
@ -8,4 +8,6 @@ constexpr DeviceType CUDAGuardImpl::static_type;
|
|||
|
||||
C10_REGISTER_GUARD_IMPL(CUDA, CUDAGuardImpl);
|
||||
|
||||
}}} // namespace c10::cuda::detail
|
||||
} // namespace impl
|
||||
} // namespace cuda
|
||||
} // namespace c10
|
||||
|
|
|
|||
|
|
@ -4,10 +4,10 @@
|
|||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
#include <c10/cuda/CUDAException.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <c10/cuda/CUDAFunctions.h>
|
||||
#include <c10/cuda/CUDACachingAllocator.h>
|
||||
#include <c10/cuda/CUDAException.h>
|
||||
#include <c10/cuda/CUDAFunctions.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
|
||||
|
|
@ -82,9 +82,7 @@ struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface {
|
|||
}
|
||||
|
||||
// Event-related functions
|
||||
void createEvent(
|
||||
cudaEvent_t* cuda_event,
|
||||
const EventFlag flag) const {
|
||||
void createEvent(cudaEvent_t* cuda_event, const EventFlag flag) const {
|
||||
// Maps PyTorch's Event::Flag to CUDA flag
|
||||
auto cuda_flag = cudaEventDefault;
|
||||
switch (flag) {
|
||||
|
|
@ -103,10 +101,10 @@ struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface {
|
|||
C10_CUDA_CHECK(cudaEventCreateWithFlags(cuda_event, cuda_flag));
|
||||
}
|
||||
|
||||
void destroyEvent(
|
||||
void* event,
|
||||
const DeviceIndex device_index) const noexcept override {
|
||||
if (!event) return;
|
||||
void destroyEvent(void* event, const DeviceIndex device_index)
|
||||
const noexcept override {
|
||||
if (!event)
|
||||
return;
|
||||
auto cuda_event = static_cast<cudaEvent_t>(event);
|
||||
int orig_device;
|
||||
C10_CUDA_CHECK_WARN(cudaGetDevice(&orig_device));
|
||||
|
|
@ -116,16 +114,17 @@ struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface {
|
|||
}
|
||||
|
||||
void record(
|
||||
void** event,
|
||||
const Stream& stream,
|
||||
const DeviceIndex device_index,
|
||||
const EventFlag flag) const override {
|
||||
TORCH_CHECK(device_index == -1 || device_index == stream.device_index(),
|
||||
"Event device index ",
|
||||
device_index,
|
||||
" does not match recording stream's device index ",
|
||||
stream.device_index(),
|
||||
".");
|
||||
void** event,
|
||||
const Stream& stream,
|
||||
const DeviceIndex device_index,
|
||||
const EventFlag flag) const override {
|
||||
TORCH_CHECK(
|
||||
device_index == -1 || device_index == stream.device_index(),
|
||||
"Event device index ",
|
||||
device_index,
|
||||
" does not match recording stream's device index ",
|
||||
stream.device_index(),
|
||||
".");
|
||||
|
||||
cudaEvent_t cuda_event = static_cast<cudaEvent_t>(*event);
|
||||
CUDAStream cuda_stream{stream};
|
||||
|
|
@ -135,7 +134,8 @@ struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface {
|
|||
setDevice(stream.device());
|
||||
|
||||
// Creates the event (lazily)
|
||||
if (!cuda_event) createEvent(&cuda_event, flag);
|
||||
if (!cuda_event)
|
||||
createEvent(&cuda_event, flag);
|
||||
C10_CUDA_CHECK(cudaEventRecord(cuda_event, cuda_stream));
|
||||
// Makes the void* point to the (possibly just allocated) CUDA event
|
||||
*event = cuda_event;
|
||||
|
|
@ -144,24 +144,24 @@ struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface {
|
|||
setDevice(orig_device);
|
||||
}
|
||||
|
||||
void block(
|
||||
void* event,
|
||||
const Stream& stream) const override {
|
||||
if (!event) return;
|
||||
void block(void* event, const Stream& stream) const override {
|
||||
if (!event)
|
||||
return;
|
||||
cudaEvent_t cuda_event = static_cast<cudaEvent_t>(event);
|
||||
CUDAStream cuda_stream{stream};
|
||||
const auto orig_device = getDevice();
|
||||
setDevice(stream.device());
|
||||
C10_CUDA_CHECK(cudaStreamWaitEvent(
|
||||
cuda_stream,
|
||||
cuda_event,
|
||||
/*flags (must be zero)=*/ 0));
|
||||
cuda_stream,
|
||||
cuda_event,
|
||||
/*flags (must be zero)=*/0));
|
||||
setDevice(orig_device);
|
||||
}
|
||||
|
||||
// May be called from any device
|
||||
bool queryEvent(void* event) const override {
|
||||
if (!event) return true;
|
||||
if (!event)
|
||||
return true;
|
||||
cudaEvent_t cuda_event = static_cast<cudaEvent_t>(event);
|
||||
const cudaError_t err = cudaEventQuery(cuda_event);
|
||||
if (err != cudaErrorNotReady) {
|
||||
|
|
@ -170,12 +170,13 @@ struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface {
|
|||
return (err == cudaSuccess);
|
||||
}
|
||||
|
||||
void recordDataPtrOnStream(
|
||||
const c10::DataPtr& data_ptr,
|
||||
const Stream& stream) const override {
|
||||
void recordDataPtrOnStream(const c10::DataPtr& data_ptr, const Stream& stream)
|
||||
const override {
|
||||
CUDAStream cuda_stream{stream};
|
||||
CUDACachingAllocator::recordStream(data_ptr, cuda_stream);
|
||||
}
|
||||
};
|
||||
|
||||
}}} // namespace c10::cuda::impl
|
||||
} // namespace impl
|
||||
} // namespace cuda
|
||||
} // namespace c10
|
||||
|
|
|
|||
|
|
@ -29,4 +29,6 @@ int c10_cuda_private_test() {
|
|||
return 2;
|
||||
}
|
||||
|
||||
}}} // namespace c10::cuda::impl
|
||||
} // namespace impl
|
||||
} // namespace cuda
|
||||
} // namespace c10
|
||||
|
|
|
|||
|
|
@ -8,4 +8,6 @@ namespace impl {
|
|||
|
||||
C10_CUDA_API int c10_cuda_test();
|
||||
|
||||
}}} /// namespace c10::cuda::impl
|
||||
}
|
||||
} // namespace cuda
|
||||
} // namespace c10
|
||||
|
|
|
|||
|
|
@ -102,17 +102,17 @@
|
|||
// two pieces with confusing names?
|
||||
// Once upon a time, there _was_ only TORCH_CUDA_API. All was happy until we
|
||||
// tried to compile PyTorch for CUDA 11.1, which ran into relocation marker
|
||||
// issues when linking big binaries. (https://github.com/pytorch/pytorch/issues/39968)
|
||||
// We had two choices:
|
||||
// issues when linking big binaries.
|
||||
// (https://github.com/pytorch/pytorch/issues/39968) We had two choices:
|
||||
// (1) Stop supporting so many GPU architectures
|
||||
// (2) Do something else
|
||||
// We chose #2 and decided to split the behemoth that was torch_cuda into two
|
||||
// smaller libraries, one with most of the core kernel functions (torch_cuda_cu)
|
||||
// and the other that had..well..everything else (torch_cuda_cpp). The idea was this:
|
||||
// instead of linking our static libraries (like the hefty libcudnn_static.a) with
|
||||
// another huge library, torch_cuda, and run into pesky relocation marker issues,
|
||||
// we could link our static libraries to a smaller part of torch_cuda (torch_cuda_cpp)
|
||||
// and avoid the issues.
|
||||
// and the other that had..well..everything else (torch_cuda_cpp). The idea was
|
||||
// this: instead of linking our static libraries (like the hefty
|
||||
// libcudnn_static.a) with another huge library, torch_cuda, and run into pesky
|
||||
// relocation marker issues, we could link our static libraries to a smaller
|
||||
// part of torch_cuda (torch_cuda_cpp) and avoid the issues.
|
||||
|
||||
// libtorch_cuda_cu.so
|
||||
#ifdef TORCH_CUDA_CU_BUILD_MAIN_LIB
|
||||
|
|
@ -128,7 +128,8 @@
|
|||
#define TORCH_CUDA_CPP_API C10_IMPORT
|
||||
#endif
|
||||
|
||||
// libtorch_cuda.so (where torch_cuda_cu and torch_cuda_cpp are a part of the same api)
|
||||
// libtorch_cuda.so (where torch_cuda_cu and torch_cuda_cpp are a part of the
|
||||
// same api)
|
||||
#ifdef TORCH_CUDA_BUILD_MAIN_LIB
|
||||
#define TORCH_CUDA_CPP_API C10_EXPORT
|
||||
#define TORCH_CUDA_CU_API C10_EXPORT
|
||||
|
|
|
|||
|
|
@ -24,16 +24,17 @@
|
|||
#include <c10/macros/Export.h>
|
||||
|
||||
#if defined(__clang__)
|
||||
#define __ubsan_ignore_float_divide_by_zero__ __attribute__((no_sanitize("float-divide-by-zero")))
|
||||
#define __ubsan_ignore_undefined__ __attribute__((no_sanitize("undefined")))
|
||||
#define __ubsan_ignore_signed_int_overflow__ __attribute__((no_sanitize("signed-integer-overflow")))
|
||||
#define __ubsan_ignore_float_divide_by_zero__ \
|
||||
__attribute__((no_sanitize("float-divide-by-zero")))
|
||||
#define __ubsan_ignore_undefined__ __attribute__((no_sanitize("undefined")))
|
||||
#define __ubsan_ignore_signed_int_overflow__ \
|
||||
__attribute__((no_sanitize("signed-integer-overflow")))
|
||||
#else
|
||||
#define __ubsan_ignore_float_divide_by_zero__
|
||||
#define __ubsan_ignore_undefined__
|
||||
#define __ubsan_ignore_signed_int_overflow__
|
||||
#define __ubsan_ignore_float_divide_by_zero__
|
||||
#define __ubsan_ignore_undefined__
|
||||
#define __ubsan_ignore_signed_int_overflow__
|
||||
#endif
|
||||
|
||||
|
||||
// Detect address sanitizer as some stuff doesn't work with it
|
||||
#undef C10_ASAN_ENABLED
|
||||
|
||||
|
|
@ -57,7 +58,6 @@
|
|||
#define C10_ASAN_ENABLED 0
|
||||
#endif
|
||||
|
||||
|
||||
// Disable the copy and assignment operator for a class. Note that this will
|
||||
// disable the usage of the class in std containers.
|
||||
#define C10_DISABLE_COPY_AND_ASSIGN(classname) \
|
||||
|
|
@ -84,7 +84,6 @@
|
|||
#define C10_ANONYMOUS_VARIABLE(str) C10_CONCATENATE(str, __LINE__)
|
||||
#endif
|
||||
|
||||
|
||||
/// C10_NODISCARD - Warn if a type or return value is discarded.
|
||||
|
||||
// Technically, we should check if __cplusplus > 201402L here, because
|
||||
|
|
@ -108,24 +107,25 @@
|
|||
// - gcc 8.3: https://godbolt.org/z/4tLMQS (always advertises support)
|
||||
#define C10_NODISCARD
|
||||
#if defined(__has_cpp_attribute)
|
||||
# if __has_cpp_attribute(nodiscard)
|
||||
# undef C10_NODISCARD
|
||||
# define C10_NODISCARD [[nodiscard]]
|
||||
# endif
|
||||
#if __has_cpp_attribute(nodiscard)
|
||||
#undef C10_NODISCARD
|
||||
#define C10_NODISCARD [[nodiscard]]
|
||||
#endif
|
||||
// Workaround for llvm.org/PR23435, since clang 3.6 and below emit a spurious
|
||||
// error when __has_cpp_attribute is given a scoped attribute in C mode.
|
||||
#elif __cplusplus && defined(__has_cpp_attribute)
|
||||
# if __has_cpp_attribute(clang::warn_unused_result)
|
||||
// TODO: It's possible this is still triggering https://github.com/pytorch/pytorch/issues/13118
|
||||
// on Windows; if it is, better fix it.
|
||||
# undef C10_NODISCARD
|
||||
# define C10_NODISCARD [[clang::warn_unused_result]]
|
||||
# endif
|
||||
#if __has_cpp_attribute(clang::warn_unused_result)
|
||||
// TODO: It's possible this is still triggering
|
||||
// https://github.com/pytorch/pytorch/issues/13118 on Windows; if it is, better
|
||||
// fix it.
|
||||
#undef C10_NODISCARD
|
||||
#define C10_NODISCARD [[clang::warn_unused_result]]
|
||||
#endif
|
||||
#endif
|
||||
|
||||
// suppress an unused variable.
|
||||
#if defined(_MSC_VER) && !defined(__clang__)
|
||||
#define C10_UNUSED __pragma(warning(suppress: 4100 4101))
|
||||
#define C10_UNUSED __pragma(warning(suppress : 4100 4101))
|
||||
#else
|
||||
#define C10_UNUSED __attribute__((__unused__))
|
||||
#endif //_MSC_VER
|
||||
|
|
@ -135,16 +135,28 @@
|
|||
// Simply define the namespace, in case a dependent library want to refer to
|
||||
// the c10 namespace but not any nontrivial files.
|
||||
namespace c10 {} // namespace c10
|
||||
namespace c10 { namespace cuda {} }
|
||||
namespace c10 { namespace hip {} }
|
||||
namespace c10 {
|
||||
namespace cuda {}
|
||||
} // namespace c10
|
||||
namespace c10 {
|
||||
namespace hip {}
|
||||
} // namespace c10
|
||||
|
||||
// Since C10 is the core library for caffe2 (and aten), we will simply reroute
|
||||
// all abstractions defined in c10 to be available in caffe2 as well.
|
||||
// This is only for backwards compatibility. Please use the symbols from the
|
||||
// c10 namespace where possible.
|
||||
namespace caffe2 { using namespace c10; }
|
||||
namespace at { using namespace c10; }
|
||||
namespace at { namespace cuda { using namespace c10::cuda; }}
|
||||
namespace caffe2 {
|
||||
using namespace c10;
|
||||
}
|
||||
namespace at {
|
||||
using namespace c10;
|
||||
}
|
||||
namespace at {
|
||||
namespace cuda {
|
||||
using namespace c10::cuda;
|
||||
}
|
||||
} // namespace at
|
||||
|
||||
// WARNING!!! THIS IS A GIANT HACK!!!
|
||||
// This line means you cannot simultaneously include c10/hip
|
||||
|
|
@ -154,7 +166,11 @@ namespace at { namespace cuda { using namespace c10::cuda; }}
|
|||
// from at::cuda. This namespace makes that happen. When
|
||||
// HIPIFY is no longer out-of-place, we can switch the cuda
|
||||
// here to hip and everyone is happy.
|
||||
namespace at { namespace cuda { using namespace c10::hip; }}
|
||||
namespace at {
|
||||
namespace cuda {
|
||||
using namespace c10::hip;
|
||||
}
|
||||
} // namespace at
|
||||
|
||||
// C10_LIKELY/C10_UNLIKELY
|
||||
//
|
||||
|
|
@ -169,11 +185,11 @@ namespace at { namespace cuda { using namespace c10::hip; }}
|
|||
// without it.
|
||||
//
|
||||
#if defined(__GNUC__) || defined(__ICL) || defined(__clang__)
|
||||
#define C10_LIKELY(expr) (__builtin_expect(static_cast<bool>(expr), 1))
|
||||
#define C10_UNLIKELY(expr) (__builtin_expect(static_cast<bool>(expr), 0))
|
||||
#define C10_LIKELY(expr) (__builtin_expect(static_cast<bool>(expr), 1))
|
||||
#define C10_UNLIKELY(expr) (__builtin_expect(static_cast<bool>(expr), 0))
|
||||
#else
|
||||
#define C10_LIKELY(expr) (expr)
|
||||
#define C10_UNLIKELY(expr) (expr)
|
||||
#define C10_LIKELY(expr) (expr)
|
||||
#define C10_UNLIKELY(expr) (expr)
|
||||
#endif
|
||||
|
||||
/// C10_NOINLINE - Functions whose declaration is annotated with this will not
|
||||
|
|
@ -209,11 +225,13 @@ namespace at { namespace cuda { using namespace c10::hip; }}
|
|||
#define C10_HOST_DEVICE __host__ __device__
|
||||
#define C10_DEVICE __device__
|
||||
#define C10_HOST __host__
|
||||
// constants from (https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications)
|
||||
// The maximum number of threads per multiprocessor is 1024 for Turing architecture (7.5),
|
||||
// 1536 for Geforce Ampere (8.6),
|
||||
// and 2048 for all other architectures. You'll get warnings if you exceed these constants.
|
||||
// Hence, the following macros adjust the input values from the user to resolve potential warnings.
|
||||
// constants from
|
||||
// (https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications)
|
||||
// The maximum number of threads per multiprocessor is 1024 for Turing
|
||||
// architecture (7.5), 1536 for Geforce Ampere (8.6), and 2048 for all other
|
||||
// architectures. You'll get warnings if you exceed these constants. Hence, the
|
||||
// following macros adjust the input values from the user to resolve potential
|
||||
// warnings.
|
||||
#if __CUDA_ARCH__ == 750
|
||||
constexpr uint32_t CUDA_MAX_THREADS_PER_SM = 1024;
|
||||
#elif __CUDA_ARCH__ == 860
|
||||
|
|
@ -223,25 +241,39 @@ constexpr uint32_t CUDA_MAX_THREADS_PER_SM = 2048;
|
|||
#endif
|
||||
// CUDA_MAX_THREADS_PER_BLOCK is same for all architectures currently
|
||||
constexpr uint32_t CUDA_MAX_THREADS_PER_BLOCK = 1024;
|
||||
// CUDA_THREADS_PER_BLOCK_FALLBACK is the "canonical fallback" choice of block size.
|
||||
// 256 is a good number for this fallback and should give good occupancy and
|
||||
// versatility across all architectures.
|
||||
// CUDA_THREADS_PER_BLOCK_FALLBACK is the "canonical fallback" choice of block
|
||||
// size. 256 is a good number for this fallback and should give good occupancy
|
||||
// and versatility across all architectures.
|
||||
constexpr uint32_t CUDA_THREADS_PER_BLOCK_FALLBACK = 256;
|
||||
// NOTE: if you are thinking of constexpr-ify the inputs to launch bounds, it
|
||||
// turns out that although __launch_bounds__ can take constexpr, it
|
||||
// can't take a constexpr that has anything to do with templates.
|
||||
// Currently we use launch_bounds that depend on template arguments in
|
||||
// Loops.cuh, Reduce.cuh and LossCTC.cuh. Hence, C10_MAX_THREADS_PER_BLOCK and
|
||||
// C10_MIN_BLOCKS_PER_SM are kept as macros.
|
||||
// Suppose you were planning to write __launch_bounds__(a, b), based on your performance tuning on a modern GPU.
|
||||
// Instead, you should write __launch_bounds__(C10_MAX_THREADS_PER_BLOCK(a), C10_MIN_BLOCKS_PER_SM(a, b)),
|
||||
// Loops.cuh, Reduce.cuh and LossCTC.cuh. Hence, C10_MAX_THREADS_PER_BLOCK
|
||||
// and C10_MIN_BLOCKS_PER_SM are kept as macros.
|
||||
// Suppose you were planning to write __launch_bounds__(a, b), based on your
|
||||
// performance tuning on a modern GPU. Instead, you should write
|
||||
// __launch_bounds__(C10_MAX_THREADS_PER_BLOCK(a), C10_MIN_BLOCKS_PER_SM(a, b)),
|
||||
// which will also properly respect limits on old architectures.
|
||||
#define C10_MAX_THREADS_PER_BLOCK(val) (((val) <= CUDA_MAX_THREADS_PER_BLOCK) ? (val) : CUDA_THREADS_PER_BLOCK_FALLBACK)
|
||||
#define C10_MIN_BLOCKS_PER_SM(threads_per_block, blocks_per_sm) ((((threads_per_block)*(blocks_per_sm) <= CUDA_MAX_THREADS_PER_SM) ? (blocks_per_sm) : ((CUDA_MAX_THREADS_PER_SM + (threads_per_block) - 1) / (threads_per_block))))
|
||||
#define C10_MAX_THREADS_PER_BLOCK(val) \
|
||||
(((val) <= CUDA_MAX_THREADS_PER_BLOCK) ? (val) \
|
||||
: CUDA_THREADS_PER_BLOCK_FALLBACK)
|
||||
#define C10_MIN_BLOCKS_PER_SM(threads_per_block, blocks_per_sm) \
|
||||
((((threads_per_block) * (blocks_per_sm) <= CUDA_MAX_THREADS_PER_SM) \
|
||||
? (blocks_per_sm) \
|
||||
: ((CUDA_MAX_THREADS_PER_SM + (threads_per_block)-1) / \
|
||||
(threads_per_block))))
|
||||
// C10_LAUNCH_BOUNDS is analogous to __launch_bounds__
|
||||
#define C10_LAUNCH_BOUNDS_0 __launch_bounds__(256, 4) // default launch bounds that should give good occupancy and versatility across all architectures.
|
||||
#define C10_LAUNCH_BOUNDS_1(max_threads_per_block) __launch_bounds__((C10_MAX_THREADS_PER_BLOCK((max_threads_per_block))))
|
||||
#define C10_LAUNCH_BOUNDS_2(max_threads_per_block, min_blocks_per_sm) __launch_bounds__((C10_MAX_THREADS_PER_BLOCK((max_threads_per_block))), (C10_MIN_BLOCKS_PER_SM((max_threads_per_block), (min_blocks_per_sm))))
|
||||
#define C10_LAUNCH_BOUNDS_0 \
|
||||
__launch_bounds__( \
|
||||
256, 4) // default launch bounds that should give good occupancy and
|
||||
// versatility across all architectures.
|
||||
#define C10_LAUNCH_BOUNDS_1(max_threads_per_block) \
|
||||
__launch_bounds__((C10_MAX_THREADS_PER_BLOCK((max_threads_per_block))))
|
||||
#define C10_LAUNCH_BOUNDS_2(max_threads_per_block, min_blocks_per_sm) \
|
||||
__launch_bounds__( \
|
||||
(C10_MAX_THREADS_PER_BLOCK((max_threads_per_block))), \
|
||||
(C10_MIN_BLOCKS_PER_SM((max_threads_per_block), (min_blocks_per_sm))))
|
||||
#else
|
||||
#define C10_HOST_DEVICE
|
||||
#define C10_HOST
|
||||
|
|
@ -273,14 +305,12 @@ constexpr uint32_t CUDA_THREADS_PER_BLOCK_FALLBACK = 256;
|
|||
#elif defined(_MSC_VER)
|
||||
#if defined(NDEBUG)
|
||||
extern "C" {
|
||||
C10_IMPORT
|
||||
C10_IMPORT
|
||||
#if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__) || defined(__HIP__)
|
||||
__host__ __device__
|
||||
__host__ __device__
|
||||
#endif // __CUDA_ARCH__
|
||||
void _wassert(
|
||||
wchar_t const* _Message,
|
||||
wchar_t const* _File,
|
||||
unsigned _Line);
|
||||
void
|
||||
_wassert(wchar_t const* _Message, wchar_t const* _File, unsigned _Line);
|
||||
}
|
||||
#endif
|
||||
#define CUDA_KERNEL_ASSERT(cond) \
|
||||
|
|
@ -304,8 +334,8 @@ __host__ __device__
|
|||
#endif // NDEBUG
|
||||
#define CUDA_KERNEL_ASSERT(cond) \
|
||||
if (C10_UNLIKELY(!(cond))) { \
|
||||
__assert_fail(#cond, __FILE__, static_cast<unsigned int>(__LINE__), \
|
||||
__func__); \
|
||||
__assert_fail( \
|
||||
#cond, __FILE__, static_cast<unsigned int>(__LINE__), __func__); \
|
||||
}
|
||||
#endif // __APPLE__
|
||||
|
||||
|
|
|
|||
|
|
@ -19,16 +19,15 @@
|
|||
* See below for more information.
|
||||
* Why?
|
||||
* It has been observed that some mobile platforms, such as pixel 3, return
|
||||
* memory aggressively to the system. This results in page faults in some cases
|
||||
* and ends up hurting performance. This caching allocator aims to address that.
|
||||
* Furthermore it also allows users to specify their own allocator by implementing
|
||||
* allocate/free virtual interfaces.
|
||||
* What are the cons?
|
||||
* There are some cons that were observed where use of caching allocator led to
|
||||
* worse performance on some platforms. Reason being that the caching mechanism
|
||||
* used by this allocator left us worse off compared to the corresponding platform's
|
||||
* tuned memory allocator. In that case it seemed better to not use this allocator.
|
||||
* Note there are some ideas to fix this in the works.
|
||||
* memory aggressively to the system. This results in page faults in some
|
||||
* cases and ends up hurting performance. This caching allocator aims to address
|
||||
* that. Furthermore it also allows users to specify their own allocator by
|
||||
* implementing allocate/free virtual interfaces. What are the cons? There are
|
||||
* some cons that were observed where use of caching allocator led to worse
|
||||
* performance on some platforms. Reason being that the caching mechanism used
|
||||
* by this allocator left us worse off compared to the corresponding platform's
|
||||
* tuned memory allocator. In that case it seemed better to not use this
|
||||
* allocator. Note there are some ideas to fix this in the works.
|
||||
*
|
||||
* Usage:
|
||||
* Usage pattern:
|
||||
|
|
@ -53,42 +52,44 @@ class C10_API CPUCachingAllocator {
|
|||
* What it does not do:
|
||||
* No speculative allocation for any future allocations.
|
||||
*/
|
||||
private:
|
||||
inline void* allocate_and_cache(const size_t bytes);
|
||||
void free_cached();
|
||||
protected:
|
||||
// Invariants.
|
||||
// 1. If memory is ever allocated via this allocator then
|
||||
// the pointer will exist in allocation_map_, unless the allocator
|
||||
// returned the memory to OS via free_cached.
|
||||
// 1.1. Therefore even when the said memory is "freed" via this
|
||||
// allocator (and thus cached), it will continue to stay
|
||||
// in allocation_map_. Furthermore it will also exist in
|
||||
// available_map_. Thus an allocated memory pointer can be in both
|
||||
// allocation_map_ and available_map_ simultaneously.
|
||||
// 2. Memory pointer maybe removed from allocation_map_, when it
|
||||
// is freed outside of the scope of this allocator, but was allocated
|
||||
// by this allocator.
|
||||
// 3. Available map only contains that memory which was allocated
|
||||
// by this allocator and subsequently freed by this allocator.
|
||||
// As a result of above invariants, allocated memory ptr cannot be in
|
||||
// available_map_ unless it is in allocation_map_ as well.
|
||||
ska::flat_hash_map<size_t, c10::SmallVector<void*, 16>> available_map_;
|
||||
static ska::flat_hash_map<void*, size_t> allocation_map_;
|
||||
// Since allocation_map, which is a global instance, is mutated/read via
|
||||
// all public APIs we need a global mutex.
|
||||
static std::mutex mutex_;
|
||||
public:
|
||||
static void record_free(void* ptr);
|
||||
virtual ~CPUCachingAllocator();
|
||||
// Checks the cache to see if allocation of size bytes can be found.
|
||||
// If so return cached memory, else
|
||||
// allocates memory, records it for caching and returns.
|
||||
virtual void* allocate(const size_t bytes);
|
||||
// Checks if the memory being freed is was marked for allocation by
|
||||
// an earlier call to allocate. If so cache the allocation.
|
||||
// Otherwise free.
|
||||
virtual void free(void* ptr);
|
||||
private:
|
||||
inline void* allocate_and_cache(const size_t bytes);
|
||||
void free_cached();
|
||||
|
||||
protected:
|
||||
// Invariants.
|
||||
// 1. If memory is ever allocated via this allocator then
|
||||
// the pointer will exist in allocation_map_, unless the allocator
|
||||
// returned the memory to OS via free_cached.
|
||||
// 1.1. Therefore even when the said memory is "freed" via this
|
||||
// allocator (and thus cached), it will continue to stay
|
||||
// in allocation_map_. Furthermore it will also exist in
|
||||
// available_map_. Thus an allocated memory pointer can be in both
|
||||
// allocation_map_ and available_map_ simultaneously.
|
||||
// 2. Memory pointer maybe removed from allocation_map_, when it
|
||||
// is freed outside of the scope of this allocator, but was allocated
|
||||
// by this allocator.
|
||||
// 3. Available map only contains that memory which was allocated
|
||||
// by this allocator and subsequently freed by this allocator.
|
||||
// As a result of above invariants, allocated memory ptr cannot be in
|
||||
// available_map_ unless it is in allocation_map_ as well.
|
||||
ska::flat_hash_map<size_t, c10::SmallVector<void*, 16>> available_map_;
|
||||
static ska::flat_hash_map<void*, size_t> allocation_map_;
|
||||
// Since allocation_map, which is a global instance, is mutated/read via
|
||||
// all public APIs we need a global mutex.
|
||||
static std::mutex mutex_;
|
||||
|
||||
public:
|
||||
static void record_free(void* ptr);
|
||||
virtual ~CPUCachingAllocator();
|
||||
// Checks the cache to see if allocation of size bytes can be found.
|
||||
// If so return cached memory, else
|
||||
// allocates memory, records it for caching and returns.
|
||||
virtual void* allocate(const size_t bytes);
|
||||
// Checks if the memory being freed is was marked for allocation by
|
||||
// an earlier call to allocate. If so cache the allocation.
|
||||
// Otherwise free.
|
||||
virtual void free(void* ptr);
|
||||
};
|
||||
|
||||
CPUCachingAllocator* GetDefaultCPUCachingAllocator();
|
||||
|
|
@ -97,11 +98,12 @@ bool ThreadLocalCachingAllocatorEnabled();
|
|||
CPUCachingAllocator* GetThreadLocalCachingAllocator();
|
||||
|
||||
class C10_API WithCPUCachingAllocatorGuard {
|
||||
public:
|
||||
WithCPUCachingAllocatorGuard(CPUCachingAllocator* allocator);
|
||||
~WithCPUCachingAllocatorGuard();
|
||||
private:
|
||||
CPUCachingAllocator* prev_caching_allocator_ptr_{nullptr};
|
||||
public:
|
||||
WithCPUCachingAllocatorGuard(CPUCachingAllocator* allocator);
|
||||
~WithCPUCachingAllocatorGuard();
|
||||
|
||||
private:
|
||||
CPUCachingAllocator* prev_caching_allocator_ptr_{nullptr};
|
||||
};
|
||||
|
||||
} // namespace c10
|
||||
|
|
|
|||
|
|
@ -19,27 +19,23 @@ struct MemBlock {
|
|||
}
|
||||
};
|
||||
|
||||
enum class EventType {
|
||||
Allocate = 0,
|
||||
Free,
|
||||
Invalid
|
||||
};
|
||||
enum class EventType { Allocate = 0, Free, Invalid };
|
||||
|
||||
struct MemEvent {
|
||||
uint64_t time;
|
||||
uint64_t allocation_id;
|
||||
uint64_t size;
|
||||
EventType type{EventType::Invalid};
|
||||
MemEvent(uint64_t t, uint64_t id, uint64_t s, EventType e) :
|
||||
time(t), allocation_id(id), size(s), type(e) {}
|
||||
MemEvent(uint64_t t, uint64_t id, uint64_t s, EventType e)
|
||||
: time(t), allocation_id(id), size(s), type(e) {}
|
||||
};
|
||||
|
||||
bool overlaps(const MemBlock& a, const MemBlock& b) {
|
||||
// two blocks dont overlap if
|
||||
// |---a--------|--------------b--------|
|
||||
// strat_a end_a <= start_b end_b
|
||||
return
|
||||
!((a.end_offset <= b.start_offset) || (b.end_offset <= a.start_offset));
|
||||
return !(
|
||||
(a.end_offset <= b.start_offset) || (b.end_offset <= a.start_offset));
|
||||
}
|
||||
|
||||
bool validate_allocation_plan(
|
||||
|
|
@ -72,7 +68,8 @@ bool validate_allocation_plan(
|
|||
allocations.emplace(mem_block);
|
||||
} else if (event.type == EventType::Free) {
|
||||
auto it = allocations.find(mem_block);
|
||||
TORCH_CHECK((*it).end_offset == end_offset,
|
||||
TORCH_CHECK(
|
||||
(*it).end_offset == end_offset,
|
||||
"Enf offset of allocation being freed must match the one recorded.");
|
||||
TORCH_CHECK(
|
||||
it != allocations.end(),
|
||||
|
|
@ -98,13 +95,15 @@ std::vector<MemEvent> create_and_sort_mem_events(
|
|||
continue;
|
||||
}
|
||||
events.emplace_back(i, i, allocation_sizes[i], EventType::Allocate);
|
||||
events.emplace_back(allocation_lifetimes[i], i, allocation_sizes[i], EventType::Free);
|
||||
events.emplace_back(
|
||||
allocation_lifetimes[i], i, allocation_sizes[i], EventType::Free);
|
||||
}
|
||||
std::sort(
|
||||
events.begin(),
|
||||
events.end(),
|
||||
[](const MemEvent& a,
|
||||
const MemEvent& b) -> bool {return a.time < b.time;});
|
||||
[](const MemEvent& a, const MemEvent& b) -> bool {
|
||||
return a.time < b.time;
|
||||
});
|
||||
return events;
|
||||
}
|
||||
|
||||
|
|
@ -132,8 +131,10 @@ std::vector<uint64_t> formulate_greedy_allocation_plan(
|
|||
std::map<uint64_t, uint64_t> free_size_to_offset;
|
||||
// This provides fast lookup when we want to insert freed block
|
||||
// back, especially when we want to merge blocks.
|
||||
ska::flat_hash_map<uint64_t, std::map<uint64_t, uint64_t>::iterator> free_start_offset_to_size_iter;
|
||||
ska::flat_hash_map<uint64_t, std::map<uint64_t, uint64_t>::iterator> free_end_offset_to_size_iter;
|
||||
ska::flat_hash_map<uint64_t, std::map<uint64_t, uint64_t>::iterator>
|
||||
free_start_offset_to_size_iter;
|
||||
ska::flat_hash_map<uint64_t, std::map<uint64_t, uint64_t>::iterator>
|
||||
free_end_offset_to_size_iter;
|
||||
// Upon free end_ptr = offset + size
|
||||
// If end_ptr exists merge freed allocation
|
||||
// Also find corresponding offset in size_to_offset
|
||||
|
|
@ -146,7 +147,8 @@ std::vector<uint64_t> formulate_greedy_allocation_plan(
|
|||
|
||||
std::vector<uint64_t> allocation_offsets(
|
||||
allocation_sizes.size(), std::numeric_limits<uint64_t>::max());
|
||||
auto mem_events = create_and_sort_mem_events(allocation_sizes, allocation_lifetimes);
|
||||
auto mem_events =
|
||||
create_and_sort_mem_events(allocation_sizes, allocation_lifetimes);
|
||||
uint64_t max_offset{0};
|
||||
for (const auto& mem_event : mem_events) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
|
|
@ -221,13 +223,14 @@ std::vector<uint64_t> formulate_greedy_allocation_plan(
|
|||
free_start_offset_to_size_iter.erase(freed_offset);
|
||||
}
|
||||
auto freed_block_it =
|
||||
free_size_to_offset.emplace(freed_size, freed_offset).first;
|
||||
free_size_to_offset.emplace(freed_size, freed_offset).first;
|
||||
free_start_offset_to_size_iter.emplace(freed_offset, freed_block_it);
|
||||
free_end_offset_to_size_iter.emplace(
|
||||
freed_offset + freed_size, freed_block_it);
|
||||
}
|
||||
}
|
||||
TORCH_CHECK(validate_allocation_plan(mem_events, allocation_offsets),
|
||||
TORCH_CHECK(
|
||||
validate_allocation_plan(mem_events, allocation_offsets),
|
||||
"ProfilingAllocator: Allocation plan invalid.");
|
||||
return allocation_offsets;
|
||||
}
|
||||
|
|
@ -241,7 +244,8 @@ void AllocationPlan::clear() {
|
|||
}
|
||||
|
||||
void AllocationPlanner::record_allocation(
|
||||
const uint64_t size, const void* ptr) {
|
||||
const uint64_t size,
|
||||
const void* ptr) {
|
||||
if (validation_mode_) {
|
||||
validation_success = validation_success && validate_allocation(size, ptr);
|
||||
return;
|
||||
|
|
@ -264,13 +268,15 @@ void AllocationPlanner::record_free(const void* ptr) {
|
|||
return;
|
||||
}
|
||||
auto id = it->second;
|
||||
TORCH_CHECK(id < allocation_plan_->allocation_lifetimes.size(),
|
||||
TORCH_CHECK(
|
||||
id < allocation_plan_->allocation_lifetimes.size(),
|
||||
"Allocation must have been recorded during record_allocation.");
|
||||
allocation_plan_->allocation_lifetimes[id] = allocation_id_;
|
||||
}
|
||||
|
||||
bool AllocationPlanner::validate_allocation(
|
||||
const uint64_t size, const void* ptr) {
|
||||
const uint64_t size,
|
||||
const void* ptr) {
|
||||
if (allocation_id_ >= allocation_plan_->allocation_sizes.size() ||
|
||||
allocation_plan_->allocation_sizes[allocation_id_] != size) {
|
||||
TORCH_WARN(
|
||||
|
|
@ -286,7 +292,7 @@ bool AllocationPlanner::validate_allocation(
|
|||
|
||||
return false;
|
||||
}
|
||||
allocation_ptr_to_id_[ptr] = allocation_id_;
|
||||
allocation_ptr_to_id_[ptr] = allocation_id_;
|
||||
allocation_id_++;
|
||||
return true;
|
||||
}
|
||||
|
|
@ -298,24 +304,27 @@ bool AllocationPlanner::validate_free(const void* ptr) {
|
|||
return true;
|
||||
}
|
||||
auto id = (*it).second;
|
||||
TORCH_CHECK(id < allocation_plan_->allocation_lifetimes.size(),
|
||||
TORCH_CHECK(
|
||||
id < allocation_plan_->allocation_lifetimes.size(),
|
||||
"Allocation must have been recorded during validate_allocation.");
|
||||
auto lifetime_id = allocation_plan_->allocation_lifetimes[id];
|
||||
return (lifetime_id == allocation_id_);
|
||||
}
|
||||
|
||||
void AllocationPlanner::formulate_plan() {
|
||||
allocation_plan_->allocation_offsets =
|
||||
formulate_greedy_allocation_plan(
|
||||
allocation_plan_->allocation_sizes, allocation_plan_->allocation_lifetimes);
|
||||
allocation_plan_->allocation_offsets = formulate_greedy_allocation_plan(
|
||||
allocation_plan_->allocation_sizes,
|
||||
allocation_plan_->allocation_lifetimes);
|
||||
allocation_plan_->total_size = 0;
|
||||
for (const auto i : c10::irange(allocation_plan_->allocation_sizes.size())) {
|
||||
if (allocation_plan_->allocation_lifetimes[i] ==
|
||||
std::numeric_limits<uint64_t>::max()) {
|
||||
continue;
|
||||
}
|
||||
auto limit = allocation_plan_->allocation_offsets[i] + allocation_plan_->allocation_sizes[i];
|
||||
allocation_plan_->total_size = std::max(allocation_plan_->total_size, limit);
|
||||
auto limit = allocation_plan_->allocation_offsets[i] +
|
||||
allocation_plan_->allocation_sizes[i];
|
||||
allocation_plan_->total_size =
|
||||
std::max(allocation_plan_->total_size, limit);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -344,7 +353,8 @@ void CPUProfilingAllocator::unset_plan() {
|
|||
}
|
||||
|
||||
void* CPUProfilingAllocator::allocate(const size_t bytes) {
|
||||
TORCH_CHECK(bytes == plan_->allocation_sizes[allocation_id_],
|
||||
TORCH_CHECK(
|
||||
bytes == plan_->allocation_sizes[allocation_id_],
|
||||
"Got allocation request that does not match with the plan.");
|
||||
if (plan_->allocation_lifetimes[allocation_id_] ==
|
||||
std::numeric_limits<uint64_t>::max()) {
|
||||
|
|
@ -352,9 +362,8 @@ void* CPUProfilingAllocator::allocate(const size_t bytes) {
|
|||
allocation_id_++;
|
||||
return c10::alloc_cpu(bytes);
|
||||
}
|
||||
void* ptr =
|
||||
reinterpret_cast<uint8_t*>(blob_) +
|
||||
plan_->allocation_offsets[allocation_id_];
|
||||
void* ptr = reinterpret_cast<uint8_t*>(blob_) +
|
||||
plan_->allocation_offsets[allocation_id_];
|
||||
allocation_ptr_to_id_[ptr] = allocation_id_;
|
||||
allocation_id_++;
|
||||
return ptr;
|
||||
|
|
@ -364,8 +373,8 @@ void CPUProfilingAllocator::free(void* const ptr) {
|
|||
auto it = allocation_ptr_to_id_.find(ptr);
|
||||
if (it == allocation_ptr_to_id_.end()) {
|
||||
// Either
|
||||
// 1. Allocation that was made outside the validation scope is being freed here
|
||||
// or
|
||||
// 1. Allocation that was made outside the validation scope is being freed
|
||||
// here or
|
||||
// 2. Allocation that is not managed by profiling allocator is being freed.
|
||||
// Example of the second type
|
||||
// Tensor out;
|
||||
|
|
@ -380,7 +389,8 @@ void CPUProfilingAllocator::free(void* const ptr) {
|
|||
return;
|
||||
}
|
||||
auto id = it->second;
|
||||
TORCH_CHECK(id < plan_->allocation_lifetimes.size(),
|
||||
TORCH_CHECK(
|
||||
id < plan_->allocation_lifetimes.size(),
|
||||
"Freeing allocation that is not accordingly to the plan.");
|
||||
auto lifetime_id = plan_->allocation_lifetimes[id];
|
||||
TORCH_CHECK(
|
||||
|
|
@ -397,10 +407,10 @@ CPUProfilingAllocator::~CPUProfilingAllocator() {
|
|||
c10::free_cpu(blob_);
|
||||
}
|
||||
|
||||
WithProfileAllocationsGuard::WithProfileAllocationsGuard(
|
||||
AllocationPlan* plan) {
|
||||
WithProfileAllocationsGuard::WithProfileAllocationsGuard(AllocationPlan* plan) {
|
||||
// Nesting of allocation profiling does not seem meaningful.
|
||||
TORCH_CHECK(allocation_planner == nullptr,
|
||||
TORCH_CHECK(
|
||||
allocation_planner == nullptr,
|
||||
"Nesting profiling allocations is not supported.");
|
||||
planner_ = std::make_unique<AllocationPlanner>(plan);
|
||||
planner_->clear();
|
||||
|
|
@ -413,9 +423,11 @@ WithProfileAllocationsGuard::~WithProfileAllocationsGuard() {
|
|||
}
|
||||
|
||||
WithValidateAllocationPlanGuard::WithValidateAllocationPlanGuard(
|
||||
AllocationPlan* plan, bool* success) {
|
||||
AllocationPlan* plan,
|
||||
bool* success) {
|
||||
// Nesting of allocation profiling does not seem meaningful.
|
||||
TORCH_CHECK(allocation_planner == nullptr,
|
||||
TORCH_CHECK(
|
||||
allocation_planner == nullptr,
|
||||
"Nesting profiling allocations is not supported.");
|
||||
planner_ = std::make_unique<AllocationPlanner>(plan, true);
|
||||
success_ = success;
|
||||
|
|
@ -432,9 +444,11 @@ AllocationPlanner* GetThreadLocalAllocationPlanner() {
|
|||
}
|
||||
|
||||
WithProfilingAllocatorGuard::WithProfilingAllocatorGuard(
|
||||
CPUProfilingAllocator* allocator, const AllocationPlan* plan) {
|
||||
CPUProfilingAllocator* allocator,
|
||||
const AllocationPlan* plan) {
|
||||
// Nesting of profiling allocator is not supported.
|
||||
TORCH_CHECK(profiling_allocator == nullptr,
|
||||
TORCH_CHECK(
|
||||
profiling_allocator == nullptr,
|
||||
"Nesting profiling allocators is not supported.");
|
||||
profiling_allocator = allocator;
|
||||
profiling_allocator->set_plan(plan);
|
||||
|
|
|
|||
|
|
@ -16,29 +16,30 @@ namespace c10 {
|
|||
* Given a sequence of allocations in a thread, AllocationPlan records
|
||||
* 1. size of each allocation
|
||||
* 2. Lifetime of each allocation.
|
||||
* 3. allocation offsets: Memory offset for each allocation in a single blob of memory
|
||||
* 3. allocation offsets: Memory offset for each allocation in a single blob of
|
||||
* memory
|
||||
* 4. Total size of a blob of memory required to satisfy all the allocations.
|
||||
*/
|
||||
class C10_API AllocationPlan {
|
||||
private:
|
||||
// Records size of each allocation by their sequential allocation ids.
|
||||
std::vector<uint64_t> allocation_sizes;
|
||||
// This maps one allocation id (X) to another allocation id (Y).
|
||||
// Allocation X is alive until allocation Y. From allocation Y onwards
|
||||
// allocation X is not referenced.
|
||||
// Thus Y is the id of the first allocation after X is freed.
|
||||
// NB: When an allocation is recorded, along with recording its size,
|
||||
// we also set the lifetime to be numeric_limits::max()
|
||||
// This is to track allocations that are made during the scope of
|
||||
// profiling but were not freed until after the scope ended.
|
||||
// Such allocations are not managed by profiling allocator.
|
||||
std::vector<uint64_t> allocation_lifetimes;
|
||||
// Maps an allocation to some offset in a blob of memory.
|
||||
std::vector<uint64_t> allocation_offsets;
|
||||
uint64_t total_size{0};
|
||||
void clear();
|
||||
friend class AllocationPlanner;
|
||||
friend class CPUProfilingAllocator;
|
||||
private:
|
||||
// Records size of each allocation by their sequential allocation ids.
|
||||
std::vector<uint64_t> allocation_sizes;
|
||||
// This maps one allocation id (X) to another allocation id (Y).
|
||||
// Allocation X is alive until allocation Y. From allocation Y onwards
|
||||
// allocation X is not referenced.
|
||||
// Thus Y is the id of the first allocation after X is freed.
|
||||
// NB: When an allocation is recorded, along with recording its size,
|
||||
// we also set the lifetime to be numeric_limits::max()
|
||||
// This is to track allocations that are made during the scope of
|
||||
// profiling but were not freed until after the scope ended.
|
||||
// Such allocations are not managed by profiling allocator.
|
||||
std::vector<uint64_t> allocation_lifetimes;
|
||||
// Maps an allocation to some offset in a blob of memory.
|
||||
std::vector<uint64_t> allocation_offsets;
|
||||
uint64_t total_size{0};
|
||||
void clear();
|
||||
friend class AllocationPlanner;
|
||||
friend class CPUProfilingAllocator;
|
||||
};
|
||||
|
||||
/*
|
||||
|
|
@ -46,43 +47,45 @@ class C10_API AllocationPlan {
|
|||
* used to establish lifetime of allocations.
|
||||
*/
|
||||
class C10_API AllocationPlanner {
|
||||
private:
|
||||
AllocationPlan* allocation_plan_{nullptr};
|
||||
// Maps allocated ptr to its allocation id.
|
||||
// This is used when freeing the memory to lookup the allocation id
|
||||
// in order to establish the lifetime of a particular allocation.
|
||||
ska::flat_hash_map<const void*, uint64_t> allocation_ptr_to_id_;
|
||||
uint64_t allocation_id_{0};
|
||||
bool validation_mode_{false};
|
||||
private:
|
||||
AllocationPlan* allocation_plan_{nullptr};
|
||||
// Maps allocated ptr to its allocation id.
|
||||
// This is used when freeing the memory to lookup the allocation id
|
||||
// in order to establish the lifetime of a particular allocation.
|
||||
ska::flat_hash_map<const void*, uint64_t> allocation_ptr_to_id_;
|
||||
uint64_t allocation_id_{0};
|
||||
bool validation_mode_{false};
|
||||
|
||||
bool validate_allocation(const uint64_t size, const void* ptr);
|
||||
bool validate_free(const void* ptr);
|
||||
public:
|
||||
bool validation_success{true};
|
||||
bool validate_allocation(const uint64_t size, const void* ptr);
|
||||
bool validate_free(const void* ptr);
|
||||
|
||||
AllocationPlanner() = delete;
|
||||
AllocationPlanner(AllocationPlan* plan, bool validate = false) :
|
||||
allocation_plan_(plan), validation_mode_(validate) {}
|
||||
void record_allocation(const uint64_t size, const void* ptr);
|
||||
void record_free(const void* ptr);
|
||||
void formulate_plan();
|
||||
void clear();
|
||||
public:
|
||||
bool validation_success{true};
|
||||
|
||||
AllocationPlanner() = delete;
|
||||
AllocationPlanner(AllocationPlan* plan, bool validate = false)
|
||||
: allocation_plan_(plan), validation_mode_(validate) {}
|
||||
void record_allocation(const uint64_t size, const void* ptr);
|
||||
void record_free(const void* ptr);
|
||||
void formulate_plan();
|
||||
void clear();
|
||||
};
|
||||
|
||||
// NOT THREAD SAFE profiling allocator.
|
||||
class C10_API CPUProfilingAllocator {
|
||||
private:
|
||||
const AllocationPlan* plan_{nullptr};
|
||||
uint64_t allocation_id_{0};
|
||||
uint64_t current_size_{0};
|
||||
void* blob_{nullptr};
|
||||
ska::flat_hash_map<const void*, uint64_t> allocation_ptr_to_id_;
|
||||
public:
|
||||
~CPUProfilingAllocator();
|
||||
void set_plan(const AllocationPlan* plan);
|
||||
void unset_plan();
|
||||
void* allocate(const size_t bytes);
|
||||
void free(void* const ptr);
|
||||
private:
|
||||
const AllocationPlan* plan_{nullptr};
|
||||
uint64_t allocation_id_{0};
|
||||
uint64_t current_size_{0};
|
||||
void* blob_{nullptr};
|
||||
ska::flat_hash_map<const void*, uint64_t> allocation_ptr_to_id_;
|
||||
|
||||
public:
|
||||
~CPUProfilingAllocator();
|
||||
void set_plan(const AllocationPlan* plan);
|
||||
void unset_plan();
|
||||
void* allocate(const size_t bytes);
|
||||
void free(void* const ptr);
|
||||
};
|
||||
|
||||
/*
|
||||
|
|
@ -95,11 +98,12 @@ class C10_API CPUProfilingAllocator {
|
|||
* plan now contains allocation plan.
|
||||
*/
|
||||
class C10_API WithProfileAllocationsGuard {
|
||||
public:
|
||||
WithProfileAllocationsGuard(AllocationPlan* plan);
|
||||
~WithProfileAllocationsGuard();
|
||||
private:
|
||||
std::unique_ptr<AllocationPlanner> planner_;
|
||||
public:
|
||||
WithProfileAllocationsGuard(AllocationPlan* plan);
|
||||
~WithProfileAllocationsGuard();
|
||||
|
||||
private:
|
||||
std::unique_ptr<AllocationPlanner> planner_;
|
||||
};
|
||||
|
||||
/*
|
||||
|
|
@ -115,12 +119,13 @@ class C10_API WithProfileAllocationsGuard {
|
|||
* else for some inputs allocation pattern changed.
|
||||
*/
|
||||
class C10_API WithValidateAllocationPlanGuard {
|
||||
public:
|
||||
WithValidateAllocationPlanGuard(AllocationPlan* plan, bool* success);
|
||||
~WithValidateAllocationPlanGuard();
|
||||
private:
|
||||
std::unique_ptr<AllocationPlanner> planner_;
|
||||
bool* success_;
|
||||
public:
|
||||
WithValidateAllocationPlanGuard(AllocationPlan* plan, bool* success);
|
||||
~WithValidateAllocationPlanGuard();
|
||||
|
||||
private:
|
||||
std::unique_ptr<AllocationPlanner> planner_;
|
||||
bool* success_;
|
||||
};
|
||||
|
||||
AllocationPlanner* GetThreadLocalAllocationPlanner();
|
||||
|
|
@ -138,10 +143,11 @@ AllocationPlanner* GetThreadLocalAllocationPlanner();
|
|||
* }
|
||||
*/
|
||||
class C10_API WithProfilingAllocatorGuard {
|
||||
public:
|
||||
WithProfilingAllocatorGuard(
|
||||
CPUProfilingAllocator* allocator, const AllocationPlan* plan);
|
||||
~WithProfilingAllocatorGuard();
|
||||
public:
|
||||
WithProfilingAllocatorGuard(
|
||||
CPUProfilingAllocator* allocator,
|
||||
const AllocationPlan* plan);
|
||||
~WithProfilingAllocatorGuard();
|
||||
};
|
||||
|
||||
CPUProfilingAllocator* GetThreadLocalProfilingAllocator();
|
||||
|
|
|
|||
|
|
@ -5,63 +5,71 @@ namespace test_is_compile_time_function_pointer {
|
|||
static_assert(!c10::is_compile_time_function_pointer<void()>::value, "");
|
||||
|
||||
void dummy() {}
|
||||
static_assert(c10::is_compile_time_function_pointer<TORCH_FN_TYPE(dummy)>::value, "");
|
||||
}
|
||||
static_assert(
|
||||
c10::is_compile_time_function_pointer<TORCH_FN_TYPE(dummy)>::value,
|
||||
"");
|
||||
} // namespace test_is_compile_time_function_pointer
|
||||
|
||||
namespace test_access_through_type {
|
||||
void dummy() {}
|
||||
using dummy_ptr = TORCH_FN_TYPE(dummy);
|
||||
static_assert(c10::is_compile_time_function_pointer<dummy_ptr>::value, "");
|
||||
static_assert(dummy_ptr::func_ptr() == &dummy, "");
|
||||
static_assert(std::is_same<void(), dummy_ptr::FuncType>::value, "");
|
||||
}
|
||||
void dummy() {}
|
||||
using dummy_ptr = TORCH_FN_TYPE(dummy);
|
||||
static_assert(c10::is_compile_time_function_pointer<dummy_ptr>::value, "");
|
||||
static_assert(dummy_ptr::func_ptr() == &dummy, "");
|
||||
static_assert(std::is_same<void(), dummy_ptr::FuncType>::value, "");
|
||||
} // namespace test_access_through_type
|
||||
|
||||
namespace test_access_through_value {
|
||||
void dummy() {}
|
||||
constexpr auto dummy_ptr = TORCH_FN(dummy);
|
||||
static_assert(dummy_ptr.func_ptr() == &dummy, "");
|
||||
static_assert(std::is_same<void(), decltype(dummy_ptr)::FuncType>::value, "");
|
||||
}
|
||||
void dummy() {}
|
||||
constexpr auto dummy_ptr = TORCH_FN(dummy);
|
||||
static_assert(dummy_ptr.func_ptr() == &dummy, "");
|
||||
static_assert(std::is_same<void(), decltype(dummy_ptr)::FuncType>::value, "");
|
||||
} // namespace test_access_through_value
|
||||
|
||||
namespace test_access_through_type_also_works_if_specified_as_pointer {
|
||||
void dummy() {}
|
||||
using dummy_ptr = TORCH_FN_TYPE(&dummy);
|
||||
static_assert(c10::is_compile_time_function_pointer<dummy_ptr>::value, "");
|
||||
static_assert(dummy_ptr::func_ptr() == &dummy, "");
|
||||
static_assert(std::is_same<void(), dummy_ptr::FuncType>::value, "");
|
||||
}
|
||||
void dummy() {}
|
||||
using dummy_ptr = TORCH_FN_TYPE(&dummy);
|
||||
static_assert(c10::is_compile_time_function_pointer<dummy_ptr>::value, "");
|
||||
static_assert(dummy_ptr::func_ptr() == &dummy, "");
|
||||
static_assert(std::is_same<void(), dummy_ptr::FuncType>::value, "");
|
||||
} // namespace test_access_through_type_also_works_if_specified_as_pointer
|
||||
|
||||
namespace test_access_through_value_also_works_if_specified_as_pointer {
|
||||
void dummy() {}
|
||||
constexpr auto dummy_ptr = TORCH_FN(&dummy);
|
||||
static_assert(dummy_ptr.func_ptr() == &dummy, "");
|
||||
static_assert(std::is_same<void(), decltype(dummy_ptr)::FuncType>::value, "");
|
||||
}
|
||||
void dummy() {}
|
||||
constexpr auto dummy_ptr = TORCH_FN(&dummy);
|
||||
static_assert(dummy_ptr.func_ptr() == &dummy, "");
|
||||
static_assert(std::is_same<void(), decltype(dummy_ptr)::FuncType>::value, "");
|
||||
} // namespace test_access_through_value_also_works_if_specified_as_pointer
|
||||
|
||||
namespace test_run_through_type {
|
||||
int add(int a, int b) {return a + b;}
|
||||
using Add = TORCH_FN_TYPE(add);
|
||||
template<class Func> struct Executor {
|
||||
int execute(int a, int b) {
|
||||
return Func::func_ptr()(a, b);
|
||||
}
|
||||
};
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(CompileTimeFunctionPointerTest, runFunctionThroughType) {
|
||||
Executor<Add> executor;
|
||||
EXPECT_EQ(3, executor.execute(1, 2));
|
||||
}
|
||||
int add(int a, int b) {
|
||||
return a + b;
|
||||
}
|
||||
using Add = TORCH_FN_TYPE(add);
|
||||
template <class Func>
|
||||
struct Executor {
|
||||
int execute(int a, int b) {
|
||||
return Func::func_ptr()(a, b);
|
||||
}
|
||||
};
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(CompileTimeFunctionPointerTest, runFunctionThroughType) {
|
||||
Executor<Add> executor;
|
||||
EXPECT_EQ(3, executor.execute(1, 2));
|
||||
}
|
||||
} // namespace test_run_through_type
|
||||
|
||||
namespace test_run_through_value {
|
||||
int add(int a, int b) {return a + b;}
|
||||
template<class Func> int execute(Func, int a, int b) {
|
||||
return Func::func_ptr()(a, b);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(CompileTimeFunctionPointerTest, runFunctionThroughValue) {
|
||||
EXPECT_EQ(3, execute(TORCH_FN(add), 1, 2));
|
||||
}
|
||||
int add(int a, int b) {
|
||||
return a + b;
|
||||
}
|
||||
template <class Func>
|
||||
int execute(Func, int a, int b) {
|
||||
return Func::func_ptr()(a, b);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(CompileTimeFunctionPointerTest, runFunctionThroughValue) {
|
||||
EXPECT_EQ(3, execute(TORCH_FN(add), 1, 2));
|
||||
}
|
||||
} // namespace test_run_through_value
|
||||
|
|
|
|||
|
|
@ -9,7 +9,8 @@ using namespace c10;
|
|||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(DispatchKeySet, Empty) {
|
||||
DispatchKeySet empty_set;
|
||||
for (uint8_t i = 1; i < static_cast<uint8_t>(DispatchKey::NumDispatchKeys); i++) {
|
||||
for (uint8_t i = 1; i < static_cast<uint8_t>(DispatchKey::NumDispatchKeys);
|
||||
i++) {
|
||||
auto tid = static_cast<DispatchKey>(i);
|
||||
ASSERT_FALSE(empty_set.has(tid));
|
||||
}
|
||||
|
|
@ -21,7 +22,8 @@ TEST(DispatchKeySet, Empty) {
|
|||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(DispatchKeySet, Singleton) {
|
||||
for (uint8_t i = 1; i < static_cast<uint8_t>(DispatchKey::NumDispatchKeys); i++) {
|
||||
for (uint8_t i = 1; i < static_cast<uint8_t>(DispatchKey::NumDispatchKeys);
|
||||
i++) {
|
||||
auto tid = static_cast<DispatchKey>(i);
|
||||
DispatchKeySet sing(tid);
|
||||
ASSERT_EQ(sing, sing);
|
||||
|
|
@ -37,8 +39,11 @@ TEST(DispatchKeySet, Singleton) {
|
|||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(DispatchKeySet, Doubleton) {
|
||||
for (uint8_t i = 1; i < static_cast<uint8_t>(DispatchKey::NumDispatchKeys); i++) {
|
||||
for (uint8_t j = i + 1; j < static_cast<uint8_t>(DispatchKey::NumDispatchKeys); j++) {
|
||||
for (uint8_t i = 1; i < static_cast<uint8_t>(DispatchKey::NumDispatchKeys);
|
||||
i++) {
|
||||
for (uint8_t j = i + 1;
|
||||
j < static_cast<uint8_t>(DispatchKey::NumDispatchKeys);
|
||||
j++) {
|
||||
ASSERT_LT(i, j);
|
||||
auto tid1 = static_cast<DispatchKey>(i);
|
||||
auto tid2 = static_cast<DispatchKey>(j);
|
||||
|
|
@ -46,7 +51,7 @@ TEST(DispatchKeySet, Doubleton) {
|
|||
ASSERT_EQ(doub, DispatchKeySet(tid1) | DispatchKeySet(tid2));
|
||||
ASSERT_TRUE(doub.has(tid1));
|
||||
ASSERT_TRUE(doub.has(tid2));
|
||||
ASSERT_EQ(doub.highestPriorityTypeId(), tid2); // relies on i < j
|
||||
ASSERT_EQ(doub.highestPriorityTypeId(), tid2); // relies on i < j
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -54,7 +59,8 @@ TEST(DispatchKeySet, Doubleton) {
|
|||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(DispatchKeySet, Full) {
|
||||
DispatchKeySet full(DispatchKeySet::FULL);
|
||||
for (uint8_t i = 1; i < static_cast<uint8_t>(DispatchKey::NumDispatchKeys); i++) {
|
||||
for (uint8_t i = 1; i < static_cast<uint8_t>(DispatchKey::NumDispatchKeys);
|
||||
i++) {
|
||||
auto tid = static_cast<DispatchKey>(i);
|
||||
ASSERT_TRUE(full.has(tid));
|
||||
}
|
||||
|
|
@ -103,13 +109,12 @@ TEST(DispatchKeySet, IteratorFull) {
|
|||
ASSERT_EQ(i, static_cast<uint8_t>(DispatchKey::NumDispatchKeys) - 1);
|
||||
}
|
||||
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(DispatchKeySet, IteratorRangeFull) {
|
||||
DispatchKeySet full_set(DispatchKeySet::FULL);
|
||||
uint8_t i = 0;
|
||||
|
||||
for (DispatchKey dispatch_key: full_set) {
|
||||
for (DispatchKey dispatch_key : full_set) {
|
||||
i++;
|
||||
ASSERT_TRUE(dispatch_key == static_cast<DispatchKey>(i));
|
||||
}
|
||||
|
|
@ -126,17 +131,20 @@ TEST(DispatchKeySet, SpecificKeys) {
|
|||
static_cast<DispatchKey>(10),
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
static_cast<DispatchKey>(15),
|
||||
});
|
||||
});
|
||||
std::unordered_set<DispatchKey> visited_keys;
|
||||
|
||||
for(DispatchKey key: keyset) {
|
||||
for (DispatchKey key : keyset) {
|
||||
visited_keys.insert(key);
|
||||
}
|
||||
|
||||
ASSERT_EQ(visited_keys.size(), 3);
|
||||
ASSERT_TRUE(visited_keys.find(static_cast<DispatchKey>(4)) != visited_keys.end());
|
||||
ASSERT_TRUE(visited_keys.find(static_cast<DispatchKey>(10)) != visited_keys.end());
|
||||
ASSERT_TRUE(visited_keys.find(static_cast<DispatchKey>(15)) != visited_keys.end());
|
||||
ASSERT_TRUE(
|
||||
visited_keys.find(static_cast<DispatchKey>(4)) != visited_keys.end());
|
||||
ASSERT_TRUE(
|
||||
visited_keys.find(static_cast<DispatchKey>(10)) != visited_keys.end());
|
||||
ASSERT_TRUE(
|
||||
visited_keys.find(static_cast<DispatchKey>(15)) != visited_keys.end());
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
|
|
@ -145,9 +153,8 @@ TEST(DispatchKeySet, FailAtEndIterator) {
|
|||
uint64_t raw_repr = full_set.raw_repr();
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
|
||||
EXPECT_THROW(DispatchKeySet::iterator(
|
||||
&raw_repr,
|
||||
static_cast<uint8_t>(DispatchKey::NumDispatchKeys) + 1
|
||||
),
|
||||
c10::Error);
|
||||
EXPECT_THROW(
|
||||
DispatchKeySet::iterator(
|
||||
&raw_repr, static_cast<uint8_t>(DispatchKey::NumDispatchKeys) + 1),
|
||||
c10::Error);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
#include <gtest/gtest.h>
|
||||
|
||||
#include <c10/core/impl/InlineDeviceGuard.h>
|
||||
#include <c10/core/impl/FakeGuardImpl.h>
|
||||
#include <c10/core/impl/InlineDeviceGuard.h>
|
||||
|
||||
using namespace c10;
|
||||
using namespace c10::impl;
|
||||
|
|
@ -55,8 +55,8 @@ TEST(InlineDeviceGuard, Constructor) {
|
|||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(InlineDeviceGuard, ConstructorError) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
|
||||
EXPECT_ANY_THROW(InlineDeviceGuard<FakeGuardImpl<DeviceType::CUDA>>
|
||||
g(Device(DeviceType::HIP, 1)));
|
||||
EXPECT_ANY_THROW(InlineDeviceGuard<FakeGuardImpl<DeviceType::CUDA>> g(
|
||||
Device(DeviceType::HIP, 1)));
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
|
|
@ -110,7 +110,8 @@ TEST(InlineDeviceGuard, SetIndex) {
|
|||
ASSERT_EQ(TestGuardImpl::getDeviceIndex(), i2);
|
||||
}
|
||||
|
||||
// -- InlineOptionalDeviceGuard --------------------------------------------------
|
||||
// -- InlineOptionalDeviceGuard
|
||||
// --------------------------------------------------
|
||||
|
||||
using MaybeTestGuard = InlineOptionalDeviceGuard<TestGuardImpl>;
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
#include <gtest/gtest.h>
|
||||
|
||||
#include <c10/core/impl/InlineStreamGuard.h>
|
||||
#include <c10/core/impl/FakeGuardImpl.h>
|
||||
#include <c10/core/impl/InlineStreamGuard.h>
|
||||
|
||||
using namespace c10;
|
||||
using namespace c10::impl;
|
||||
|
|
@ -40,7 +40,6 @@ TEST(InlineStreamGuard, Constructor) {
|
|||
ASSERT_EQ(TestGuardImpl::getCurrentStreamIdFor(0), 0);
|
||||
}
|
||||
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(InlineStreamGuard, ResetStreamSameSameDevice) {
|
||||
TestGuardImpl::setDeviceIndex(0);
|
||||
|
|
@ -101,7 +100,8 @@ TEST(InlineStreamGuard, ResetStreamDifferentDevice) {
|
|||
ASSERT_EQ(TestGuardImpl::getCurrentStreamIdFor(0), 0);
|
||||
}
|
||||
|
||||
// -- OptionalInlineStreamGuard -------------------------------------------------------
|
||||
// -- OptionalInlineStreamGuard
|
||||
// -------------------------------------------------------
|
||||
|
||||
using OptionalTestGuard = InlineOptionalStreamGuard<TestGuardImpl>;
|
||||
|
||||
|
|
@ -180,7 +180,8 @@ TEST(InlineOptionalStreamGuard, ResetStreamDifferentDevice) {
|
|||
ASSERT_EQ(TestGuardImpl::getCurrentStreamIdFor(0), 0);
|
||||
}
|
||||
|
||||
// -- InlineMultiStreamGuard -------------------------------------------------------
|
||||
// -- InlineMultiStreamGuard
|
||||
// -------------------------------------------------------
|
||||
|
||||
using MultiTestGuard = InlineMultiStreamGuard<TestGuardImpl>;
|
||||
|
||||
|
|
|
|||
|
|
@ -6,12 +6,16 @@
|
|||
using namespace c10;
|
||||
using namespace c10::impl;
|
||||
|
||||
static void checkData(const SizesAndStrides& sz, IntArrayRef sizes, IntArrayRef strides) {
|
||||
EXPECT_EQ(sizes.size(), strides.size()) << "bad test case: size() of sizes and strides don't match";
|
||||
static void checkData(
|
||||
const SizesAndStrides& sz,
|
||||
IntArrayRef sizes,
|
||||
IntArrayRef strides) {
|
||||
EXPECT_EQ(sizes.size(), strides.size())
|
||||
<< "bad test case: size() of sizes and strides don't match";
|
||||
EXPECT_EQ(sz.size(), sizes.size());
|
||||
|
||||
int idx = 0;
|
||||
for (auto x: sizes) {
|
||||
for (auto x : sizes) {
|
||||
EXPECT_EQ(sz.size_at_unchecked(idx), x) << "index: " << idx;
|
||||
EXPECT_EQ(sz.size_at(idx), x) << "index: " << idx;
|
||||
EXPECT_EQ(sz.sizes_data()[idx], x) << "index: " << idx;
|
||||
|
|
@ -21,7 +25,7 @@ static void checkData(const SizesAndStrides& sz, IntArrayRef sizes, IntArrayRef
|
|||
EXPECT_EQ(sz.sizes_arrayref(), sizes);
|
||||
|
||||
idx = 0;
|
||||
for (auto x: strides) {
|
||||
for (auto x : strides) {
|
||||
EXPECT_EQ(sz.stride_at_unchecked(idx), x) << "index: " << idx;
|
||||
EXPECT_EQ(sz.stride_at(idx), x) << "index: " << idx;
|
||||
EXPECT_EQ(sz.strides_data()[idx], x) << "index: " << idx;
|
||||
|
|
|
|||
|
|
@ -6,90 +6,92 @@ using c10::guts::to_array;
|
|||
|
||||
namespace {
|
||||
namespace test_equals {
|
||||
static_assert(array<int, 0>{{}} == array<int, 0>{{}}, "");
|
||||
static_assert(array<int, 3>{{2, 3, 4}} == array<int, 3>{{2, 3, 4}}, "");
|
||||
static_assert(!(array<int, 3>{{2, 3, 4}} == array<int, 3>{{1, 3, 4}}), "");
|
||||
static_assert(!(array<int, 3>{{2, 3, 4}} == array<int, 3>{{2, 1, 4}}), "");
|
||||
static_assert(!(array<int, 3>{{2, 3, 4}} == array<int, 3>{{2, 3, 1}}), "");
|
||||
}
|
||||
static_assert(array<int, 0>{{}} == array<int, 0>{{}}, "");
|
||||
static_assert(array<int, 3>{{2, 3, 4}} == array<int, 3>{{2, 3, 4}}, "");
|
||||
static_assert(!(array<int, 3>{{2, 3, 4}} == array<int, 3>{{1, 3, 4}}), "");
|
||||
static_assert(!(array<int, 3>{{2, 3, 4}} == array<int, 3>{{2, 1, 4}}), "");
|
||||
static_assert(!(array<int, 3>{{2, 3, 4}} == array<int, 3>{{2, 3, 1}}), "");
|
||||
} // namespace test_equals
|
||||
|
||||
namespace test_notequals {
|
||||
static_assert(!(array<int, 0>{{}} != array<int, 0>{{}}), "");
|
||||
static_assert(!(array<int, 3>{{2, 3, 4}} != array<int, 3>{{2, 3, 4}}), "");
|
||||
static_assert(array<int, 3>{{2, 3, 4}} != array<int, 3>{{1, 3, 4}}, "");
|
||||
static_assert(array<int, 3>{{2, 3, 4}} != array<int, 3>{{2, 1, 4}}, "");
|
||||
static_assert(array<int, 3>{{2, 3, 4}} != array<int, 3>{{2, 3, 1}}, "");
|
||||
}
|
||||
static_assert(!(array<int, 0>{{}} != array<int, 0>{{}}), "");
|
||||
static_assert(!(array<int, 3>{{2, 3, 4}} != array<int, 3>{{2, 3, 4}}), "");
|
||||
static_assert(array<int, 3>{{2, 3, 4}} != array<int, 3>{{1, 3, 4}}, "");
|
||||
static_assert(array<int, 3>{{2, 3, 4}} != array<int, 3>{{2, 1, 4}}, "");
|
||||
static_assert(array<int, 3>{{2, 3, 4}} != array<int, 3>{{2, 3, 1}}, "");
|
||||
} // namespace test_notequals
|
||||
|
||||
namespace test_lessthan {
|
||||
static_assert(!(array<int, 0>{{}} < array<int, 0>{{}}), "");
|
||||
static_assert(!(array<int, 1>{{2}} < array<int, 1>{{1}}), "");
|
||||
static_assert(array<int, 1>{{1}} < array<int, 1>{{2}}, "");
|
||||
static_assert(!(array<int, 3>{{1, 2, 3}} < array<int, 3>{{1, 2, 3}}), "");
|
||||
static_assert(array<int, 3>{{1, 2, 3}} < array<int, 3>{{2, 2, 3}}, "");
|
||||
static_assert(!(array<int, 3>{{1, 2, 3}} < array<int, 3>{{0, 2, 3}}), "");
|
||||
static_assert(array<int, 3>{{1, 2, 3}} < array<int, 3>{{1, 3, 3}}, "");
|
||||
static_assert(!(array<int, 3>{{1, 2, 3}} < array<int, 3>{{1, 1, 3}}), "");
|
||||
static_assert(array<int, 3>{{1, 2, 3}} < array<int, 3>{{1, 2, 4}}, "");
|
||||
static_assert(!(array<int, 3>{{1, 2, 3}} < array<int, 3>{{1, 2, 2}}), "");
|
||||
}
|
||||
static_assert(!(array<int, 0>{{}} < array<int, 0>{{}}), "");
|
||||
static_assert(!(array<int, 1>{{2}} < array<int, 1>{{1}}), "");
|
||||
static_assert(array<int, 1>{{1}} < array<int, 1>{{2}}, "");
|
||||
static_assert(!(array<int, 3>{{1, 2, 3}} < array<int, 3>{{1, 2, 3}}), "");
|
||||
static_assert(array<int, 3>{{1, 2, 3}} < array<int, 3>{{2, 2, 3}}, "");
|
||||
static_assert(!(array<int, 3>{{1, 2, 3}} < array<int, 3>{{0, 2, 3}}), "");
|
||||
static_assert(array<int, 3>{{1, 2, 3}} < array<int, 3>{{1, 3, 3}}, "");
|
||||
static_assert(!(array<int, 3>{{1, 2, 3}} < array<int, 3>{{1, 1, 3}}), "");
|
||||
static_assert(array<int, 3>{{1, 2, 3}} < array<int, 3>{{1, 2, 4}}, "");
|
||||
static_assert(!(array<int, 3>{{1, 2, 3}} < array<int, 3>{{1, 2, 2}}), "");
|
||||
} // namespace test_lessthan
|
||||
|
||||
namespace test_greaterthan {
|
||||
static_assert(!(array<int, 0>{{}} > array<int, 0>{{}}), "");
|
||||
static_assert(!(array<int, 1>{{1}} > array<int, 1>{{2}}), "");
|
||||
static_assert(array<int, 1>{{2}} > array<int, 1>{{1}}, "");
|
||||
static_assert(!(array<int, 3>{{1, 2, 3}} > array<int, 3>{{1, 2, 3}}), "");
|
||||
static_assert(array<int, 3>{{2, 2, 3}} > array<int, 3>{{1, 2, 3}}, "");
|
||||
static_assert(!(array<int, 3>{{0, 2, 3}} > array<int, 3>{{1, 2, 3}}), "");
|
||||
static_assert(array<int, 3>{{1, 3, 3}} > array<int, 3>{{1, 2, 3}}, "");
|
||||
static_assert(!(array<int, 3>{{1, 1, 3}} > array<int, 3>{{1, 2, 3}}), "");
|
||||
static_assert(array<int, 3>{{1, 2, 4}} > array<int, 3>{{1, 2, 3}}, "");
|
||||
static_assert(!(array<int, 3>{{1, 2, 2}} > array<int, 3>{{1, 2, 3}}), "");
|
||||
}
|
||||
static_assert(!(array<int, 0>{{}} > array<int, 0>{{}}), "");
|
||||
static_assert(!(array<int, 1>{{1}} > array<int, 1>{{2}}), "");
|
||||
static_assert(array<int, 1>{{2}} > array<int, 1>{{1}}, "");
|
||||
static_assert(!(array<int, 3>{{1, 2, 3}} > array<int, 3>{{1, 2, 3}}), "");
|
||||
static_assert(array<int, 3>{{2, 2, 3}} > array<int, 3>{{1, 2, 3}}, "");
|
||||
static_assert(!(array<int, 3>{{0, 2, 3}} > array<int, 3>{{1, 2, 3}}), "");
|
||||
static_assert(array<int, 3>{{1, 3, 3}} > array<int, 3>{{1, 2, 3}}, "");
|
||||
static_assert(!(array<int, 3>{{1, 1, 3}} > array<int, 3>{{1, 2, 3}}), "");
|
||||
static_assert(array<int, 3>{{1, 2, 4}} > array<int, 3>{{1, 2, 3}}, "");
|
||||
static_assert(!(array<int, 3>{{1, 2, 2}} > array<int, 3>{{1, 2, 3}}), "");
|
||||
} // namespace test_greaterthan
|
||||
|
||||
namespace test_lessequals {
|
||||
static_assert(array<int, 0>{{}} <= array<int, 0>{{}}, "");
|
||||
static_assert(!(array<int, 1>{{2}} <= array<int, 1>{{1}}), "");
|
||||
static_assert(array<int, 1>{{1}} <= array<int, 1>{{2}}, "");
|
||||
static_assert(array<int, 3>{{1, 2, 3}} <= array<int, 3>{{1, 2, 3}}, "");
|
||||
static_assert(array<int, 3>{{1, 2, 3}} <= array<int, 3>{{2, 2, 3}}, "");
|
||||
static_assert(!(array<int, 3>{{1, 2, 3}} <= array<int, 3>{{0, 2, 3}}), "");
|
||||
static_assert(array<int, 3>{{1, 2, 3}} <= array<int, 3>{{1, 3, 3}}, "");
|
||||
static_assert(!(array<int, 3>{{1, 2, 3}} <= array<int, 3>{{1, 1, 3}}), "");
|
||||
static_assert(array<int, 3>{{1, 2, 3}} <= array<int, 3>{{1, 2, 4}}, "");
|
||||
static_assert(!(array<int, 3>{{1, 2, 3}} <= array<int, 3>{{1, 2, 2}}), "");
|
||||
}
|
||||
static_assert(array<int, 0>{{}} <= array<int, 0>{{}}, "");
|
||||
static_assert(!(array<int, 1>{{2}} <= array<int, 1>{{1}}), "");
|
||||
static_assert(array<int, 1>{{1}} <= array<int, 1>{{2}}, "");
|
||||
static_assert(array<int, 3>{{1, 2, 3}} <= array<int, 3>{{1, 2, 3}}, "");
|
||||
static_assert(array<int, 3>{{1, 2, 3}} <= array<int, 3>{{2, 2, 3}}, "");
|
||||
static_assert(!(array<int, 3>{{1, 2, 3}} <= array<int, 3>{{0, 2, 3}}), "");
|
||||
static_assert(array<int, 3>{{1, 2, 3}} <= array<int, 3>{{1, 3, 3}}, "");
|
||||
static_assert(!(array<int, 3>{{1, 2, 3}} <= array<int, 3>{{1, 1, 3}}), "");
|
||||
static_assert(array<int, 3>{{1, 2, 3}} <= array<int, 3>{{1, 2, 4}}, "");
|
||||
static_assert(!(array<int, 3>{{1, 2, 3}} <= array<int, 3>{{1, 2, 2}}), "");
|
||||
} // namespace test_lessequals
|
||||
|
||||
namespace test_greaterequals {
|
||||
static_assert(array<int, 0>{{}} >= array<int, 0>{{}}, "");
|
||||
static_assert(!(array<int, 1>{{1}} >= array<int, 1>{{2}}), "");
|
||||
static_assert(array<int, 1>{{2}} >= array<int, 1>{{1}}, "");
|
||||
static_assert(array<int, 3>{{1, 2, 3}} >= array<int, 3>{{1, 2, 3}}, "");
|
||||
static_assert(array<int, 3>{{2, 2, 3}} >= array<int, 3>{{1, 2, 3}}, "");
|
||||
static_assert(!(array<int, 3>{{0, 2, 3}} >= array<int, 3>{{1, 2, 3}}), "");
|
||||
static_assert(array<int, 3>{{1, 3, 3}} >= array<int, 3>{{1, 2, 3}}, "");
|
||||
static_assert(!(array<int, 3>{{1, 1, 3}} >= array<int, 3>{{1, 2, 3}}), "");
|
||||
static_assert(array<int, 3>{{1, 2, 4}} >= array<int, 3>{{1, 2, 3}}, "");
|
||||
static_assert(!(array<int, 3>{{1, 2, 2}} >= array<int, 3>{{1, 2, 3}}), "");
|
||||
}
|
||||
static_assert(array<int, 0>{{}} >= array<int, 0>{{}}, "");
|
||||
static_assert(!(array<int, 1>{{1}} >= array<int, 1>{{2}}), "");
|
||||
static_assert(array<int, 1>{{2}} >= array<int, 1>{{1}}, "");
|
||||
static_assert(array<int, 3>{{1, 2, 3}} >= array<int, 3>{{1, 2, 3}}, "");
|
||||
static_assert(array<int, 3>{{2, 2, 3}} >= array<int, 3>{{1, 2, 3}}, "");
|
||||
static_assert(!(array<int, 3>{{0, 2, 3}} >= array<int, 3>{{1, 2, 3}}), "");
|
||||
static_assert(array<int, 3>{{1, 3, 3}} >= array<int, 3>{{1, 2, 3}}, "");
|
||||
static_assert(!(array<int, 3>{{1, 1, 3}} >= array<int, 3>{{1, 2, 3}}), "");
|
||||
static_assert(array<int, 3>{{1, 2, 4}} >= array<int, 3>{{1, 2, 3}}, "");
|
||||
static_assert(!(array<int, 3>{{1, 2, 2}} >= array<int, 3>{{1, 2, 3}}), "");
|
||||
} // namespace test_greaterequals
|
||||
|
||||
namespace test_tail {
|
||||
static_assert(array < int, 2 > {{3, 4}} == tail(array < int, 3 > {{2, 3, 4}}), "");
|
||||
static_assert(array < int, 0 > {{}} == tail(array < int, 1 > {{3}}), "");
|
||||
}
|
||||
static_assert(array<int, 2>{{3, 4}} == tail(array<int, 3>{{2, 3, 4}}), "");
|
||||
static_assert(array<int, 0>{{}} == tail(array<int, 1>{{3}}), "");
|
||||
} // namespace test_tail
|
||||
|
||||
namespace test_prepend {
|
||||
static_assert(array < int, 3 > {{2, 3, 4}} == prepend(2, array < int, 2 > {{3, 4}}), "");
|
||||
static_assert(array < int, 1 > {{3}} == prepend(3, array < int, 0 > {{}}), "");
|
||||
}
|
||||
static_assert(
|
||||
array<int, 3>{{2, 3, 4}} == prepend(2, array<int, 2>{{3, 4}}),
|
||||
"");
|
||||
static_assert(array<int, 1>{{3}} == prepend(3, array<int, 0>{{}}), "");
|
||||
} // namespace test_prepend
|
||||
|
||||
namespace test_to_std_array {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
||||
constexpr int obj2[3] = {3, 5, 6};
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
static_assert(array < int, 3 > {{3, 5, 6}} == to_array(obj2), "");
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
static_assert(array < int, 3 > {{3, 5, 6}} == to_array<int, 3>({3, 5, 6}), "");
|
||||
}
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
||||
constexpr int obj2[3] = {3, 5, 6};
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
static_assert(array<int, 3>{{3, 5, 6}} == to_array(obj2), "");
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
static_assert(array<int, 3>{{3, 5, 6}} == to_array<int, 3>({3, 5, 6}), "");
|
||||
} // namespace test_to_std_array
|
||||
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ static_assert(min(5, 3) == 3, "");
|
|||
static_assert(min(3, 3) == 3, "");
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
static_assert(min(3.0, 3.1) == 3.0, "");
|
||||
}
|
||||
} // namespace test_min
|
||||
|
||||
namespace test_max {
|
||||
using c10::guts::max;
|
||||
|
|
@ -23,7 +23,7 @@ static_assert(max(5, 3) == 5, "");
|
|||
static_assert(max(3, 3) == 3, "");
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
static_assert(max(3.0, 3.1) == 3.1, "");
|
||||
}
|
||||
} // namespace test_max
|
||||
|
||||
namespace test_if_constexpr {
|
||||
|
||||
|
|
@ -31,129 +31,125 @@ using c10::guts::if_constexpr;
|
|||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(if_constexpr, whenIsTrue_thenReturnsTrueCase) {
|
||||
EXPECT_EQ(4, if_constexpr<true>([](auto) { return 4; }, [](auto) { return 5; }));
|
||||
EXPECT_EQ(
|
||||
4, if_constexpr<true>([](auto) { return 4; }, [](auto) { return 5; }));
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(if_constexpr, whenIsFalse_thenReturnsFalseCase) {
|
||||
EXPECT_EQ(5, if_constexpr<false>([](auto) { return 4; }, [](auto) { return 5; }));
|
||||
EXPECT_EQ(
|
||||
5, if_constexpr<false>([](auto) { return 4; }, [](auto) { return 5; }));
|
||||
}
|
||||
|
||||
struct MovableOnly final {
|
||||
int value;
|
||||
int value;
|
||||
|
||||
MovableOnly(int v) : value(v) {}
|
||||
MovableOnly(MovableOnly&&) = default;
|
||||
MovableOnly(const MovableOnly&) = delete;
|
||||
MovableOnly& operator=(MovableOnly&&) = default;
|
||||
MovableOnly& operator=(const MovableOnly&) = delete;
|
||||
MovableOnly(int v) : value(v) {}
|
||||
MovableOnly(MovableOnly&&) = default;
|
||||
MovableOnly(const MovableOnly&) = delete;
|
||||
MovableOnly& operator=(MovableOnly&&) = default;
|
||||
MovableOnly& operator=(const MovableOnly&) = delete;
|
||||
};
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(if_constexpr, worksWithMovableOnlyTypes_withIdentityArg) {
|
||||
EXPECT_EQ(
|
||||
4,
|
||||
if_constexpr<true>([](auto) { return MovableOnly(4); }, [](auto) { return MovableOnly(5); })
|
||||
.value);
|
||||
EXPECT_EQ(
|
||||
5,
|
||||
if_constexpr<false>([](auto) { return MovableOnly(4); }, [](auto) { return MovableOnly(5); })
|
||||
.value);
|
||||
EXPECT_EQ(
|
||||
4,
|
||||
if_constexpr<true>(
|
||||
[](auto) { return MovableOnly(4); },
|
||||
[](auto) { return MovableOnly(5); })
|
||||
.value);
|
||||
EXPECT_EQ(
|
||||
5,
|
||||
if_constexpr<false>(
|
||||
[](auto) { return MovableOnly(4); },
|
||||
[](auto) { return MovableOnly(5); })
|
||||
.value);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(if_constexpr, worksWithMovableOnlyTypes_withoutIdentityArg) {
|
||||
EXPECT_EQ(
|
||||
4,
|
||||
if_constexpr<true>([] { return MovableOnly(4); }, [] { return MovableOnly(5); })
|
||||
.value);
|
||||
EXPECT_EQ(
|
||||
5,
|
||||
if_constexpr<false>([] { return MovableOnly(4); }, [] { return MovableOnly(5); })
|
||||
.value);
|
||||
EXPECT_EQ(
|
||||
4,
|
||||
if_constexpr<true>(
|
||||
[] { return MovableOnly(4); }, [] { return MovableOnly(5); })
|
||||
.value);
|
||||
EXPECT_EQ(
|
||||
5,
|
||||
if_constexpr<false>(
|
||||
[] { return MovableOnly(4); }, [] { return MovableOnly(5); })
|
||||
.value);
|
||||
}
|
||||
|
||||
struct MyClass1 {
|
||||
int value;
|
||||
int value;
|
||||
};
|
||||
|
||||
struct MyClass2 {
|
||||
int val;
|
||||
int val;
|
||||
};
|
||||
|
||||
template<class T>
|
||||
template <class T>
|
||||
int func(T t) {
|
||||
return if_constexpr<std::is_same<T, MyClass1>::value>(
|
||||
[&](auto _) { return _(t).value; }, // this code is invalid for T == MyClass2
|
||||
[&](auto _) { return _(t).val; } // this code is invalid for T == MyClass1
|
||||
);
|
||||
return if_constexpr<std::is_same<T, MyClass1>::value>(
|
||||
[&](auto _) {
|
||||
return _(t).value;
|
||||
}, // this code is invalid for T == MyClass2
|
||||
[&](auto _) { return _(t).val; } // this code is invalid for T == MyClass1
|
||||
);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(if_constexpr, otherCaseCanHaveInvalidCode) {
|
||||
EXPECT_EQ(8, func(MyClass1{/* .value = */ 8}));
|
||||
EXPECT_EQ(4, func(MyClass2{/* .val = */ 4}));
|
||||
EXPECT_EQ(8, func(MyClass1{/* .value = */ 8}));
|
||||
EXPECT_EQ(4, func(MyClass2{/* .val = */ 4}));
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(if_constexpr, worksWithoutElseCase_withIdentityArg) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
int var = 5;
|
||||
if_constexpr<false>(
|
||||
[&](auto) { var = 3; }
|
||||
);
|
||||
EXPECT_EQ(5, var);
|
||||
if_constexpr<true>(
|
||||
[&](auto) { var = 3; }
|
||||
);
|
||||
EXPECT_EQ(3, var);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
int var = 5;
|
||||
if_constexpr<false>([&](auto) { var = 3; });
|
||||
EXPECT_EQ(5, var);
|
||||
if_constexpr<true>([&](auto) { var = 3; });
|
||||
EXPECT_EQ(3, var);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(if_constexpr, worksWithoutElseCase_withoutIdentityArg) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
int var = 5;
|
||||
if_constexpr<false>(
|
||||
[&] { var = 3; }
|
||||
);
|
||||
EXPECT_EQ(5, var);
|
||||
if_constexpr<true>(
|
||||
[&] { var = 3; }
|
||||
);
|
||||
EXPECT_EQ(3, var);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
int var = 5;
|
||||
if_constexpr<false>([&] { var = 3; });
|
||||
EXPECT_EQ(5, var);
|
||||
if_constexpr<true>([&] { var = 3; });
|
||||
EXPECT_EQ(3, var);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(if_constexpr, returnTypeCanDiffer_withIdentityArg) {
|
||||
auto a_string = if_constexpr<false>(
|
||||
[&](auto) -> int64_t { return 3; },
|
||||
[&](auto) -> std::string { return "3"; }
|
||||
);
|
||||
static_assert(std::is_same<std::string, decltype(a_string)>::value, "");
|
||||
auto a_string = if_constexpr<false>(
|
||||
[&](auto) -> int64_t { return 3; },
|
||||
[&](auto) -> std::string { return "3"; });
|
||||
static_assert(std::is_same<std::string, decltype(a_string)>::value, "");
|
||||
|
||||
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
|
||||
auto an_int = if_constexpr<true>(
|
||||
[&](auto) -> int64_t { return 3; },
|
||||
[&](auto) -> std::string { return "3"; }
|
||||
);
|
||||
static_assert(std::is_same<int64_t, decltype(an_int)>::value, "");
|
||||
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
|
||||
auto an_int = if_constexpr<true>(
|
||||
[&](auto) -> int64_t { return 3; },
|
||||
[&](auto) -> std::string { return "3"; });
|
||||
static_assert(std::is_same<int64_t, decltype(an_int)>::value, "");
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(if_constexpr, returnTypeCanDiffer_withoutIdentityArg) {
|
||||
auto a_string = if_constexpr<false>(
|
||||
[&] () -> int64_t { return 3; },
|
||||
[&] () -> std::string { return "3"; }
|
||||
);
|
||||
static_assert(std::is_same<std::string, decltype(a_string)>::value, "");
|
||||
auto a_string = if_constexpr<false>(
|
||||
[&]() -> int64_t { return 3; }, [&]() -> std::string { return "3"; });
|
||||
static_assert(std::is_same<std::string, decltype(a_string)>::value, "");
|
||||
|
||||
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
|
||||
auto an_int = if_constexpr<true>(
|
||||
[&] () -> int64_t { return 3; },
|
||||
[&] () -> std::string { return "3"; }
|
||||
);
|
||||
static_assert(std::is_same<int64_t, decltype(an_int)>::value, "");
|
||||
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
|
||||
auto an_int = if_constexpr<true>(
|
||||
[&]() -> int64_t { return 3; }, [&]() -> std::string { return "3"; });
|
||||
static_assert(std::is_same<int64_t, decltype(an_int)>::value, "");
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
} // namespace test_if_constexpr
|
||||
} // namespace
|
||||
|
|
|
|||
|
|
@ -10,259 +10,257 @@ TEST(LeftRightTest, givenInt_whenWritingAndReading_thenChangesArePresent) {
|
|||
LeftRight<int> obj;
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
obj.write([] (int& obj) {obj = 5;});
|
||||
int read = obj.read([] (const int& obj) {return obj;});
|
||||
obj.write([](int& obj) { obj = 5; });
|
||||
int read = obj.read([](const int& obj) { return obj; });
|
||||
EXPECT_EQ(5, read);
|
||||
|
||||
// check changes are also present in background copy
|
||||
obj.write([] (int&) {}); // this switches to the background copy
|
||||
read = obj.read([] (const int& obj) {return obj;});
|
||||
obj.write([](int&) {}); // this switches to the background copy
|
||||
read = obj.read([](const int& obj) { return obj; });
|
||||
EXPECT_EQ(5, read);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(LeftRightTest, givenVector_whenWritingAndReading_thenChangesArePresent) {
|
||||
LeftRight<vector<int>> obj;
|
||||
LeftRight<vector<int>> obj;
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
obj.write([] (vector<int>& obj) {obj.push_back(5);});
|
||||
vector<int> read = obj.read([] (const vector<int>& obj) {return obj;});
|
||||
EXPECT_EQ((vector<int>{5}), read);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
obj.write([](vector<int>& obj) { obj.push_back(5); });
|
||||
vector<int> read = obj.read([](const vector<int>& obj) { return obj; });
|
||||
EXPECT_EQ((vector<int>{5}), read);
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
obj.write([] (vector<int>& obj) {obj.push_back(6);});
|
||||
read = obj.read([] (const vector<int>& obj) {return obj;});
|
||||
EXPECT_EQ((vector<int>{5, 6}), read);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
obj.write([](vector<int>& obj) { obj.push_back(6); });
|
||||
read = obj.read([](const vector<int>& obj) { return obj; });
|
||||
EXPECT_EQ((vector<int>{5, 6}), read);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(LeftRightTest, givenVector_whenWritingReturnsValue_thenValueIsReturned) {
|
||||
LeftRight<vector<int>> obj;
|
||||
LeftRight<vector<int>> obj;
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
auto a = obj.write([] (vector<int>&) -> int {return 5;});
|
||||
static_assert(std::is_same<int, decltype(a)>::value, "");
|
||||
EXPECT_EQ(5, a);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
auto a = obj.write([](vector<int>&) -> int { return 5; });
|
||||
static_assert(std::is_same<int, decltype(a)>::value, "");
|
||||
EXPECT_EQ(5, a);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(LeftRightTest, readsCanBeConcurrent) {
|
||||
LeftRight<int> obj;
|
||||
std::atomic<int> num_running_readers{0};
|
||||
LeftRight<int> obj;
|
||||
std::atomic<int> num_running_readers{0};
|
||||
|
||||
std::thread reader1([&] () {
|
||||
obj.read([&] (const int&) {
|
||||
++num_running_readers;
|
||||
while(num_running_readers.load() < 2) {}
|
||||
});
|
||||
std::thread reader1([&]() {
|
||||
obj.read([&](const int&) {
|
||||
++num_running_readers;
|
||||
while (num_running_readers.load() < 2) {
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
std::thread reader2([&] () {
|
||||
obj.read([&] (const int&) {
|
||||
++num_running_readers;
|
||||
while(num_running_readers.load() < 2) {}
|
||||
});
|
||||
std::thread reader2([&]() {
|
||||
obj.read([&](const int&) {
|
||||
++num_running_readers;
|
||||
while (num_running_readers.load() < 2) {
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// the threads only finish after both entered the read function.
|
||||
// if LeftRight didn't allow concurrency, this would cause a deadlock.
|
||||
reader1.join();
|
||||
reader2.join();
|
||||
// the threads only finish after both entered the read function.
|
||||
// if LeftRight didn't allow concurrency, this would cause a deadlock.
|
||||
reader1.join();
|
||||
reader2.join();
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(LeftRightTest, writesCanBeConcurrentWithReads_readThenWrite) {
|
||||
LeftRight<int> obj;
|
||||
std::atomic<bool> reader_running{false};
|
||||
std::atomic<bool> writer_running{false};
|
||||
LeftRight<int> obj;
|
||||
std::atomic<bool> reader_running{false};
|
||||
std::atomic<bool> writer_running{false};
|
||||
|
||||
std::thread reader([&] () {
|
||||
obj.read([&] (const int&) {
|
||||
reader_running = true;
|
||||
while(!writer_running.load()) {}
|
||||
});
|
||||
std::thread reader([&]() {
|
||||
obj.read([&](const int&) {
|
||||
reader_running = true;
|
||||
while (!writer_running.load()) {
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
std::thread writer([&] () {
|
||||
// run read first, write second
|
||||
while (!reader_running.load()) {}
|
||||
std::thread writer([&]() {
|
||||
// run read first, write second
|
||||
while (!reader_running.load()) {
|
||||
}
|
||||
|
||||
obj.write([&] (int&) {
|
||||
writer_running = true;
|
||||
});
|
||||
});
|
||||
obj.write([&](int&) { writer_running = true; });
|
||||
});
|
||||
|
||||
// the threads only finish after both entered the read function.
|
||||
// if LeftRight didn't allow concurrency, this would cause a deadlock.
|
||||
reader.join();
|
||||
writer.join();
|
||||
// the threads only finish after both entered the read function.
|
||||
// if LeftRight didn't allow concurrency, this would cause a deadlock.
|
||||
reader.join();
|
||||
writer.join();
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(LeftRightTest, writesCanBeConcurrentWithReads_writeThenRead) {
|
||||
LeftRight<int> obj;
|
||||
std::atomic<bool> writer_running{false};
|
||||
std::atomic<bool> reader_running{false};
|
||||
LeftRight<int> obj;
|
||||
std::atomic<bool> writer_running{false};
|
||||
std::atomic<bool> reader_running{false};
|
||||
|
||||
std::thread writer([&] () {
|
||||
obj.read([&] (const int&) {
|
||||
writer_running = true;
|
||||
while(!reader_running.load()) {}
|
||||
});
|
||||
std::thread writer([&]() {
|
||||
obj.read([&](const int&) {
|
||||
writer_running = true;
|
||||
while (!reader_running.load()) {
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
std::thread reader([&] () {
|
||||
// run write first, read second
|
||||
while (!writer_running.load()) {}
|
||||
std::thread reader([&]() {
|
||||
// run write first, read second
|
||||
while (!writer_running.load()) {
|
||||
}
|
||||
|
||||
obj.read([&] (const int&) {
|
||||
reader_running = true;
|
||||
});
|
||||
});
|
||||
obj.read([&](const int&) { reader_running = true; });
|
||||
});
|
||||
|
||||
// the threads only finish after both entered the read function.
|
||||
// if LeftRight didn't allow concurrency, this would cause a deadlock.
|
||||
writer.join();
|
||||
reader.join();
|
||||
// the threads only finish after both entered the read function.
|
||||
// if LeftRight didn't allow concurrency, this would cause a deadlock.
|
||||
writer.join();
|
||||
reader.join();
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(LeftRightTest, writesCannotBeConcurrentWithWrites) {
|
||||
LeftRight<int> obj;
|
||||
std::atomic<bool> first_writer_started{false};
|
||||
std::atomic<bool> first_writer_finished{false};
|
||||
LeftRight<int> obj;
|
||||
std::atomic<bool> first_writer_started{false};
|
||||
std::atomic<bool> first_writer_finished{false};
|
||||
|
||||
std::thread writer1([&] () {
|
||||
obj.write([&] (int&) {
|
||||
first_writer_started = true;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(50));
|
||||
first_writer_finished = true;
|
||||
});
|
||||
std::thread writer1([&]() {
|
||||
obj.write([&](int&) {
|
||||
first_writer_started = true;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(50));
|
||||
first_writer_finished = true;
|
||||
});
|
||||
});
|
||||
|
||||
std::thread writer2([&] () {
|
||||
// make sure the other writer runs first
|
||||
while (!first_writer_started.load()) {}
|
||||
std::thread writer2([&]() {
|
||||
// make sure the other writer runs first
|
||||
while (!first_writer_started.load()) {
|
||||
}
|
||||
|
||||
obj.write([&] (int&) {
|
||||
// expect the other writer finished before this one starts
|
||||
EXPECT_TRUE(first_writer_finished.load());
|
||||
});
|
||||
obj.write([&](int&) {
|
||||
// expect the other writer finished before this one starts
|
||||
EXPECT_TRUE(first_writer_finished.load());
|
||||
});
|
||||
});
|
||||
|
||||
writer1.join();
|
||||
writer2.join();
|
||||
writer1.join();
|
||||
writer2.join();
|
||||
}
|
||||
|
||||
namespace {
|
||||
class MyException : public std::exception {};
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(LeftRightTest, whenReadThrowsException_thenThrowsThrough) {
|
||||
LeftRight<int> obj;
|
||||
LeftRight<int> obj;
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
|
||||
EXPECT_THROW(
|
||||
obj.read([](const int&) {throw MyException();}),
|
||||
MyException
|
||||
);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
|
||||
EXPECT_THROW(obj.read([](const int&) { throw MyException(); }), MyException);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(LeftRightTest, whenWriteThrowsException_thenThrowsThrough) {
|
||||
LeftRight<int> obj;
|
||||
LeftRight<int> obj;
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
|
||||
EXPECT_THROW(
|
||||
obj.write([](int&) {throw MyException();}),
|
||||
MyException
|
||||
);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
|
||||
EXPECT_THROW(obj.write([](int&) { throw MyException(); }), MyException);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(LeftRightTest, givenInt_whenWriteThrowsExceptionOnFirstCall_thenResetsToOldState) {
|
||||
LeftRight<int> obj;
|
||||
TEST(
|
||||
LeftRightTest,
|
||||
givenInt_whenWriteThrowsExceptionOnFirstCall_thenResetsToOldState) {
|
||||
LeftRight<int> obj;
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
obj.write([](int& obj) {obj = 5;});
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
obj.write([](int& obj) { obj = 5; });
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
|
||||
EXPECT_THROW(
|
||||
obj.write([](int& obj) {
|
||||
obj = 6;
|
||||
throw MyException();
|
||||
}),
|
||||
MyException
|
||||
);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
|
||||
EXPECT_THROW(
|
||||
obj.write([](int& obj) {
|
||||
obj = 6;
|
||||
throw MyException();
|
||||
}),
|
||||
MyException);
|
||||
|
||||
// check reading it returns old value
|
||||
int read = obj.read([] (const int& obj) {return obj;});
|
||||
EXPECT_EQ(5, read);
|
||||
// check reading it returns old value
|
||||
int read = obj.read([](const int& obj) { return obj; });
|
||||
EXPECT_EQ(5, read);
|
||||
|
||||
// check changes are also present in background copy
|
||||
obj.write([] (int&) {}); // this switches to the background copy
|
||||
read = obj.read([] (const int& obj) {return obj;});
|
||||
EXPECT_EQ(5, read);
|
||||
// check changes are also present in background copy
|
||||
obj.write([](int&) {}); // this switches to the background copy
|
||||
read = obj.read([](const int& obj) { return obj; });
|
||||
EXPECT_EQ(5, read);
|
||||
}
|
||||
|
||||
// note: each write is executed twice, on the foreground and background copy.
|
||||
// We need to test a thrown exception in either call is handled correctly.
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(LeftRightTest, givenInt_whenWriteThrowsExceptionOnSecondCall_thenKeepsNewState) {
|
||||
LeftRight<int> obj;
|
||||
TEST(
|
||||
LeftRightTest,
|
||||
givenInt_whenWriteThrowsExceptionOnSecondCall_thenKeepsNewState) {
|
||||
LeftRight<int> obj;
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
obj.write([](int& obj) {obj = 5;});
|
||||
bool write_called = false;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
obj.write([](int& obj) { obj = 5; });
|
||||
bool write_called = false;
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
|
||||
EXPECT_THROW(
|
||||
obj.write([&](int& obj) {
|
||||
obj = 6;
|
||||
if (write_called) {
|
||||
// this is the second time the write callback is executed
|
||||
throw MyException();
|
||||
} else {
|
||||
write_called = true;
|
||||
}
|
||||
}),
|
||||
MyException
|
||||
);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
|
||||
EXPECT_THROW(
|
||||
obj.write([&](int& obj) {
|
||||
obj = 6;
|
||||
if (write_called) {
|
||||
// this is the second time the write callback is executed
|
||||
throw MyException();
|
||||
} else {
|
||||
write_called = true;
|
||||
}
|
||||
}),
|
||||
MyException);
|
||||
|
||||
// check reading it returns new value
|
||||
int read = obj.read([] (const int& obj) {return obj;});
|
||||
EXPECT_EQ(6, read);
|
||||
// check reading it returns new value
|
||||
int read = obj.read([](const int& obj) { return obj; });
|
||||
EXPECT_EQ(6, read);
|
||||
|
||||
// check changes are also present in background copy
|
||||
obj.write([] (int&) {}); // this switches to the background copy
|
||||
read = obj.read([] (const int& obj) {return obj;});
|
||||
EXPECT_EQ(6, read);
|
||||
// check changes are also present in background copy
|
||||
obj.write([](int&) {}); // this switches to the background copy
|
||||
read = obj.read([](const int& obj) { return obj; });
|
||||
EXPECT_EQ(6, read);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(LeftRightTest, givenVector_whenWriteThrowsException_thenResetsToOldState) {
|
||||
LeftRight<vector<int>> obj;
|
||||
LeftRight<vector<int>> obj;
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
obj.write([](vector<int>& obj) {obj.push_back(5);});
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
obj.write([](vector<int>& obj) { obj.push_back(5); });
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
|
||||
EXPECT_THROW(
|
||||
obj.write([](vector<int>& obj) {
|
||||
obj.push_back(6);
|
||||
throw MyException();
|
||||
}),
|
||||
MyException
|
||||
);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
|
||||
EXPECT_THROW(
|
||||
obj.write([](vector<int>& obj) {
|
||||
obj.push_back(6);
|
||||
throw MyException();
|
||||
}),
|
||||
MyException);
|
||||
|
||||
// check reading it returns old value
|
||||
vector<int> read = obj.read([] (const vector<int>& obj) {return obj;});
|
||||
EXPECT_EQ((vector<int>{5}), read);
|
||||
// check reading it returns old value
|
||||
vector<int> read = obj.read([](const vector<int>& obj) { return obj; });
|
||||
EXPECT_EQ((vector<int>{5}), read);
|
||||
|
||||
// check changes are also present in background copy
|
||||
obj.write([] (vector<int>&) {}); // this switches to the background copy
|
||||
read = obj.read([] (const vector<int>& obj) {return obj;});
|
||||
EXPECT_EQ((vector<int>{5}), read);
|
||||
// check changes are also present in background copy
|
||||
obj.write([](vector<int>&) {}); // this switches to the background copy
|
||||
read = obj.read([](const vector<int>& obj) { return obj; });
|
||||
EXPECT_EQ((vector<int>{5}), read);
|
||||
}
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -61,12 +61,10 @@ static_assert(
|
|||
#endif
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(TypeIndex, TopLevelName) {
|
||||
EXPECT_NE(
|
||||
string_view::npos,
|
||||
get_fully_qualified_type_name<Dummy>().find("Dummy")
|
||||
);
|
||||
}
|
||||
EXPECT_NE(
|
||||
string_view::npos, get_fully_qualified_type_name<Dummy>().find("Dummy"));
|
||||
}
|
||||
} // namespace test_top_level_name
|
||||
|
||||
namespace test_nested_name {
|
||||
struct Dummy final {};
|
||||
|
|
@ -79,10 +77,9 @@ static_assert(
|
|||
#endif
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(TypeIndex, NestedName) {
|
||||
EXPECT_NE(
|
||||
string_view::npos,
|
||||
get_fully_qualified_type_name<Dummy>().find("test_nested_name::Dummy")
|
||||
);
|
||||
EXPECT_NE(
|
||||
string_view::npos,
|
||||
get_fully_qualified_type_name<Dummy>().find("test_nested_name::Dummy"));
|
||||
}
|
||||
} // namespace test_nested_name
|
||||
|
||||
|
|
@ -105,16 +102,14 @@ static_assert(
|
|||
#endif
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(TypeIndex, TypeTemplateParameter) {
|
||||
EXPECT_NE(
|
||||
string_view::npos,
|
||||
get_fully_qualified_type_name<Outer<Inner>>().find(
|
||||
"test_type_template_parameter::Outer")
|
||||
);
|
||||
EXPECT_NE(
|
||||
string_view::npos,
|
||||
get_fully_qualified_type_name<Outer<Inner>>().find(
|
||||
"test_type_template_parameter::Inner")
|
||||
);
|
||||
EXPECT_NE(
|
||||
string_view::npos,
|
||||
get_fully_qualified_type_name<Outer<Inner>>().find(
|
||||
"test_type_template_parameter::Outer"));
|
||||
EXPECT_NE(
|
||||
string_view::npos,
|
||||
get_fully_qualified_type_name<Outer<Inner>>().find(
|
||||
"test_type_template_parameter::Inner"));
|
||||
}
|
||||
} // namespace test_type_template_parameter
|
||||
|
||||
|
|
@ -131,10 +126,9 @@ static_assert(
|
|||
#endif
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(TypeIndex, NonTypeTemplateParameter) {
|
||||
EXPECT_NE(
|
||||
string_view::npos,
|
||||
get_fully_qualified_type_name<Class<38474355>>().find("38474355")
|
||||
);
|
||||
EXPECT_NE(
|
||||
string_view::npos,
|
||||
get_fully_qualified_type_name<Class<38474355>>().find("38474355"));
|
||||
}
|
||||
} // namespace test_nontype_template_parameter
|
||||
|
||||
|
|
@ -164,21 +158,18 @@ static_assert(
|
|||
#endif
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(TypeIndex, TypeComputationsAreResolved) {
|
||||
EXPECT_NE(
|
||||
string_view::npos,
|
||||
get_fully_qualified_type_name<typename Type<int>::type>().find("int")
|
||||
);
|
||||
EXPECT_NE(
|
||||
string_view::npos,
|
||||
get_fully_qualified_type_name<typename Type<int>::type>().find("*")
|
||||
);
|
||||
// but with remove_pointer applied, there is no '*' in the type name anymore
|
||||
EXPECT_EQ(
|
||||
string_view::npos,
|
||||
get_fully_qualified_type_name<
|
||||
typename std::remove_pointer<typename Type<int>::type>::type>()
|
||||
.find("*")
|
||||
);
|
||||
EXPECT_NE(
|
||||
string_view::npos,
|
||||
get_fully_qualified_type_name<typename Type<int>::type>().find("int"));
|
||||
EXPECT_NE(
|
||||
string_view::npos,
|
||||
get_fully_qualified_type_name<typename Type<int>::type>().find("*"));
|
||||
// but with remove_pointer applied, there is no '*' in the type name anymore
|
||||
EXPECT_EQ(
|
||||
string_view::npos,
|
||||
get_fully_qualified_type_name<
|
||||
typename std::remove_pointer<typename Type<int>::type>::type>()
|
||||
.find("*"));
|
||||
}
|
||||
|
||||
struct Functor final {
|
||||
|
|
@ -193,11 +184,10 @@ static_assert(
|
|||
#endif
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(TypeIndex, FunctionTypeComputationsAreResolved) {
|
||||
EXPECT_EQ(
|
||||
get_fully_qualified_type_name<std::string(int64_t, const Type<int>&)>(),
|
||||
get_fully_qualified_type_name<
|
||||
typename c10::guts::infer_function_traits_t<Functor>::func_type>()
|
||||
);
|
||||
EXPECT_EQ(
|
||||
get_fully_qualified_type_name<std::string(int64_t, const Type<int>&)>(),
|
||||
get_fully_qualified_type_name<
|
||||
typename c10::guts::infer_function_traits_t<Functor>::func_type>());
|
||||
}
|
||||
} // namespace test_type_computations_are_resolved
|
||||
|
||||
|
|
@ -218,16 +208,14 @@ static_assert(
|
|||
#endif
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(TypeIndex, FunctionArgumentsAndReturns) {
|
||||
EXPECT_NE(
|
||||
string_view::npos,
|
||||
get_fully_qualified_type_name<Dummy(int)>().find(
|
||||
"test_function_arguments_and_returns::Dummy")
|
||||
);
|
||||
EXPECT_NE(
|
||||
string_view::npos,
|
||||
get_fully_qualified_type_name<void(Dummy)>().find(
|
||||
"test_function_arguments_and_returns::Dummy")
|
||||
);
|
||||
EXPECT_NE(
|
||||
string_view::npos,
|
||||
get_fully_qualified_type_name<Dummy(int)>().find(
|
||||
"test_function_arguments_and_returns::Dummy"));
|
||||
EXPECT_NE(
|
||||
string_view::npos,
|
||||
get_fully_qualified_type_name<void(Dummy)>().find(
|
||||
"test_function_arguments_and_returns::Dummy"));
|
||||
}
|
||||
} // namespace test_function_arguments_and_returns
|
||||
} // namespace
|
||||
|
|
|
|||
|
|
@ -5,204 +5,382 @@
|
|||
using namespace c10::guts::typelist;
|
||||
|
||||
namespace test_size {
|
||||
class MyClass {};
|
||||
static_assert(0 == size<typelist<>>::value, "");
|
||||
static_assert(1 == size<typelist<int>>::value, "");
|
||||
static_assert(3 == size<typelist<int, float&, const MyClass&&>>::value, "");
|
||||
}
|
||||
class MyClass {};
|
||||
static_assert(0 == size<typelist<>>::value, "");
|
||||
static_assert(1 == size<typelist<int>>::value, "");
|
||||
static_assert(3 == size<typelist<int, float&, const MyClass&&>>::value, "");
|
||||
} // namespace test_size
|
||||
|
||||
namespace test_from_tuple {
|
||||
class MyClass {};
|
||||
static_assert(std::is_same<typelist<int, float&, const MyClass&&>, from_tuple_t<std::tuple<int, float&, const MyClass&&>>>::value, "");
|
||||
static_assert(std::is_same<typelist<>, from_tuple_t<std::tuple<>>>::value, "");
|
||||
}
|
||||
class MyClass {};
|
||||
static_assert(
|
||||
std::is_same<
|
||||
typelist<int, float&, const MyClass&&>,
|
||||
from_tuple_t<std::tuple<int, float&, const MyClass&&>>>::value,
|
||||
"");
|
||||
static_assert(std::is_same<typelist<>, from_tuple_t<std::tuple<>>>::value, "");
|
||||
} // namespace test_from_tuple
|
||||
|
||||
namespace test_to_tuple {
|
||||
class MyClass {};
|
||||
static_assert(std::is_same<std::tuple<int, float&, const MyClass&&>, to_tuple_t<typelist<int, float&, const MyClass&&>>>::value, "");
|
||||
static_assert(std::is_same<std::tuple<>, to_tuple_t<typelist<>>>::value, "");
|
||||
}
|
||||
class MyClass {};
|
||||
static_assert(
|
||||
std::is_same<
|
||||
std::tuple<int, float&, const MyClass&&>,
|
||||
to_tuple_t<typelist<int, float&, const MyClass&&>>>::value,
|
||||
"");
|
||||
static_assert(std::is_same<std::tuple<>, to_tuple_t<typelist<>>>::value, "");
|
||||
} // namespace test_to_tuple
|
||||
|
||||
namespace test_concat {
|
||||
class MyClass {};
|
||||
static_assert(std::is_same<typelist<>, concat_t<>>::value, "");
|
||||
static_assert(std::is_same<typelist<>, concat_t<typelist<>>>::value, "");
|
||||
static_assert(std::is_same<typelist<>, concat_t<typelist<>, typelist<>>>::value, "");
|
||||
static_assert(std::is_same<typelist<int>, concat_t<typelist<int>>>::value, "");
|
||||
static_assert(std::is_same<typelist<int>, concat_t<typelist<int>, typelist<>>>::value, "");
|
||||
static_assert(std::is_same<typelist<int>, concat_t<typelist<>, typelist<int>>>::value, "");
|
||||
static_assert(std::is_same<typelist<int>, concat_t<typelist<>, typelist<int>, typelist<>>>::value, "");
|
||||
static_assert(std::is_same<typelist<int, float&>, concat_t<typelist<int>, typelist<float&>>>::value, "");
|
||||
static_assert(std::is_same<typelist<int, float&>, concat_t<typelist<>, typelist<int, float&>, typelist<>>>::value, "");
|
||||
static_assert(std::is_same<typelist<int, float&, const MyClass&&>, concat_t<typelist<>, typelist<int, float&>, typelist<const MyClass&&>>>::value, "");
|
||||
}
|
||||
class MyClass {};
|
||||
static_assert(std::is_same<typelist<>, concat_t<>>::value, "");
|
||||
static_assert(std::is_same<typelist<>, concat_t<typelist<>>>::value, "");
|
||||
static_assert(
|
||||
std::is_same<typelist<>, concat_t<typelist<>, typelist<>>>::value,
|
||||
"");
|
||||
static_assert(std::is_same<typelist<int>, concat_t<typelist<int>>>::value, "");
|
||||
static_assert(
|
||||
std::is_same<typelist<int>, concat_t<typelist<int>, typelist<>>>::value,
|
||||
"");
|
||||
static_assert(
|
||||
std::is_same<typelist<int>, concat_t<typelist<>, typelist<int>>>::value,
|
||||
"");
|
||||
static_assert(
|
||||
std::is_same<
|
||||
typelist<int>,
|
||||
concat_t<typelist<>, typelist<int>, typelist<>>>::value,
|
||||
"");
|
||||
static_assert(
|
||||
std::is_same<
|
||||
typelist<int, float&>,
|
||||
concat_t<typelist<int>, typelist<float&>>>::value,
|
||||
"");
|
||||
static_assert(
|
||||
std::is_same<
|
||||
typelist<int, float&>,
|
||||
concat_t<typelist<>, typelist<int, float&>, typelist<>>>::value,
|
||||
"");
|
||||
static_assert(
|
||||
std::is_same<
|
||||
typelist<int, float&, const MyClass&&>,
|
||||
concat_t<
|
||||
typelist<>,
|
||||
typelist<int, float&>,
|
||||
typelist<const MyClass&&>>>::value,
|
||||
"");
|
||||
} // namespace test_concat
|
||||
|
||||
namespace test_filter {
|
||||
class MyClass {};
|
||||
static_assert(std::is_same<typelist<>, filter_t<std::is_reference, typelist<>>>::value, "");
|
||||
static_assert(std::is_same<typelist<>, filter_t<std::is_reference, typelist<int, float, double, MyClass>>>::value, "");
|
||||
static_assert(std::is_same<typelist<float&, const MyClass&&>, filter_t<std::is_reference, typelist<int, float&, double, const MyClass&&>>>::value, "");
|
||||
}
|
||||
class MyClass {};
|
||||
static_assert(
|
||||
std::is_same<typelist<>, filter_t<std::is_reference, typelist<>>>::value,
|
||||
"");
|
||||
static_assert(
|
||||
std::is_same<
|
||||
typelist<>,
|
||||
filter_t<std::is_reference, typelist<int, float, double, MyClass>>>::
|
||||
value,
|
||||
"");
|
||||
static_assert(
|
||||
std::is_same<
|
||||
typelist<float&, const MyClass&&>,
|
||||
filter_t<
|
||||
std::is_reference,
|
||||
typelist<int, float&, double, const MyClass&&>>>::value,
|
||||
"");
|
||||
} // namespace test_filter
|
||||
|
||||
namespace test_count_if {
|
||||
class MyClass final {};
|
||||
static_assert(count_if<std::is_reference, typelist<int, bool&, const MyClass&&, float, double>>::value == 2, "");
|
||||
static_assert(count_if<std::is_reference, typelist<int, bool>>::value == 0, "");
|
||||
static_assert(count_if<std::is_reference, typelist<>>::value == 0, "");
|
||||
}
|
||||
class MyClass final {};
|
||||
static_assert(
|
||||
count_if<
|
||||
std::is_reference,
|
||||
typelist<int, bool&, const MyClass&&, float, double>>::value == 2,
|
||||
"");
|
||||
static_assert(count_if<std::is_reference, typelist<int, bool>>::value == 0, "");
|
||||
static_assert(count_if<std::is_reference, typelist<>>::value == 0, "");
|
||||
} // namespace test_count_if
|
||||
|
||||
namespace test_true_for_each_type {
|
||||
template<class> class Test;
|
||||
class MyClass {};
|
||||
static_assert(all<std::is_reference, typelist<int&, const float&&, const MyClass&>>::value, "");
|
||||
static_assert(!all<std::is_reference, typelist<int&, const float, const MyClass&>>::value, "");
|
||||
static_assert(all<std::is_reference, typelist<>>::value, "");
|
||||
}
|
||||
template <class>
|
||||
class Test;
|
||||
class MyClass {};
|
||||
static_assert(
|
||||
all<std::is_reference,
|
||||
typelist<int&, const float&&, const MyClass&>>::value,
|
||||
"");
|
||||
static_assert(
|
||||
!all<std::is_reference, typelist<int&, const float, const MyClass&>>::value,
|
||||
"");
|
||||
static_assert(all<std::is_reference, typelist<>>::value, "");
|
||||
} // namespace test_true_for_each_type
|
||||
|
||||
namespace test_true_for_any_type {
|
||||
template<class> class Test;
|
||||
class MyClass {};
|
||||
static_assert(true_for_any_type<std::is_reference, typelist<int&, const float&&, const MyClass&>>::value, "");
|
||||
static_assert(true_for_any_type<std::is_reference, typelist<int&, const float, const MyClass&>>::value, "");
|
||||
static_assert(!true_for_any_type<std::is_reference, typelist<int, const float, const MyClass>>::value, "");
|
||||
static_assert(!true_for_any_type<std::is_reference, typelist<>>::value, "");
|
||||
}
|
||||
template <class>
|
||||
class Test;
|
||||
class MyClass {};
|
||||
static_assert(
|
||||
true_for_any_type<
|
||||
std::is_reference,
|
||||
typelist<int&, const float&&, const MyClass&>>::value,
|
||||
"");
|
||||
static_assert(
|
||||
true_for_any_type<
|
||||
std::is_reference,
|
||||
typelist<int&, const float, const MyClass&>>::value,
|
||||
"");
|
||||
static_assert(
|
||||
!true_for_any_type<
|
||||
std::is_reference,
|
||||
typelist<int, const float, const MyClass>>::value,
|
||||
"");
|
||||
static_assert(!true_for_any_type<std::is_reference, typelist<>>::value, "");
|
||||
} // namespace test_true_for_any_type
|
||||
|
||||
namespace test_map {
|
||||
class MyClass {};
|
||||
static_assert(std::is_same<typelist<>, map_t<std::add_lvalue_reference_t, typelist<>>>::value, "");
|
||||
static_assert(std::is_same<typelist<int&>, map_t<std::add_lvalue_reference_t, typelist<int>>>::value, "");
|
||||
static_assert(std::is_same<typelist<int&, double&, const MyClass&>, map_t<std::add_lvalue_reference_t, typelist<int, double, const MyClass>>>::value, "");
|
||||
}
|
||||
class MyClass {};
|
||||
static_assert(
|
||||
std::is_same<typelist<>, map_t<std::add_lvalue_reference_t, typelist<>>>::
|
||||
value,
|
||||
"");
|
||||
static_assert(
|
||||
std::is_same<
|
||||
typelist<int&>,
|
||||
map_t<std::add_lvalue_reference_t, typelist<int>>>::value,
|
||||
"");
|
||||
static_assert(
|
||||
std::is_same<
|
||||
typelist<int&, double&, const MyClass&>,
|
||||
map_t<
|
||||
std::add_lvalue_reference_t,
|
||||
typelist<int, double, const MyClass>>>::value,
|
||||
"");
|
||||
} // namespace test_map
|
||||
|
||||
namespace test_head {
|
||||
class MyClass {};
|
||||
static_assert(std::is_same<int, head_t<typelist<int, double>>>::value, "");
|
||||
static_assert(std::is_same<const MyClass&, head_t<typelist<const MyClass&, double>>>::value, "");
|
||||
static_assert(std::is_same<MyClass&&, head_t<typelist<MyClass&&, MyClass>>>::value, "");
|
||||
static_assert(std::is_same<bool, head_t<typelist<bool>>>::value, "");
|
||||
}
|
||||
class MyClass {};
|
||||
static_assert(std::is_same<int, head_t<typelist<int, double>>>::value, "");
|
||||
static_assert(
|
||||
std::is_same<const MyClass&, head_t<typelist<const MyClass&, double>>>::
|
||||
value,
|
||||
"");
|
||||
static_assert(
|
||||
std::is_same<MyClass&&, head_t<typelist<MyClass&&, MyClass>>>::value,
|
||||
"");
|
||||
static_assert(std::is_same<bool, head_t<typelist<bool>>>::value, "");
|
||||
} // namespace test_head
|
||||
|
||||
namespace test_head_with_default {
|
||||
class MyClass {};
|
||||
static_assert(std::is_same<int, head_with_default_t<bool, typelist<int, double>>>::value, "");
|
||||
static_assert(std::is_same<const MyClass&, head_with_default_t<bool, typelist<const MyClass&, double>>>::value, "");
|
||||
static_assert(std::is_same<MyClass&&, head_with_default_t<bool, typelist<MyClass&&, MyClass>>>::value, "");
|
||||
static_assert(std::is_same<int, head_with_default_t<bool, typelist<int>>>::value, "");
|
||||
static_assert(std::is_same<bool, head_with_default_t<bool, typelist<>>>::value, "");
|
||||
}
|
||||
class MyClass {};
|
||||
static_assert(
|
||||
std::is_same<int, head_with_default_t<bool, typelist<int, double>>>::value,
|
||||
"");
|
||||
static_assert(
|
||||
std::is_same<
|
||||
const MyClass&,
|
||||
head_with_default_t<bool, typelist<const MyClass&, double>>>::value,
|
||||
"");
|
||||
static_assert(
|
||||
std::is_same<
|
||||
MyClass&&,
|
||||
head_with_default_t<bool, typelist<MyClass&&, MyClass>>>::value,
|
||||
"");
|
||||
static_assert(
|
||||
std::is_same<int, head_with_default_t<bool, typelist<int>>>::value,
|
||||
"");
|
||||
static_assert(
|
||||
std::is_same<bool, head_with_default_t<bool, typelist<>>>::value,
|
||||
"");
|
||||
} // namespace test_head_with_default
|
||||
|
||||
namespace test_reverse {
|
||||
class MyClass {};
|
||||
static_assert(std::is_same<
|
||||
typelist<int, double, MyClass*, const MyClass&&>,
|
||||
reverse_t<typelist<const MyClass&&, MyClass*, double, int>>
|
||||
>::value, "");
|
||||
static_assert(std::is_same<
|
||||
typelist<>,
|
||||
reverse_t<typelist<>>
|
||||
>::value, "");
|
||||
}
|
||||
class MyClass {};
|
||||
static_assert(
|
||||
std::is_same<
|
||||
typelist<int, double, MyClass*, const MyClass&&>,
|
||||
reverse_t<typelist<const MyClass&&, MyClass*, double, int>>>::value,
|
||||
"");
|
||||
static_assert(std::is_same<typelist<>, reverse_t<typelist<>>>::value, "");
|
||||
} // namespace test_reverse
|
||||
|
||||
namespace test_map_types_to_values {
|
||||
struct map_to_size {
|
||||
template<class T> constexpr size_t operator()(T) const {return sizeof(typename T::type);}
|
||||
};
|
||||
struct map_to_size {
|
||||
template <class T>
|
||||
constexpr size_t operator()(T) const {
|
||||
return sizeof(typename T::type);
|
||||
}
|
||||
};
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(TypeListTest, MapTypesToValues_sametype) {
|
||||
auto sizes =
|
||||
map_types_to_values<typelist<int64_t, bool, uint32_t>>(map_to_size());
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
std::tuple<size_t, size_t, size_t> expected(8, 1, 4);
|
||||
static_assert(std::is_same<decltype(expected), decltype(sizes)>::value, "");
|
||||
EXPECT_EQ(expected, sizes);
|
||||
}
|
||||
|
||||
struct map_make_shared {
|
||||
template<class T> std::shared_ptr<typename T::type> operator()(T) {
|
||||
return std::make_shared<typename T::type>();
|
||||
}
|
||||
};
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(TypeListTest, MapTypesToValues_differenttypes) {
|
||||
auto shared_ptrs =
|
||||
map_types_to_values<typelist<int, double>>(map_make_shared());
|
||||
static_assert(std::is_same<std::tuple<std::shared_ptr<int>, std::shared_ptr<double>>, decltype(shared_ptrs)>::value, "");
|
||||
}
|
||||
|
||||
struct Class1 {static int func() {return 3;}};
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
struct Class2 {static double func() {return 2.0;}};
|
||||
|
||||
struct mapper_call_func {
|
||||
template<class T> decltype(auto) operator()(T) { return T::type::func(); }
|
||||
};
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(TypeListTest, MapTypesToValues_members) {
|
||||
auto result =
|
||||
map_types_to_values<typelist<Class1, Class2>>(mapper_call_func());
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
std::tuple<int, double> expected(3, 2.0);
|
||||
static_assert(std::is_same<decltype(expected), decltype(result)>::value, "");
|
||||
EXPECT_EQ(expected, result);
|
||||
}
|
||||
|
||||
struct mapper_call_nonexistent_function {
|
||||
template<class T> decltype(auto) operator()(T) { return T::type::this_doesnt_exist(); }
|
||||
};
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(TypeListTest, MapTypesToValues_empty) {
|
||||
auto result =
|
||||
map_types_to_values<typelist<>>(mapper_call_nonexistent_function());
|
||||
std::tuple<> expected;
|
||||
static_assert(std::is_same<decltype(expected), decltype(result)>::value, "");
|
||||
EXPECT_EQ(expected, result);
|
||||
}
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(TypeListTest, MapTypesToValues_sametype) {
|
||||
auto sizes =
|
||||
map_types_to_values<typelist<int64_t, bool, uint32_t>>(map_to_size());
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
std::tuple<size_t, size_t, size_t> expected(8, 1, 4);
|
||||
static_assert(std::is_same<decltype(expected), decltype(sizes)>::value, "");
|
||||
EXPECT_EQ(expected, sizes);
|
||||
}
|
||||
|
||||
struct map_make_shared {
|
||||
template <class T>
|
||||
std::shared_ptr<typename T::type> operator()(T) {
|
||||
return std::make_shared<typename T::type>();
|
||||
}
|
||||
};
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(TypeListTest, MapTypesToValues_differenttypes) {
|
||||
auto shared_ptrs =
|
||||
map_types_to_values<typelist<int, double>>(map_make_shared());
|
||||
static_assert(
|
||||
std::is_same<
|
||||
std::tuple<std::shared_ptr<int>, std::shared_ptr<double>>,
|
||||
decltype(shared_ptrs)>::value,
|
||||
"");
|
||||
}
|
||||
|
||||
struct Class1 {
|
||||
static int func() {
|
||||
return 3;
|
||||
}
|
||||
};
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
struct Class2 {
|
||||
static double func() {
|
||||
return 2.0;
|
||||
}
|
||||
};
|
||||
|
||||
struct mapper_call_func {
|
||||
template <class T>
|
||||
decltype(auto) operator()(T) {
|
||||
return T::type::func();
|
||||
}
|
||||
};
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(TypeListTest, MapTypesToValues_members) {
|
||||
auto result =
|
||||
map_types_to_values<typelist<Class1, Class2>>(mapper_call_func());
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
std::tuple<int, double> expected(3, 2.0);
|
||||
static_assert(std::is_same<decltype(expected), decltype(result)>::value, "");
|
||||
EXPECT_EQ(expected, result);
|
||||
}
|
||||
|
||||
struct mapper_call_nonexistent_function {
|
||||
template <class T>
|
||||
decltype(auto) operator()(T) {
|
||||
return T::type::this_doesnt_exist();
|
||||
}
|
||||
};
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(TypeListTest, MapTypesToValues_empty) {
|
||||
auto result =
|
||||
map_types_to_values<typelist<>>(mapper_call_nonexistent_function());
|
||||
std::tuple<> expected;
|
||||
static_assert(std::is_same<decltype(expected), decltype(result)>::value, "");
|
||||
EXPECT_EQ(expected, result);
|
||||
}
|
||||
} // namespace test_map_types_to_values
|
||||
|
||||
namespace test_find_if {
|
||||
static_assert(0 == find_if<typelist<char&>, std::is_reference>::value, "");
|
||||
static_assert(0 == find_if<typelist<char&, int, char&, int&>, std::is_reference>::value, "");
|
||||
static_assert(2 == find_if<typelist<char, int, char&, int&>, std::is_reference>::value, "");
|
||||
static_assert(3 == find_if<typelist<char, int, char, int&>, std::is_reference>::value, "");
|
||||
}
|
||||
static_assert(0 == find_if<typelist<char&>, std::is_reference>::value, "");
|
||||
static_assert(
|
||||
0 == find_if<typelist<char&, int, char&, int&>, std::is_reference>::value,
|
||||
"");
|
||||
static_assert(
|
||||
2 == find_if<typelist<char, int, char&, int&>, std::is_reference>::value,
|
||||
"");
|
||||
static_assert(
|
||||
3 == find_if<typelist<char, int, char, int&>, std::is_reference>::value,
|
||||
"");
|
||||
} // namespace test_find_if
|
||||
|
||||
namespace test_contains {
|
||||
static_assert(contains<typelist<double>, double>::value, "");
|
||||
static_assert(contains<typelist<int, double>, double>::value, "");
|
||||
static_assert(!contains<typelist<int, double>, float>::value, "");
|
||||
static_assert(!contains<typelist<>, double>::value, "");
|
||||
}
|
||||
static_assert(contains<typelist<double>, double>::value, "");
|
||||
static_assert(contains<typelist<int, double>, double>::value, "");
|
||||
static_assert(!contains<typelist<int, double>, float>::value, "");
|
||||
static_assert(!contains<typelist<>, double>::value, "");
|
||||
} // namespace test_contains
|
||||
|
||||
namespace test_take {
|
||||
static_assert(std::is_same<typelist<>, take_t<typelist<>, 0>>::value, "");
|
||||
static_assert(std::is_same<typelist<>, take_t<typelist<int64_t>, 0>>::value, "");
|
||||
static_assert(std::is_same<typelist<int64_t>, take_t<typelist<int64_t>, 1>>::value, "");
|
||||
static_assert(std::is_same<typelist<>, take_t<typelist<int64_t, int32_t>, 0>>::value, "");
|
||||
static_assert(std::is_same<typelist<int64_t>, take_t<typelist<int64_t, int32_t>, 1>>::value, "");
|
||||
static_assert(std::is_same<typelist<int64_t, int32_t>, take_t<typelist<int64_t, int32_t>, 2>>::value, "");
|
||||
}
|
||||
static_assert(std::is_same<typelist<>, take_t<typelist<>, 0>>::value, "");
|
||||
static_assert(
|
||||
std::is_same<typelist<>, take_t<typelist<int64_t>, 0>>::value,
|
||||
"");
|
||||
static_assert(
|
||||
std::is_same<typelist<int64_t>, take_t<typelist<int64_t>, 1>>::value,
|
||||
"");
|
||||
static_assert(
|
||||
std::is_same<typelist<>, take_t<typelist<int64_t, int32_t>, 0>>::value,
|
||||
"");
|
||||
static_assert(
|
||||
std::is_same<typelist<int64_t>, take_t<typelist<int64_t, int32_t>, 1>>::
|
||||
value,
|
||||
"");
|
||||
static_assert(
|
||||
std::is_same<
|
||||
typelist<int64_t, int32_t>,
|
||||
take_t<typelist<int64_t, int32_t>, 2>>::value,
|
||||
"");
|
||||
} // namespace test_take
|
||||
|
||||
namespace test_drop {
|
||||
static_assert(std::is_same<typelist<>, drop_t<typelist<>, 0>>::value, "");
|
||||
static_assert(std::is_same<typelist<int64_t>, drop_t<typelist<int64_t>, 0>>::value, "");
|
||||
static_assert(std::is_same<typelist<>, drop_t<typelist<int64_t>, 1>>::value, "");
|
||||
static_assert(std::is_same<typelist<int64_t, int32_t>, drop_t<typelist<int64_t, int32_t>, 0>>::value, "");
|
||||
static_assert(std::is_same<typelist<int32_t>, drop_t<typelist<int64_t, int32_t>, 1>>::value, "");
|
||||
static_assert(std::is_same<typelist<>, drop_t<typelist<int64_t, int32_t>, 2>>::value, "");
|
||||
}
|
||||
static_assert(std::is_same<typelist<>, drop_t<typelist<>, 0>>::value, "");
|
||||
static_assert(
|
||||
std::is_same<typelist<int64_t>, drop_t<typelist<int64_t>, 0>>::value,
|
||||
"");
|
||||
static_assert(
|
||||
std::is_same<typelist<>, drop_t<typelist<int64_t>, 1>>::value,
|
||||
"");
|
||||
static_assert(
|
||||
std::is_same<
|
||||
typelist<int64_t, int32_t>,
|
||||
drop_t<typelist<int64_t, int32_t>, 0>>::value,
|
||||
"");
|
||||
static_assert(
|
||||
std::is_same<typelist<int32_t>, drop_t<typelist<int64_t, int32_t>, 1>>::
|
||||
value,
|
||||
"");
|
||||
static_assert(
|
||||
std::is_same<typelist<>, drop_t<typelist<int64_t, int32_t>, 2>>::value,
|
||||
"");
|
||||
} // namespace test_drop
|
||||
|
||||
namespace test_drop_if_nonempty {
|
||||
static_assert(std::is_same<typelist<>, drop_if_nonempty_t<typelist<>, 0>>::value, "");
|
||||
static_assert(std::is_same<typelist<int64_t>, drop_if_nonempty_t<typelist<int64_t>, 0>>::value, "");
|
||||
static_assert(std::is_same<typelist<>, drop_if_nonempty_t<typelist<int64_t>, 1>>::value, "");
|
||||
static_assert(std::is_same<typelist<int64_t, int32_t>, drop_if_nonempty_t<typelist<int64_t, int32_t>, 0>>::value, "");
|
||||
static_assert(std::is_same<typelist<int32_t>, drop_if_nonempty_t<typelist<int64_t, int32_t>, 1>>::value, "");
|
||||
static_assert(std::is_same<typelist<>, drop_if_nonempty_t<typelist<int64_t, int32_t>, 2>>::value, "");
|
||||
static_assert(std::is_same<typelist<>, drop_if_nonempty_t<typelist<>, 1>>::value, "");
|
||||
static_assert(std::is_same<typelist<>, drop_if_nonempty_t<typelist<int64_t, int32_t>, 3>>::value, "");
|
||||
}
|
||||
static_assert(
|
||||
std::is_same<typelist<>, drop_if_nonempty_t<typelist<>, 0>>::value,
|
||||
"");
|
||||
static_assert(
|
||||
std::is_same<typelist<int64_t>, drop_if_nonempty_t<typelist<int64_t>, 0>>::
|
||||
value,
|
||||
"");
|
||||
static_assert(
|
||||
std::is_same<typelist<>, drop_if_nonempty_t<typelist<int64_t>, 1>>::value,
|
||||
"");
|
||||
static_assert(
|
||||
std::is_same<
|
||||
typelist<int64_t, int32_t>,
|
||||
drop_if_nonempty_t<typelist<int64_t, int32_t>, 0>>::value,
|
||||
"");
|
||||
static_assert(
|
||||
std::is_same<
|
||||
typelist<int32_t>,
|
||||
drop_if_nonempty_t<typelist<int64_t, int32_t>, 1>>::value,
|
||||
"");
|
||||
static_assert(
|
||||
std::is_same<
|
||||
typelist<>,
|
||||
drop_if_nonempty_t<typelist<int64_t, int32_t>, 2>>::value,
|
||||
"");
|
||||
static_assert(
|
||||
std::is_same<typelist<>, drop_if_nonempty_t<typelist<>, 1>>::value,
|
||||
"");
|
||||
static_assert(
|
||||
std::is_same<
|
||||
typelist<>,
|
||||
drop_if_nonempty_t<typelist<int64_t, int32_t>, 3>>::value,
|
||||
"");
|
||||
} // namespace test_drop_if_nonempty
|
||||
|
|
|
|||
|
|
@ -6,152 +6,190 @@ using namespace c10::guts;
|
|||
namespace {
|
||||
|
||||
namespace test_is_equality_comparable {
|
||||
class NotEqualityComparable {};
|
||||
class EqualityComparable {};
|
||||
class NotEqualityComparable {};
|
||||
class EqualityComparable {};
|
||||
|
||||
inline bool operator==(const EqualityComparable &, const EqualityComparable &) { return false; }
|
||||
|
||||
static_assert(!is_equality_comparable<NotEqualityComparable>::value, "");
|
||||
static_assert(is_equality_comparable<EqualityComparable>::value, "");
|
||||
static_assert(is_equality_comparable<int>::value, "");
|
||||
|
||||
// v_ just exists to silence a compiler warning about operator==(EqualityComparable, EqualityComparable) not being needed
|
||||
const bool v_ = EqualityComparable() == EqualityComparable();
|
||||
inline bool operator==(const EqualityComparable&, const EqualityComparable&) {
|
||||
return false;
|
||||
}
|
||||
|
||||
static_assert(!is_equality_comparable<NotEqualityComparable>::value, "");
|
||||
static_assert(is_equality_comparable<EqualityComparable>::value, "");
|
||||
static_assert(is_equality_comparable<int>::value, "");
|
||||
|
||||
// v_ just exists to silence a compiler warning about
|
||||
// operator==(EqualityComparable, EqualityComparable) not being needed
|
||||
const bool v_ = EqualityComparable() == EqualityComparable();
|
||||
} // namespace test_is_equality_comparable
|
||||
|
||||
namespace test_is_hashable {
|
||||
class NotHashable {};
|
||||
class Hashable {};
|
||||
}
|
||||
}
|
||||
class NotHashable {};
|
||||
class Hashable {};
|
||||
} // namespace test_is_hashable
|
||||
} // namespace
|
||||
namespace std {
|
||||
template<> struct hash<test_is_hashable::Hashable> final {
|
||||
size_t operator()(const test_is_hashable::Hashable &) { return 0; }
|
||||
};
|
||||
}
|
||||
template <>
|
||||
struct hash<test_is_hashable::Hashable> final {
|
||||
size_t operator()(const test_is_hashable::Hashable&) {
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
} // namespace std
|
||||
namespace {
|
||||
namespace test_is_hashable {
|
||||
static_assert(is_hashable<int>::value, "");
|
||||
static_assert(is_hashable<Hashable>::value, "");
|
||||
static_assert(!is_hashable<NotHashable>::value, "");
|
||||
}
|
||||
static_assert(is_hashable<int>::value, "");
|
||||
static_assert(is_hashable<Hashable>::value, "");
|
||||
static_assert(!is_hashable<NotHashable>::value, "");
|
||||
} // namespace test_is_hashable
|
||||
|
||||
namespace test_is_function_type {
|
||||
class MyClass {};
|
||||
struct Functor {
|
||||
void operator()() {}
|
||||
};
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
auto lambda = [] () {};
|
||||
// func() and func__ just exists to silence a compiler warning about lambda being unused
|
||||
bool func() {
|
||||
lambda();
|
||||
return true;
|
||||
}
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
bool func__ = func();
|
||||
|
||||
static_assert(is_function_type<void()>::value, "");
|
||||
static_assert(is_function_type<int()>::value, "");
|
||||
static_assert(is_function_type<MyClass()>::value, "");
|
||||
static_assert(is_function_type<void(MyClass)>::value, "");
|
||||
static_assert(is_function_type<void(int)>::value, "");
|
||||
static_assert(is_function_type<void(void*)>::value, "");
|
||||
static_assert(is_function_type<int()>::value, "");
|
||||
static_assert(is_function_type<int(MyClass)>::value, "");
|
||||
static_assert(is_function_type<int(const MyClass&)>::value, "");
|
||||
static_assert(is_function_type<int(MyClass&&)>::value, "");
|
||||
static_assert(is_function_type<MyClass&&()>::value, "");
|
||||
static_assert(is_function_type<MyClass&&(MyClass&&)>::value, "");
|
||||
static_assert(is_function_type<const MyClass&(int, float, MyClass)>::value, "");
|
||||
|
||||
static_assert(!is_function_type<void>::value, "");
|
||||
static_assert(!is_function_type<int>::value, "");
|
||||
static_assert(!is_function_type<MyClass>::value, "");
|
||||
static_assert(!is_function_type<void*>::value, "");
|
||||
static_assert(!is_function_type<const MyClass&>::value, "");
|
||||
static_assert(!is_function_type<MyClass&&>::value, "");
|
||||
|
||||
static_assert(!is_function_type<void (*)()>::value, "function pointers aren't plain functions");
|
||||
static_assert(!is_function_type<Functor>::value, "Functors aren't plain functions");
|
||||
static_assert(!is_function_type<decltype(lambda)>::value, "Lambdas aren't plain functions");
|
||||
class MyClass {};
|
||||
struct Functor {
|
||||
void operator()() {}
|
||||
};
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
auto lambda = []() {};
|
||||
// func() and func__ just exists to silence a compiler warning about lambda
|
||||
// being unused
|
||||
bool func() {
|
||||
lambda();
|
||||
return true;
|
||||
}
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
bool func__ = func();
|
||||
|
||||
static_assert(is_function_type<void()>::value, "");
|
||||
static_assert(is_function_type<int()>::value, "");
|
||||
static_assert(is_function_type<MyClass()>::value, "");
|
||||
static_assert(is_function_type<void(MyClass)>::value, "");
|
||||
static_assert(is_function_type<void(int)>::value, "");
|
||||
static_assert(is_function_type<void(void*)>::value, "");
|
||||
static_assert(is_function_type<int()>::value, "");
|
||||
static_assert(is_function_type<int(MyClass)>::value, "");
|
||||
static_assert(is_function_type<int(const MyClass&)>::value, "");
|
||||
static_assert(is_function_type<int(MyClass&&)>::value, "");
|
||||
static_assert(is_function_type < MyClass && () > ::value, "");
|
||||
static_assert(is_function_type < MyClass && (MyClass &&) > ::value, "");
|
||||
static_assert(is_function_type<const MyClass&(int, float, MyClass)>::value, "");
|
||||
|
||||
static_assert(!is_function_type<void>::value, "");
|
||||
static_assert(!is_function_type<int>::value, "");
|
||||
static_assert(!is_function_type<MyClass>::value, "");
|
||||
static_assert(!is_function_type<void*>::value, "");
|
||||
static_assert(!is_function_type<const MyClass&>::value, "");
|
||||
static_assert(!is_function_type<MyClass&&>::value, "");
|
||||
|
||||
static_assert(
|
||||
!is_function_type<void (*)()>::value,
|
||||
"function pointers aren't plain functions");
|
||||
static_assert(
|
||||
!is_function_type<Functor>::value,
|
||||
"Functors aren't plain functions");
|
||||
static_assert(
|
||||
!is_function_type<decltype(lambda)>::value,
|
||||
"Lambdas aren't plain functions");
|
||||
} // namespace test_is_function_type
|
||||
|
||||
namespace test_is_instantiation_of {
|
||||
class MyClass {};
|
||||
template<class T> class Single {};
|
||||
template<class T1, class T2> class Double {};
|
||||
template<class... T> class Multiple {};
|
||||
class MyClass {};
|
||||
template <class T>
|
||||
class Single {};
|
||||
template <class T1, class T2>
|
||||
class Double {};
|
||||
template <class... T>
|
||||
class Multiple {};
|
||||
|
||||
static_assert(is_instantiation_of<Single, Single<void>>::value, "");
|
||||
static_assert(is_instantiation_of<Single, Single<MyClass>>::value, "");
|
||||
static_assert(is_instantiation_of<Single, Single<int>>::value, "");
|
||||
static_assert(is_instantiation_of<Single, Single<void*>>::value, "");
|
||||
static_assert(is_instantiation_of<Single, Single<int*>>::value, "");
|
||||
static_assert(is_instantiation_of<Single, Single<const MyClass&>>::value, "");
|
||||
static_assert(is_instantiation_of<Single, Single<MyClass&&>>::value, "");
|
||||
static_assert(is_instantiation_of<Double, Double<int, void>>::value, "");
|
||||
static_assert(is_instantiation_of<Double, Double<const int&, MyClass*>>::value, "");
|
||||
static_assert(is_instantiation_of<Multiple, Multiple<>>::value, "");
|
||||
static_assert(is_instantiation_of<Multiple, Multiple<int>>::value, "");
|
||||
static_assert(is_instantiation_of<Multiple, Multiple<MyClass&, int>>::value, "");
|
||||
static_assert(is_instantiation_of<Multiple, Multiple<MyClass&, int, MyClass>>::value, "");
|
||||
static_assert(is_instantiation_of<Multiple, Multiple<MyClass&, int, MyClass, void*>>::value, "");
|
||||
static_assert(is_instantiation_of<Single, Single<void>>::value, "");
|
||||
static_assert(is_instantiation_of<Single, Single<MyClass>>::value, "");
|
||||
static_assert(is_instantiation_of<Single, Single<int>>::value, "");
|
||||
static_assert(is_instantiation_of<Single, Single<void*>>::value, "");
|
||||
static_assert(is_instantiation_of<Single, Single<int*>>::value, "");
|
||||
static_assert(is_instantiation_of<Single, Single<const MyClass&>>::value, "");
|
||||
static_assert(is_instantiation_of<Single, Single<MyClass&&>>::value, "");
|
||||
static_assert(is_instantiation_of<Double, Double<int, void>>::value, "");
|
||||
static_assert(
|
||||
is_instantiation_of<Double, Double<const int&, MyClass*>>::value,
|
||||
"");
|
||||
static_assert(is_instantiation_of<Multiple, Multiple<>>::value, "");
|
||||
static_assert(is_instantiation_of<Multiple, Multiple<int>>::value, "");
|
||||
static_assert(
|
||||
is_instantiation_of<Multiple, Multiple<MyClass&, int>>::value,
|
||||
"");
|
||||
static_assert(
|
||||
is_instantiation_of<Multiple, Multiple<MyClass&, int, MyClass>>::value,
|
||||
"");
|
||||
static_assert(
|
||||
is_instantiation_of<Multiple, Multiple<MyClass&, int, MyClass, void*>>::
|
||||
value,
|
||||
"");
|
||||
|
||||
static_assert(!is_instantiation_of<Single, Double<int, int>>::value, "");
|
||||
static_assert(!is_instantiation_of<Single, Double<int, void>>::value, "");
|
||||
static_assert(!is_instantiation_of<Single, Multiple<int>>::value, "");
|
||||
static_assert(!is_instantiation_of<Double, Single<int>>::value, "");
|
||||
static_assert(!is_instantiation_of<Double, Multiple<int, int>>::value, "");
|
||||
static_assert(!is_instantiation_of<Double, Multiple<>>::value, "");
|
||||
static_assert(!is_instantiation_of<Multiple, Double<int, int>>::value, "");
|
||||
static_assert(!is_instantiation_of<Multiple, Single<int>>::value, "");
|
||||
}
|
||||
static_assert(!is_instantiation_of<Single, Double<int, int>>::value, "");
|
||||
static_assert(!is_instantiation_of<Single, Double<int, void>>::value, "");
|
||||
static_assert(!is_instantiation_of<Single, Multiple<int>>::value, "");
|
||||
static_assert(!is_instantiation_of<Double, Single<int>>::value, "");
|
||||
static_assert(!is_instantiation_of<Double, Multiple<int, int>>::value, "");
|
||||
static_assert(!is_instantiation_of<Double, Multiple<>>::value, "");
|
||||
static_assert(!is_instantiation_of<Multiple, Double<int, int>>::value, "");
|
||||
static_assert(!is_instantiation_of<Multiple, Single<int>>::value, "");
|
||||
} // namespace test_is_instantiation_of
|
||||
|
||||
namespace test_is_type_condition {
|
||||
template<class> class NotATypeCondition {};
|
||||
static_assert(is_type_condition<std::is_reference>::value, "");
|
||||
static_assert(!is_type_condition<NotATypeCondition>::value, "");
|
||||
}
|
||||
}
|
||||
template <class>
|
||||
class NotATypeCondition {};
|
||||
static_assert(is_type_condition<std::is_reference>::value, "");
|
||||
static_assert(!is_type_condition<NotATypeCondition>::value, "");
|
||||
} // namespace test_is_type_condition
|
||||
} // namespace
|
||||
|
||||
namespace test_lambda_is_stateless {
|
||||
template<class Result, class... Args>
|
||||
struct MyStatelessFunctor final {
|
||||
Result operator()(Args...) {}
|
||||
};
|
||||
template <class Result, class... Args>
|
||||
struct MyStatelessFunctor final {
|
||||
Result operator()(Args...) {}
|
||||
};
|
||||
|
||||
template<class Result, class... Args>
|
||||
struct MyStatelessConstFunctor final {
|
||||
Result operator()(Args...) const {}
|
||||
};
|
||||
template <class Result, class... Args>
|
||||
struct MyStatelessConstFunctor final {
|
||||
Result operator()(Args...) const {}
|
||||
};
|
||||
|
||||
void func() {
|
||||
auto stateless_lambda = [] (int a) {return a;};
|
||||
static_assert(is_stateless_lambda<decltype(stateless_lambda)>::value, "");
|
||||
void func() {
|
||||
auto stateless_lambda = [](int a) { return a; };
|
||||
static_assert(is_stateless_lambda<decltype(stateless_lambda)>::value, "");
|
||||
|
||||
int b = 4;
|
||||
auto stateful_lambda_1 = [&] (int a) {return a + b;};
|
||||
static_assert(!is_stateless_lambda<decltype(stateful_lambda_1)>::value, "");
|
||||
int b = 4;
|
||||
auto stateful_lambda_1 = [&](int a) { return a + b; };
|
||||
static_assert(!is_stateless_lambda<decltype(stateful_lambda_1)>::value, "");
|
||||
|
||||
auto stateful_lambda_2 = [=] (int a) {return a + b;};
|
||||
static_assert(!is_stateless_lambda<decltype(stateful_lambda_2)>::value, "");
|
||||
auto stateful_lambda_2 = [=](int a) { return a + b; };
|
||||
static_assert(!is_stateless_lambda<decltype(stateful_lambda_2)>::value, "");
|
||||
|
||||
auto stateful_lambda_3 = [b] (int a) {return a + b;};
|
||||
static_assert(!is_stateless_lambda<decltype(stateful_lambda_3)>::value, "");
|
||||
auto stateful_lambda_3 = [b](int a) { return a + b; };
|
||||
static_assert(!is_stateless_lambda<decltype(stateful_lambda_3)>::value, "");
|
||||
|
||||
static_assert(!is_stateless_lambda<MyStatelessFunctor<int, int>>::value, "even if stateless, a functor is not a lambda, so it's false");
|
||||
static_assert(!is_stateless_lambda<MyStatelessFunctor<void, int>>::value, "even if stateless, a functor is not a lambda, so it's false");
|
||||
static_assert(!is_stateless_lambda<MyStatelessConstFunctor<int, int>>::value, "even if stateless, a functor is not a lambda, so it's false");
|
||||
static_assert(!is_stateless_lambda<MyStatelessConstFunctor<void, int>>::value, "even if stateless, a functor is not a lambda, so it's false");
|
||||
static_assert(
|
||||
!is_stateless_lambda<MyStatelessFunctor<int, int>>::value,
|
||||
"even if stateless, a functor is not a lambda, so it's false");
|
||||
static_assert(
|
||||
!is_stateless_lambda<MyStatelessFunctor<void, int>>::value,
|
||||
"even if stateless, a functor is not a lambda, so it's false");
|
||||
static_assert(
|
||||
!is_stateless_lambda<MyStatelessConstFunctor<int, int>>::value,
|
||||
"even if stateless, a functor is not a lambda, so it's false");
|
||||
static_assert(
|
||||
!is_stateless_lambda<MyStatelessConstFunctor<void, int>>::value,
|
||||
"even if stateless, a functor is not a lambda, so it's false");
|
||||
|
||||
class Dummy final {};
|
||||
static_assert(!is_stateless_lambda<Dummy>::value, "A non-functor type is also not a lambda");
|
||||
class Dummy final {};
|
||||
static_assert(
|
||||
!is_stateless_lambda<Dummy>::value,
|
||||
"A non-functor type is also not a lambda");
|
||||
|
||||
static_assert(!is_stateless_lambda<int>::value, "An int is not a lambda");
|
||||
static_assert(!is_stateless_lambda<int>::value, "An int is not a lambda");
|
||||
|
||||
using Func = int(int);
|
||||
static_assert(!is_stateless_lambda<Func>::value, "A function is not a lambda");
|
||||
static_assert(!is_stateless_lambda<Func*>::value, "A function pointer is not a lambda");
|
||||
}
|
||||
using Func = int(int);
|
||||
static_assert(
|
||||
!is_stateless_lambda<Func>::value, "A function is not a lambda");
|
||||
static_assert(
|
||||
!is_stateless_lambda<Func*>::value, "A function pointer is not a lambda");
|
||||
}
|
||||
} // namespace test_lambda_is_stateless
|
||||
|
|
|
|||
|
|
@ -11,70 +11,73 @@ using namespace ::testing;
|
|||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(accumulate_test, vector_test) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
std::vector<int> ints = {1, 2, 3, 4, 5};
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
std::vector<int> ints = {1, 2, 3, 4, 5};
|
||||
|
||||
EXPECT_EQ(c10::sum_integers(ints), 1+2+3+4+5);
|
||||
EXPECT_EQ(c10::multiply_integers(ints), 1*2*3*4*5);
|
||||
EXPECT_EQ(c10::sum_integers(ints), 1 + 2 + 3 + 4 + 5);
|
||||
EXPECT_EQ(c10::multiply_integers(ints), 1 * 2 * 3 * 4 * 5);
|
||||
|
||||
EXPECT_EQ(c10::sum_integers(ints.begin(), ints.end()), 1+2+3+4+5);
|
||||
EXPECT_EQ(c10::multiply_integers(ints.begin(), ints.end()), 1*2*3*4*5);
|
||||
EXPECT_EQ(c10::sum_integers(ints.begin(), ints.end()), 1 + 2 + 3 + 4 + 5);
|
||||
EXPECT_EQ(
|
||||
c10::multiply_integers(ints.begin(), ints.end()), 1 * 2 * 3 * 4 * 5);
|
||||
|
||||
EXPECT_EQ(c10::sum_integers(ints.begin()+1, ints.end()-1), 2+3+4);
|
||||
EXPECT_EQ(c10::multiply_integers(ints.begin()+1, ints.end()-1), 2*3*4);
|
||||
EXPECT_EQ(c10::sum_integers(ints.begin() + 1, ints.end() - 1), 2 + 3 + 4);
|
||||
EXPECT_EQ(
|
||||
c10::multiply_integers(ints.begin() + 1, ints.end() - 1), 2 * 3 * 4);
|
||||
|
||||
EXPECT_EQ(c10::numelements_from_dim(2, ints), 3*4*5);
|
||||
EXPECT_EQ(c10::numelements_to_dim(3, ints), 1*2*3);
|
||||
EXPECT_EQ(c10::numelements_between_dim(2, 4, ints), 3*4);
|
||||
EXPECT_EQ(c10::numelements_between_dim(4, 2, ints), 3*4);
|
||||
EXPECT_EQ(c10::numelements_from_dim(2, ints), 3 * 4 * 5);
|
||||
EXPECT_EQ(c10::numelements_to_dim(3, ints), 1 * 2 * 3);
|
||||
EXPECT_EQ(c10::numelements_between_dim(2, 4, ints), 3 * 4);
|
||||
EXPECT_EQ(c10::numelements_between_dim(4, 2, ints), 3 * 4);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(accumulate_test, list_test) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
std::list<int> ints = {1, 2, 3, 4, 5};
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
std::list<int> ints = {1, 2, 3, 4, 5};
|
||||
|
||||
EXPECT_EQ(c10::sum_integers(ints), 1+2+3+4+5);
|
||||
EXPECT_EQ(c10::multiply_integers(ints), 1*2*3*4*5);
|
||||
EXPECT_EQ(c10::sum_integers(ints), 1 + 2 + 3 + 4 + 5);
|
||||
EXPECT_EQ(c10::multiply_integers(ints), 1 * 2 * 3 * 4 * 5);
|
||||
|
||||
EXPECT_EQ(c10::sum_integers(ints.begin(), ints.end()), 1+2+3+4+5);
|
||||
EXPECT_EQ(c10::multiply_integers(ints.begin(), ints.end()), 1*2*3*4*5);
|
||||
EXPECT_EQ(c10::sum_integers(ints.begin(), ints.end()), 1 + 2 + 3 + 4 + 5);
|
||||
EXPECT_EQ(
|
||||
c10::multiply_integers(ints.begin(), ints.end()), 1 * 2 * 3 * 4 * 5);
|
||||
|
||||
EXPECT_EQ(c10::numelements_from_dim(2, ints), 3*4*5);
|
||||
EXPECT_EQ(c10::numelements_to_dim(3, ints), 1*2*3);
|
||||
EXPECT_EQ(c10::numelements_between_dim(2, 4, ints), 3*4);
|
||||
EXPECT_EQ(c10::numelements_between_dim(4, 2, ints), 3*4);
|
||||
EXPECT_EQ(c10::numelements_from_dim(2, ints), 3 * 4 * 5);
|
||||
EXPECT_EQ(c10::numelements_to_dim(3, ints), 1 * 2 * 3);
|
||||
EXPECT_EQ(c10::numelements_between_dim(2, 4, ints), 3 * 4);
|
||||
EXPECT_EQ(c10::numelements_between_dim(4, 2, ints), 3 * 4);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(accumulate_test, base_cases) {
|
||||
std::vector<int> ints = {};
|
||||
std::vector<int> ints = {};
|
||||
|
||||
EXPECT_EQ(c10::sum_integers(ints), 0);
|
||||
EXPECT_EQ(c10::multiply_integers(ints), 1);
|
||||
EXPECT_EQ(c10::sum_integers(ints), 0);
|
||||
EXPECT_EQ(c10::multiply_integers(ints), 1);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(accumulate_test, errors) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
std::vector<int> ints = {1,2,3,4,5};
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
std::vector<int> ints = {1, 2, 3, 4, 5};
|
||||
|
||||
#ifndef NDEBUG
|
||||
EXPECT_THROW(c10::numelements_from_dim(-1, ints), c10::Error);
|
||||
#endif
|
||||
#ifndef NDEBUG
|
||||
EXPECT_THROW(c10::numelements_from_dim(-1, ints), c10::Error);
|
||||
#endif
|
||||
|
||||
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
|
||||
EXPECT_THROW(c10::numelements_to_dim(-1, ints), c10::Error);
|
||||
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
|
||||
EXPECT_THROW(c10::numelements_between_dim(-1, 10, ints), c10::Error);
|
||||
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
|
||||
EXPECT_THROW(c10::numelements_between_dim(10, -1, ints), c10::Error);
|
||||
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
|
||||
EXPECT_THROW(c10::numelements_to_dim(-1, ints), c10::Error);
|
||||
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
|
||||
EXPECT_THROW(c10::numelements_between_dim(-1, 10, ints), c10::Error);
|
||||
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
|
||||
EXPECT_THROW(c10::numelements_between_dim(10, -1, ints), c10::Error);
|
||||
|
||||
EXPECT_EQ(c10::numelements_from_dim(10, ints),1);
|
||||
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
|
||||
EXPECT_THROW(c10::numelements_to_dim(10, ints), c10::Error);
|
||||
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
|
||||
EXPECT_THROW(c10::numelements_between_dim(10, 4, ints), c10::Error);
|
||||
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
|
||||
EXPECT_THROW(c10::numelements_between_dim(4, 10, ints), c10::Error);
|
||||
EXPECT_EQ(c10::numelements_from_dim(10, ints), 1);
|
||||
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
|
||||
EXPECT_THROW(c10::numelements_to_dim(10, ints), c10::Error);
|
||||
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
|
||||
EXPECT_THROW(c10::numelements_between_dim(10, 4, ints), c10::Error);
|
||||
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
|
||||
EXPECT_THROW(c10::numelements_between_dim(4, 10, ints), c10::Error);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,196 +2,194 @@
|
|||
#include <gtest/gtest.h>
|
||||
|
||||
namespace {
|
||||
float float_from_bytes(
|
||||
uint32_t sign,
|
||||
uint32_t exponent,
|
||||
uint32_t fraction
|
||||
) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
uint32_t bytes;
|
||||
bytes = 0;
|
||||
bytes |= sign;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
bytes <<= 8;
|
||||
bytes |= exponent;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
bytes <<= 23;
|
||||
bytes |= fraction;
|
||||
float float_from_bytes(uint32_t sign, uint32_t exponent, uint32_t fraction) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
uint32_t bytes;
|
||||
bytes = 0;
|
||||
bytes |= sign;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
bytes <<= 8;
|
||||
bytes |= exponent;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
bytes <<= 23;
|
||||
bytes |= fraction;
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
float res;
|
||||
std::memcpy(&res, &bytes, sizeof(res));
|
||||
return res;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
float res;
|
||||
std::memcpy(&res, &bytes, sizeof(res));
|
||||
return res;
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(BFloat16Conversion, FloatToBFloat16AndBack) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays)
|
||||
float in[100];
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
for (int i = 0; i < 100; ++i) {
|
||||
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers)
|
||||
in[i] = i + 1.25;
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(BFloat16Conversion, FloatToBFloat16AndBack) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays)
|
||||
float in[100];
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
for (int i = 0; i < 100; ++i) {
|
||||
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers)
|
||||
in[i] = i + 1.25;
|
||||
}
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays)
|
||||
c10::BFloat16 bfloats[100];
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays)
|
||||
float out[100];
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays)
|
||||
c10::BFloat16 bfloats[100];
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays)
|
||||
float out[100];
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
for (int i = 0; i < 100; ++i) {
|
||||
bfloats[i].x = c10::detail::bits_from_f32(in[i]);
|
||||
out[i] = c10::detail::f32_from_bits(bfloats[i].x);
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
for (int i = 0; i < 100; ++i) {
|
||||
bfloats[i].x = c10::detail::bits_from_f32(in[i]);
|
||||
out[i] = c10::detail::f32_from_bits(bfloats[i].x);
|
||||
// The relative error should be less than 1/(2^7) since BFloat16
|
||||
// has 7 bits mantissa.
|
||||
EXPECT_LE(fabs(out[i] - in[i]) / in[i], 1.0 / 128);
|
||||
}
|
||||
}
|
||||
|
||||
// The relative error should be less than 1/(2^7) since BFloat16
|
||||
// has 7 bits mantissa.
|
||||
EXPECT_LE(fabs(out[i] - in[i]) / in[i], 1.0 / 128);
|
||||
}
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(BFloat16Conversion, FloatToBFloat16RNEAndBack) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays)
|
||||
float in[100];
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
for (int i = 0; i < 100; ++i) {
|
||||
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers)
|
||||
in[i] = i + 1.25;
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(BFloat16Conversion, FloatToBFloat16RNEAndBack) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays)
|
||||
float in[100];
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
for (int i = 0; i < 100; ++i) {
|
||||
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers)
|
||||
in[i] = i + 1.25;
|
||||
}
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays)
|
||||
c10::BFloat16 bfloats[100];
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays)
|
||||
float out[100];
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays)
|
||||
c10::BFloat16 bfloats[100];
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays)
|
||||
float out[100];
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
for (int i = 0; i < 100; ++i) {
|
||||
bfloats[i].x = c10::detail::round_to_nearest_even(in[i]);
|
||||
out[i] = c10::detail::f32_from_bits(bfloats[i].x);
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
for (int i = 0; i < 100; ++i) {
|
||||
bfloats[i].x = c10::detail::round_to_nearest_even(in[i]);
|
||||
out[i] = c10::detail::f32_from_bits(bfloats[i].x);
|
||||
|
||||
// The relative error should be less than 1/(2^7) since BFloat16
|
||||
// has 7 bits mantissa.
|
||||
EXPECT_LE(fabs(out[i] - in[i]) / in[i], 1.0 / 128);
|
||||
}
|
||||
// The relative error should be less than 1/(2^7) since BFloat16
|
||||
// has 7 bits mantissa.
|
||||
EXPECT_LE(fabs(out[i] - in[i]) / in[i], 1.0 / 128);
|
||||
}
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(BFloat16Conversion, NaN) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
float inNaN = float_from_bytes(0, 0xFF, 0x7FFFFF);
|
||||
EXPECT_TRUE(std::isnan(inNaN));
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(BFloat16Conversion, NaN) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
float inNaN = float_from_bytes(0, 0xFF, 0x7FFFFF);
|
||||
EXPECT_TRUE(std::isnan(inNaN));
|
||||
|
||||
c10::BFloat16 a = c10::BFloat16(inNaN);
|
||||
float out = c10::detail::f32_from_bits(a.x);
|
||||
c10::BFloat16 a = c10::BFloat16(inNaN);
|
||||
float out = c10::detail::f32_from_bits(a.x);
|
||||
|
||||
EXPECT_TRUE(std::isnan(out));
|
||||
}
|
||||
EXPECT_TRUE(std::isnan(out));
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(BFloat16Conversion, Inf) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
float inInf = float_from_bytes(0, 0xFF, 0);
|
||||
EXPECT_TRUE(std::isinf(inInf));
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(BFloat16Conversion, Inf) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
float inInf = float_from_bytes(0, 0xFF, 0);
|
||||
EXPECT_TRUE(std::isinf(inInf));
|
||||
|
||||
c10::BFloat16 a = c10::BFloat16(inInf);
|
||||
float out = c10::detail::f32_from_bits(a.x);
|
||||
c10::BFloat16 a = c10::BFloat16(inInf);
|
||||
float out = c10::detail::f32_from_bits(a.x);
|
||||
|
||||
EXPECT_TRUE(std::isinf(out));
|
||||
}
|
||||
EXPECT_TRUE(std::isinf(out));
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(BFloat16Conversion, SmallestDenormal) {
|
||||
float in = std::numeric_limits<float>::denorm_min(); // The smallest non-zero subnormal number
|
||||
c10::BFloat16 a = c10::BFloat16(in);
|
||||
float out = c10::detail::f32_from_bits(a.x);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(BFloat16Conversion, SmallestDenormal) {
|
||||
float in = std::numeric_limits<float>::denorm_min(); // The smallest non-zero
|
||||
// subnormal number
|
||||
c10::BFloat16 a = c10::BFloat16(in);
|
||||
float out = c10::detail::f32_from_bits(a.x);
|
||||
|
||||
EXPECT_FLOAT_EQ(in, out);
|
||||
}
|
||||
EXPECT_FLOAT_EQ(in, out);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(BFloat16Math, Addition) {
|
||||
// This test verifies that if only first 7 bits of float's mantissa are
|
||||
// changed after addition, we should have no loss in precision.
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(BFloat16Math, Addition) {
|
||||
// This test verifies that if only first 7 bits of float's mantissa are
|
||||
// changed after addition, we should have no loss in precision.
|
||||
|
||||
// input bits
|
||||
// S | Exponent | Mantissa
|
||||
// 0 | 10000000 | 10010000000000000000000 = 3.125
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
float input = float_from_bytes(0, 0, 0x40480000);
|
||||
// input bits
|
||||
// S | Exponent | Mantissa
|
||||
// 0 | 10000000 | 10010000000000000000000 = 3.125
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
float input = float_from_bytes(0, 0, 0x40480000);
|
||||
|
||||
// expected bits
|
||||
// S | Exponent | Mantissa
|
||||
// 0 | 10000001 | 10010000000000000000000 = 6.25
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
float expected = float_from_bytes(0, 0, 0x40c80000);
|
||||
// expected bits
|
||||
// S | Exponent | Mantissa
|
||||
// 0 | 10000001 | 10010000000000000000000 = 6.25
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
float expected = float_from_bytes(0, 0, 0x40c80000);
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
c10::BFloat16 b;
|
||||
b.x = c10::detail::bits_from_f32(input);
|
||||
b = b + b;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
c10::BFloat16 b;
|
||||
b.x = c10::detail::bits_from_f32(input);
|
||||
b = b + b;
|
||||
|
||||
float res = c10::detail::f32_from_bits(b.x);
|
||||
EXPECT_EQ(res, expected);
|
||||
}
|
||||
float res = c10::detail::f32_from_bits(b.x);
|
||||
EXPECT_EQ(res, expected);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(BFloat16Math, Subtraction) {
|
||||
// This test verifies that if only first 7 bits of float's mantissa are
|
||||
// changed after subtraction, we should have no loss in precision.
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(BFloat16Math, Subtraction) {
|
||||
// This test verifies that if only first 7 bits of float's mantissa are
|
||||
// changed after subtraction, we should have no loss in precision.
|
||||
|
||||
// input bits
|
||||
// S | Exponent | Mantissa
|
||||
// 0 | 10000001 | 11101000000000000000000 = 7.625
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
float input = float_from_bytes(0, 0, 0x40f40000);
|
||||
// input bits
|
||||
// S | Exponent | Mantissa
|
||||
// 0 | 10000001 | 11101000000000000000000 = 7.625
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
float input = float_from_bytes(0, 0, 0x40f40000);
|
||||
|
||||
// expected bits
|
||||
// S | Exponent | Mantissa
|
||||
// 0 | 10000000 | 01010000000000000000000 = 2.625
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
float expected = float_from_bytes(0, 0, 0x40280000);
|
||||
// expected bits
|
||||
// S | Exponent | Mantissa
|
||||
// 0 | 10000000 | 01010000000000000000000 = 2.625
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
float expected = float_from_bytes(0, 0, 0x40280000);
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
c10::BFloat16 b;
|
||||
b.x = c10::detail::bits_from_f32(input);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
b = b - 5;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
c10::BFloat16 b;
|
||||
b.x = c10::detail::bits_from_f32(input);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
b = b - 5;
|
||||
|
||||
float res = c10::detail::f32_from_bits(b.x);
|
||||
EXPECT_EQ(res, expected);
|
||||
}
|
||||
float res = c10::detail::f32_from_bits(b.x);
|
||||
EXPECT_EQ(res, expected);
|
||||
}
|
||||
|
||||
float BinaryToFloat(uint32_t bytes) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
float res;
|
||||
std::memcpy(&res, &bytes, sizeof(res));
|
||||
return res;
|
||||
}
|
||||
float BinaryToFloat(uint32_t bytes) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
float res;
|
||||
std::memcpy(&res, &bytes, sizeof(res));
|
||||
return res;
|
||||
}
|
||||
|
||||
struct BFloat16TestParam {
|
||||
uint32_t input;
|
||||
uint16_t rne;
|
||||
};
|
||||
struct BFloat16TestParam {
|
||||
uint32_t input;
|
||||
uint16_t rne;
|
||||
};
|
||||
|
||||
class BFloat16Test : public ::testing::Test,
|
||||
public ::testing::WithParamInterface<BFloat16TestParam> {
|
||||
};
|
||||
class BFloat16Test : public ::testing::Test,
|
||||
public ::testing::WithParamInterface<BFloat16TestParam> {};
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST_P(BFloat16Test, BFloat16RNETest) {
|
||||
float value = BinaryToFloat(GetParam().input);
|
||||
uint16_t rounded = c10::detail::round_to_nearest_even(value);
|
||||
EXPECT_EQ(GetParam().rne, rounded);
|
||||
}
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST_P(BFloat16Test, BFloat16RNETest) {
|
||||
float value = BinaryToFloat(GetParam().input);
|
||||
uint16_t rounded = c10::detail::round_to_nearest_even(value);
|
||||
EXPECT_EQ(GetParam().rne, rounded);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
INSTANTIATE_TEST_CASE_P(
|
||||
BFloat16Test_Instantiation, BFloat16Test,
|
||||
::testing::Values(BFloat16TestParam{0x3F848000, 0x3F84},
|
||||
BFloat16TestParam{0x3F848010, 0x3F85},
|
||||
BFloat16TestParam{0x3F850000, 0x3F85},
|
||||
BFloat16TestParam{0x3F858000, 0x3F86},
|
||||
BFloat16TestParam{0x3FFF8000, 0x4000}));
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
INSTANTIATE_TEST_CASE_P(
|
||||
BFloat16Test_Instantiation,
|
||||
BFloat16Test,
|
||||
::testing::Values(
|
||||
BFloat16TestParam{0x3F848000, 0x3F84},
|
||||
BFloat16TestParam{0x3F848010, 0x3F85},
|
||||
BFloat16TestParam{0x3F850000, 0x3F85},
|
||||
BFloat16TestParam{0x3F858000, 0x3F86},
|
||||
BFloat16TestParam{0x3FFF8000, 0x4000}));
|
||||
|
||||
} // namespace
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
// Warning: this file is included twice in aten/src/ATen/test/cuda_complex_math_test.cu
|
||||
// Warning: this file is included twice in
|
||||
// aten/src/ATen/test/cuda_complex_math_test.cu
|
||||
|
||||
#include <c10/util/complex.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
|
@ -16,152 +17,152 @@
|
|||
C10_DEFINE_TEST(TestExponential, IPi) {
|
||||
// exp(i*pi) = -1
|
||||
{
|
||||
c10::complex<float> e_i_pi = std::exp(c10::complex<float>(0, float(PI)));
|
||||
C10_ASSERT_NEAR(e_i_pi.real(), -1, tol);
|
||||
C10_ASSERT_NEAR(e_i_pi.imag(), 0, tol);
|
||||
c10::complex<float> e_i_pi = std::exp(c10::complex<float>(0, float(PI)));
|
||||
C10_ASSERT_NEAR(e_i_pi.real(), -1, tol);
|
||||
C10_ASSERT_NEAR(e_i_pi.imag(), 0, tol);
|
||||
}
|
||||
{
|
||||
c10::complex<float> e_i_pi = ::exp(c10::complex<float>(0, float(PI)));
|
||||
C10_ASSERT_NEAR(e_i_pi.real(), -1, tol);
|
||||
C10_ASSERT_NEAR(e_i_pi.imag(), 0, tol);
|
||||
c10::complex<float> e_i_pi = ::exp(c10::complex<float>(0, float(PI)));
|
||||
C10_ASSERT_NEAR(e_i_pi.real(), -1, tol);
|
||||
C10_ASSERT_NEAR(e_i_pi.imag(), 0, tol);
|
||||
}
|
||||
{
|
||||
c10::complex<double> e_i_pi = std::exp(c10::complex<double>(0, PI));
|
||||
C10_ASSERT_NEAR(e_i_pi.real(), -1, tol);
|
||||
C10_ASSERT_NEAR(e_i_pi.imag(), 0, tol);
|
||||
c10::complex<double> e_i_pi = std::exp(c10::complex<double>(0, PI));
|
||||
C10_ASSERT_NEAR(e_i_pi.real(), -1, tol);
|
||||
C10_ASSERT_NEAR(e_i_pi.imag(), 0, tol);
|
||||
}
|
||||
{
|
||||
c10::complex<double> e_i_pi = ::exp(c10::complex<double>(0, PI));
|
||||
C10_ASSERT_NEAR(e_i_pi.real(), -1, tol);
|
||||
C10_ASSERT_NEAR(e_i_pi.imag(), 0, tol);
|
||||
c10::complex<double> e_i_pi = ::exp(c10::complex<double>(0, PI));
|
||||
C10_ASSERT_NEAR(e_i_pi.real(), -1, tol);
|
||||
C10_ASSERT_NEAR(e_i_pi.imag(), 0, tol);
|
||||
}
|
||||
}
|
||||
|
||||
C10_DEFINE_TEST(TestExponential, EulerFormula) {
|
||||
// exp(ix) = cos(x) + i * sin(x)
|
||||
{
|
||||
c10::complex<float> x(0.1, 1.2);
|
||||
c10::complex<float> e = std::exp(x);
|
||||
float expected_real = std::exp(x.real()) * std::cos(x.imag());
|
||||
float expected_imag = std::exp(x.real()) * std::sin(x.imag());
|
||||
C10_ASSERT_NEAR(e.real(), expected_real, tol);
|
||||
C10_ASSERT_NEAR(e.imag(), expected_imag, tol);
|
||||
c10::complex<float> x(0.1, 1.2);
|
||||
c10::complex<float> e = std::exp(x);
|
||||
float expected_real = std::exp(x.real()) * std::cos(x.imag());
|
||||
float expected_imag = std::exp(x.real()) * std::sin(x.imag());
|
||||
C10_ASSERT_NEAR(e.real(), expected_real, tol);
|
||||
C10_ASSERT_NEAR(e.imag(), expected_imag, tol);
|
||||
}
|
||||
{
|
||||
c10::complex<float> x(0.1, 1.2);
|
||||
c10::complex<float> e = ::exp(x);
|
||||
float expected_real = ::exp(x.real()) * ::cos(x.imag());
|
||||
float expected_imag = ::exp(x.real()) * ::sin(x.imag());
|
||||
C10_ASSERT_NEAR(e.real(), expected_real, tol);
|
||||
C10_ASSERT_NEAR(e.imag(), expected_imag, tol);
|
||||
c10::complex<float> x(0.1, 1.2);
|
||||
c10::complex<float> e = ::exp(x);
|
||||
float expected_real = ::exp(x.real()) * ::cos(x.imag());
|
||||
float expected_imag = ::exp(x.real()) * ::sin(x.imag());
|
||||
C10_ASSERT_NEAR(e.real(), expected_real, tol);
|
||||
C10_ASSERT_NEAR(e.imag(), expected_imag, tol);
|
||||
}
|
||||
{
|
||||
c10::complex<double> x(0.1, 1.2);
|
||||
c10::complex<double> e = std::exp(x);
|
||||
float expected_real = std::exp(x.real()) * std::cos(x.imag());
|
||||
float expected_imag = std::exp(x.real()) * std::sin(x.imag());
|
||||
C10_ASSERT_NEAR(e.real(), expected_real, tol);
|
||||
C10_ASSERT_NEAR(e.imag(), expected_imag, tol);
|
||||
c10::complex<double> x(0.1, 1.2);
|
||||
c10::complex<double> e = std::exp(x);
|
||||
float expected_real = std::exp(x.real()) * std::cos(x.imag());
|
||||
float expected_imag = std::exp(x.real()) * std::sin(x.imag());
|
||||
C10_ASSERT_NEAR(e.real(), expected_real, tol);
|
||||
C10_ASSERT_NEAR(e.imag(), expected_imag, tol);
|
||||
}
|
||||
{
|
||||
c10::complex<double> x(0.1, 1.2);
|
||||
c10::complex<double> e = ::exp(x);
|
||||
float expected_real = ::exp(x.real()) * ::cos(x.imag());
|
||||
float expected_imag = ::exp(x.real()) * ::sin(x.imag());
|
||||
C10_ASSERT_NEAR(e.real(), expected_real, tol);
|
||||
C10_ASSERT_NEAR(e.imag(), expected_imag, tol);
|
||||
c10::complex<double> x(0.1, 1.2);
|
||||
c10::complex<double> e = ::exp(x);
|
||||
float expected_real = ::exp(x.real()) * ::cos(x.imag());
|
||||
float expected_imag = ::exp(x.real()) * ::sin(x.imag());
|
||||
C10_ASSERT_NEAR(e.real(), expected_real, tol);
|
||||
C10_ASSERT_NEAR(e.imag(), expected_imag, tol);
|
||||
}
|
||||
}
|
||||
|
||||
C10_DEFINE_TEST(TestLog, Definition) {
|
||||
// log(x) = log(r) + i*theta
|
||||
{
|
||||
c10::complex<float> x(1.2, 3.4);
|
||||
c10::complex<float> l = std::log(x);
|
||||
float expected_real = std::log(std::abs(x));
|
||||
float expected_imag = std::arg(x);
|
||||
C10_ASSERT_NEAR(l.real(), expected_real, tol);
|
||||
C10_ASSERT_NEAR(l.imag(), expected_imag, tol);
|
||||
c10::complex<float> x(1.2, 3.4);
|
||||
c10::complex<float> l = std::log(x);
|
||||
float expected_real = std::log(std::abs(x));
|
||||
float expected_imag = std::arg(x);
|
||||
C10_ASSERT_NEAR(l.real(), expected_real, tol);
|
||||
C10_ASSERT_NEAR(l.imag(), expected_imag, tol);
|
||||
}
|
||||
{
|
||||
c10::complex<float> x(1.2, 3.4);
|
||||
c10::complex<float> l = ::log(x);
|
||||
float expected_real = ::log(std::abs(x));
|
||||
float expected_imag = std::arg(x);
|
||||
C10_ASSERT_NEAR(l.real(), expected_real, tol);
|
||||
C10_ASSERT_NEAR(l.imag(), expected_imag, tol);
|
||||
c10::complex<float> x(1.2, 3.4);
|
||||
c10::complex<float> l = ::log(x);
|
||||
float expected_real = ::log(std::abs(x));
|
||||
float expected_imag = std::arg(x);
|
||||
C10_ASSERT_NEAR(l.real(), expected_real, tol);
|
||||
C10_ASSERT_NEAR(l.imag(), expected_imag, tol);
|
||||
}
|
||||
{
|
||||
c10::complex<double> x(1.2, 3.4);
|
||||
c10::complex<double> l = std::log(x);
|
||||
float expected_real = std::log(std::abs(x));
|
||||
float expected_imag = std::arg(x);
|
||||
C10_ASSERT_NEAR(l.real(), expected_real, tol);
|
||||
C10_ASSERT_NEAR(l.imag(), expected_imag, tol);
|
||||
c10::complex<double> x(1.2, 3.4);
|
||||
c10::complex<double> l = std::log(x);
|
||||
float expected_real = std::log(std::abs(x));
|
||||
float expected_imag = std::arg(x);
|
||||
C10_ASSERT_NEAR(l.real(), expected_real, tol);
|
||||
C10_ASSERT_NEAR(l.imag(), expected_imag, tol);
|
||||
}
|
||||
{
|
||||
c10::complex<double> x(1.2, 3.4);
|
||||
c10::complex<double> l = ::log(x);
|
||||
float expected_real = ::log(std::abs(x));
|
||||
float expected_imag = std::arg(x);
|
||||
C10_ASSERT_NEAR(l.real(), expected_real, tol);
|
||||
C10_ASSERT_NEAR(l.imag(), expected_imag, tol);
|
||||
c10::complex<double> x(1.2, 3.4);
|
||||
c10::complex<double> l = ::log(x);
|
||||
float expected_real = ::log(std::abs(x));
|
||||
float expected_imag = std::arg(x);
|
||||
C10_ASSERT_NEAR(l.real(), expected_real, tol);
|
||||
C10_ASSERT_NEAR(l.imag(), expected_imag, tol);
|
||||
}
|
||||
}
|
||||
|
||||
C10_DEFINE_TEST(TestLog10, Rev) {
|
||||
// log10(10^x) = x
|
||||
{
|
||||
c10::complex<float> x(0.1, 1.2);
|
||||
c10::complex<float> l = std::log10(std::pow(float(10), x));
|
||||
C10_ASSERT_NEAR(l.real(), float(0.1), tol);
|
||||
C10_ASSERT_NEAR(l.imag(), float(1.2), tol);
|
||||
c10::complex<float> x(0.1, 1.2);
|
||||
c10::complex<float> l = std::log10(std::pow(float(10), x));
|
||||
C10_ASSERT_NEAR(l.real(), float(0.1), tol);
|
||||
C10_ASSERT_NEAR(l.imag(), float(1.2), tol);
|
||||
}
|
||||
{
|
||||
c10::complex<float> x(0.1, 1.2);
|
||||
c10::complex<float> l = ::log10(::pow(float(10), x));
|
||||
C10_ASSERT_NEAR(l.real(), float(0.1), tol);
|
||||
C10_ASSERT_NEAR(l.imag(), float(1.2), tol);
|
||||
c10::complex<float> x(0.1, 1.2);
|
||||
c10::complex<float> l = ::log10(::pow(float(10), x));
|
||||
C10_ASSERT_NEAR(l.real(), float(0.1), tol);
|
||||
C10_ASSERT_NEAR(l.imag(), float(1.2), tol);
|
||||
}
|
||||
{
|
||||
c10::complex<double> x(0.1, 1.2);
|
||||
c10::complex<double> l = std::log10(std::pow(double(10), x));
|
||||
C10_ASSERT_NEAR(l.real(), double(0.1), tol);
|
||||
C10_ASSERT_NEAR(l.imag(), double(1.2), tol);
|
||||
c10::complex<double> x(0.1, 1.2);
|
||||
c10::complex<double> l = std::log10(std::pow(double(10), x));
|
||||
C10_ASSERT_NEAR(l.real(), double(0.1), tol);
|
||||
C10_ASSERT_NEAR(l.imag(), double(1.2), tol);
|
||||
}
|
||||
{
|
||||
c10::complex<double> x(0.1, 1.2);
|
||||
c10::complex<double> l = ::log10(::pow(double(10), x));
|
||||
C10_ASSERT_NEAR(l.real(), double(0.1), tol);
|
||||
C10_ASSERT_NEAR(l.imag(), double(1.2), tol);
|
||||
c10::complex<double> x(0.1, 1.2);
|
||||
c10::complex<double> l = ::log10(::pow(double(10), x));
|
||||
C10_ASSERT_NEAR(l.real(), double(0.1), tol);
|
||||
C10_ASSERT_NEAR(l.imag(), double(1.2), tol);
|
||||
}
|
||||
}
|
||||
|
||||
C10_DEFINE_TEST(TestLog2, Rev) {
|
||||
// log2(2^x) = x
|
||||
{
|
||||
c10::complex<float> x(0.1, 1.2);
|
||||
c10::complex<float> l = std::log2(std::pow(float(2), x));
|
||||
C10_ASSERT_NEAR(l.real(), float(0.1), tol);
|
||||
C10_ASSERT_NEAR(l.imag(), float(1.2), tol);
|
||||
c10::complex<float> x(0.1, 1.2);
|
||||
c10::complex<float> l = std::log2(std::pow(float(2), x));
|
||||
C10_ASSERT_NEAR(l.real(), float(0.1), tol);
|
||||
C10_ASSERT_NEAR(l.imag(), float(1.2), tol);
|
||||
}
|
||||
{
|
||||
c10::complex<float> x(0.1, 1.2);
|
||||
c10::complex<float> l = ::log2(std::pow(float(2), x));
|
||||
C10_ASSERT_NEAR(l.real(), float(0.1), tol);
|
||||
C10_ASSERT_NEAR(l.imag(), float(1.2), tol);
|
||||
c10::complex<float> x(0.1, 1.2);
|
||||
c10::complex<float> l = ::log2(std::pow(float(2), x));
|
||||
C10_ASSERT_NEAR(l.real(), float(0.1), tol);
|
||||
C10_ASSERT_NEAR(l.imag(), float(1.2), tol);
|
||||
}
|
||||
{
|
||||
c10::complex<double> x(0.1, 1.2);
|
||||
c10::complex<double> l = std::log2(std::pow(double(2), x));
|
||||
C10_ASSERT_NEAR(l.real(), double(0.1), tol);
|
||||
C10_ASSERT_NEAR(l.imag(), double(1.2), tol);
|
||||
c10::complex<double> x(0.1, 1.2);
|
||||
c10::complex<double> l = std::log2(std::pow(double(2), x));
|
||||
C10_ASSERT_NEAR(l.real(), double(0.1), tol);
|
||||
C10_ASSERT_NEAR(l.imag(), double(1.2), tol);
|
||||
}
|
||||
{
|
||||
c10::complex<double> x(0.1, 1.2);
|
||||
c10::complex<double> l = ::log2(std::pow(double(2), x));
|
||||
C10_ASSERT_NEAR(l.real(), double(0.1), tol);
|
||||
C10_ASSERT_NEAR(l.imag(), double(1.2), tol);
|
||||
c10::complex<double> x(0.1, 1.2);
|
||||
c10::complex<double> l = ::log2(std::pow(double(2), x));
|
||||
C10_ASSERT_NEAR(l.real(), double(0.1), tol);
|
||||
C10_ASSERT_NEAR(l.imag(), double(1.2), tol);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -170,64 +171,64 @@ C10_DEFINE_TEST(TestLog2, Rev) {
|
|||
C10_DEFINE_TEST(TestPowSqrt, Equal) {
|
||||
// x^0.5 = sqrt(x)
|
||||
{
|
||||
c10::complex<float> x(0.1, 1.2);
|
||||
c10::complex<float> y = std::pow(x, float(0.5));
|
||||
c10::complex<float> z = std::sqrt(x);
|
||||
C10_ASSERT_NEAR(y.real(), z.real(), tol);
|
||||
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
|
||||
c10::complex<float> x(0.1, 1.2);
|
||||
c10::complex<float> y = std::pow(x, float(0.5));
|
||||
c10::complex<float> z = std::sqrt(x);
|
||||
C10_ASSERT_NEAR(y.real(), z.real(), tol);
|
||||
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
|
||||
}
|
||||
{
|
||||
c10::complex<float> x(0.1, 1.2);
|
||||
c10::complex<float> y = ::pow(x, float(0.5));
|
||||
c10::complex<float> z = ::sqrt(x);
|
||||
C10_ASSERT_NEAR(y.real(), z.real(), tol);
|
||||
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
|
||||
c10::complex<float> x(0.1, 1.2);
|
||||
c10::complex<float> y = ::pow(x, float(0.5));
|
||||
c10::complex<float> z = ::sqrt(x);
|
||||
C10_ASSERT_NEAR(y.real(), z.real(), tol);
|
||||
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
|
||||
}
|
||||
{
|
||||
c10::complex<double> x(0.1, 1.2);
|
||||
c10::complex<double> y = std::pow(x, double(0.5));
|
||||
c10::complex<double> z = std::sqrt(x);
|
||||
C10_ASSERT_NEAR(y.real(), z.real(), tol);
|
||||
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
|
||||
c10::complex<double> x(0.1, 1.2);
|
||||
c10::complex<double> y = std::pow(x, double(0.5));
|
||||
c10::complex<double> z = std::sqrt(x);
|
||||
C10_ASSERT_NEAR(y.real(), z.real(), tol);
|
||||
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
|
||||
}
|
||||
{
|
||||
c10::complex<double> x(0.1, 1.2);
|
||||
c10::complex<double> y = ::pow(x, double(0.5));
|
||||
c10::complex<double> z = ::sqrt(x);
|
||||
C10_ASSERT_NEAR(y.real(), z.real(), tol);
|
||||
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
|
||||
c10::complex<double> x(0.1, 1.2);
|
||||
c10::complex<double> y = ::pow(x, double(0.5));
|
||||
c10::complex<double> z = ::sqrt(x);
|
||||
C10_ASSERT_NEAR(y.real(), z.real(), tol);
|
||||
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
|
||||
}
|
||||
}
|
||||
|
||||
C10_DEFINE_TEST(TestPow, Square) {
|
||||
// x^2 = x * x
|
||||
{
|
||||
c10::complex<float> x(0.1, 1.2);
|
||||
c10::complex<float> y = std::pow(x, float(2));
|
||||
c10::complex<float> z = x * x;
|
||||
C10_ASSERT_NEAR(y.real(), z.real(), tol);
|
||||
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
|
||||
c10::complex<float> x(0.1, 1.2);
|
||||
c10::complex<float> y = std::pow(x, float(2));
|
||||
c10::complex<float> z = x * x;
|
||||
C10_ASSERT_NEAR(y.real(), z.real(), tol);
|
||||
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
|
||||
}
|
||||
{
|
||||
c10::complex<float> x(0.1, 1.2);
|
||||
c10::complex<float> y = ::pow(x, float(2));
|
||||
c10::complex<float> z = x * x;
|
||||
C10_ASSERT_NEAR(y.real(), z.real(), tol);
|
||||
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
|
||||
c10::complex<float> x(0.1, 1.2);
|
||||
c10::complex<float> y = ::pow(x, float(2));
|
||||
c10::complex<float> z = x * x;
|
||||
C10_ASSERT_NEAR(y.real(), z.real(), tol);
|
||||
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
|
||||
}
|
||||
{
|
||||
c10::complex<double> x(0.1, 1.2);
|
||||
c10::complex<double> y = std::pow(x, double(2));
|
||||
c10::complex<double> z = x * x;
|
||||
C10_ASSERT_NEAR(y.real(), z.real(), tol);
|
||||
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
|
||||
c10::complex<double> x(0.1, 1.2);
|
||||
c10::complex<double> y = std::pow(x, double(2));
|
||||
c10::complex<double> z = x * x;
|
||||
C10_ASSERT_NEAR(y.real(), z.real(), tol);
|
||||
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
|
||||
}
|
||||
{
|
||||
c10::complex<double> x(0.1, 1.2);
|
||||
c10::complex<double> y = ::pow(x, double(2));
|
||||
c10::complex<double> z = x * x;
|
||||
C10_ASSERT_NEAR(y.real(), z.real(), tol);
|
||||
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
|
||||
c10::complex<double> x(0.1, 1.2);
|
||||
c10::complex<double> y = ::pow(x, double(2));
|
||||
c10::complex<double> z = x * x;
|
||||
C10_ASSERT_NEAR(y.real(), z.real(), tol);
|
||||
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -237,132 +238,132 @@ C10_DEFINE_TEST(TestSinCosSinhCosh, Identity) {
|
|||
// sin(x + i * y) = sin(x) * cosh(y) + i * cos(x) * sinh(y)
|
||||
// cos(x + i * y) = cos(x) * cosh(y) - i * sin(x) * sinh(y)
|
||||
{
|
||||
c10::complex<float> x(0.1, 1.2);
|
||||
c10::complex<float> y = std::sin(x);
|
||||
float expected_real = std::sin(x.real()) * std::cosh(x.imag());
|
||||
float expected_imag = std::cos(x.real()) * std::sinh(x.imag());
|
||||
C10_ASSERT_NEAR(y.real(), expected_real, tol);
|
||||
C10_ASSERT_NEAR(y.imag(), expected_imag, tol);
|
||||
c10::complex<float> x(0.1, 1.2);
|
||||
c10::complex<float> y = std::sin(x);
|
||||
float expected_real = std::sin(x.real()) * std::cosh(x.imag());
|
||||
float expected_imag = std::cos(x.real()) * std::sinh(x.imag());
|
||||
C10_ASSERT_NEAR(y.real(), expected_real, tol);
|
||||
C10_ASSERT_NEAR(y.imag(), expected_imag, tol);
|
||||
}
|
||||
{
|
||||
c10::complex<float> x(0.1, 1.2);
|
||||
c10::complex<float> y = ::sin(x);
|
||||
float expected_real = ::sin(x.real()) * ::cosh(x.imag());
|
||||
float expected_imag = ::cos(x.real()) * ::sinh(x.imag());
|
||||
C10_ASSERT_NEAR(y.real(), expected_real, tol);
|
||||
C10_ASSERT_NEAR(y.imag(), expected_imag, tol);
|
||||
c10::complex<float> x(0.1, 1.2);
|
||||
c10::complex<float> y = ::sin(x);
|
||||
float expected_real = ::sin(x.real()) * ::cosh(x.imag());
|
||||
float expected_imag = ::cos(x.real()) * ::sinh(x.imag());
|
||||
C10_ASSERT_NEAR(y.real(), expected_real, tol);
|
||||
C10_ASSERT_NEAR(y.imag(), expected_imag, tol);
|
||||
}
|
||||
{
|
||||
c10::complex<float> x(0.1, 1.2);
|
||||
c10::complex<float> y = std::cos(x);
|
||||
float expected_real = std::cos(x.real()) * std::cosh(x.imag());
|
||||
float expected_imag = - std::sin(x.real()) * std::sinh(x.imag());
|
||||
C10_ASSERT_NEAR(y.real(), expected_real, tol);
|
||||
C10_ASSERT_NEAR(y.imag(), expected_imag, tol);
|
||||
c10::complex<float> x(0.1, 1.2);
|
||||
c10::complex<float> y = std::cos(x);
|
||||
float expected_real = std::cos(x.real()) * std::cosh(x.imag());
|
||||
float expected_imag = -std::sin(x.real()) * std::sinh(x.imag());
|
||||
C10_ASSERT_NEAR(y.real(), expected_real, tol);
|
||||
C10_ASSERT_NEAR(y.imag(), expected_imag, tol);
|
||||
}
|
||||
{
|
||||
c10::complex<float> x(0.1, 1.2);
|
||||
c10::complex<float> y = ::cos(x);
|
||||
float expected_real = ::cos(x.real()) * ::cosh(x.imag());
|
||||
float expected_imag = - ::sin(x.real()) * ::sinh(x.imag());
|
||||
C10_ASSERT_NEAR(y.real(), expected_real, tol);
|
||||
C10_ASSERT_NEAR(y.imag(), expected_imag, tol);
|
||||
c10::complex<float> x(0.1, 1.2);
|
||||
c10::complex<float> y = ::cos(x);
|
||||
float expected_real = ::cos(x.real()) * ::cosh(x.imag());
|
||||
float expected_imag = -::sin(x.real()) * ::sinh(x.imag());
|
||||
C10_ASSERT_NEAR(y.real(), expected_real, tol);
|
||||
C10_ASSERT_NEAR(y.imag(), expected_imag, tol);
|
||||
}
|
||||
{
|
||||
c10::complex<double> x(0.1, 1.2);
|
||||
c10::complex<double> y = std::sin(x);
|
||||
float expected_real = std::sin(x.real()) * std::cosh(x.imag());
|
||||
float expected_imag = std::cos(x.real()) * std::sinh(x.imag());
|
||||
C10_ASSERT_NEAR(y.real(), expected_real, tol);
|
||||
C10_ASSERT_NEAR(y.imag(), expected_imag, tol);
|
||||
c10::complex<double> x(0.1, 1.2);
|
||||
c10::complex<double> y = std::sin(x);
|
||||
float expected_real = std::sin(x.real()) * std::cosh(x.imag());
|
||||
float expected_imag = std::cos(x.real()) * std::sinh(x.imag());
|
||||
C10_ASSERT_NEAR(y.real(), expected_real, tol);
|
||||
C10_ASSERT_NEAR(y.imag(), expected_imag, tol);
|
||||
}
|
||||
{
|
||||
c10::complex<double> x(0.1, 1.2);
|
||||
c10::complex<double> y = ::sin(x);
|
||||
float expected_real = ::sin(x.real()) * ::cosh(x.imag());
|
||||
float expected_imag = ::cos(x.real()) * ::sinh(x.imag());
|
||||
C10_ASSERT_NEAR(y.real(), expected_real, tol);
|
||||
C10_ASSERT_NEAR(y.imag(), expected_imag, tol);
|
||||
c10::complex<double> x(0.1, 1.2);
|
||||
c10::complex<double> y = ::sin(x);
|
||||
float expected_real = ::sin(x.real()) * ::cosh(x.imag());
|
||||
float expected_imag = ::cos(x.real()) * ::sinh(x.imag());
|
||||
C10_ASSERT_NEAR(y.real(), expected_real, tol);
|
||||
C10_ASSERT_NEAR(y.imag(), expected_imag, tol);
|
||||
}
|
||||
{
|
||||
c10::complex<double> x(0.1, 1.2);
|
||||
c10::complex<double> y = std::cos(x);
|
||||
float expected_real = std::cos(x.real()) * std::cosh(x.imag());
|
||||
float expected_imag = - std::sin(x.real()) * std::sinh(x.imag());
|
||||
C10_ASSERT_NEAR(y.real(), expected_real, tol);
|
||||
C10_ASSERT_NEAR(y.imag(), expected_imag, tol);
|
||||
c10::complex<double> x(0.1, 1.2);
|
||||
c10::complex<double> y = std::cos(x);
|
||||
float expected_real = std::cos(x.real()) * std::cosh(x.imag());
|
||||
float expected_imag = -std::sin(x.real()) * std::sinh(x.imag());
|
||||
C10_ASSERT_NEAR(y.real(), expected_real, tol);
|
||||
C10_ASSERT_NEAR(y.imag(), expected_imag, tol);
|
||||
}
|
||||
{
|
||||
c10::complex<double> x(0.1, 1.2);
|
||||
c10::complex<double> y = ::cos(x);
|
||||
float expected_real = ::cos(x.real()) * ::cosh(x.imag());
|
||||
float expected_imag = - ::sin(x.real()) * ::sinh(x.imag());
|
||||
C10_ASSERT_NEAR(y.real(), expected_real, tol);
|
||||
C10_ASSERT_NEAR(y.imag(), expected_imag, tol);
|
||||
c10::complex<double> x(0.1, 1.2);
|
||||
c10::complex<double> y = ::cos(x);
|
||||
float expected_real = ::cos(x.real()) * ::cosh(x.imag());
|
||||
float expected_imag = -::sin(x.real()) * ::sinh(x.imag());
|
||||
C10_ASSERT_NEAR(y.real(), expected_real, tol);
|
||||
C10_ASSERT_NEAR(y.imag(), expected_imag, tol);
|
||||
}
|
||||
}
|
||||
|
||||
C10_DEFINE_TEST(TestTan, Identity) {
|
||||
// tan(x) = sin(x) / cos(x)
|
||||
{
|
||||
c10::complex<float> x(0.1, 1.2);
|
||||
c10::complex<float> y = std::tan(x);
|
||||
c10::complex<float> z = std::sin(x) / std::cos(x);
|
||||
C10_ASSERT_NEAR(y.real(), z.real(), tol);
|
||||
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
|
||||
c10::complex<float> x(0.1, 1.2);
|
||||
c10::complex<float> y = std::tan(x);
|
||||
c10::complex<float> z = std::sin(x) / std::cos(x);
|
||||
C10_ASSERT_NEAR(y.real(), z.real(), tol);
|
||||
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
|
||||
}
|
||||
{
|
||||
c10::complex<float> x(0.1, 1.2);
|
||||
c10::complex<float> y = ::tan(x);
|
||||
c10::complex<float> z = ::sin(x) / ::cos(x);
|
||||
C10_ASSERT_NEAR(y.real(), z.real(), tol);
|
||||
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
|
||||
c10::complex<float> x(0.1, 1.2);
|
||||
c10::complex<float> y = ::tan(x);
|
||||
c10::complex<float> z = ::sin(x) / ::cos(x);
|
||||
C10_ASSERT_NEAR(y.real(), z.real(), tol);
|
||||
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
|
||||
}
|
||||
{
|
||||
c10::complex<double> x(0.1, 1.2);
|
||||
c10::complex<double> y = std::tan(x);
|
||||
c10::complex<double> z = std::sin(x) / std::cos(x);
|
||||
C10_ASSERT_NEAR(y.real(), z.real(), tol);
|
||||
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
|
||||
c10::complex<double> x(0.1, 1.2);
|
||||
c10::complex<double> y = std::tan(x);
|
||||
c10::complex<double> z = std::sin(x) / std::cos(x);
|
||||
C10_ASSERT_NEAR(y.real(), z.real(), tol);
|
||||
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
|
||||
}
|
||||
{
|
||||
c10::complex<double> x(0.1, 1.2);
|
||||
c10::complex<double> y = ::tan(x);
|
||||
c10::complex<double> z = ::sin(x) / ::cos(x);
|
||||
C10_ASSERT_NEAR(y.real(), z.real(), tol);
|
||||
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
|
||||
c10::complex<double> x(0.1, 1.2);
|
||||
c10::complex<double> y = ::tan(x);
|
||||
c10::complex<double> z = ::sin(x) / ::cos(x);
|
||||
C10_ASSERT_NEAR(y.real(), z.real(), tol);
|
||||
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
|
||||
}
|
||||
}
|
||||
|
||||
C10_DEFINE_TEST(TestTanh, Identity) {
|
||||
// tanh(x) = sinh(x) / cosh(x)
|
||||
{
|
||||
c10::complex<float> x(0.1, 1.2);
|
||||
c10::complex<float> y = std::tanh(x);
|
||||
c10::complex<float> z = std::sinh(x) / std::cosh(x);
|
||||
C10_ASSERT_NEAR(y.real(), z.real(), tol);
|
||||
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
|
||||
c10::complex<float> x(0.1, 1.2);
|
||||
c10::complex<float> y = std::tanh(x);
|
||||
c10::complex<float> z = std::sinh(x) / std::cosh(x);
|
||||
C10_ASSERT_NEAR(y.real(), z.real(), tol);
|
||||
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
|
||||
}
|
||||
{
|
||||
c10::complex<float> x(0.1, 1.2);
|
||||
c10::complex<float> y = ::tanh(x);
|
||||
c10::complex<float> z = ::sinh(x) / ::cosh(x);
|
||||
C10_ASSERT_NEAR(y.real(), z.real(), tol);
|
||||
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
|
||||
c10::complex<float> x(0.1, 1.2);
|
||||
c10::complex<float> y = ::tanh(x);
|
||||
c10::complex<float> z = ::sinh(x) / ::cosh(x);
|
||||
C10_ASSERT_NEAR(y.real(), z.real(), tol);
|
||||
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
|
||||
}
|
||||
{
|
||||
c10::complex<double> x(0.1, 1.2);
|
||||
c10::complex<double> y = std::tanh(x);
|
||||
c10::complex<double> z = std::sinh(x) / std::cosh(x);
|
||||
C10_ASSERT_NEAR(y.real(), z.real(), tol);
|
||||
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
|
||||
c10::complex<double> x(0.1, 1.2);
|
||||
c10::complex<double> y = std::tanh(x);
|
||||
c10::complex<double> z = std::sinh(x) / std::cosh(x);
|
||||
C10_ASSERT_NEAR(y.real(), z.real(), tol);
|
||||
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
|
||||
}
|
||||
{
|
||||
c10::complex<double> x(0.1, 1.2);
|
||||
c10::complex<double> y = ::tanh(x);
|
||||
c10::complex<double> z = ::sinh(x) / ::cosh(x);
|
||||
C10_ASSERT_NEAR(y.real(), z.real(), tol);
|
||||
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
|
||||
c10::complex<double> x(0.1, 1.2);
|
||||
c10::complex<double> y = ::tanh(x);
|
||||
c10::complex<double> z = ::sinh(x) / ::cosh(x);
|
||||
C10_ASSERT_NEAR(y.real(), z.real(), tol);
|
||||
C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -373,64 +374,64 @@ C10_DEFINE_TEST(TestRevTrigonometric, Rev) {
|
|||
// acos(cos(x)) = x
|
||||
// atan(tan(x)) = x
|
||||
{
|
||||
c10::complex<float> x(0.5, 0.6);
|
||||
c10::complex<float> s = std::sin(x);
|
||||
c10::complex<float> ss = std::asin(s);
|
||||
c10::complex<float> c = std::cos(x);
|
||||
c10::complex<float> cc = std::acos(c);
|
||||
c10::complex<float> t = std::tan(x);
|
||||
c10::complex<float> tt = std::atan(t);
|
||||
C10_ASSERT_NEAR(x.real(), ss.real(), tol);
|
||||
C10_ASSERT_NEAR(x.imag(), ss.imag(), tol);
|
||||
C10_ASSERT_NEAR(x.real(), cc.real(), tol);
|
||||
C10_ASSERT_NEAR(x.imag(), cc.imag(), tol);
|
||||
C10_ASSERT_NEAR(x.real(), tt.real(), tol);
|
||||
C10_ASSERT_NEAR(x.imag(), tt.imag(), tol);
|
||||
c10::complex<float> x(0.5, 0.6);
|
||||
c10::complex<float> s = std::sin(x);
|
||||
c10::complex<float> ss = std::asin(s);
|
||||
c10::complex<float> c = std::cos(x);
|
||||
c10::complex<float> cc = std::acos(c);
|
||||
c10::complex<float> t = std::tan(x);
|
||||
c10::complex<float> tt = std::atan(t);
|
||||
C10_ASSERT_NEAR(x.real(), ss.real(), tol);
|
||||
C10_ASSERT_NEAR(x.imag(), ss.imag(), tol);
|
||||
C10_ASSERT_NEAR(x.real(), cc.real(), tol);
|
||||
C10_ASSERT_NEAR(x.imag(), cc.imag(), tol);
|
||||
C10_ASSERT_NEAR(x.real(), tt.real(), tol);
|
||||
C10_ASSERT_NEAR(x.imag(), tt.imag(), tol);
|
||||
}
|
||||
{
|
||||
c10::complex<float> x(0.5, 0.6);
|
||||
c10::complex<float> s = ::sin(x);
|
||||
c10::complex<float> ss = ::asin(s);
|
||||
c10::complex<float> c = ::cos(x);
|
||||
c10::complex<float> cc = ::acos(c);
|
||||
c10::complex<float> t = ::tan(x);
|
||||
c10::complex<float> tt = ::atan(t);
|
||||
C10_ASSERT_NEAR(x.real(), ss.real(), tol);
|
||||
C10_ASSERT_NEAR(x.imag(), ss.imag(), tol);
|
||||
C10_ASSERT_NEAR(x.real(), cc.real(), tol);
|
||||
C10_ASSERT_NEAR(x.imag(), cc.imag(), tol);
|
||||
C10_ASSERT_NEAR(x.real(), tt.real(), tol);
|
||||
C10_ASSERT_NEAR(x.imag(), tt.imag(), tol);
|
||||
c10::complex<float> x(0.5, 0.6);
|
||||
c10::complex<float> s = ::sin(x);
|
||||
c10::complex<float> ss = ::asin(s);
|
||||
c10::complex<float> c = ::cos(x);
|
||||
c10::complex<float> cc = ::acos(c);
|
||||
c10::complex<float> t = ::tan(x);
|
||||
c10::complex<float> tt = ::atan(t);
|
||||
C10_ASSERT_NEAR(x.real(), ss.real(), tol);
|
||||
C10_ASSERT_NEAR(x.imag(), ss.imag(), tol);
|
||||
C10_ASSERT_NEAR(x.real(), cc.real(), tol);
|
||||
C10_ASSERT_NEAR(x.imag(), cc.imag(), tol);
|
||||
C10_ASSERT_NEAR(x.real(), tt.real(), tol);
|
||||
C10_ASSERT_NEAR(x.imag(), tt.imag(), tol);
|
||||
}
|
||||
{
|
||||
c10::complex<double> x(0.5, 0.6);
|
||||
c10::complex<double> s = std::sin(x);
|
||||
c10::complex<double> ss = std::asin(s);
|
||||
c10::complex<double> c = std::cos(x);
|
||||
c10::complex<double> cc = std::acos(c);
|
||||
c10::complex<double> t = std::tan(x);
|
||||
c10::complex<double> tt = std::atan(t);
|
||||
C10_ASSERT_NEAR(x.real(), ss.real(), tol);
|
||||
C10_ASSERT_NEAR(x.imag(), ss.imag(), tol);
|
||||
C10_ASSERT_NEAR(x.real(), cc.real(), tol);
|
||||
C10_ASSERT_NEAR(x.imag(), cc.imag(), tol);
|
||||
C10_ASSERT_NEAR(x.real(), tt.real(), tol);
|
||||
C10_ASSERT_NEAR(x.imag(), tt.imag(), tol);
|
||||
c10::complex<double> x(0.5, 0.6);
|
||||
c10::complex<double> s = std::sin(x);
|
||||
c10::complex<double> ss = std::asin(s);
|
||||
c10::complex<double> c = std::cos(x);
|
||||
c10::complex<double> cc = std::acos(c);
|
||||
c10::complex<double> t = std::tan(x);
|
||||
c10::complex<double> tt = std::atan(t);
|
||||
C10_ASSERT_NEAR(x.real(), ss.real(), tol);
|
||||
C10_ASSERT_NEAR(x.imag(), ss.imag(), tol);
|
||||
C10_ASSERT_NEAR(x.real(), cc.real(), tol);
|
||||
C10_ASSERT_NEAR(x.imag(), cc.imag(), tol);
|
||||
C10_ASSERT_NEAR(x.real(), tt.real(), tol);
|
||||
C10_ASSERT_NEAR(x.imag(), tt.imag(), tol);
|
||||
}
|
||||
{
|
||||
c10::complex<double> x(0.5, 0.6);
|
||||
c10::complex<double> s = ::sin(x);
|
||||
c10::complex<double> ss = ::asin(s);
|
||||
c10::complex<double> c = ::cos(x);
|
||||
c10::complex<double> cc = ::acos(c);
|
||||
c10::complex<double> t = ::tan(x);
|
||||
c10::complex<double> tt = ::atan(t);
|
||||
C10_ASSERT_NEAR(x.real(), ss.real(), tol);
|
||||
C10_ASSERT_NEAR(x.imag(), ss.imag(), tol);
|
||||
C10_ASSERT_NEAR(x.real(), cc.real(), tol);
|
||||
C10_ASSERT_NEAR(x.imag(), cc.imag(), tol);
|
||||
C10_ASSERT_NEAR(x.real(), tt.real(), tol);
|
||||
C10_ASSERT_NEAR(x.imag(), tt.imag(), tol);
|
||||
c10::complex<double> x(0.5, 0.6);
|
||||
c10::complex<double> s = ::sin(x);
|
||||
c10::complex<double> ss = ::asin(s);
|
||||
c10::complex<double> c = ::cos(x);
|
||||
c10::complex<double> cc = ::acos(c);
|
||||
c10::complex<double> t = ::tan(x);
|
||||
c10::complex<double> tt = ::atan(t);
|
||||
C10_ASSERT_NEAR(x.real(), ss.real(), tol);
|
||||
C10_ASSERT_NEAR(x.imag(), ss.imag(), tol);
|
||||
C10_ASSERT_NEAR(x.real(), cc.real(), tol);
|
||||
C10_ASSERT_NEAR(x.imag(), cc.imag(), tol);
|
||||
C10_ASSERT_NEAR(x.real(), tt.real(), tol);
|
||||
C10_ASSERT_NEAR(x.imag(), tt.imag(), tol);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -441,63 +442,63 @@ C10_DEFINE_TEST(TestRevHyperbolic, Rev) {
|
|||
// acosh(cosh(x)) = x
|
||||
// atanh(tanh(x)) = x
|
||||
{
|
||||
c10::complex<float> x(0.5, 0.6);
|
||||
c10::complex<float> s = std::sinh(x);
|
||||
c10::complex<float> ss = std::asinh(s);
|
||||
c10::complex<float> c = std::cosh(x);
|
||||
c10::complex<float> cc = std::acosh(c);
|
||||
c10::complex<float> t = std::tanh(x);
|
||||
c10::complex<float> tt = std::atanh(t);
|
||||
C10_ASSERT_NEAR(x.real(), ss.real(), tol);
|
||||
C10_ASSERT_NEAR(x.imag(), ss.imag(), tol);
|
||||
C10_ASSERT_NEAR(x.real(), cc.real(), tol);
|
||||
C10_ASSERT_NEAR(x.imag(), cc.imag(), tol);
|
||||
C10_ASSERT_NEAR(x.real(), tt.real(), tol);
|
||||
C10_ASSERT_NEAR(x.imag(), tt.imag(), tol);
|
||||
c10::complex<float> x(0.5, 0.6);
|
||||
c10::complex<float> s = std::sinh(x);
|
||||
c10::complex<float> ss = std::asinh(s);
|
||||
c10::complex<float> c = std::cosh(x);
|
||||
c10::complex<float> cc = std::acosh(c);
|
||||
c10::complex<float> t = std::tanh(x);
|
||||
c10::complex<float> tt = std::atanh(t);
|
||||
C10_ASSERT_NEAR(x.real(), ss.real(), tol);
|
||||
C10_ASSERT_NEAR(x.imag(), ss.imag(), tol);
|
||||
C10_ASSERT_NEAR(x.real(), cc.real(), tol);
|
||||
C10_ASSERT_NEAR(x.imag(), cc.imag(), tol);
|
||||
C10_ASSERT_NEAR(x.real(), tt.real(), tol);
|
||||
C10_ASSERT_NEAR(x.imag(), tt.imag(), tol);
|
||||
}
|
||||
{
|
||||
c10::complex<float> x(0.5, 0.6);
|
||||
c10::complex<float> s = ::sinh(x);
|
||||
c10::complex<float> ss = ::asinh(s);
|
||||
c10::complex<float> c = ::cosh(x);
|
||||
c10::complex<float> cc = ::acosh(c);
|
||||
c10::complex<float> t = ::tanh(x);
|
||||
c10::complex<float> tt = ::atanh(t);
|
||||
C10_ASSERT_NEAR(x.real(), ss.real(), tol);
|
||||
C10_ASSERT_NEAR(x.imag(), ss.imag(), tol);
|
||||
C10_ASSERT_NEAR(x.real(), cc.real(), tol);
|
||||
C10_ASSERT_NEAR(x.imag(), cc.imag(), tol);
|
||||
C10_ASSERT_NEAR(x.real(), tt.real(), tol);
|
||||
C10_ASSERT_NEAR(x.imag(), tt.imag(), tol);
|
||||
c10::complex<float> x(0.5, 0.6);
|
||||
c10::complex<float> s = ::sinh(x);
|
||||
c10::complex<float> ss = ::asinh(s);
|
||||
c10::complex<float> c = ::cosh(x);
|
||||
c10::complex<float> cc = ::acosh(c);
|
||||
c10::complex<float> t = ::tanh(x);
|
||||
c10::complex<float> tt = ::atanh(t);
|
||||
C10_ASSERT_NEAR(x.real(), ss.real(), tol);
|
||||
C10_ASSERT_NEAR(x.imag(), ss.imag(), tol);
|
||||
C10_ASSERT_NEAR(x.real(), cc.real(), tol);
|
||||
C10_ASSERT_NEAR(x.imag(), cc.imag(), tol);
|
||||
C10_ASSERT_NEAR(x.real(), tt.real(), tol);
|
||||
C10_ASSERT_NEAR(x.imag(), tt.imag(), tol);
|
||||
}
|
||||
{
|
||||
c10::complex<double> x(0.5, 0.6);
|
||||
c10::complex<double> s = std::sinh(x);
|
||||
c10::complex<double> ss = std::asinh(s);
|
||||
c10::complex<double> c = std::cosh(x);
|
||||
c10::complex<double> cc = std::acosh(c);
|
||||
c10::complex<double> t = std::tanh(x);
|
||||
c10::complex<double> tt = std::atanh(t);
|
||||
C10_ASSERT_NEAR(x.real(), ss.real(), tol);
|
||||
C10_ASSERT_NEAR(x.imag(), ss.imag(), tol);
|
||||
C10_ASSERT_NEAR(x.real(), cc.real(), tol);
|
||||
C10_ASSERT_NEAR(x.imag(), cc.imag(), tol);
|
||||
C10_ASSERT_NEAR(x.real(), tt.real(), tol);
|
||||
C10_ASSERT_NEAR(x.imag(), tt.imag(), tol);
|
||||
c10::complex<double> x(0.5, 0.6);
|
||||
c10::complex<double> s = std::sinh(x);
|
||||
c10::complex<double> ss = std::asinh(s);
|
||||
c10::complex<double> c = std::cosh(x);
|
||||
c10::complex<double> cc = std::acosh(c);
|
||||
c10::complex<double> t = std::tanh(x);
|
||||
c10::complex<double> tt = std::atanh(t);
|
||||
C10_ASSERT_NEAR(x.real(), ss.real(), tol);
|
||||
C10_ASSERT_NEAR(x.imag(), ss.imag(), tol);
|
||||
C10_ASSERT_NEAR(x.real(), cc.real(), tol);
|
||||
C10_ASSERT_NEAR(x.imag(), cc.imag(), tol);
|
||||
C10_ASSERT_NEAR(x.real(), tt.real(), tol);
|
||||
C10_ASSERT_NEAR(x.imag(), tt.imag(), tol);
|
||||
}
|
||||
{
|
||||
c10::complex<double> x(0.5, 0.6);
|
||||
c10::complex<double> s = ::sinh(x);
|
||||
c10::complex<double> ss = ::asinh(s);
|
||||
c10::complex<double> c = ::cosh(x);
|
||||
c10::complex<double> cc = ::acosh(c);
|
||||
c10::complex<double> t = ::tanh(x);
|
||||
c10::complex<double> tt = ::atanh(t);
|
||||
C10_ASSERT_NEAR(x.real(), ss.real(), tol);
|
||||
C10_ASSERT_NEAR(x.imag(), ss.imag(), tol);
|
||||
C10_ASSERT_NEAR(x.real(), cc.real(), tol);
|
||||
C10_ASSERT_NEAR(x.imag(), cc.imag(), tol);
|
||||
C10_ASSERT_NEAR(x.real(), tt.real(), tol);
|
||||
C10_ASSERT_NEAR(x.imag(), tt.imag(), tol);
|
||||
c10::complex<double> x(0.5, 0.6);
|
||||
c10::complex<double> s = ::sinh(x);
|
||||
c10::complex<double> ss = ::asinh(s);
|
||||
c10::complex<double> c = ::cosh(x);
|
||||
c10::complex<double> cc = ::acosh(c);
|
||||
c10::complex<double> t = ::tanh(x);
|
||||
c10::complex<double> tt = ::atanh(t);
|
||||
C10_ASSERT_NEAR(x.real(), ss.real(), tol);
|
||||
C10_ASSERT_NEAR(x.imag(), ss.imag(), tol);
|
||||
C10_ASSERT_NEAR(x.real(), cc.real(), tol);
|
||||
C10_ASSERT_NEAR(x.imag(), cc.imag(), tol);
|
||||
C10_ASSERT_NEAR(x.real(), tt.real(), tol);
|
||||
C10_ASSERT_NEAR(x.imag(), tt.imag(), tol);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
#include <type_traits>
|
||||
#include <tuple>
|
||||
#include <sstream>
|
||||
#include <c10/util/complex.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/complex.h>
|
||||
#include <c10/util/hash.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include <sstream>
|
||||
#include <tuple>
|
||||
#include <type_traits>
|
||||
#include <unordered_map>
|
||||
|
||||
#if (defined(__CUDACC__) || defined(__HIPCC__))
|
||||
|
|
@ -34,71 +34,72 @@ MAYBE_GLOBAL void test_pod() {
|
|||
|
||||
TEST(TestMemory, ReinterpretCast) {
|
||||
{
|
||||
std::complex<float> z(1, 2);
|
||||
c10::complex<float> zz = *reinterpret_cast<c10::complex<float>*>(&z);
|
||||
ASSERT_EQ(zz.real(), float(1));
|
||||
ASSERT_EQ(zz.imag(), float(2));
|
||||
std::complex<float> z(1, 2);
|
||||
c10::complex<float> zz = *reinterpret_cast<c10::complex<float>*>(&z);
|
||||
ASSERT_EQ(zz.real(), float(1));
|
||||
ASSERT_EQ(zz.imag(), float(2));
|
||||
}
|
||||
|
||||
{
|
||||
c10::complex<float> z(3, 4);
|
||||
std::complex<float> zz = *reinterpret_cast<std::complex<float>*>(&z);
|
||||
ASSERT_EQ(zz.real(), float(3));
|
||||
ASSERT_EQ(zz.imag(), float(4));
|
||||
c10::complex<float> z(3, 4);
|
||||
std::complex<float> zz = *reinterpret_cast<std::complex<float>*>(&z);
|
||||
ASSERT_EQ(zz.real(), float(3));
|
||||
ASSERT_EQ(zz.imag(), float(4));
|
||||
}
|
||||
|
||||
{
|
||||
std::complex<double> z(1, 2);
|
||||
c10::complex<double> zz = *reinterpret_cast<c10::complex<double>*>(&z);
|
||||
ASSERT_EQ(zz.real(), double(1));
|
||||
ASSERT_EQ(zz.imag(), double(2));
|
||||
std::complex<double> z(1, 2);
|
||||
c10::complex<double> zz = *reinterpret_cast<c10::complex<double>*>(&z);
|
||||
ASSERT_EQ(zz.real(), double(1));
|
||||
ASSERT_EQ(zz.imag(), double(2));
|
||||
}
|
||||
|
||||
{
|
||||
c10::complex<double> z(3, 4);
|
||||
std::complex<double> zz = *reinterpret_cast<std::complex<double>*>(&z);
|
||||
ASSERT_EQ(zz.real(), double(3));
|
||||
ASSERT_EQ(zz.imag(), double(4));
|
||||
c10::complex<double> z(3, 4);
|
||||
std::complex<double> zz = *reinterpret_cast<std::complex<double>*>(&z);
|
||||
ASSERT_EQ(zz.real(), double(3));
|
||||
ASSERT_EQ(zz.imag(), double(4));
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(__CUDACC__) || defined(__HIPCC__)
|
||||
TEST(TestMemory, ThrustReinterpretCast) {
|
||||
{
|
||||
thrust::complex<float> z(1, 2);
|
||||
c10::complex<float> zz = *reinterpret_cast<c10::complex<float>*>(&z);
|
||||
ASSERT_EQ(zz.real(), float(1));
|
||||
ASSERT_EQ(zz.imag(), float(2));
|
||||
thrust::complex<float> z(1, 2);
|
||||
c10::complex<float> zz = *reinterpret_cast<c10::complex<float>*>(&z);
|
||||
ASSERT_EQ(zz.real(), float(1));
|
||||
ASSERT_EQ(zz.imag(), float(2));
|
||||
}
|
||||
|
||||
{
|
||||
c10::complex<float> z(3, 4);
|
||||
thrust::complex<float> zz = *reinterpret_cast<thrust::complex<float>*>(&z);
|
||||
ASSERT_EQ(zz.real(), float(3));
|
||||
ASSERT_EQ(zz.imag(), float(4));
|
||||
c10::complex<float> z(3, 4);
|
||||
thrust::complex<float> zz = *reinterpret_cast<thrust::complex<float>*>(&z);
|
||||
ASSERT_EQ(zz.real(), float(3));
|
||||
ASSERT_EQ(zz.imag(), float(4));
|
||||
}
|
||||
|
||||
{
|
||||
thrust::complex<double> z(1, 2);
|
||||
c10::complex<double> zz = *reinterpret_cast<c10::complex<double>*>(&z);
|
||||
ASSERT_EQ(zz.real(), double(1));
|
||||
ASSERT_EQ(zz.imag(), double(2));
|
||||
thrust::complex<double> z(1, 2);
|
||||
c10::complex<double> zz = *reinterpret_cast<c10::complex<double>*>(&z);
|
||||
ASSERT_EQ(zz.real(), double(1));
|
||||
ASSERT_EQ(zz.imag(), double(2));
|
||||
}
|
||||
|
||||
{
|
||||
c10::complex<double> z(3, 4);
|
||||
thrust::complex<double> zz = *reinterpret_cast<thrust::complex<double>*>(&z);
|
||||
ASSERT_EQ(zz.real(), double(3));
|
||||
ASSERT_EQ(zz.imag(), double(4));
|
||||
c10::complex<double> z(3, 4);
|
||||
thrust::complex<double> zz =
|
||||
*reinterpret_cast<thrust::complex<double>*>(&z);
|
||||
ASSERT_EQ(zz.real(), double(3));
|
||||
ASSERT_EQ(zz.imag(), double(4));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
} // memory
|
||||
} // namespace memory
|
||||
|
||||
namespace constructors {
|
||||
|
||||
template<typename scalar_t>
|
||||
template <typename scalar_t>
|
||||
C10_HOST_DEVICE void test_construct_from_scalar() {
|
||||
constexpr scalar_t num1 = scalar_t(1.23);
|
||||
constexpr scalar_t num2 = scalar_t(4.56);
|
||||
|
|
@ -111,29 +112,48 @@ C10_HOST_DEVICE void test_construct_from_scalar() {
|
|||
static_assert(c10::complex<scalar_t>().imag() == zero, "");
|
||||
}
|
||||
|
||||
template<typename scalar_t, typename other_t>
|
||||
template <typename scalar_t, typename other_t>
|
||||
C10_HOST_DEVICE void test_construct_from_other() {
|
||||
constexpr other_t num1 = other_t(1.23);
|
||||
constexpr other_t num2 = other_t(4.56);
|
||||
constexpr scalar_t num3 = scalar_t(num1);
|
||||
constexpr scalar_t num4 = scalar_t(num2);
|
||||
static_assert(c10::complex<scalar_t>(c10::complex<other_t>(num1, num2)).real() == num3, "");
|
||||
static_assert(c10::complex<scalar_t>(c10::complex<other_t>(num1, num2)).imag() == num4, "");
|
||||
static_assert(
|
||||
c10::complex<scalar_t>(c10::complex<other_t>(num1, num2)).real() == num3,
|
||||
"");
|
||||
static_assert(
|
||||
c10::complex<scalar_t>(c10::complex<other_t>(num1, num2)).imag() == num4,
|
||||
"");
|
||||
}
|
||||
|
||||
MAYBE_GLOBAL void test_convert_constructors() {
|
||||
test_construct_from_scalar<float>();
|
||||
test_construct_from_scalar<double>();
|
||||
|
||||
static_assert(std::is_convertible<c10::complex<float>, c10::complex<float>>::value, "");
|
||||
static_assert(!std::is_convertible<c10::complex<double>, c10::complex<float>>::value, "");
|
||||
static_assert(std::is_convertible<c10::complex<float>, c10::complex<double>>::value, "");
|
||||
static_assert(std::is_convertible<c10::complex<double>, c10::complex<double>>::value, "");
|
||||
static_assert(
|
||||
std::is_convertible<c10::complex<float>, c10::complex<float>>::value, "");
|
||||
static_assert(
|
||||
!std::is_convertible<c10::complex<double>, c10::complex<float>>::value,
|
||||
"");
|
||||
static_assert(
|
||||
std::is_convertible<c10::complex<float>, c10::complex<double>>::value,
|
||||
"");
|
||||
static_assert(
|
||||
std::is_convertible<c10::complex<double>, c10::complex<double>>::value,
|
||||
"");
|
||||
|
||||
static_assert(std::is_constructible<c10::complex<float>, c10::complex<float>>::value, "");
|
||||
static_assert(std::is_constructible<c10::complex<double>, c10::complex<float>>::value, "");
|
||||
static_assert(std::is_constructible<c10::complex<float>, c10::complex<double>>::value, "");
|
||||
static_assert(std::is_constructible<c10::complex<double>, c10::complex<double>>::value, "");
|
||||
static_assert(
|
||||
std::is_constructible<c10::complex<float>, c10::complex<float>>::value,
|
||||
"");
|
||||
static_assert(
|
||||
std::is_constructible<c10::complex<double>, c10::complex<float>>::value,
|
||||
"");
|
||||
static_assert(
|
||||
std::is_constructible<c10::complex<float>, c10::complex<double>>::value,
|
||||
"");
|
||||
static_assert(
|
||||
std::is_constructible<c10::complex<double>, c10::complex<double>>::value,
|
||||
"");
|
||||
|
||||
test_construct_from_other<float, float>();
|
||||
test_construct_from_other<float, double>();
|
||||
|
|
@ -141,12 +161,16 @@ MAYBE_GLOBAL void test_convert_constructors() {
|
|||
test_construct_from_other<double, double>();
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
template <typename scalar_t>
|
||||
C10_HOST_DEVICE void test_construct_from_std() {
|
||||
constexpr scalar_t num1 = scalar_t(1.23);
|
||||
constexpr scalar_t num2 = scalar_t(4.56);
|
||||
static_assert(c10::complex<scalar_t>(std::complex<scalar_t>(num1, num2)).real() == num1, "");
|
||||
static_assert(c10::complex<scalar_t>(std::complex<scalar_t>(num1, num2)).imag() == num2, "");
|
||||
static_assert(
|
||||
c10::complex<scalar_t>(std::complex<scalar_t>(num1, num2)).real() == num1,
|
||||
"");
|
||||
static_assert(
|
||||
c10::complex<scalar_t>(std::complex<scalar_t>(num1, num2)).imag() == num2,
|
||||
"");
|
||||
}
|
||||
|
||||
MAYBE_GLOBAL void test_std_conversion() {
|
||||
|
|
@ -155,12 +179,16 @@ MAYBE_GLOBAL void test_std_conversion() {
|
|||
}
|
||||
|
||||
#if defined(__CUDACC__) || defined(__HIPCC__)
|
||||
template<typename scalar_t>
|
||||
template <typename scalar_t>
|
||||
void test_construct_from_thrust() {
|
||||
constexpr scalar_t num1 = scalar_t(1.23);
|
||||
constexpr scalar_t num2 = scalar_t(4.56);
|
||||
ASSERT_EQ(c10::complex<scalar_t>(thrust::complex<scalar_t>(num1, num2)).real(), num1);
|
||||
ASSERT_EQ(c10::complex<scalar_t>(thrust::complex<scalar_t>(num1, num2)).imag(), num2);
|
||||
ASSERT_EQ(
|
||||
c10::complex<scalar_t>(thrust::complex<scalar_t>(num1, num2)).real(),
|
||||
num1);
|
||||
ASSERT_EQ(
|
||||
c10::complex<scalar_t>(thrust::complex<scalar_t>(num1, num2)).imag(),
|
||||
num2);
|
||||
}
|
||||
|
||||
TEST(TestConstructors, FromThrust) {
|
||||
|
|
@ -170,7 +198,11 @@ TEST(TestConstructors, FromThrust) {
|
|||
#endif
|
||||
|
||||
TEST(TestConstructors, UnorderedMap) {
|
||||
std::unordered_map<c10::complex<double>, c10::complex<double>, c10::hash<c10::complex<double>>> m;
|
||||
std::unordered_map<
|
||||
c10::complex<double>,
|
||||
c10::complex<double>,
|
||||
c10::hash<c10::complex<double>>>
|
||||
m;
|
||||
auto key1 = c10::complex<double>(2.5, 3);
|
||||
auto key2 = c10::complex<double>(2, 0);
|
||||
auto val1 = c10::complex<double>(2, -3.2);
|
||||
|
|
@ -181,11 +213,11 @@ TEST(TestConstructors, UnorderedMap) {
|
|||
ASSERT_EQ(m[key2], val2);
|
||||
}
|
||||
|
||||
} // constructors
|
||||
} // namespace constructors
|
||||
|
||||
namespace assignment {
|
||||
|
||||
template<typename scalar_t>
|
||||
template <typename scalar_t>
|
||||
constexpr c10::complex<scalar_t> one() {
|
||||
c10::complex<scalar_t> result(3, 4);
|
||||
result = scalar_t(1);
|
||||
|
|
@ -232,7 +264,8 @@ MAYBE_GLOBAL void test_assign_std() {
|
|||
}
|
||||
|
||||
#if defined(__CUDACC__) || defined(__HIPCC__)
|
||||
C10_HOST_DEVICE std::tuple<c10::complex<double>, c10::complex<float>> one_two_thrust() {
|
||||
C10_HOST_DEVICE std::tuple<c10::complex<double>, c10::complex<float>>
|
||||
one_two_thrust() {
|
||||
thrust::complex<float> src(1, 2);
|
||||
c10::complex<double> ret0;
|
||||
c10::complex<float> ret1;
|
||||
|
|
@ -258,7 +291,8 @@ MAYBE_GLOBAL void test_complex_literals() {
|
|||
static_assert(std::is_same<decltype(0.5_if), c10::complex<float>>::value, "");
|
||||
static_assert((0.5_if).real() == float(), "");
|
||||
static_assert((0.5_if).imag() == float(0.5), "");
|
||||
static_assert(std::is_same<decltype(0.5_id), c10::complex<double>>::value, "");
|
||||
static_assert(
|
||||
std::is_same<decltype(0.5_id), c10::complex<double>>::value, "");
|
||||
static_assert((0.5_id).real() == float(), "");
|
||||
static_assert((0.5_id).imag() == float(0.5), "");
|
||||
|
||||
|
|
@ -274,14 +308,14 @@ MAYBE_GLOBAL void test_complex_literals() {
|
|||
|
||||
namespace real_imag {
|
||||
|
||||
template<typename scalar_t>
|
||||
template <typename scalar_t>
|
||||
constexpr c10::complex<scalar_t> zero_one() {
|
||||
c10::complex<scalar_t> result;
|
||||
result.imag(scalar_t(1));
|
||||
return result;
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
template <typename scalar_t>
|
||||
constexpr c10::complex<scalar_t> one_zero() {
|
||||
c10::complex<scalar_t> result;
|
||||
result.real(scalar_t(1));
|
||||
|
|
@ -304,35 +338,35 @@ MAYBE_GLOBAL void test_real_imag_modify() {
|
|||
|
||||
namespace arithmetic_assign {
|
||||
|
||||
template<typename scalar_t>
|
||||
template <typename scalar_t>
|
||||
constexpr c10::complex<scalar_t> p(scalar_t value) {
|
||||
c10::complex<scalar_t> result(scalar_t(2), scalar_t(2));
|
||||
result += value;
|
||||
return result;
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
template <typename scalar_t>
|
||||
constexpr c10::complex<scalar_t> m(scalar_t value) {
|
||||
c10::complex<scalar_t> result(scalar_t(2), scalar_t(2));
|
||||
result -= value;
|
||||
return result;
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
template <typename scalar_t>
|
||||
constexpr c10::complex<scalar_t> t(scalar_t value) {
|
||||
c10::complex<scalar_t> result(scalar_t(2), scalar_t(2));
|
||||
result *= value;
|
||||
return result;
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
template <typename scalar_t>
|
||||
constexpr c10::complex<scalar_t> d(scalar_t value) {
|
||||
c10::complex<scalar_t> result(scalar_t(2), scalar_t(2));
|
||||
result /= value;
|
||||
return result;
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
template <typename scalar_t>
|
||||
C10_HOST_DEVICE void test_arithmetic_assign_scalar() {
|
||||
constexpr c10::complex<scalar_t> x = p(scalar_t(1));
|
||||
static_assert(x.real() == scalar_t(3), "");
|
||||
|
|
@ -348,35 +382,47 @@ C10_HOST_DEVICE void test_arithmetic_assign_scalar() {
|
|||
static_assert(t.imag() == scalar_t(1), "");
|
||||
}
|
||||
|
||||
template<typename scalar_t, typename rhs_t>
|
||||
constexpr c10::complex<scalar_t> p(scalar_t real, scalar_t imag, c10::complex<rhs_t> rhs) {
|
||||
template <typename scalar_t, typename rhs_t>
|
||||
constexpr c10::complex<scalar_t> p(
|
||||
scalar_t real,
|
||||
scalar_t imag,
|
||||
c10::complex<rhs_t> rhs) {
|
||||
c10::complex<scalar_t> result(real, imag);
|
||||
result += rhs;
|
||||
return result;
|
||||
}
|
||||
|
||||
template<typename scalar_t, typename rhs_t>
|
||||
constexpr c10::complex<scalar_t> m(scalar_t real, scalar_t imag, c10::complex<rhs_t> rhs) {
|
||||
template <typename scalar_t, typename rhs_t>
|
||||
constexpr c10::complex<scalar_t> m(
|
||||
scalar_t real,
|
||||
scalar_t imag,
|
||||
c10::complex<rhs_t> rhs) {
|
||||
c10::complex<scalar_t> result(real, imag);
|
||||
result -= rhs;
|
||||
return result;
|
||||
}
|
||||
|
||||
template<typename scalar_t, typename rhs_t>
|
||||
constexpr c10::complex<scalar_t> t(scalar_t real, scalar_t imag, c10::complex<rhs_t> rhs) {
|
||||
template <typename scalar_t, typename rhs_t>
|
||||
constexpr c10::complex<scalar_t> t(
|
||||
scalar_t real,
|
||||
scalar_t imag,
|
||||
c10::complex<rhs_t> rhs) {
|
||||
c10::complex<scalar_t> result(real, imag);
|
||||
result *= rhs;
|
||||
return result;
|
||||
}
|
||||
|
||||
template<typename scalar_t, typename rhs_t>
|
||||
constexpr c10::complex<scalar_t> d(scalar_t real, scalar_t imag, c10::complex<rhs_t> rhs) {
|
||||
template <typename scalar_t, typename rhs_t>
|
||||
constexpr c10::complex<scalar_t> d(
|
||||
scalar_t real,
|
||||
scalar_t imag,
|
||||
c10::complex<rhs_t> rhs) {
|
||||
c10::complex<scalar_t> result(real, imag);
|
||||
result /= rhs;
|
||||
return result;
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
template <typename scalar_t>
|
||||
C10_HOST_DEVICE void test_arithmetic_assign_complex() {
|
||||
using namespace c10::complex_literals;
|
||||
constexpr c10::complex<scalar_t> x2 = p(scalar_t(2), scalar_t(2), 1.0_if);
|
||||
|
|
@ -429,26 +475,64 @@ MAYBE_GLOBAL void test_arithmetic_assign() {
|
|||
|
||||
namespace arithmetic {
|
||||
|
||||
template<typename scalar_t>
|
||||
template <typename scalar_t>
|
||||
C10_HOST_DEVICE void test_arithmetic_() {
|
||||
static_assert(c10::complex<scalar_t>(1, 2) == +c10::complex<scalar_t>(1, 2), "");
|
||||
static_assert(c10::complex<scalar_t>(-1, -2) == -c10::complex<scalar_t>(1, 2), "");
|
||||
static_assert(
|
||||
c10::complex<scalar_t>(1, 2) == +c10::complex<scalar_t>(1, 2), "");
|
||||
static_assert(
|
||||
c10::complex<scalar_t>(-1, -2) == -c10::complex<scalar_t>(1, 2), "");
|
||||
|
||||
static_assert(c10::complex<scalar_t>(1, 2) + c10::complex<scalar_t>(3, 4) == c10::complex<scalar_t>(4, 6), "");
|
||||
static_assert(c10::complex<scalar_t>(1, 2) + scalar_t(3) == c10::complex<scalar_t>(4, 2), "");
|
||||
static_assert(scalar_t(3) + c10::complex<scalar_t>(1, 2) == c10::complex<scalar_t>(4, 2), "");
|
||||
static_assert(
|
||||
c10::complex<scalar_t>(1, 2) + c10::complex<scalar_t>(3, 4) ==
|
||||
c10::complex<scalar_t>(4, 6),
|
||||
"");
|
||||
static_assert(
|
||||
c10::complex<scalar_t>(1, 2) + scalar_t(3) ==
|
||||
c10::complex<scalar_t>(4, 2),
|
||||
"");
|
||||
static_assert(
|
||||
scalar_t(3) + c10::complex<scalar_t>(1, 2) ==
|
||||
c10::complex<scalar_t>(4, 2),
|
||||
"");
|
||||
|
||||
static_assert(c10::complex<scalar_t>(1, 2) - c10::complex<scalar_t>(3, 4) == c10::complex<scalar_t>(-2, -2), "");
|
||||
static_assert(c10::complex<scalar_t>(1, 2) - scalar_t(3) == c10::complex<scalar_t>(-2, 2), "");
|
||||
static_assert(scalar_t(3) - c10::complex<scalar_t>(1, 2) == c10::complex<scalar_t>(2, -2), "");
|
||||
static_assert(
|
||||
c10::complex<scalar_t>(1, 2) - c10::complex<scalar_t>(3, 4) ==
|
||||
c10::complex<scalar_t>(-2, -2),
|
||||
"");
|
||||
static_assert(
|
||||
c10::complex<scalar_t>(1, 2) - scalar_t(3) ==
|
||||
c10::complex<scalar_t>(-2, 2),
|
||||
"");
|
||||
static_assert(
|
||||
scalar_t(3) - c10::complex<scalar_t>(1, 2) ==
|
||||
c10::complex<scalar_t>(2, -2),
|
||||
"");
|
||||
|
||||
static_assert(c10::complex<scalar_t>(1, 2) * c10::complex<scalar_t>(3, 4) == c10::complex<scalar_t>(-5, 10), "");
|
||||
static_assert(c10::complex<scalar_t>(1, 2) * scalar_t(3) == c10::complex<scalar_t>(3, 6), "");
|
||||
static_assert(scalar_t(3) * c10::complex<scalar_t>(1, 2) == c10::complex<scalar_t>(3, 6), "");
|
||||
static_assert(
|
||||
c10::complex<scalar_t>(1, 2) * c10::complex<scalar_t>(3, 4) ==
|
||||
c10::complex<scalar_t>(-5, 10),
|
||||
"");
|
||||
static_assert(
|
||||
c10::complex<scalar_t>(1, 2) * scalar_t(3) ==
|
||||
c10::complex<scalar_t>(3, 6),
|
||||
"");
|
||||
static_assert(
|
||||
scalar_t(3) * c10::complex<scalar_t>(1, 2) ==
|
||||
c10::complex<scalar_t>(3, 6),
|
||||
"");
|
||||
|
||||
static_assert(c10::complex<scalar_t>(-5, 10) / c10::complex<scalar_t>(3, 4) == c10::complex<scalar_t>(1, 2), "");
|
||||
static_assert(c10::complex<scalar_t>(5, 10) / scalar_t(5) == c10::complex<scalar_t>(1, 2), "");
|
||||
static_assert(scalar_t(25) / c10::complex<scalar_t>(3, 4) == c10::complex<scalar_t>(3, -4), "");
|
||||
static_assert(
|
||||
c10::complex<scalar_t>(-5, 10) / c10::complex<scalar_t>(3, 4) ==
|
||||
c10::complex<scalar_t>(1, 2),
|
||||
"");
|
||||
static_assert(
|
||||
c10::complex<scalar_t>(5, 10) / scalar_t(5) ==
|
||||
c10::complex<scalar_t>(1, 2),
|
||||
"");
|
||||
static_assert(
|
||||
scalar_t(25) / c10::complex<scalar_t>(3, 4) ==
|
||||
c10::complex<scalar_t>(3, -4),
|
||||
"");
|
||||
}
|
||||
|
||||
MAYBE_GLOBAL void test_arithmetic() {
|
||||
|
|
@ -456,7 +540,7 @@ MAYBE_GLOBAL void test_arithmetic() {
|
|||
test_arithmetic_<double>();
|
||||
}
|
||||
|
||||
template<typename T, typename int_t>
|
||||
template <typename T, typename int_t>
|
||||
void test_binary_ops_for_int_type_(T real, T img, int_t num) {
|
||||
c10::complex<T> c(real, img);
|
||||
ASSERT_EQ(c + num, c10::complex<T>(real + num, img));
|
||||
|
|
@ -466,10 +550,12 @@ void test_binary_ops_for_int_type_(T real, T img, int_t num) {
|
|||
ASSERT_EQ(c * num, c10::complex<T>(real * num, img * num));
|
||||
ASSERT_EQ(num * c, c10::complex<T>(num * real, num * img));
|
||||
ASSERT_EQ(c / num, c10::complex<T>(real / num, img / num));
|
||||
ASSERT_EQ(num / c, c10::complex<T>(num * real / std::norm(c), -num * img / std::norm(c)));
|
||||
ASSERT_EQ(
|
||||
num / c,
|
||||
c10::complex<T>(num * real / std::norm(c), -num * img / std::norm(c)));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
template <typename T>
|
||||
void test_binary_ops_for_all_int_types_(T real, T img, int8_t i) {
|
||||
test_binary_ops_for_int_type_<T, int8_t>(real, img, i);
|
||||
test_binary_ops_for_int_type_<T, int16_t>(real, img, i);
|
||||
|
|
@ -486,12 +572,14 @@ TEST(TestArithmeticIntScalar, All) {
|
|||
|
||||
namespace equality {
|
||||
|
||||
template<typename scalar_t>
|
||||
template <typename scalar_t>
|
||||
C10_HOST_DEVICE void test_equality_() {
|
||||
static_assert(c10::complex<scalar_t>(1, 2) == c10::complex<scalar_t>(1, 2), "");
|
||||
static_assert(
|
||||
c10::complex<scalar_t>(1, 2) == c10::complex<scalar_t>(1, 2), "");
|
||||
static_assert(c10::complex<scalar_t>(1, 0) == scalar_t(1), "");
|
||||
static_assert(scalar_t(1) == c10::complex<scalar_t>(1, 0), "");
|
||||
static_assert(c10::complex<scalar_t>(1, 2) != c10::complex<scalar_t>(3, 4), "");
|
||||
static_assert(
|
||||
c10::complex<scalar_t>(1, 2) != c10::complex<scalar_t>(3, 4), "");
|
||||
static_assert(c10::complex<scalar_t>(1, 2) != scalar_t(1), "");
|
||||
static_assert(scalar_t(1) != c10::complex<scalar_t>(1, 2), "");
|
||||
}
|
||||
|
|
@ -505,7 +593,7 @@ MAYBE_GLOBAL void test_equality() {
|
|||
|
||||
namespace io {
|
||||
|
||||
template<typename scalar_t>
|
||||
template <typename scalar_t>
|
||||
void test_io_() {
|
||||
std::stringstream ss;
|
||||
c10::complex<scalar_t> a(1, 2);
|
||||
|
|
@ -525,14 +613,16 @@ TEST(TestIO, All) {
|
|||
|
||||
namespace test_std {
|
||||
|
||||
template<typename scalar_t>
|
||||
template <typename scalar_t>
|
||||
C10_HOST_DEVICE void test_callable_() {
|
||||
static_assert(std::real(c10::complex<scalar_t>(1, 2)) == scalar_t(1), "");
|
||||
static_assert(std::imag(c10::complex<scalar_t>(1, 2)) == scalar_t(2), "");
|
||||
std::abs(c10::complex<scalar_t>(1, 2));
|
||||
std::arg(c10::complex<scalar_t>(1, 2));
|
||||
static_assert(std::norm(c10::complex<scalar_t>(3, 4)) == scalar_t(25), "");
|
||||
static_assert(std::conj(c10::complex<scalar_t>(3, 4)) == c10::complex<scalar_t>(3, -4), "");
|
||||
static_assert(
|
||||
std::conj(c10::complex<scalar_t>(3, 4)) == c10::complex<scalar_t>(3, -4),
|
||||
"");
|
||||
c10::polar(float(1), float(PI / 2));
|
||||
c10::polar(double(1), double(PI / 2));
|
||||
}
|
||||
|
|
@ -542,11 +632,15 @@ MAYBE_GLOBAL void test_callable() {
|
|||
test_callable_<double>();
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
template <typename scalar_t>
|
||||
void test_values_() {
|
||||
ASSERT_EQ(std::abs(c10::complex<scalar_t>(3, 4)), scalar_t(5));
|
||||
ASSERT_LT(std::abs(std::arg(c10::complex<scalar_t>(0, 1)) - PI / 2), 1e-6);
|
||||
ASSERT_LT(std::abs(c10::polar(scalar_t(1), scalar_t(PI / 2)) - c10::complex<scalar_t>(0, 1)), 1e-6);
|
||||
ASSERT_LT(
|
||||
std::abs(
|
||||
c10::polar(scalar_t(1), scalar_t(PI / 2)) -
|
||||
c10::complex<scalar_t>(0, 1)),
|
||||
1e-6);
|
||||
}
|
||||
|
||||
TEST(TestStd, BasicFunctions) {
|
||||
|
|
@ -554,8 +648,11 @@ TEST(TestStd, BasicFunctions) {
|
|||
test_values_<double>();
|
||||
// CSQRT edge cases: checks for overflows which are likely to occur
|
||||
// if square root is computed using polar form
|
||||
ASSERT_LT(std::abs(std::sqrt(c10::complex<float>(-1e20, -4988429.2)).real()), 3e-4);
|
||||
ASSERT_LT(std::abs(std::sqrt(c10::complex<double>(-1e60, -4988429.2)).real()), 3e-4);
|
||||
ASSERT_LT(
|
||||
std::abs(std::sqrt(c10::complex<float>(-1e20, -4988429.2)).real()), 3e-4);
|
||||
ASSERT_LT(
|
||||
std::abs(std::sqrt(c10::complex<double>(-1e60, -4988429.2)).real()),
|
||||
3e-4);
|
||||
}
|
||||
|
||||
} // namespace test_std
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -9,7 +9,7 @@ bool throw_func() {
|
|||
throw std::runtime_error("I'm throwing...");
|
||||
}
|
||||
|
||||
template<class Functor>
|
||||
template <class Functor>
|
||||
inline void expectThrowsEq(Functor&& functor, const char* expectedMessage) {
|
||||
try {
|
||||
std::forward<Functor>(functor)();
|
||||
|
|
@ -18,7 +18,7 @@ inline void expectThrowsEq(Functor&& functor, const char* expectedMessage) {
|
|||
return;
|
||||
}
|
||||
ADD_FAILURE() << "Expected to throw exception with message \""
|
||||
<< expectedMessage << "\" but didn't throw";
|
||||
<< expectedMessage << "\" but didn't throw";
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
|
@ -43,30 +43,32 @@ TEST(WarningTest, JustPrintWarning) {
|
|||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(ExceptionTest, ErrorFormatting) {
|
||||
expectThrowsEq([]() {
|
||||
TORCH_CHECK(false, "This is invalid");
|
||||
}, "This is invalid");
|
||||
expectThrowsEq(
|
||||
[]() { TORCH_CHECK(false, "This is invalid"); }, "This is invalid");
|
||||
|
||||
expectThrowsEq([]() {
|
||||
try {
|
||||
TORCH_CHECK(false, "This is invalid");
|
||||
} catch (Error& e) {
|
||||
TORCH_RETHROW(e, "While checking X");
|
||||
}
|
||||
}, "This is invalid (While checking X)");
|
||||
expectThrowsEq(
|
||||
[]() {
|
||||
try {
|
||||
TORCH_CHECK(false, "This is invalid");
|
||||
} catch (Error& e) {
|
||||
TORCH_RETHROW(e, "While checking X");
|
||||
}
|
||||
},
|
||||
"This is invalid (While checking X)");
|
||||
|
||||
expectThrowsEq([]() {
|
||||
try {
|
||||
try {
|
||||
TORCH_CHECK(false, "This is invalid");
|
||||
} catch (Error& e) {
|
||||
TORCH_RETHROW(e, "While checking X");
|
||||
}
|
||||
} catch (Error& e) {
|
||||
TORCH_RETHROW(e, "While checking Y");
|
||||
}
|
||||
},
|
||||
R"msg(This is invalid
|
||||
expectThrowsEq(
|
||||
[]() {
|
||||
try {
|
||||
try {
|
||||
TORCH_CHECK(false, "This is invalid");
|
||||
} catch (Error& e) {
|
||||
TORCH_RETHROW(e, "While checking X");
|
||||
}
|
||||
} catch (Error& e) {
|
||||
TORCH_RETHROW(e, "While checking Y");
|
||||
}
|
||||
},
|
||||
R"msg(This is invalid
|
||||
While checking X
|
||||
While checking Y)msg");
|
||||
}
|
||||
|
|
@ -94,5 +96,6 @@ TEST(ExceptionTest, DontCallArgumentFunctionsTwiceOnFailure) {
|
|||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
|
||||
EXPECT_ANY_THROW(failInternalAssert());
|
||||
EXPECT_EQ(assertionArgumentCounter, 2) << "TORCH_INTERNAL_ASSERT called argument twice";
|
||||
EXPECT_EQ(assertionArgumentCounter, 2)
|
||||
<< "TORCH_INTERNAL_ASSERT called argument twice";
|
||||
}
|
||||
|
|
|
|||
|
|
@ -67,7 +67,8 @@ class ChildDestructableMock final : public DestructableMock {
|
|||
class NullType1 final {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
static SomeClass singleton_;
|
||||
public:
|
||||
|
||||
public:
|
||||
static constexpr SomeClass* singleton() {
|
||||
return &singleton_;
|
||||
}
|
||||
|
|
@ -77,7 +78,8 @@ SomeClass NullType1::singleton_;
|
|||
class NullType2 final {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
static SomeClass singleton_;
|
||||
public:
|
||||
|
||||
public:
|
||||
static constexpr SomeClass* singleton() {
|
||||
return &singleton_;
|
||||
}
|
||||
|
|
@ -934,7 +936,8 @@ TEST(IntrusivePtrTest, givenPtr_whenDestructed_thenDestructsObject) {
|
|||
bool resourcesReleased = false;
|
||||
bool wasDestructed = false;
|
||||
{
|
||||
auto obj = make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
auto obj =
|
||||
make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
EXPECT_FALSE(resourcesReleased);
|
||||
EXPECT_FALSE(wasDestructed);
|
||||
}
|
||||
|
|
@ -948,7 +951,8 @@ TEST(
|
|||
givenPtr_whenMoveConstructed_thenDestructsObjectAfterSecondDestructed) {
|
||||
bool resourcesReleased = false;
|
||||
bool wasDestructed = false;
|
||||
auto obj = make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
auto obj =
|
||||
make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
{
|
||||
auto obj2 = std::move(obj);
|
||||
EXPECT_FALSE(resourcesReleased);
|
||||
|
|
@ -964,7 +968,8 @@ TEST(
|
|||
givenPtr_whenMoveConstructedToBaseClass_thenDestructsObjectAfterSecondDestructed) {
|
||||
bool resourcesReleased = false;
|
||||
bool wasDestructed = false;
|
||||
auto obj = make_intrusive<ChildDestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
auto obj =
|
||||
make_intrusive<ChildDestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
{
|
||||
intrusive_ptr<DestructableMock> obj2 = std::move(obj);
|
||||
EXPECT_FALSE(resourcesReleased);
|
||||
|
|
@ -981,7 +986,8 @@ TEST(IntrusivePtrTest, givenPtr_whenMoveAssigned_thenDestructsOldObject) {
|
|||
bool wasDestructed = false;
|
||||
auto obj = make_intrusive<DestructableMock>(&dummy, &dummy);
|
||||
{
|
||||
auto obj2 = make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
auto obj2 =
|
||||
make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
EXPECT_FALSE(resourcesReleased);
|
||||
EXPECT_FALSE(wasDestructed);
|
||||
obj2 = std::move(obj);
|
||||
|
|
@ -999,7 +1005,8 @@ TEST(
|
|||
bool wasDestructed = false;
|
||||
auto obj = make_intrusive<ChildDestructableMock>(&dummy, &dummy);
|
||||
{
|
||||
auto obj2 = make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
auto obj2 =
|
||||
make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
EXPECT_FALSE(resourcesReleased);
|
||||
EXPECT_FALSE(wasDestructed);
|
||||
obj2 = std::move(obj);
|
||||
|
|
@ -1017,7 +1024,8 @@ TEST(
|
|||
bool wasDestructed = false;
|
||||
auto obj = make_intrusive<DestructableMock>(&dummy, &dummy);
|
||||
{
|
||||
auto obj2 = make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
auto obj2 =
|
||||
make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
{
|
||||
auto copy = obj2;
|
||||
EXPECT_FALSE(resourcesReleased);
|
||||
|
|
@ -1040,8 +1048,8 @@ TEST(
|
|||
bool wasDestructed = false;
|
||||
auto obj = make_intrusive<ChildDestructableMock>(&dummy, &dummy);
|
||||
{
|
||||
auto obj2 =
|
||||
make_intrusive<ChildDestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
auto obj2 = make_intrusive<ChildDestructableMock>(
|
||||
&resourcesReleased, &wasDestructed);
|
||||
{
|
||||
intrusive_ptr<DestructableMock> copy = obj2;
|
||||
EXPECT_FALSE(resourcesReleased);
|
||||
|
|
@ -1064,7 +1072,8 @@ TEST(
|
|||
bool wasDestructed = false;
|
||||
auto obj = make_intrusive<ChildDestructableMock>(&dummy, &dummy);
|
||||
{
|
||||
auto obj2 = make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
auto obj2 =
|
||||
make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
{
|
||||
intrusive_ptr<DestructableMock> copy = obj2;
|
||||
EXPECT_FALSE(resourcesReleased);
|
||||
|
|
@ -1085,7 +1094,8 @@ TEST(
|
|||
bool dummy = false;
|
||||
bool resourcesReleased = false;
|
||||
bool wasDestructed = false;
|
||||
auto obj = make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
auto obj =
|
||||
make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
{
|
||||
auto obj2 = make_intrusive<DestructableMock>(&dummy, &dummy);
|
||||
obj2 = std::move(obj);
|
||||
|
|
@ -1103,7 +1113,8 @@ TEST(
|
|||
bool dummy = false;
|
||||
bool resourcesReleased = false;
|
||||
bool wasDestructed = false;
|
||||
auto obj = make_intrusive<ChildDestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
auto obj =
|
||||
make_intrusive<ChildDestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
{
|
||||
auto obj2 = make_intrusive<DestructableMock>(&dummy, &dummy);
|
||||
obj2 = std::move(obj);
|
||||
|
|
@ -1121,7 +1132,8 @@ TEST(
|
|||
bool resourcesReleased = false;
|
||||
bool wasDestructed = false;
|
||||
{
|
||||
auto obj = make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
auto obj =
|
||||
make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
{
|
||||
// NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
|
||||
intrusive_ptr<DestructableMock> copy = obj;
|
||||
|
|
@ -1142,7 +1154,8 @@ TEST(
|
|||
bool resourcesReleased = false;
|
||||
bool wasDestructed = false;
|
||||
{
|
||||
auto obj = make_intrusive<ChildDestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
auto obj = make_intrusive<ChildDestructableMock>(
|
||||
&resourcesReleased, &wasDestructed);
|
||||
{
|
||||
intrusive_ptr<DestructableMock> copy = obj;
|
||||
EXPECT_FALSE(resourcesReleased);
|
||||
|
|
@ -1162,7 +1175,8 @@ TEST(
|
|||
bool resourcesReleased = false;
|
||||
bool wasDestructed = false;
|
||||
{
|
||||
auto obj = make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
auto obj =
|
||||
make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
intrusive_ptr<DestructableMock> copy = obj;
|
||||
obj.reset();
|
||||
EXPECT_FALSE(resourcesReleased);
|
||||
|
|
@ -1179,7 +1193,8 @@ TEST(
|
|||
bool resourcesReleased = false;
|
||||
bool wasDestructed = false;
|
||||
{
|
||||
auto obj = make_intrusive<ChildDestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
auto obj = make_intrusive<ChildDestructableMock>(
|
||||
&resourcesReleased, &wasDestructed);
|
||||
intrusive_ptr<DestructableMock> copy = obj;
|
||||
obj.reset();
|
||||
EXPECT_FALSE(resourcesReleased);
|
||||
|
|
@ -1197,7 +1212,8 @@ TEST(
|
|||
bool wasDestructed = false;
|
||||
bool dummy = false;
|
||||
{
|
||||
auto obj = make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
auto obj =
|
||||
make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
{
|
||||
intrusive_ptr<DestructableMock> copy =
|
||||
make_intrusive<DestructableMock>(&dummy, &dummy);
|
||||
|
|
@ -1220,7 +1236,8 @@ TEST(
|
|||
bool wasDestructed = false;
|
||||
bool dummy = false;
|
||||
{
|
||||
auto obj = make_intrusive<ChildDestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
auto obj = make_intrusive<ChildDestructableMock>(
|
||||
&resourcesReleased, &wasDestructed);
|
||||
{
|
||||
intrusive_ptr<DestructableMock> copy =
|
||||
make_intrusive<DestructableMock>(&dummy, &dummy);
|
||||
|
|
@ -1245,7 +1262,8 @@ TEST(
|
|||
{
|
||||
auto copy = make_intrusive<DestructableMock>(&dummy, &dummy);
|
||||
{
|
||||
auto obj = make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
auto obj =
|
||||
make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
copy = obj;
|
||||
EXPECT_FALSE(resourcesReleased);
|
||||
EXPECT_FALSE(wasDestructed);
|
||||
|
|
@ -1267,8 +1285,8 @@ TEST(
|
|||
{
|
||||
auto copy = make_intrusive<DestructableMock>(&dummy, &dummy);
|
||||
{
|
||||
auto obj =
|
||||
make_intrusive<ChildDestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
auto obj = make_intrusive<ChildDestructableMock>(
|
||||
&resourcesReleased, &wasDestructed);
|
||||
copy = obj;
|
||||
EXPECT_FALSE(resourcesReleased);
|
||||
EXPECT_FALSE(wasDestructed);
|
||||
|
|
@ -1287,7 +1305,8 @@ TEST(IntrusivePtrTest, givenPtr_whenCopyAssigned_thenDestructsOldObject) {
|
|||
bool wasDestructed = false;
|
||||
auto obj = make_intrusive<DestructableMock>(&dummy, &dummy);
|
||||
{
|
||||
auto obj2 = make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
auto obj2 =
|
||||
make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
EXPECT_FALSE(resourcesReleased);
|
||||
EXPECT_FALSE(wasDestructed);
|
||||
obj2 = obj;
|
||||
|
|
@ -1305,7 +1324,8 @@ TEST(
|
|||
bool wasDestructed = false;
|
||||
auto obj = make_intrusive<ChildDestructableMock>(&dummy, &dummy);
|
||||
{
|
||||
auto obj2 = make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
auto obj2 =
|
||||
make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
EXPECT_FALSE(resourcesReleased);
|
||||
EXPECT_FALSE(wasDestructed);
|
||||
obj2 = obj;
|
||||
|
|
@ -1323,7 +1343,8 @@ TEST(
|
|||
bool wasDestructed = false;
|
||||
auto obj = make_intrusive<DestructableMock>(&dummy, &dummy);
|
||||
{
|
||||
auto obj2 = make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
auto obj2 =
|
||||
make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
{
|
||||
auto copy = obj2;
|
||||
EXPECT_FALSE(resourcesReleased);
|
||||
|
|
@ -1346,8 +1367,8 @@ TEST(
|
|||
bool wasDestructed = false;
|
||||
auto obj = make_intrusive<ChildDestructableMock>(&dummy, &dummy);
|
||||
{
|
||||
auto obj2 =
|
||||
make_intrusive<ChildDestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
auto obj2 = make_intrusive<ChildDestructableMock>(
|
||||
&resourcesReleased, &wasDestructed);
|
||||
{
|
||||
intrusive_ptr<DestructableMock> copy = obj2;
|
||||
EXPECT_FALSE(resourcesReleased);
|
||||
|
|
@ -1370,7 +1391,8 @@ TEST(
|
|||
bool wasDestructed = false;
|
||||
auto obj = make_intrusive<ChildDestructableMock>(&dummy, &dummy);
|
||||
{
|
||||
auto obj2 = make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
auto obj2 =
|
||||
make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
{
|
||||
intrusive_ptr<DestructableMock> copy = obj2;
|
||||
EXPECT_FALSE(resourcesReleased);
|
||||
|
|
@ -1388,7 +1410,8 @@ TEST(
|
|||
TEST(IntrusivePtrTest, givenPtr_whenCallingReset_thenDestructs) {
|
||||
bool resourcesReleased = false;
|
||||
bool wasDestructed = false;
|
||||
auto obj = make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
auto obj =
|
||||
make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
EXPECT_FALSE(resourcesReleased);
|
||||
EXPECT_FALSE(wasDestructed);
|
||||
obj.reset();
|
||||
|
|
@ -1402,7 +1425,8 @@ TEST(
|
|||
givenPtrWithCopy_whenCallingReset_thenDestructsAfterCopyDestructed) {
|
||||
bool resourcesReleased = false;
|
||||
bool wasDestructed = false;
|
||||
auto obj = make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
auto obj =
|
||||
make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
{
|
||||
auto copy = obj;
|
||||
obj.reset();
|
||||
|
|
@ -1420,7 +1444,8 @@ TEST(
|
|||
givenPtrWithCopy_whenCallingResetOnCopy_thenDestructsAfterOriginalDestructed) {
|
||||
bool resourcesReleased = false;
|
||||
bool wasDestructed = false;
|
||||
auto obj = make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
auto obj =
|
||||
make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
{
|
||||
auto copy = obj;
|
||||
copy.reset();
|
||||
|
|
@ -1438,7 +1463,8 @@ TEST(
|
|||
givenPtrWithMoved_whenCallingReset_thenDestructsAfterMovedDestructed) {
|
||||
bool resourcesReleased = false;
|
||||
bool wasDestructed = false;
|
||||
auto obj = make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
auto obj =
|
||||
make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
{
|
||||
auto moved = std::move(obj);
|
||||
// NOLINTNEXTLINE(bugprone-use-after-move)
|
||||
|
|
@ -1457,7 +1483,8 @@ TEST(
|
|||
givenPtrWithMoved_whenCallingResetOnMoved_thenDestructsImmediately) {
|
||||
bool resourcesReleased = false;
|
||||
bool wasDestructed = false;
|
||||
auto obj = make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
auto obj =
|
||||
make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
{
|
||||
auto moved = std::move(obj);
|
||||
moved.reset();
|
||||
|
|
@ -1800,9 +1827,7 @@ TEST(IntrusivePtrTest, givenPtr_whenNonOwningReclaimed_thenDoesntCrash) {
|
|||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(
|
||||
IntrusivePtrTest,
|
||||
givenPtr_whenNonOwningReclaimed_thenIsDestructedAtEnd) {
|
||||
TEST(IntrusivePtrTest, givenPtr_whenNonOwningReclaimed_thenIsDestructedAtEnd) {
|
||||
bool resourcesReleased = false;
|
||||
bool wasDestructed = false;
|
||||
{
|
||||
|
|
@ -1840,7 +1865,9 @@ weak_intrusive_ptr<T> make_weak_only(Args&&... args) {
|
|||
auto intrusive = make_intrusive<T>(std::forward<Args>(args)...);
|
||||
return weak_intrusive_ptr<T>(intrusive);
|
||||
}
|
||||
template <class T, class NullType = c10::detail::intrusive_target_default_null_type<T>>
|
||||
template <
|
||||
class T,
|
||||
class NullType = c10::detail::intrusive_target_default_null_type<T>>
|
||||
weak_intrusive_ptr<T, NullType> make_invalid_weak() {
|
||||
return weak_intrusive_ptr<T, NullType>(intrusive_ptr<T, NullType>());
|
||||
}
|
||||
|
|
@ -1903,9 +1930,7 @@ TEST(
|
|||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(
|
||||
WeakIntrusivePtrTest,
|
||||
vector_insert_weak_intrusive) {
|
||||
TEST(WeakIntrusivePtrTest, vector_insert_weak_intrusive) {
|
||||
std::vector<weak_intrusive_ptr<SomeClass>> priorWorks;
|
||||
std::vector<intrusive_ptr<SomeClass>> wips;
|
||||
wips.push_back(make_intrusive<SomeClass>());
|
||||
|
|
@ -2139,8 +2164,10 @@ TEST(
|
|||
TEST(
|
||||
WeakIntrusivePtrTest,
|
||||
givenNullPtr_whenMoveAssigningToDifferentNullptr_thenHasNewNullptr) {
|
||||
weak_intrusive_ptr<SomeClass, NullType1> obj1 = make_invalid_weak<SomeClass, NullType1>();
|
||||
weak_intrusive_ptr<SomeClass, NullType2> obj2 = make_invalid_weak<SomeClass, NullType2>();
|
||||
weak_intrusive_ptr<SomeClass, NullType1> obj1 =
|
||||
make_invalid_weak<SomeClass, NullType1>();
|
||||
weak_intrusive_ptr<SomeClass, NullType2> obj2 =
|
||||
make_invalid_weak<SomeClass, NullType2>();
|
||||
obj2 = std::move(obj1);
|
||||
EXPECT_NE(NullType1::singleton(), NullType2::singleton());
|
||||
// NOLINTNEXTLINE(bugprone-use-after-move)
|
||||
|
|
@ -2362,8 +2389,10 @@ TEST(
|
|||
TEST(
|
||||
WeakIntrusivePtrTest,
|
||||
givenNullPtr_whenCopyAssigningToDifferentNullptr_thenHasNewNullptr) {
|
||||
weak_intrusive_ptr<SomeClass, NullType1> obj1 = make_invalid_weak<SomeClass, NullType1>();
|
||||
weak_intrusive_ptr<SomeClass, NullType2> obj2 = make_invalid_weak<SomeClass, NullType2>();
|
||||
weak_intrusive_ptr<SomeClass, NullType1> obj1 =
|
||||
make_invalid_weak<SomeClass, NullType1>();
|
||||
weak_intrusive_ptr<SomeClass, NullType2> obj2 =
|
||||
make_invalid_weak<SomeClass, NullType2>();
|
||||
obj2 = obj1;
|
||||
EXPECT_NE(NullType1::singleton(), NullType2::singleton());
|
||||
EXPECT_TRUE(obj1.expired());
|
||||
|
|
@ -2468,7 +2497,8 @@ TEST(
|
|||
TEST(
|
||||
WeakIntrusivePtrTest,
|
||||
givenNullPtr_whenMoveConstructingToDifferentNullptr_thenHasNewNullptr) {
|
||||
weak_intrusive_ptr<SomeClass, NullType1> obj1 = make_invalid_weak<SomeClass, NullType1>();
|
||||
weak_intrusive_ptr<SomeClass, NullType1> obj1 =
|
||||
make_invalid_weak<SomeClass, NullType1>();
|
||||
weak_intrusive_ptr<SomeClass, NullType2> obj2 = std::move(obj1);
|
||||
EXPECT_NE(NullType1::singleton(), NullType2::singleton());
|
||||
// NOLINTNEXTLINE(bugprone-use-after-move)
|
||||
|
|
@ -2575,7 +2605,8 @@ TEST(
|
|||
TEST(
|
||||
WeakIntrusivePtrTest,
|
||||
givenNullPtr_whenCopyConstructingToDifferentNullptr_thenHasNewNullptr) {
|
||||
weak_intrusive_ptr<SomeClass, NullType1> obj1 = make_invalid_weak<SomeClass, NullType1>();
|
||||
weak_intrusive_ptr<SomeClass, NullType1> obj1 =
|
||||
make_invalid_weak<SomeClass, NullType1>();
|
||||
weak_intrusive_ptr<SomeClass, NullType2> obj2 = obj1;
|
||||
EXPECT_NE(NullType1::singleton(), NullType2::singleton());
|
||||
EXPECT_TRUE(obj1.expired());
|
||||
|
|
@ -3175,7 +3206,8 @@ TEST(
|
|||
givenPtr_whenLastStrongPointerResets_thenReleasesResources) {
|
||||
bool resourcesReleased = false;
|
||||
bool wasDestructed = false;
|
||||
auto obj = make_weak_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
auto obj =
|
||||
make_weak_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
EXPECT_FALSE(resourcesReleased);
|
||||
EXPECT_FALSE(wasDestructed);
|
||||
obj.ptr.reset();
|
||||
|
|
@ -3192,7 +3224,8 @@ TEST(
|
|||
givenPtr_whenDestructedButStillHasStrongPointers_thenDoesntReleaseResources) {
|
||||
bool resourcesReleased = false;
|
||||
bool wasDestructed = false;
|
||||
auto obj = make_weak_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
auto obj =
|
||||
make_weak_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
EXPECT_FALSE(resourcesReleased);
|
||||
EXPECT_FALSE(wasDestructed);
|
||||
obj.weak.reset();
|
||||
|
|
@ -3208,7 +3241,8 @@ TEST(WeakIntrusivePtrTest, givenPtr_whenDestructed_thenDestructsObject) {
|
|||
bool resourcesReleased = false;
|
||||
bool wasDestructed = false;
|
||||
{
|
||||
auto obj = make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
auto obj =
|
||||
make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
EXPECT_TRUE(resourcesReleased);
|
||||
EXPECT_FALSE(wasDestructed);
|
||||
}
|
||||
|
|
@ -3222,7 +3256,8 @@ TEST(
|
|||
givenPtr_whenMoveConstructed_thenDestructsObjectAfterSecondDestructed) {
|
||||
bool resourcesReleased = false;
|
||||
bool wasDestructed = false;
|
||||
auto obj = make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
auto obj =
|
||||
make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
{
|
||||
auto obj2 = std::move(obj);
|
||||
EXPECT_TRUE(resourcesReleased);
|
||||
|
|
@ -3238,7 +3273,8 @@ TEST(
|
|||
givenPtr_whenMoveConstructedToBaseClass_thenDestructsObjectAfterSecondDestructed) {
|
||||
bool resourcesReleased = false;
|
||||
bool wasDestructed = false;
|
||||
auto obj = make_weak_only<ChildDestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
auto obj =
|
||||
make_weak_only<ChildDestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
{
|
||||
weak_intrusive_ptr<DestructableMock> obj2 = std::move(obj);
|
||||
EXPECT_TRUE(resourcesReleased);
|
||||
|
|
@ -3255,7 +3291,8 @@ TEST(WeakIntrusivePtrTest, givenPtr_whenMoveAssigned_thenDestructsOldObject) {
|
|||
bool wasDestructed = false;
|
||||
auto obj = make_weak_only<DestructableMock>(&dummy, &dummy);
|
||||
{
|
||||
auto obj2 = make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
auto obj2 =
|
||||
make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
EXPECT_TRUE(resourcesReleased);
|
||||
EXPECT_FALSE(wasDestructed);
|
||||
obj2 = std::move(obj);
|
||||
|
|
@ -3273,7 +3310,8 @@ TEST(
|
|||
bool wasDestructed = false;
|
||||
auto obj = make_weak_only<ChildDestructableMock>(&dummy, &dummy);
|
||||
{
|
||||
auto obj2 = make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
auto obj2 =
|
||||
make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
EXPECT_TRUE(resourcesReleased);
|
||||
EXPECT_FALSE(wasDestructed);
|
||||
obj2 = std::move(obj);
|
||||
|
|
@ -3291,7 +3329,8 @@ TEST(
|
|||
bool wasDestructed = false;
|
||||
auto obj = make_weak_only<DestructableMock>(&dummy, &dummy);
|
||||
{
|
||||
auto obj2 = make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
auto obj2 =
|
||||
make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
{
|
||||
auto copy = obj2;
|
||||
EXPECT_TRUE(resourcesReleased);
|
||||
|
|
@ -3314,8 +3353,8 @@ TEST(
|
|||
bool wasDestructed = false;
|
||||
auto obj = make_weak_only<ChildDestructableMock>(&dummy, &dummy);
|
||||
{
|
||||
auto obj2 =
|
||||
make_weak_only<ChildDestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
auto obj2 = make_weak_only<ChildDestructableMock>(
|
||||
&resourcesReleased, &wasDestructed);
|
||||
{
|
||||
weak_intrusive_ptr<DestructableMock> copy = obj2;
|
||||
EXPECT_TRUE(resourcesReleased);
|
||||
|
|
@ -3338,7 +3377,8 @@ TEST(
|
|||
bool wasDestructed = false;
|
||||
auto obj = make_weak_only<ChildDestructableMock>(&dummy, &dummy);
|
||||
{
|
||||
auto obj2 = make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
auto obj2 =
|
||||
make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
{
|
||||
weak_intrusive_ptr<DestructableMock> copy = obj2;
|
||||
EXPECT_TRUE(resourcesReleased);
|
||||
|
|
@ -3359,7 +3399,8 @@ TEST(
|
|||
bool dummy = false;
|
||||
bool resourcesReleased = false;
|
||||
bool wasDestructed = false;
|
||||
auto obj = make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
auto obj =
|
||||
make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
{
|
||||
auto obj2 = make_weak_only<DestructableMock>(&dummy, &dummy);
|
||||
obj2 = std::move(obj);
|
||||
|
|
@ -3377,7 +3418,8 @@ TEST(
|
|||
bool dummy = false;
|
||||
bool resourcesReleased = false;
|
||||
bool wasDestructed = false;
|
||||
auto obj = make_weak_only<ChildDestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
auto obj =
|
||||
make_weak_only<ChildDestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
{
|
||||
auto obj2 = make_weak_only<DestructableMock>(&dummy, &dummy);
|
||||
obj2 = std::move(obj);
|
||||
|
|
@ -3395,7 +3437,8 @@ TEST(
|
|||
bool resourcesReleased = false;
|
||||
bool wasDestructed = false;
|
||||
{
|
||||
auto obj = make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
auto obj =
|
||||
make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
{
|
||||
// NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
|
||||
weak_intrusive_ptr<DestructableMock> copy = obj;
|
||||
|
|
@ -3416,7 +3459,8 @@ TEST(
|
|||
bool resourcesReleased = false;
|
||||
bool wasDestructed = false;
|
||||
{
|
||||
auto obj = make_weak_only<ChildDestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
auto obj = make_weak_only<ChildDestructableMock>(
|
||||
&resourcesReleased, &wasDestructed);
|
||||
{
|
||||
weak_intrusive_ptr<DestructableMock> copy = obj;
|
||||
EXPECT_TRUE(resourcesReleased);
|
||||
|
|
@ -3436,7 +3480,8 @@ TEST(
|
|||
bool resourcesReleased = false;
|
||||
bool wasDestructed = false;
|
||||
{
|
||||
auto obj = make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
auto obj =
|
||||
make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
weak_intrusive_ptr<DestructableMock> copy = obj;
|
||||
obj.reset();
|
||||
EXPECT_TRUE(resourcesReleased);
|
||||
|
|
@ -3453,7 +3498,8 @@ TEST(
|
|||
bool resourcesReleased = false;
|
||||
bool wasDestructed = false;
|
||||
{
|
||||
auto obj = make_weak_only<ChildDestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
auto obj = make_weak_only<ChildDestructableMock>(
|
||||
&resourcesReleased, &wasDestructed);
|
||||
weak_intrusive_ptr<DestructableMock> copy = obj;
|
||||
obj.reset();
|
||||
EXPECT_TRUE(resourcesReleased);
|
||||
|
|
@ -3471,7 +3517,8 @@ TEST(
|
|||
bool wasDestructed = false;
|
||||
bool dummy = false;
|
||||
{
|
||||
auto obj = make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
auto obj =
|
||||
make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
{
|
||||
weak_intrusive_ptr<DestructableMock> copy =
|
||||
make_weak_only<DestructableMock>(&dummy, &dummy);
|
||||
|
|
@ -3494,7 +3541,8 @@ TEST(
|
|||
bool wasDestructed = false;
|
||||
bool dummy = false;
|
||||
{
|
||||
auto obj = make_weak_only<ChildDestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
auto obj = make_weak_only<ChildDestructableMock>(
|
||||
&resourcesReleased, &wasDestructed);
|
||||
{
|
||||
weak_intrusive_ptr<DestructableMock> copy =
|
||||
make_weak_only<DestructableMock>(&dummy, &dummy);
|
||||
|
|
@ -3519,7 +3567,8 @@ TEST(
|
|||
{
|
||||
auto copy = make_weak_only<DestructableMock>(&dummy, &dummy);
|
||||
{
|
||||
auto obj = make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
auto obj =
|
||||
make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
copy = obj;
|
||||
EXPECT_TRUE(resourcesReleased);
|
||||
EXPECT_FALSE(wasDestructed);
|
||||
|
|
@ -3541,8 +3590,8 @@ TEST(
|
|||
{
|
||||
auto copy = make_weak_only<DestructableMock>(&dummy, &dummy);
|
||||
{
|
||||
auto obj =
|
||||
make_weak_only<ChildDestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
auto obj = make_weak_only<ChildDestructableMock>(
|
||||
&resourcesReleased, &wasDestructed);
|
||||
copy = obj;
|
||||
EXPECT_TRUE(resourcesReleased);
|
||||
EXPECT_FALSE(wasDestructed);
|
||||
|
|
@ -3561,7 +3610,8 @@ TEST(WeakIntrusivePtrTest, givenPtr_whenCopyAssigned_thenDestructsOldObject) {
|
|||
bool wasDestructed = false;
|
||||
auto obj = make_weak_only<DestructableMock>(&dummy, &dummy);
|
||||
{
|
||||
auto obj2 = make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
auto obj2 =
|
||||
make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
EXPECT_TRUE(resourcesReleased);
|
||||
EXPECT_FALSE(wasDestructed);
|
||||
obj2 = obj;
|
||||
|
|
@ -3579,7 +3629,8 @@ TEST(
|
|||
bool wasDestructed = false;
|
||||
auto obj = make_weak_only<ChildDestructableMock>(&dummy, &dummy);
|
||||
{
|
||||
auto obj2 = make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
auto obj2 =
|
||||
make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
EXPECT_TRUE(resourcesReleased);
|
||||
EXPECT_FALSE(wasDestructed);
|
||||
obj2 = obj;
|
||||
|
|
@ -3597,7 +3648,8 @@ TEST(
|
|||
bool wasDestructed = false;
|
||||
auto obj = make_weak_only<DestructableMock>(&dummy, &dummy);
|
||||
{
|
||||
auto obj2 = make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
auto obj2 =
|
||||
make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
{
|
||||
auto copy = obj2;
|
||||
EXPECT_TRUE(resourcesReleased);
|
||||
|
|
@ -3620,8 +3672,8 @@ TEST(
|
|||
bool wasDestructed = false;
|
||||
auto obj = make_weak_only<ChildDestructableMock>(&dummy, &dummy);
|
||||
{
|
||||
auto obj2 =
|
||||
make_weak_only<ChildDestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
auto obj2 = make_weak_only<ChildDestructableMock>(
|
||||
&resourcesReleased, &wasDestructed);
|
||||
{
|
||||
weak_intrusive_ptr<DestructableMock> copy = obj2;
|
||||
EXPECT_TRUE(resourcesReleased);
|
||||
|
|
@ -3644,7 +3696,8 @@ TEST(
|
|||
bool wasDestructed = false;
|
||||
auto obj = make_weak_only<ChildDestructableMock>(&dummy, &dummy);
|
||||
{
|
||||
auto obj2 = make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
auto obj2 =
|
||||
make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
{
|
||||
weak_intrusive_ptr<DestructableMock> copy = obj2;
|
||||
EXPECT_TRUE(resourcesReleased);
|
||||
|
|
@ -3662,7 +3715,8 @@ TEST(
|
|||
TEST(WeakIntrusivePtrTest, givenPtr_whenCallingReset_thenDestructs) {
|
||||
bool resourcesReleased = false;
|
||||
bool wasDestructed = false;
|
||||
auto obj = make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
auto obj =
|
||||
make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
EXPECT_TRUE(resourcesReleased);
|
||||
EXPECT_FALSE(wasDestructed);
|
||||
obj.reset();
|
||||
|
|
@ -3676,7 +3730,8 @@ TEST(
|
|||
givenPtrWithCopy_whenCallingReset_thenDestructsAfterCopyDestructed) {
|
||||
bool resourcesReleased = false;
|
||||
bool wasDestructed = false;
|
||||
auto obj = make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
auto obj =
|
||||
make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
{
|
||||
auto copy = obj;
|
||||
obj.reset();
|
||||
|
|
@ -3694,7 +3749,8 @@ TEST(
|
|||
givenPtrWithCopy_whenCallingResetOnCopy_thenDestructsAfterOriginalDestructed) {
|
||||
bool resourcesReleased = false;
|
||||
bool wasDestructed = false;
|
||||
auto obj = make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
auto obj =
|
||||
make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
{
|
||||
auto copy = obj;
|
||||
copy.reset();
|
||||
|
|
@ -3712,7 +3768,8 @@ TEST(
|
|||
givenPtrWithMoved_whenCallingReset_thenDestructsAfterMovedDestructed) {
|
||||
bool resourcesReleased = false;
|
||||
bool wasDestructed = false;
|
||||
auto obj = make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
auto obj =
|
||||
make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
{
|
||||
auto moved = std::move(obj);
|
||||
// NOLINTNEXTLINE(bugprone-use-after-move)
|
||||
|
|
@ -3731,7 +3788,8 @@ TEST(
|
|||
givenPtrWithMoved_whenCallingResetOnMoved_thenDestructsImmediately) {
|
||||
bool resourcesReleased = false;
|
||||
bool wasDestructed = false;
|
||||
auto obj = make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
auto obj =
|
||||
make_weak_only<DestructableMock>(&resourcesReleased, &wasDestructed);
|
||||
{
|
||||
auto moved = std::move(obj);
|
||||
moved.reset();
|
||||
|
|
|
|||
|
|
@ -8,58 +8,58 @@ using namespace ::testing;
|
|||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(irange_test, range_test) {
|
||||
std::vector<int> test_vec;
|
||||
for(const auto i : c10::irange(4, 11)){
|
||||
test_vec.push_back(i);
|
||||
}
|
||||
const std::vector<int> correct = {{4,5,6,7,8,9,10}};
|
||||
ASSERT_EQ(test_vec, correct);
|
||||
std::vector<int> test_vec;
|
||||
for (const auto i : c10::irange(4, 11)) {
|
||||
test_vec.push_back(i);
|
||||
}
|
||||
const std::vector<int> correct = {{4, 5, 6, 7, 8, 9, 10}};
|
||||
ASSERT_EQ(test_vec, correct);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(irange_test, end_test) {
|
||||
std::vector<int> test_vec;
|
||||
for(const auto i : c10::irange(5)){
|
||||
test_vec.push_back(i);
|
||||
}
|
||||
const std::vector<int> correct = {{0, 1, 2, 3, 4}};
|
||||
ASSERT_EQ(test_vec, correct);
|
||||
std::vector<int> test_vec;
|
||||
for (const auto i : c10::irange(5)) {
|
||||
test_vec.push_back(i);
|
||||
}
|
||||
const std::vector<int> correct = {{0, 1, 2, 3, 4}};
|
||||
ASSERT_EQ(test_vec, correct);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(irange_test, neg_range_test) {
|
||||
std::vector<int> test_vec;
|
||||
for(const auto i : c10::irange(-2, 3)){
|
||||
test_vec.push_back(i);
|
||||
}
|
||||
const std::vector<int> correct = {{-2,-1,0,1,2}};
|
||||
ASSERT_EQ(test_vec, correct);
|
||||
std::vector<int> test_vec;
|
||||
for (const auto i : c10::irange(-2, 3)) {
|
||||
test_vec.push_back(i);
|
||||
}
|
||||
const std::vector<int> correct = {{-2, -1, 0, 1, 2}};
|
||||
ASSERT_EQ(test_vec, correct);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(irange, empty_reverse_range_two_inputs){
|
||||
std::vector<int> test_vec;
|
||||
for(const auto i : c10::irange(3, -3)){
|
||||
test_vec.push_back(i);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
if(i>20){ //Cap the number of elements we add if something goes wrong
|
||||
break;
|
||||
}
|
||||
TEST(irange, empty_reverse_range_two_inputs) {
|
||||
std::vector<int> test_vec;
|
||||
for (const auto i : c10::irange(3, -3)) {
|
||||
test_vec.push_back(i);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
if (i > 20) { // Cap the number of elements we add if something goes wrong
|
||||
break;
|
||||
}
|
||||
const std::vector<int> correct = {};
|
||||
ASSERT_EQ(test_vec, correct);
|
||||
}
|
||||
const std::vector<int> correct = {};
|
||||
ASSERT_EQ(test_vec, correct);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(irange, empty_reverse_range_one_input){
|
||||
std::vector<int> test_vec;
|
||||
for(const auto i : c10::irange(-3)){
|
||||
test_vec.push_back(i);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
if(i>20){ //Cap the number of elements we add if something goes wrong
|
||||
break;
|
||||
}
|
||||
TEST(irange, empty_reverse_range_one_input) {
|
||||
std::vector<int> test_vec;
|
||||
for (const auto i : c10::irange(-3)) {
|
||||
test_vec.push_back(i);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
if (i > 20) { // Cap the number of elements we add if something goes wrong
|
||||
break;
|
||||
}
|
||||
const std::vector<int> correct = {};
|
||||
ASSERT_EQ(test_vec, correct);
|
||||
}
|
||||
const std::vector<int> correct = {};
|
||||
ASSERT_EQ(test_vec, correct);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
#include <algorithm>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <c10/util/ArrayRef.h>
|
||||
#include <c10/util/Logging.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
namespace c10_test {
|
||||
|
||||
|
|
@ -54,15 +54,15 @@ TEST(LoggingTest, TestEnforceEquals) {
|
|||
|
||||
namespace {
|
||||
struct EnforceEqWithCaller {
|
||||
void test(const char *x) {
|
||||
void test(const char* x) {
|
||||
CAFFE_ENFORCE_EQ_WITH_CALLER(1, 1, "variable: ", x, " is a variable");
|
||||
}
|
||||
};
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(LoggingTest, TestEnforceMessageVariables) {
|
||||
const char *const x = "hello";
|
||||
const char* const x = "hello";
|
||||
CAFFE_ENFORCE_EQ(1, 1, "variable: ", x, " is a variable");
|
||||
|
||||
EnforceEqWithCaller e;
|
||||
|
|
@ -70,7 +70,9 @@ TEST(LoggingTest, TestEnforceMessageVariables) {
|
|||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(LoggingTest, EnforceEqualsObjectWithReferenceToTemporaryWithoutUseOutOfScope) {
|
||||
TEST(
|
||||
LoggingTest,
|
||||
EnforceEqualsObjectWithReferenceToTemporaryWithoutUseOutOfScope) {
|
||||
std::vector<int> x = {1, 2, 3, 4};
|
||||
// This case is a little tricky. We have a temporary
|
||||
// std::initializer_list to which our temporary ArrayRef
|
||||
|
|
@ -102,7 +104,7 @@ std::ostream& operator<<(std::ostream& out, const Noncopyable& nc) {
|
|||
out << "Noncopyable(" << nc.x << ")";
|
||||
return out;
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(LoggingTest, DoesntCopyComparedObjects) {
|
||||
|
|
@ -132,7 +134,8 @@ TEST(LoggingTest, EnforceShowcase) {
|
|||
WRAP_AND_PRINT(CAFFE_ENFORCE_EQ(
|
||||
one * two + three, three * two, "It's a pretty complicated expression"));
|
||||
|
||||
WRAP_AND_PRINT(CAFFE_ENFORCE_THAT(std::equal_to<void>(), ==, one * two + three, three * two));
|
||||
WRAP_AND_PRINT(CAFFE_ENFORCE_THAT(
|
||||
std::equal_to<void>(), ==, one * two + three, three * two));
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
|
|
|
|||
|
|
@ -17,32 +17,29 @@ class OptionalTest : public ::testing::Test {
|
|||
template <typename T>
|
||||
T getSampleValue();
|
||||
|
||||
template<>
|
||||
template <>
|
||||
bool getSampleValue() {
|
||||
return true;
|
||||
}
|
||||
|
||||
template<>
|
||||
template <>
|
||||
uint64_t getSampleValue() {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
return 42;
|
||||
}
|
||||
|
||||
template<>
|
||||
template <>
|
||||
std::string getSampleValue() {
|
||||
return "hello";
|
||||
}
|
||||
|
||||
|
||||
using OptionalTypes = ::testing::Types<
|
||||
// 32-bit scalar optimization.
|
||||
bool,
|
||||
// Trivially destructible but not 32-bit scalar.
|
||||
uint64_t,
|
||||
// Non-trivial destructor.
|
||||
std::string
|
||||
>;
|
||||
|
||||
// 32-bit scalar optimization.
|
||||
bool,
|
||||
// Trivially destructible but not 32-bit scalar.
|
||||
uint64_t,
|
||||
// Non-trivial destructor.
|
||||
std::string>;
|
||||
|
||||
TYPED_TEST_CASE(OptionalTest, OptionalTypes);
|
||||
|
||||
|
|
@ -71,7 +68,8 @@ TYPED_TEST(OptionalTest, Initialized) {
|
|||
moveAssign = std::move(moveFrom2);
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
std::array<typename TestFixture::optional *, 5> opts = {&opt, ©, ©Assign, &move, &moveAssign};
|
||||
std::array<typename TestFixture::optional*, 5> opts = {
|
||||
&opt, ©, ©Assign, &move, &moveAssign};
|
||||
for (auto* popt : opts) {
|
||||
auto& opt = *popt;
|
||||
EXPECT_TRUE((bool)opt);
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
#include <vector>
|
||||
#include <unordered_set>
|
||||
#include <algorithm>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/Exception.h>
|
||||
|
|
@ -11,7 +11,8 @@ namespace {
|
|||
|
||||
#define ASSERT_EQUAL_PRIM(t1, t2) ASSERT_TRUE(t1 == t2);
|
||||
|
||||
using dict_int_int = ska_ordered::order_preserving_flat_hash_map<int64_t, int64_t>;
|
||||
using dict_int_int =
|
||||
ska_ordered::order_preserving_flat_hash_map<int64_t, int64_t>;
|
||||
|
||||
dict_int_int test_dict(dict_int_int& dict) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
|
|
@ -20,7 +21,7 @@ dict_int_int test_dict(dict_int_int& dict) {
|
|||
}
|
||||
|
||||
int64_t i = 0;
|
||||
for (auto entry: dict) {
|
||||
for (auto entry : dict) {
|
||||
TORCH_INTERNAL_ASSERT(entry.first == i && entry.second == i + 1);
|
||||
++i;
|
||||
}
|
||||
|
|
@ -28,7 +29,7 @@ dict_int_int test_dict(dict_int_int& dict) {
|
|||
// erase a few entries by themselves
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
std::unordered_set<int64_t> erase_set = {0, 2, 9, 71};
|
||||
for (auto erase: erase_set) {
|
||||
for (auto erase : erase_set) {
|
||||
dict.erase(erase);
|
||||
}
|
||||
|
||||
|
|
@ -55,7 +56,7 @@ dict_int_int test_dict(dict_int_int& dict) {
|
|||
}
|
||||
|
||||
i = 0;
|
||||
for (auto entry: dict) {
|
||||
for (auto entry : dict) {
|
||||
TORCH_INTERNAL_ASSERT(order[i] == entry.first);
|
||||
TORCH_INTERNAL_ASSERT(dict[order[i]] == entry.second);
|
||||
TORCH_INTERNAL_ASSERT(entry.second == order[i] + 1);
|
||||
|
|
@ -89,11 +90,11 @@ TEST(OrderedPreservingDictTest, InsertExistingDoesntAffectOrder) {
|
|||
TORCH_INTERNAL_ASSERT(dict.begin()->first == 1);
|
||||
}
|
||||
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(OrderedPreservingDictTest, testRefType) {
|
||||
std::shared_ptr<int64_t> t;
|
||||
using dict_references = ska_ordered::order_preserving_flat_hash_map<int64_t, std::shared_ptr<int64_t>>;
|
||||
using dict_references = ska_ordered::
|
||||
order_preserving_flat_hash_map<int64_t, std::shared_ptr<int64_t>>;
|
||||
|
||||
dict_references dict;
|
||||
|
||||
|
|
@ -108,7 +109,6 @@ TEST(OrderedPreservingDictTest, testRefType) {
|
|||
TORCH_INTERNAL_ASSERT(ptr.use_count() == 1);
|
||||
}
|
||||
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(OrderedPreservingDictTest, DictCollisions) {
|
||||
struct BadHash {
|
||||
|
|
@ -171,254 +171,262 @@ TEST(OrderedPreservingDictTest, DictCollisions) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
// Tests taken from https://github.com/Tessil/ordered-map/blob/master/tests/ordered_map_tests.cpp
|
||||
// Tests taken from
|
||||
// https://github.com/Tessil/ordered-map/blob/master/tests/ordered_map_tests.cpp
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(OrderedPreservingDictTest, test_range_insert) {
|
||||
// insert x values in vector, range insert x-15 values from vector to map, check values
|
||||
const int nb_values = 1000;
|
||||
std::vector<std::pair<int, int>> values;
|
||||
for(int i = 0; i < nb_values; i++) {
|
||||
// NOLINTNEXTLINE(modernize-use-emplace,performance-inefficient-vector-operation)
|
||||
values.push_back(std::make_pair(i, i+1));
|
||||
}
|
||||
// insert x values in vector, range insert x-15 values from vector to map,
|
||||
// check values
|
||||
const int nb_values = 1000;
|
||||
std::vector<std::pair<int, int>> values;
|
||||
for (int i = 0; i < nb_values; i++) {
|
||||
// NOLINTNEXTLINE(modernize-use-emplace,performance-inefficient-vector-operation)
|
||||
values.push_back(std::make_pair(i, i + 1));
|
||||
}
|
||||
|
||||
dict_int_int map = {{-1, 0}, {-2, 0}};
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
map.insert(values.begin() + 10, values.end() - 5);
|
||||
dict_int_int map = {{-1, 0}, {-2, 0}};
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
map.insert(values.begin() + 10, values.end() - 5);
|
||||
|
||||
TORCH_INTERNAL_ASSERT(map.size(), 987);
|
||||
TORCH_INTERNAL_ASSERT(map.size(), 987);
|
||||
|
||||
ASSERT_EQUAL_PRIM(map.at(-1), 0);
|
||||
ASSERT_EQUAL_PRIM(map.at(-1), 0);
|
||||
|
||||
ASSERT_EQUAL_PRIM(map.at(-2), 0);
|
||||
ASSERT_EQUAL_PRIM(map.at(-2), 0);
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
for(int i = 10, j = 2; i < nb_values - 5; i++, j++) {
|
||||
ASSERT_EQUAL_PRIM(map.at(i), i+1);
|
||||
}
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
for (int i = 10, j = 2; i < nb_values - 5; i++, j++) {
|
||||
ASSERT_EQUAL_PRIM(map.at(i), i + 1);
|
||||
}
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(OrderedPreservingDictTest, test_range_erase_all) {
|
||||
// insert x values, delete all
|
||||
const std::size_t nb_values = 1000;
|
||||
dict_int_int map;
|
||||
for (size_t i = 0; i < nb_values; ++i) {
|
||||
map[i] = i + 1;
|
||||
}
|
||||
auto it = map.erase(map.begin(), map.end());
|
||||
ASSERT_TRUE(it == map.end());
|
||||
ASSERT_TRUE(map.empty());
|
||||
// insert x values, delete all
|
||||
const std::size_t nb_values = 1000;
|
||||
dict_int_int map;
|
||||
for (size_t i = 0; i < nb_values; ++i) {
|
||||
map[i] = i + 1;
|
||||
}
|
||||
auto it = map.erase(map.begin(), map.end());
|
||||
ASSERT_TRUE(it == map.end());
|
||||
ASSERT_TRUE(map.empty());
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(OrderedPreservingDictTest, test_range_erase) {
|
||||
// insert x values, delete all with iterators except 10 first and 780 last values
|
||||
using HMap = ska_ordered::order_preserving_flat_hash_map<std::string, std::int64_t>;
|
||||
// insert x values, delete all with iterators except 10 first and 780 last
|
||||
// values
|
||||
using HMap =
|
||||
ska_ordered::order_preserving_flat_hash_map<std::string, std::int64_t>;
|
||||
|
||||
const std::size_t nb_values = 1000;
|
||||
HMap map;
|
||||
for (size_t i = 0; i < nb_values; ++i) {
|
||||
map[c10::guts::to_string(i)] = i;
|
||||
auto begin = map.begin();
|
||||
for (size_t j = 0; j <= i; ++j, begin++) {
|
||||
TORCH_INTERNAL_ASSERT(begin->second == j);
|
||||
}
|
||||
const std::size_t nb_values = 1000;
|
||||
HMap map;
|
||||
for (size_t i = 0; i < nb_values; ++i) {
|
||||
map[c10::guts::to_string(i)] = i;
|
||||
auto begin = map.begin();
|
||||
for (size_t j = 0; j <= i; ++j, begin++) {
|
||||
TORCH_INTERNAL_ASSERT(begin->second == j);
|
||||
}
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
auto it_first = std::next(map.begin(), 10);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
auto it_last = std::next(map.begin(), 220);
|
||||
|
||||
auto it = map.erase(it_first, it_last);
|
||||
ASSERT_EQUAL_PRIM(std::distance(it, map.end()), 780);
|
||||
ASSERT_EQUAL_PRIM(map.size(), 790);
|
||||
ASSERT_EQUAL_PRIM(std::distance(map.begin(), map.end()), 790);
|
||||
|
||||
for (auto& val : map) {
|
||||
ASSERT_EQUAL_PRIM(map.count(val.first), 1);
|
||||
}
|
||||
|
||||
// Check order
|
||||
it = map.begin();
|
||||
for (std::size_t i = 0; i < nb_values; i++) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
auto it_first = std::next(map.begin(), 10);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
auto it_last = std::next(map.begin(), 220);
|
||||
|
||||
auto it = map.erase(it_first, it_last);
|
||||
ASSERT_EQUAL_PRIM(std::distance(it, map.end()), 780);
|
||||
ASSERT_EQUAL_PRIM(map.size(), 790);
|
||||
ASSERT_EQUAL_PRIM(std::distance(map.begin(), map.end()), 790);
|
||||
|
||||
for(auto& val: map) {
|
||||
ASSERT_EQUAL_PRIM(map.count(val.first), 1);
|
||||
}
|
||||
|
||||
// Check order
|
||||
it = map.begin();
|
||||
for(std::size_t i = 0; i < nb_values; i++) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
if(i >= 10 && i < 220) {
|
||||
continue;
|
||||
}
|
||||
auto exp_it = std::pair<std::string, std::int64_t>(c10::guts::to_string(i), i);
|
||||
TORCH_INTERNAL_ASSERT(*it == exp_it);
|
||||
++it;
|
||||
if (i >= 10 && i < 220) {
|
||||
continue;
|
||||
}
|
||||
auto exp_it =
|
||||
std::pair<std::string, std::int64_t>(c10::guts::to_string(i), i);
|
||||
TORCH_INTERNAL_ASSERT(*it == exp_it);
|
||||
++it;
|
||||
}
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(OrderedPreservingDictTest, test_move_constructor_empty) {
|
||||
ska_ordered::order_preserving_flat_hash_map<std::string, int64_t> map(0);
|
||||
ska_ordered::order_preserving_flat_hash_map<std::string, int64_t> map_move(std::move(map));
|
||||
ska_ordered::order_preserving_flat_hash_map<std::string, int64_t> map(0);
|
||||
ska_ordered::order_preserving_flat_hash_map<std::string, int64_t> map_move(
|
||||
std::move(map));
|
||||
|
||||
// NOLINTNEXTLINE(bugprone-use-after-move)
|
||||
TORCH_INTERNAL_ASSERT(map.empty());
|
||||
TORCH_INTERNAL_ASSERT(map_move.empty());
|
||||
// NOLINTNEXTLINE(bugprone-use-after-move)
|
||||
TORCH_INTERNAL_ASSERT(map.empty());
|
||||
TORCH_INTERNAL_ASSERT(map_move.empty());
|
||||
|
||||
// NOLINTNEXTLINE(clang-analyzer-cplusplus.Move)
|
||||
TORCH_INTERNAL_ASSERT(map.find("") == map.end());
|
||||
TORCH_INTERNAL_ASSERT(map_move.find("") == map_move.end());
|
||||
// NOLINTNEXTLINE(clang-analyzer-cplusplus.Move)
|
||||
TORCH_INTERNAL_ASSERT(map.find("") == map.end());
|
||||
TORCH_INTERNAL_ASSERT(map_move.find("") == map_move.end());
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(OrderedPreservingDictTest, test_move_operator_empty) {
|
||||
ska_ordered::order_preserving_flat_hash_map<std::string, int64_t> map(0);
|
||||
ska_ordered::order_preserving_flat_hash_map<std::string, int64_t> map_move;
|
||||
map_move = (std::move(map));
|
||||
ska_ordered::order_preserving_flat_hash_map<std::string, int64_t> map(0);
|
||||
ska_ordered::order_preserving_flat_hash_map<std::string, int64_t> map_move;
|
||||
map_move = (std::move(map));
|
||||
|
||||
// NOLINTNEXTLINE(bugprone-use-after-move)
|
||||
TORCH_INTERNAL_ASSERT(map.empty());
|
||||
TORCH_INTERNAL_ASSERT(map_move.empty());
|
||||
// NOLINTNEXTLINE(bugprone-use-after-move)
|
||||
TORCH_INTERNAL_ASSERT(map.empty());
|
||||
TORCH_INTERNAL_ASSERT(map_move.empty());
|
||||
|
||||
// NOLINTNEXTLINE(clang-analyzer-cplusplus.Move)
|
||||
TORCH_INTERNAL_ASSERT(map.find("") == map.end());
|
||||
TORCH_INTERNAL_ASSERT(map_move.find("") == map_move.end());
|
||||
// NOLINTNEXTLINE(clang-analyzer-cplusplus.Move)
|
||||
TORCH_INTERNAL_ASSERT(map.find("") == map.end());
|
||||
TORCH_INTERNAL_ASSERT(map_move.find("") == map_move.end());
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(OrderedPreservingDictTest, test_reassign_moved_object_move_constructor) {
|
||||
using HMap = ska_ordered::order_preserving_flat_hash_map<std::string, std::string>;
|
||||
using HMap =
|
||||
ska_ordered::order_preserving_flat_hash_map<std::string, std::string>;
|
||||
|
||||
HMap map = {{"Key1", "Value1"}, {"Key2", "Value2"}, {"Key3", "Value3"}};
|
||||
HMap map_move(std::move(map));
|
||||
HMap map = {{"Key1", "Value1"}, {"Key2", "Value2"}, {"Key3", "Value3"}};
|
||||
HMap map_move(std::move(map));
|
||||
|
||||
ASSERT_EQUAL_PRIM(map_move.size(), 3);
|
||||
// NOLINTNEXTLINE(bugprone-use-after-move,clang-analyzer-cplusplus.Move)
|
||||
ASSERT_EQUAL_PRIM(map.size(), 0);
|
||||
ASSERT_EQUAL_PRIM(map_move.size(), 3);
|
||||
// NOLINTNEXTLINE(bugprone-use-after-move,clang-analyzer-cplusplus.Move)
|
||||
ASSERT_EQUAL_PRIM(map.size(), 0);
|
||||
|
||||
map = {{"Key4", "Value4"}, {"Key5", "Value5"}};
|
||||
TORCH_INTERNAL_ASSERT(map == (HMap({{"Key4", "Value4"}, {"Key5", "Value5"}})));
|
||||
map = {{"Key4", "Value4"}, {"Key5", "Value5"}};
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
map == (HMap({{"Key4", "Value4"}, {"Key5", "Value5"}})));
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(OrderedPreservingDictTest, test_reassign_moved_object_move_operator) {
|
||||
using HMap = ska_ordered::order_preserving_flat_hash_map<std::string, std::string>;
|
||||
using HMap =
|
||||
ska_ordered::order_preserving_flat_hash_map<std::string, std::string>;
|
||||
|
||||
HMap map = {{"Key1", "Value1"}, {"Key2", "Value2"}, {"Key3", "Value3"}};
|
||||
HMap map_move = std::move(map);
|
||||
HMap map = {{"Key1", "Value1"}, {"Key2", "Value2"}, {"Key3", "Value3"}};
|
||||
HMap map_move = std::move(map);
|
||||
|
||||
ASSERT_EQUAL_PRIM(map_move.size(), 3);
|
||||
// NOLINTNEXTLINE(bugprone-use-after-move,clang-analyzer-cplusplus.Move)
|
||||
ASSERT_EQUAL_PRIM(map.size(), 0);
|
||||
ASSERT_EQUAL_PRIM(map_move.size(), 3);
|
||||
// NOLINTNEXTLINE(bugprone-use-after-move,clang-analyzer-cplusplus.Move)
|
||||
ASSERT_EQUAL_PRIM(map.size(), 0);
|
||||
|
||||
map = {{"Key4", "Value4"}, {"Key5", "Value5"}};
|
||||
TORCH_INTERNAL_ASSERT(map == (HMap({{"Key4", "Value4"}, {"Key5", "Value5"}})));
|
||||
map = {{"Key4", "Value4"}, {"Key5", "Value5"}};
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
map == (HMap({{"Key4", "Value4"}, {"Key5", "Value5"}})));
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(OrderedPreservingDictTest, test_copy_constructor_and_operator) {
|
||||
using HMap = ska_ordered::order_preserving_flat_hash_map<std::string, std::string>;
|
||||
using HMap =
|
||||
ska_ordered::order_preserving_flat_hash_map<std::string, std::string>;
|
||||
|
||||
const std::size_t nb_values = 100;
|
||||
HMap map;
|
||||
for (size_t i = 0; i < nb_values; ++i) {
|
||||
map[c10::guts::to_string(i)] = c10::guts::to_string(i);
|
||||
}
|
||||
|
||||
const std::size_t nb_values = 100;
|
||||
HMap map;
|
||||
for (size_t i = 0; i < nb_values; ++i) {
|
||||
map[c10::guts::to_string(i)] = c10::guts::to_string(i);
|
||||
}
|
||||
HMap map_copy = map;
|
||||
HMap map_copy2(map);
|
||||
HMap map_copy3;
|
||||
map_copy3[c10::guts::to_string(0)] = c10::guts::to_string(0);
|
||||
|
||||
map_copy3 = map;
|
||||
|
||||
HMap map_copy = map;
|
||||
HMap map_copy2(map);
|
||||
HMap map_copy3;
|
||||
map_copy3[c10::guts::to_string(0)] = c10::guts::to_string(0);
|
||||
TORCH_INTERNAL_ASSERT(map == map_copy);
|
||||
map.clear();
|
||||
|
||||
map_copy3 = map;
|
||||
|
||||
TORCH_INTERNAL_ASSERT(map == map_copy);
|
||||
map.clear();
|
||||
|
||||
TORCH_INTERNAL_ASSERT(map_copy == map_copy2);
|
||||
TORCH_INTERNAL_ASSERT(map_copy == map_copy3);
|
||||
TORCH_INTERNAL_ASSERT(map_copy == map_copy2);
|
||||
TORCH_INTERNAL_ASSERT(map_copy == map_copy3);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(OrderedPreservingDictTest, test_copy_constructor_empty) {
|
||||
ska_ordered::order_preserving_flat_hash_map<std::string, int> map(0);
|
||||
ska_ordered::order_preserving_flat_hash_map<std::string, int> map_copy(map);
|
||||
ska_ordered::order_preserving_flat_hash_map<std::string, int> map(0);
|
||||
ska_ordered::order_preserving_flat_hash_map<std::string, int> map_copy(map);
|
||||
|
||||
TORCH_INTERNAL_ASSERT(map.empty());
|
||||
TORCH_INTERNAL_ASSERT(map_copy.empty());
|
||||
TORCH_INTERNAL_ASSERT(map.empty());
|
||||
TORCH_INTERNAL_ASSERT(map_copy.empty());
|
||||
|
||||
TORCH_INTERNAL_ASSERT(map.find("") == map.end());
|
||||
TORCH_INTERNAL_ASSERT(map_copy.find("") == map_copy.end());
|
||||
TORCH_INTERNAL_ASSERT(map.find("") == map.end());
|
||||
TORCH_INTERNAL_ASSERT(map_copy.find("") == map_copy.end());
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(OrderedPreservingDictTest, test_copy_operator_empty) {
|
||||
ska_ordered::order_preserving_flat_hash_map<std::string, int> map(0);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
ska_ordered::order_preserving_flat_hash_map<std::string, int> map_copy(16);
|
||||
map_copy = map;
|
||||
ska_ordered::order_preserving_flat_hash_map<std::string, int> map(0);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
ska_ordered::order_preserving_flat_hash_map<std::string, int> map_copy(16);
|
||||
map_copy = map;
|
||||
|
||||
TORCH_INTERNAL_ASSERT(map.empty());
|
||||
TORCH_INTERNAL_ASSERT(map_copy.empty());
|
||||
TORCH_INTERNAL_ASSERT(map.empty());
|
||||
TORCH_INTERNAL_ASSERT(map_copy.empty());
|
||||
|
||||
TORCH_INTERNAL_ASSERT(map.find("") == map.end());
|
||||
TORCH_INTERNAL_ASSERT(map_copy.find("") == map_copy.end());
|
||||
TORCH_INTERNAL_ASSERT(map.find("") == map.end());
|
||||
TORCH_INTERNAL_ASSERT(map_copy.find("") == map_copy.end());
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* at
|
||||
*/
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(OrderedPreservingDictTest, test_at) {
|
||||
// insert x values, use at for known and unknown values.
|
||||
const ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t> map = {{0, 10}, {-2, 20}};
|
||||
// insert x values, use at for known and unknown values.
|
||||
const ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t>
|
||||
map = {{0, 10}, {-2, 20}};
|
||||
|
||||
ASSERT_EQUAL_PRIM(map.at(0), 10);
|
||||
ASSERT_EQUAL_PRIM(map.at(-2), 20);
|
||||
bool thrown = false;
|
||||
try {
|
||||
map.at(1);
|
||||
} catch (...) {
|
||||
thrown = true;
|
||||
}
|
||||
ASSERT_TRUE(thrown);
|
||||
ASSERT_EQUAL_PRIM(map.at(0), 10);
|
||||
ASSERT_EQUAL_PRIM(map.at(-2), 20);
|
||||
bool thrown = false;
|
||||
try {
|
||||
map.at(1);
|
||||
} catch (...) {
|
||||
thrown = true;
|
||||
}
|
||||
ASSERT_TRUE(thrown);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* equal_range
|
||||
*/
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(OrderedPreservingDictTest, test_equal_range) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t> map = {{0, 10}, {-2, 20}};
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t> map =
|
||||
{{0, 10}, {-2, 20}};
|
||||
|
||||
auto it_pair = map.equal_range(0);
|
||||
ASSERT_EQUAL_PRIM(std::distance(it_pair.first, it_pair.second), 1);
|
||||
ASSERT_EQUAL_PRIM(it_pair.first->second, 10);
|
||||
auto it_pair = map.equal_range(0);
|
||||
ASSERT_EQUAL_PRIM(std::distance(it_pair.first, it_pair.second), 1);
|
||||
ASSERT_EQUAL_PRIM(it_pair.first->second, 10);
|
||||
|
||||
it_pair = map.equal_range(1);
|
||||
TORCH_INTERNAL_ASSERT(it_pair.first == it_pair.second);
|
||||
TORCH_INTERNAL_ASSERT(it_pair.first == map.end());
|
||||
it_pair = map.equal_range(1);
|
||||
TORCH_INTERNAL_ASSERT(it_pair.first == it_pair.second);
|
||||
TORCH_INTERNAL_ASSERT(it_pair.first == map.end());
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* operator[]
|
||||
*/
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(OrderedPreservingDictTest, test_access_operator) {
|
||||
// insert x values, use at for known and unknown values.
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t> map = {{0, 10}, {-2, 20}};
|
||||
// insert x values, use at for known and unknown values.
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t> map =
|
||||
{{0, 10}, {-2, 20}};
|
||||
|
||||
ASSERT_EQUAL_PRIM(map[0], 10);
|
||||
ASSERT_EQUAL_PRIM(map[-2], 20);
|
||||
ASSERT_EQUAL_PRIM(map[2], std::int64_t());
|
||||
ASSERT_EQUAL_PRIM(map[0], 10);
|
||||
ASSERT_EQUAL_PRIM(map[-2], 20);
|
||||
ASSERT_EQUAL_PRIM(map[2], std::int64_t());
|
||||
|
||||
ASSERT_EQUAL_PRIM(map.size(), 3);
|
||||
ASSERT_EQUAL_PRIM(map.size(), 3);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -426,45 +434,72 @@ TEST(OrderedPreservingDictTest, test_access_operator) {
|
|||
*/
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(OrderedPreservingDictTest, test_swap) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t> map = {{1, 10}, {8, 80}, {3, 30}};
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t> map2 = {{4, 40}, {5, 50}};
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t> map =
|
||||
{{1, 10}, {8, 80}, {3, 30}};
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t> map2 =
|
||||
{{4, 40}, {5, 50}};
|
||||
|
||||
using std::swap;
|
||||
swap(map, map2);
|
||||
using std::swap;
|
||||
swap(map, map2);
|
||||
|
||||
TORCH_INTERNAL_ASSERT(map == (ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t>{{4, 40}, {5, 50}}));
|
||||
TORCH_INTERNAL_ASSERT(map2 == (ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t>{{1, 10}, {8, 80}, {3, 30}}));
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
map ==
|
||||
(ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t>{
|
||||
{4, 40}, {5, 50}}));
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
map2 ==
|
||||
(ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t>{
|
||||
{1, 10}, {8, 80}, {3, 30}}));
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
map.insert({6, 60});
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
map2.insert({4, 40});
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
map.insert({6, 60});
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
map2.insert({4, 40});
|
||||
|
||||
TORCH_INTERNAL_ASSERT(map == (ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t>{{4, 40}, {5, 50}, {6, 60}}));
|
||||
TORCH_INTERNAL_ASSERT(map2 == (ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t>{{1, 10}, {8, 80}, {3, 30}, {4, 40}}));
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
map ==
|
||||
(ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t>{
|
||||
{4, 40}, {5, 50}, {6, 60}}));
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
map2 ==
|
||||
(ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t>{
|
||||
{1, 10}, {8, 80}, {3, 30}, {4, 40}}));
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(OrderedPreservingDictTest, test_swap_empty) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t> map = {{1, 10}, {8, 80}, {3, 30}};
|
||||
ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t> map2;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t> map =
|
||||
{{1, 10}, {8, 80}, {3, 30}};
|
||||
ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t> map2;
|
||||
|
||||
using std::swap;
|
||||
swap(map, map2);
|
||||
using std::swap;
|
||||
swap(map, map2);
|
||||
|
||||
TORCH_INTERNAL_ASSERT(map == (ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t>{}));
|
||||
TORCH_INTERNAL_ASSERT(map2 == (ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t>{{1, 10}, {8, 80}, {3, 30}}));
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
map ==
|
||||
(ska_ordered::
|
||||
order_preserving_flat_hash_map<std::int64_t, std::int64_t>{}));
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
map2 ==
|
||||
(ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t>{
|
||||
{1, 10}, {8, 80}, {3, 30}}));
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
map.insert({6, 60});
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
map2.insert({4, 40});
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
map.insert({6, 60});
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
map2.insert({4, 40});
|
||||
|
||||
TORCH_INTERNAL_ASSERT(map == (ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t>{{6, 60}}));
|
||||
TORCH_INTERNAL_ASSERT(map2 == (ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t>{{1, 10}, {8, 80}, {3, 30}, {4, 40}}));
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
map ==
|
||||
(ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t>{
|
||||
{6, 60}}));
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
map2 ==
|
||||
(ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t>{
|
||||
{1, 10}, {8, 80}, {3, 30}, {4, 40}}));
|
||||
}
|
||||
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
|||
|
|
@ -7,9 +7,9 @@ using c10::string_view;
|
|||
namespace {
|
||||
namespace testutils {
|
||||
constexpr bool string_equal(const char* lhs, const char* rhs, size_t size) {
|
||||
return (size == 0)
|
||||
? true
|
||||
: (*lhs != *rhs) ? false : string_equal(lhs + 1, rhs + 1, size - 1);
|
||||
return (size == 0) ? true
|
||||
: (*lhs != *rhs) ? false
|
||||
: string_equal(lhs + 1, rhs + 1, size - 1);
|
||||
}
|
||||
static_assert(string_equal("hi", "hi", 2), "");
|
||||
static_assert(string_equal("", "", 0), "");
|
||||
|
|
@ -124,7 +124,8 @@ TEST(StringViewTest, testCopyAssignment) {
|
|||
static_assert(5 == (string_view() = "hello").size(), "");
|
||||
static_assert(
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
string_equal("hello", (string_view() = "hello").data(), 5), "");
|
||||
string_equal("hello", (string_view() = "hello").data(), 5),
|
||||
"");
|
||||
}
|
||||
#endif
|
||||
const string_view hello = assign("hello");
|
||||
|
|
@ -233,7 +234,7 @@ static_assert('o' == string_view("hello").back(), "");
|
|||
namespace test_data {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
static_assert(string_equal("hello", string_view("hello").data(), 5), "");
|
||||
}
|
||||
} // namespace test_data
|
||||
|
||||
namespace test_size_length {
|
||||
static_assert(0 == string_view("").size(), "");
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
#include <c10/util/tempfile.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include <sys/types.h>
|
||||
#include <sys/stat.h>
|
||||
#include <sys/types.h>
|
||||
|
||||
#if !defined(_WIN32)
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ namespace {
|
|||
|
||||
class TypeMetaTestFoo {};
|
||||
class TypeMetaTestBar {};
|
||||
}
|
||||
} // namespace
|
||||
|
||||
CAFFE_KNOWN_TYPE(TypeMetaTestFoo);
|
||||
CAFFE_KNOWN_TYPE(TypeMetaTestBar);
|
||||
|
|
@ -73,7 +73,6 @@ TEST(TypeMetaTest, TypeMeta) {
|
|||
EXPECT_NE(bar_meta.name().find("TypeMetaTestBar"), c10::string_view::npos);
|
||||
}
|
||||
|
||||
|
||||
class ClassAllowAssignment {
|
||||
public:
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
|
|
@ -92,7 +91,7 @@ class ClassNoAssignment {
|
|||
ClassNoAssignment& operator=(const ClassNoAssignment& src) = delete;
|
||||
int x;
|
||||
};
|
||||
}
|
||||
} // namespace
|
||||
|
||||
CAFFE_KNOWN_TYPE(ClassAllowAssignment);
|
||||
CAFFE_KNOWN_TYPE(ClassNoAssignment);
|
||||
|
|
@ -135,5 +134,5 @@ TEST(TypeMetaTest, Float16IsNotUint16) {
|
|||
EXPECT_NE(TypeMeta::Id<uint16_t>(), TypeMeta::Id<at::Half>());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace caffe2
|
||||
} // namespace
|
||||
} // namespace caffe2
|
||||
|
|
|
|||
364
c10/util/Array.h
364
c10/util/Array.h
|
|
@ -1,18 +1,20 @@
|
|||
/**
|
||||
* This file is based on the std::array implementation of libstdc++ at
|
||||
* https://gcc.gnu.org/onlinedocs/gcc-7.1.0/libstdc++/api/a01056_source.html
|
||||
*
|
||||
* Changes:
|
||||
* - isolate, i.e. remove dependencies on internal libstdc++ stuff
|
||||
* - use c++17 behavior even in c++11 or c++14
|
||||
* - remove std::swappable special case because that doesn't work with MSVC
|
||||
* - constexpr more things
|
||||
* - add some features like prepend/tail
|
||||
*
|
||||
* If using std::array at runtime, feel free to either keep using std::array or use this one - it doesn't really matter.
|
||||
* For compile time computations, this one here is preferred because std::array in C++11
|
||||
* misses some constexpr specifiers, forcing these methods to be called at runtime instead of compile time.
|
||||
*/
|
||||
* This file is based on the std::array implementation of libstdc++ at
|
||||
* https://gcc.gnu.org/onlinedocs/gcc-7.1.0/libstdc++/api/a01056_source.html
|
||||
*
|
||||
* Changes:
|
||||
* - isolate, i.e. remove dependencies on internal libstdc++ stuff
|
||||
* - use c++17 behavior even in c++11 or c++14
|
||||
* - remove std::swappable special case because that doesn't work with MSVC
|
||||
* - constexpr more things
|
||||
* - add some features like prepend/tail
|
||||
*
|
||||
* If using std::array at runtime, feel free to either keep using std::array or
|
||||
* use this one - it doesn't really matter. For compile time computations, this
|
||||
* one here is preferred because std::array in C++11 misses some constexpr
|
||||
* specifiers, forcing these methods to be called at runtime instead of compile
|
||||
* time.
|
||||
*/
|
||||
|
||||
// Copyright (C) 2007-2017 Free Software Foundation, Inc.
|
||||
//
|
||||
|
|
@ -38,16 +40,17 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <c10/util/C++17.h>
|
||||
#include <algorithm>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
namespace c10 { namespace guts {
|
||||
namespace c10 {
|
||||
namespace guts {
|
||||
|
||||
namespace detail {
|
||||
template<typename _Tp, std::size_t _Nm>
|
||||
template <typename _Tp, std::size_t _Nm>
|
||||
struct __array_traits final {
|
||||
using _Type = _Tp[_Nm];
|
||||
|
||||
|
|
@ -60,7 +63,7 @@ struct __array_traits final {
|
|||
}
|
||||
};
|
||||
|
||||
template<typename _Tp>
|
||||
template <typename _Tp>
|
||||
struct __array_traits<_Tp, 0> final {
|
||||
struct _Type final {};
|
||||
|
||||
|
|
@ -76,11 +79,11 @@ struct __array_traits<_Tp, 0> final {
|
|||
[[noreturn]] inline void __throw_out_of_range(std::string msg) {
|
||||
throw std::out_of_range(std::move(msg));
|
||||
}
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
template<typename _Tp, std::size_t _Nm>
|
||||
template <typename _Tp, std::size_t _Nm>
|
||||
class array final {
|
||||
public:
|
||||
public:
|
||||
using value_type = _Tp;
|
||||
using pointer = value_type*;
|
||||
using const_pointer = const value_type*;
|
||||
|
|
@ -93,76 +96,99 @@ public:
|
|||
using reverse_iterator = std::reverse_iterator<iterator>;
|
||||
using const_reverse_iterator = std::reverse_iterator<const_iterator>;
|
||||
|
||||
private:
|
||||
private:
|
||||
using _AT_Type = detail::__array_traits<_Tp, _Nm>;
|
||||
public: // needs to be public member for aggregate initialization
|
||||
|
||||
public: // needs to be public member for aggregate initialization
|
||||
typename _AT_Type::_Type _M_elems;
|
||||
|
||||
public:
|
||||
public:
|
||||
// No explicit construct/copy/destroy for aggregate type.
|
||||
|
||||
// DR 776.
|
||||
constexpr void fill(const value_type& __u)
|
||||
{ std::fill_n(begin(), size(), __u); }
|
||||
constexpr void fill(const value_type& __u) {
|
||||
std::fill_n(begin(), size(), __u);
|
||||
}
|
||||
|
||||
constexpr void swap(array& __other)
|
||||
{ std::swap_ranges(begin(), end(), __other.begin()); }
|
||||
constexpr void swap(array& __other) {
|
||||
std::swap_ranges(begin(), end(), __other.begin());
|
||||
}
|
||||
|
||||
// Iterators.
|
||||
constexpr iterator begin() noexcept
|
||||
{ return iterator(data()); }
|
||||
constexpr iterator begin() noexcept {
|
||||
return iterator(data());
|
||||
}
|
||||
|
||||
constexpr const_iterator begin() const noexcept
|
||||
{ return const_iterator(data()); }
|
||||
constexpr const_iterator begin() const noexcept {
|
||||
return const_iterator(data());
|
||||
}
|
||||
|
||||
constexpr iterator end() noexcept
|
||||
{ return iterator(data() + _Nm); }
|
||||
constexpr iterator end() noexcept {
|
||||
return iterator(data() + _Nm);
|
||||
}
|
||||
|
||||
constexpr const_iterator end() const noexcept
|
||||
{ return const_iterator(data() + _Nm); }
|
||||
constexpr const_iterator end() const noexcept {
|
||||
return const_iterator(data() + _Nm);
|
||||
}
|
||||
|
||||
constexpr reverse_iterator rbegin() noexcept
|
||||
{ return reverse_iterator(end()); }
|
||||
constexpr reverse_iterator rbegin() noexcept {
|
||||
return reverse_iterator(end());
|
||||
}
|
||||
|
||||
constexpr const_reverse_iterator rbegin() const noexcept
|
||||
{ return const_reverse_iterator(end()); }
|
||||
constexpr const_reverse_iterator rbegin() const noexcept {
|
||||
return const_reverse_iterator(end());
|
||||
}
|
||||
|
||||
constexpr reverse_iterator rend() noexcept
|
||||
{ return reverse_iterator(begin()); }
|
||||
constexpr reverse_iterator rend() noexcept {
|
||||
return reverse_iterator(begin());
|
||||
}
|
||||
|
||||
constexpr const_reverse_iterator rend() const noexcept
|
||||
{ return const_reverse_iterator(begin()); }
|
||||
constexpr const_reverse_iterator rend() const noexcept {
|
||||
return const_reverse_iterator(begin());
|
||||
}
|
||||
|
||||
constexpr const_iterator cbegin() const noexcept
|
||||
{ return const_iterator(data()); }
|
||||
constexpr const_iterator cbegin() const noexcept {
|
||||
return const_iterator(data());
|
||||
}
|
||||
|
||||
constexpr const_iterator cend() const noexcept
|
||||
{ return const_iterator(data() + _Nm); }
|
||||
constexpr const_iterator cend() const noexcept {
|
||||
return const_iterator(data() + _Nm);
|
||||
}
|
||||
|
||||
constexpr const_reverse_iterator crbegin() const noexcept
|
||||
{ return const_reverse_iterator(end()); }
|
||||
constexpr const_reverse_iterator crbegin() const noexcept {
|
||||
return const_reverse_iterator(end());
|
||||
}
|
||||
|
||||
constexpr const_reverse_iterator crend() const noexcept
|
||||
{ return const_reverse_iterator(begin()); }
|
||||
constexpr const_reverse_iterator crend() const noexcept {
|
||||
return const_reverse_iterator(begin());
|
||||
}
|
||||
|
||||
// Capacity.
|
||||
constexpr size_type size() const noexcept { return _Nm; }
|
||||
constexpr size_type size() const noexcept {
|
||||
return _Nm;
|
||||
}
|
||||
|
||||
constexpr size_type max_size() const noexcept { return _Nm; }
|
||||
constexpr size_type max_size() const noexcept {
|
||||
return _Nm;
|
||||
}
|
||||
|
||||
constexpr bool empty() const noexcept { return size() == 0; }
|
||||
constexpr bool empty() const noexcept {
|
||||
return size() == 0;
|
||||
}
|
||||
|
||||
// Element access.
|
||||
constexpr reference operator[](size_type __n) noexcept
|
||||
{ return _AT_Type::_S_ref(_M_elems, __n); }
|
||||
constexpr reference operator[](size_type __n) noexcept {
|
||||
return _AT_Type::_S_ref(_M_elems, __n);
|
||||
}
|
||||
|
||||
constexpr const_reference operator[](size_type __n) const noexcept
|
||||
{ return _AT_Type::_S_ref(_M_elems, __n); }
|
||||
constexpr const_reference operator[](size_type __n) const noexcept {
|
||||
return _AT_Type::_S_ref(_M_elems, __n);
|
||||
}
|
||||
|
||||
constexpr reference at(size_type __n) {
|
||||
if (__n >= _Nm) {
|
||||
detail::__throw_out_of_range(std::string() +
|
||||
"array::at: __n (which is " + to_string(__n) + ") " +
|
||||
detail::__throw_out_of_range(
|
||||
std::string() + "array::at: __n (which is " + to_string(__n) + ") " +
|
||||
">= _Nm (which is " + to_string(_Nm) + ")");
|
||||
}
|
||||
return _AT_Type::_S_ref(_M_elems, __n);
|
||||
|
|
@ -171,101 +197,133 @@ public:
|
|||
constexpr const_reference at(size_type __n) const {
|
||||
// Result of conditional expression must be an lvalue so use
|
||||
// boolean ? lvalue : (throw-expr, lvalue)
|
||||
return __n < _Nm ? _AT_Type::_S_ref(_M_elems, __n)
|
||||
: (detail::__throw_out_of_range(std::string() +
|
||||
"array::at: __n (which is " + to_string(__n) + ") " +
|
||||
">= _Nm (which is " + to_string(_Nm) + ")"),
|
||||
_AT_Type::_S_ref(_M_elems, 0));
|
||||
}
|
||||
|
||||
constexpr reference front() noexcept
|
||||
{ return *begin(); }
|
||||
|
||||
constexpr const_reference front() const noexcept
|
||||
{ return _AT_Type::_S_ref(_M_elems, 0); }
|
||||
|
||||
constexpr reference back() noexcept
|
||||
{ return _Nm ? *(end() - 1) : *end(); }
|
||||
|
||||
constexpr const_reference back() const noexcept
|
||||
{
|
||||
return _Nm ? _AT_Type::_S_ref(_M_elems, _Nm - 1)
|
||||
: _AT_Type::_S_ref(_M_elems, 0);
|
||||
return __n < _Nm
|
||||
? _AT_Type::_S_ref(_M_elems, __n)
|
||||
: (detail::__throw_out_of_range(
|
||||
std::string() + "array::at: __n (which is " + to_string(__n) +
|
||||
") " + ">= _Nm (which is " + to_string(_Nm) + ")"),
|
||||
_AT_Type::_S_ref(_M_elems, 0));
|
||||
}
|
||||
|
||||
constexpr pointer data() noexcept
|
||||
{ return _AT_Type::_S_ptr(_M_elems); }
|
||||
constexpr reference front() noexcept {
|
||||
return *begin();
|
||||
}
|
||||
|
||||
constexpr const_pointer data() const noexcept
|
||||
{ return _AT_Type::_S_ptr(_M_elems); }
|
||||
constexpr const_reference front() const noexcept {
|
||||
return _AT_Type::_S_ref(_M_elems, 0);
|
||||
}
|
||||
|
||||
constexpr reference back() noexcept {
|
||||
return _Nm ? *(end() - 1) : *end();
|
||||
}
|
||||
|
||||
constexpr const_reference back() const noexcept {
|
||||
return _Nm ? _AT_Type::_S_ref(_M_elems, _Nm - 1)
|
||||
: _AT_Type::_S_ref(_M_elems, 0);
|
||||
}
|
||||
|
||||
constexpr pointer data() noexcept {
|
||||
return _AT_Type::_S_ptr(_M_elems);
|
||||
}
|
||||
|
||||
constexpr const_pointer data() const noexcept {
|
||||
return _AT_Type::_S_ptr(_M_elems);
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(__cpp_deduction_guides) && __cpp_deduction_guides >= 201606
|
||||
template<typename _Tp, typename... _Up>
|
||||
array(_Tp, _Up...) ->
|
||||
array<std::enable_if_t<(std::is_same<_Tp, _Up>::value && ...), _Tp>, 1 + sizeof...(_Up)>;
|
||||
template <typename _Tp, typename... _Up>
|
||||
array(_Tp, _Up...) -> array<
|
||||
std::enable_if_t<(std::is_same<_Tp, _Up>::value && ...), _Tp>,
|
||||
1 + sizeof...(_Up)>;
|
||||
#endif
|
||||
|
||||
// Array comparisons.
|
||||
namespace detail {
|
||||
template<class T, size_t N>
|
||||
constexpr inline bool array_equals_(const array<T, N>& lhs, const array<T, N>& rhs, size_t current_index) {
|
||||
template <class T, size_t N>
|
||||
constexpr inline bool array_equals_(
|
||||
const array<T, N>& lhs,
|
||||
const array<T, N>& rhs,
|
||||
size_t current_index) {
|
||||
return (current_index == N)
|
||||
? true
|
||||
: (lhs.at(current_index) == rhs.at(current_index) && array_equals_(lhs, rhs, current_index + 1));
|
||||
? true
|
||||
: (lhs.at(current_index) == rhs.at(current_index) &&
|
||||
array_equals_(lhs, rhs, current_index + 1));
|
||||
}
|
||||
template<class T, size_t N>
|
||||
constexpr inline bool array_less_(const array<T, N>& lhs, const array<T, N>& rhs, size_t current_index) {
|
||||
template <class T, size_t N>
|
||||
constexpr inline bool array_less_(
|
||||
const array<T, N>& lhs,
|
||||
const array<T, N>& rhs,
|
||||
size_t current_index) {
|
||||
return (current_index == N)
|
||||
? false
|
||||
: (lhs.at(current_index) < rhs.at(current_index) || array_less_(lhs, rhs, current_index + 1));
|
||||
? false
|
||||
: (lhs.at(current_index) < rhs.at(current_index) ||
|
||||
array_less_(lhs, rhs, current_index + 1));
|
||||
}
|
||||
} // namespace detail
|
||||
template <typename _Tp, std::size_t _Nm>
|
||||
constexpr inline bool operator==(
|
||||
const array<_Tp, _Nm>& __one,
|
||||
const array<_Tp, _Nm>& __two) {
|
||||
return detail::array_equals_(__one, __two, 0);
|
||||
}
|
||||
template<typename _Tp, std::size_t _Nm>
|
||||
constexpr inline bool operator==(const array<_Tp, _Nm>& __one, const array<_Tp, _Nm>& __two)
|
||||
{ return detail::array_equals_(__one, __two, 0); }
|
||||
|
||||
template<typename _Tp, std::size_t _Nm>
|
||||
constexpr inline bool operator!=(const array<_Tp, _Nm>& __one, const array<_Tp, _Nm>& __two)
|
||||
{ return !(__one == __two); }
|
||||
template <typename _Tp, std::size_t _Nm>
|
||||
constexpr inline bool operator!=(
|
||||
const array<_Tp, _Nm>& __one,
|
||||
const array<_Tp, _Nm>& __two) {
|
||||
return !(__one == __two);
|
||||
}
|
||||
|
||||
template<typename _Tp, std::size_t _Nm>
|
||||
constexpr inline bool operator<(const array<_Tp, _Nm>& __a, const array<_Tp, _Nm>& __b)
|
||||
{ return detail::array_less_(__a, __b, 0); }
|
||||
template <typename _Tp, std::size_t _Nm>
|
||||
constexpr inline bool operator<(
|
||||
const array<_Tp, _Nm>& __a,
|
||||
const array<_Tp, _Nm>& __b) {
|
||||
return detail::array_less_(__a, __b, 0);
|
||||
}
|
||||
|
||||
template<typename _Tp, std::size_t _Nm>
|
||||
constexpr inline bool operator>(const array<_Tp, _Nm>& __one, const array<_Tp, _Nm>& __two)
|
||||
{ return __two < __one; }
|
||||
template <typename _Tp, std::size_t _Nm>
|
||||
constexpr inline bool operator>(
|
||||
const array<_Tp, _Nm>& __one,
|
||||
const array<_Tp, _Nm>& __two) {
|
||||
return __two < __one;
|
||||
}
|
||||
|
||||
template<typename _Tp, std::size_t _Nm>
|
||||
constexpr inline bool operator<=(const array<_Tp, _Nm>& __one, const array<_Tp, _Nm>& __two)
|
||||
{ return !(__one > __two); }
|
||||
template <typename _Tp, std::size_t _Nm>
|
||||
constexpr inline bool operator<=(
|
||||
const array<_Tp, _Nm>& __one,
|
||||
const array<_Tp, _Nm>& __two) {
|
||||
return !(__one > __two);
|
||||
}
|
||||
|
||||
template<typename _Tp, std::size_t _Nm>
|
||||
constexpr inline bool operator>=(const array<_Tp, _Nm>& __one, const array<_Tp, _Nm>& __two)
|
||||
{ return !(__one < __two); }
|
||||
template <typename _Tp, std::size_t _Nm>
|
||||
constexpr inline bool operator>=(
|
||||
const array<_Tp, _Nm>& __one,
|
||||
const array<_Tp, _Nm>& __two) {
|
||||
return !(__one < __two);
|
||||
}
|
||||
|
||||
// Specialized algorithms.
|
||||
template<typename _Tp, std::size_t _Nm>
|
||||
inline void swap(array<_Tp, _Nm>& __one, array<_Tp, _Nm>& __two) noexcept(noexcept(__one.swap(__two)))
|
||||
{ __one.swap(__two); }
|
||||
|
||||
template<std::size_t _Int, typename _Tp, std::size_t _Nm>
|
||||
constexpr _Tp& get(array<_Tp, _Nm>& __arr) noexcept {
|
||||
static_assert(_Int < _Nm, "array index is within bounds");
|
||||
return detail::__array_traits<_Tp, _Nm>::_S_ref(__arr._M_elems, _Int);
|
||||
template <typename _Tp, std::size_t _Nm>
|
||||
inline void swap(array<_Tp, _Nm>& __one, array<_Tp, _Nm>& __two) noexcept(
|
||||
noexcept(__one.swap(__two))) {
|
||||
__one.swap(__two);
|
||||
}
|
||||
|
||||
template<std::size_t _Int, typename _Tp, std::size_t _Nm>
|
||||
constexpr _Tp&& get(array<_Tp, _Nm>&& __arr) noexcept
|
||||
{
|
||||
template <std::size_t _Int, typename _Tp, std::size_t _Nm>
|
||||
constexpr _Tp& get(array<_Tp, _Nm>& __arr) noexcept {
|
||||
static_assert(_Int < _Nm, "array index is within bounds");
|
||||
return detail::__array_traits<_Tp, _Nm>::_S_ref(__arr._M_elems, _Int);
|
||||
}
|
||||
|
||||
template <std::size_t _Int, typename _Tp, std::size_t _Nm>
|
||||
constexpr _Tp&& get(array<_Tp, _Nm>&& __arr) noexcept {
|
||||
static_assert(_Int < _Nm, "array index is within bounds");
|
||||
return std::move(get<_Int>(__arr));
|
||||
}
|
||||
|
||||
template<std::size_t _Int, typename _Tp, std::size_t _Nm>
|
||||
constexpr const _Tp& get(const array<_Tp, _Nm>& __arr) noexcept
|
||||
{
|
||||
template <std::size_t _Int, typename _Tp, std::size_t _Nm>
|
||||
constexpr const _Tp& get(const array<_Tp, _Nm>& __arr) noexcept {
|
||||
static_assert(_Int < _Nm, "array index is within bounds");
|
||||
return detail::__array_traits<_Tp, _Nm>::_S_ref(__arr._M_elems, _Int);
|
||||
}
|
||||
|
|
@ -278,27 +336,34 @@ constexpr const _Tp& get(const array<_Tp, _Nm>& __arr) noexcept
|
|||
* prepend(2, {3, 4}) == {2, 3, 4}
|
||||
*/
|
||||
namespace detail {
|
||||
template<class T, size_t N, size_t... INDEX>
|
||||
constexpr inline array<T, N-1> tail_(const array<T, N>& arg, std::index_sequence<INDEX...>) {
|
||||
static_assert(sizeof...(INDEX) == N-1, "invariant");
|
||||
return {{get<INDEX+1>(arg)...}};
|
||||
template <class T, size_t N, size_t... INDEX>
|
||||
constexpr inline array<T, N - 1> tail_(
|
||||
const array<T, N>& arg,
|
||||
std::index_sequence<INDEX...>) {
|
||||
static_assert(sizeof...(INDEX) == N - 1, "invariant");
|
||||
return {{get<INDEX + 1>(arg)...}};
|
||||
}
|
||||
}
|
||||
template<class T, size_t N>
|
||||
constexpr inline array<T, N-1> tail(const array<T, N>& arg) {
|
||||
static_assert(N > 0, "Can only call tail() on an array with at least one element");
|
||||
return detail::tail_(arg, std::make_index_sequence<N-1>());
|
||||
} // namespace detail
|
||||
template <class T, size_t N>
|
||||
constexpr inline array<T, N - 1> tail(const array<T, N>& arg) {
|
||||
static_assert(
|
||||
N > 0, "Can only call tail() on an array with at least one element");
|
||||
return detail::tail_(arg, std::make_index_sequence<N - 1>());
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
template<class T, size_t N, size_t... INDEX>
|
||||
constexpr inline array<T, N+1> prepend_(T&& head, const array<T, N>& tail, std::index_sequence<INDEX...>) {
|
||||
template <class T, size_t N, size_t... INDEX>
|
||||
constexpr inline array<T, N + 1> prepend_(
|
||||
T&& head,
|
||||
const array<T, N>& tail,
|
||||
std::index_sequence<INDEX...>) {
|
||||
return {{std::forward<T>(head), get<INDEX>(tail)...}};
|
||||
}
|
||||
}
|
||||
template<class T, size_t N>
|
||||
constexpr inline array<T, N+1> prepend(T&& head, const array<T, N>& tail) {
|
||||
return detail::prepend_(std::forward<T>(head), tail, std::make_index_sequence<N>());
|
||||
} // namespace detail
|
||||
template <class T, size_t N>
|
||||
constexpr inline array<T, N + 1> prepend(T&& head, const array<T, N>& tail) {
|
||||
return detail::prepend_(
|
||||
std::forward<T>(head), tail, std::make_index_sequence<N>());
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -309,15 +374,18 @@ constexpr inline array<T, N+1> prepend(T&& head, const array<T, N>& tail) {
|
|||
*/
|
||||
|
||||
namespace detail {
|
||||
template<class T, size_t N, size_t... INDEX>
|
||||
constexpr array<T, N> to_array_(const T (&arr)[N], std::index_sequence<INDEX...>) {
|
||||
template <class T, size_t N, size_t... INDEX>
|
||||
constexpr array<T, N> to_array_(
|
||||
const T (&arr)[N],
|
||||
std::index_sequence<INDEX...>) {
|
||||
return {{arr[INDEX]...}};
|
||||
}
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
template<class T, size_t N>
|
||||
template <class T, size_t N>
|
||||
constexpr array<T, N> to_array(const T (&arr)[N]) {
|
||||
return detail::to_array_(arr, std::make_index_sequence<N>());
|
||||
}
|
||||
|
||||
}}
|
||||
} // namespace guts
|
||||
} // namespace c10
|
||||
|
|
|
|||
|
|
@ -15,10 +15,10 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <c10/util/SmallVector.h>
|
||||
#include <c10/util/C++17.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/Deprecated.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/SmallVector.h>
|
||||
|
||||
#include <array>
|
||||
#include <iterator>
|
||||
|
|
@ -81,12 +81,15 @@ class ArrayRef final {
|
|||
: Data(Vec.data()), Length(Vec.size()) {}
|
||||
|
||||
/// Construct an ArrayRef from a std::vector.
|
||||
// The enable_if stuff here makes sure that this isn't used for std::vector<bool>,
|
||||
// because ArrayRef can't work on a std::vector<bool> bitfield.
|
||||
// The enable_if stuff here makes sure that this isn't used for
|
||||
// std::vector<bool>, because ArrayRef can't work on a std::vector<bool>
|
||||
// bitfield.
|
||||
template <typename A>
|
||||
/* implicit */ ArrayRef(const std::vector<T, A>& Vec)
|
||||
: Data(Vec.data()), Length(Vec.size()) {
|
||||
static_assert(!std::is_same<T, bool>::value, "ArrayRef<bool> cannot be constructed from a std::vector<bool> bitfield.");
|
||||
static_assert(
|
||||
!std::is_same<T, bool>::value,
|
||||
"ArrayRef<bool> cannot be constructed from a std::vector<bool> bitfield.");
|
||||
}
|
||||
|
||||
/// Construct an ArrayRef from a std::array
|
||||
|
|
@ -100,7 +103,9 @@ class ArrayRef final {
|
|||
|
||||
/// Construct an ArrayRef from a std::initializer_list.
|
||||
/* implicit */ constexpr ArrayRef(const std::initializer_list<T>& Vec)
|
||||
: Data(std::begin(Vec) == std::end(Vec) ? static_cast<T*>(nullptr) : std::begin(Vec)),
|
||||
: Data(
|
||||
std::begin(Vec) == std::end(Vec) ? static_cast<T*>(nullptr)
|
||||
: std::begin(Vec)),
|
||||
Length(Vec.size()) {}
|
||||
|
||||
/// @}
|
||||
|
|
@ -146,7 +151,8 @@ class ArrayRef final {
|
|||
|
||||
/// front - Get the first element.
|
||||
C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA const T& front() const {
|
||||
TORCH_CHECK(!empty(), "ArrayRef: attempted to access front() of empty list");
|
||||
TORCH_CHECK(
|
||||
!empty(), "ArrayRef: attempted to access front() of empty list");
|
||||
return Data[0];
|
||||
}
|
||||
|
||||
|
|
@ -162,7 +168,8 @@ class ArrayRef final {
|
|||
}
|
||||
|
||||
/// slice(n, m) - Take M elements of the array starting at element N
|
||||
C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA ArrayRef<T> slice(size_t N, size_t M) const {
|
||||
C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA ArrayRef<T> slice(size_t N, size_t M)
|
||||
const {
|
||||
TORCH_CHECK(
|
||||
N + M <= size(),
|
||||
"ArrayRef: invalid slice, N = ",
|
||||
|
|
@ -224,10 +231,10 @@ class ArrayRef final {
|
|||
};
|
||||
|
||||
template <typename T>
|
||||
std::ostream& operator<<(std::ostream & out, ArrayRef<T> list) {
|
||||
std::ostream& operator<<(std::ostream& out, ArrayRef<T> list) {
|
||||
int i = 0;
|
||||
out << "[";
|
||||
for(auto e : list) {
|
||||
for (auto e : list) {
|
||||
if (i++ > 0)
|
||||
out << ", ";
|
||||
out << e;
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user